[mypy] Enforcing typing for charts (#9411)

Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
John Bodley 2020-03-29 13:39:36 -07:00 committed by GitHub
parent 2e81e27272
commit ec795a4711
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 41 additions and 29 deletions

View File

@ -53,7 +53,7 @@ order_by_type = false
ignore_missing_imports = true ignore_missing_imports = true
no_implicit_optional = true no_implicit_optional = true
[mypy-superset.db_engine_specs.*] [mypy-superset.charts.*,superset.db_engine_specs.*]
check_untyped_defs = true check_untyped_defs = true
disallow_untyped_calls = true disallow_untyped_calls = true
disallow_untyped_defs = true disallow_untyped_defs = true

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from typing import Any
from flask import g, request, Response from flask import g, request, Response
from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.api import expose, protect, rison, safe
@ -287,7 +288,9 @@ class ChartRestApi(BaseSupersetModelRestApi):
@protect() @protect()
@safe @safe
@rison(get_delete_ids_schema) @rison(get_delete_ids_schema)
def bulk_delete(self, **kwargs) -> Response: # pylint: disable=arguments-differ def bulk_delete(
self, **kwargs: Any
) -> Response: # pylint: disable=arguments-differ
"""Delete bulk Charts """Delete bulk Charts
--- ---
delete: delete:

View File

@ -40,7 +40,7 @@ class BulkDeleteChartCommand(BaseCommand):
self._model_ids = model_ids self._model_ids = model_ids
self._models: Optional[List[Slice]] = None self._models: Optional[List[Slice]] = None
def run(self): def run(self) -> None:
self.validate() self.validate()
try: try:
ChartDAO.bulk_delete(self._models) ChartDAO.bulk_delete(self._models)

View File

@ -17,6 +17,7 @@
import logging import logging
from typing import Dict, List, Optional from typing import Dict, List, Optional
from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User from flask_appbuilder.security.sqla.models import User
from marshmallow import ValidationError from marshmallow import ValidationError
@ -39,7 +40,7 @@ class CreateChartCommand(BaseCommand):
self._actor = user self._actor = user
self._properties = data.copy() self._properties = data.copy()
def run(self): def run(self) -> Model:
self.validate() self.validate()
try: try:
chart = ChartDAO.create(self._properties) chart = ChartDAO.create(self._properties)

View File

@ -17,6 +17,7 @@
import logging import logging
from typing import Optional from typing import Optional
from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User from flask_appbuilder.security.sqla.models import User
from superset.charts.commands.exceptions import ( from superset.charts.commands.exceptions import (
@ -40,7 +41,7 @@ class DeleteChartCommand(BaseCommand):
self._model_id = model_id self._model_id = model_id
self._model: Optional[SqlaTable] = None self._model: Optional[SqlaTable] = None
def run(self): def run(self) -> Model:
self.validate() self.validate()
try: try:
chart = ChartDAO.delete(self._model) chart = ChartDAO.delete(self._model)

View File

@ -32,7 +32,7 @@ class DatabaseNotFoundValidationError(ValidationError):
Marshmallow validation error for database does not exist Marshmallow validation error for database does not exist
""" """
def __init__(self): def __init__(self) -> None:
super().__init__(_("Database does not exist"), field_names=["database"]) super().__init__(_("Database does not exist"), field_names=["database"])
@ -41,7 +41,7 @@ class DashboardsNotFoundValidationError(ValidationError):
Marshmallow validation error for dashboards don't exist Marshmallow validation error for dashboards don't exist
""" """
def __init__(self): def __init__(self) -> None:
super().__init__(_("Dashboards do not exist"), field_names=["dashboards"]) super().__init__(_("Dashboards do not exist"), field_names=["dashboards"])
@ -50,7 +50,7 @@ class DatasourceTypeUpdateRequiredValidationError(ValidationError):
Marshmallow validation error for dashboards don't exist Marshmallow validation error for dashboards don't exist
""" """
def __init__(self): def __init__(self) -> None:
super().__init__( super().__init__(
_("Datasource type is required when datasource_id is given"), _("Datasource type is required when datasource_id is given"),
field_names=["datasource_type"], field_names=["datasource_type"],

View File

@ -17,6 +17,7 @@
import logging import logging
from typing import Dict, List, Optional from typing import Dict, List, Optional
from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User from flask_appbuilder.security.sqla.models import User
from marshmallow import ValidationError from marshmallow import ValidationError
@ -47,7 +48,7 @@ class UpdateChartCommand(BaseCommand):
self._properties = data.copy() self._properties = data.copy()
self._model: Optional[SqlaTable] = None self._model: Optional[SqlaTable] = None
def run(self): def run(self) -> Model:
self.validate() self.validate()
try: try:
chart = ChartDAO.update(self._model, self._properties) chart = ChartDAO.update(self._model, self._properties)

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from typing import List from typing import List, Optional
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
@ -32,13 +32,14 @@ class ChartDAO(BaseDAO):
base_filter = ChartFilter base_filter = ChartFilter
@staticmethod @staticmethod
def bulk_delete(models: List[Slice], commit=True): def bulk_delete(models: Optional[List[Slice]], commit: bool = True) -> None:
item_ids = [model.id for model in models] item_ids = [model.id for model in models] if models else []
# bulk delete, first delete related data # bulk delete, first delete related data
for model in models: if models:
model.owners = [] for model in models:
model.dashboards = [] model.owners = []
db.session.merge(model) model.dashboards = []
db.session.merge(model)
# bulk delete itself # bulk delete itself
try: try:
db.session.query(Slice).filter(Slice.id.in_(item_ids)).delete( db.session.query(Slice).filter(Slice.id.in_(item_ids)).delete(

View File

@ -14,14 +14,17 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
from typing import Any
from sqlalchemy import or_ from sqlalchemy import or_
from sqlalchemy.orm.query import Query
from superset import security_manager from superset import security_manager
from superset.views.base import BaseFilter from superset.views.base import BaseFilter
class ChartFilter(BaseFilter): # pylint: disable=too-few-public-methods class ChartFilter(BaseFilter): # pylint: disable=too-few-public-methods
def apply(self, query, value): def apply(self, query: Query, value: Any) -> Query:
if security_manager.all_datasource_access(): if security_manager.all_datasource_access():
return query return query
perms = security_manager.user_view_menu_names("datasource_access") perms = security_manager.user_view_menu_names("datasource_access")

View File

@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
from typing import Union
from marshmallow import fields, Schema, ValidationError from marshmallow import fields, Schema, ValidationError
from marshmallow.validate import Length from marshmallow.validate import Length
@ -24,7 +25,7 @@ from superset.utils import core as utils
get_delete_ids_schema = {"type": "array", "items": {"type": "integer"}} get_delete_ids_schema = {"type": "array", "items": {"type": "integer"}}
def validate_json(value): def validate_json(value: Union[bytes, bytearray, str]) -> None:
try: try:
utils.validate_json(value) utils.validate_json(value)
except SupersetException: except SupersetException:

View File

@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
from typing import List from typing import Any, Dict, List
from flask_babel import lazy_gettext as _ from flask_babel import lazy_gettext as _
from marshmallow import ValidationError from marshmallow import ValidationError
@ -36,8 +36,8 @@ class CommandInvalidError(CommandException):
status = 422 status = 422
def __init__(self, message=""): def __init__(self, message="") -> None:
self._invalid_exceptions = list() self._invalid_exceptions: List[ValidationError] = []
super().__init__(self.message) super().__init__(self.message)
def add(self, exception: ValidationError): def add(self, exception: ValidationError):
@ -46,8 +46,8 @@ class CommandInvalidError(CommandException):
def add_list(self, exceptions: List[ValidationError]): def add_list(self, exceptions: List[ValidationError]):
self._invalid_exceptions.extend(exceptions) self._invalid_exceptions.extend(exceptions)
def normalized_messages(self): def normalized_messages(self) -> Dict[Any, Any]:
errors = {} errors: Dict[Any, Any] = {}
for exception in self._invalid_exceptions: for exception in self._invalid_exceptions:
errors.update(exception.normalized_messages()) errors.update(exception.normalized_messages())
return errors return errors

View File

@ -75,7 +75,7 @@ class BaseDAO:
return query.all() return query.all()
@classmethod @classmethod
def create(cls, properties: Dict, commit=True) -> Optional[Model]: def create(cls, properties: Dict, commit: bool = True) -> Model:
""" """
Generic for creating models Generic for creating models
:raises: DAOCreateFailedError :raises: DAOCreateFailedError
@ -95,7 +95,7 @@ class BaseDAO:
return model return model
@classmethod @classmethod
def update(cls, model: Model, properties: Dict, commit=True) -> Optional[Model]: def update(cls, model: Model, properties: Dict, commit: bool = True) -> Model:
""" """
Generic update a model Generic update a model
:raises: DAOCreateFailedError :raises: DAOCreateFailedError

View File

@ -547,7 +547,7 @@ def get_datasource_full_name(database_name, datasource_name, schema=None):
return "[{}].[{}].[{}]".format(database_name, schema, datasource_name) return "[{}].[{}].[{}]".format(database_name, schema, datasource_name)
def validate_json(obj): def validate_json(obj: Union[bytes, bytearray, str]) -> None:
if obj: if obj:
try: try:
json.loads(obj) json.loads(obj)

View File

@ -18,7 +18,7 @@ import functools
import logging import logging
import traceback import traceback
from datetime import datetime from datetime import datetime
from typing import Any, Dict, Optional from typing import Any, Dict, List, Optional
import simplejson as json import simplejson as json
import yaml import yaml
@ -27,6 +27,7 @@ from flask_appbuilder import BaseView, ModelView
from flask_appbuilder.actions import action from flask_appbuilder.actions import action
from flask_appbuilder.forms import DynamicForm from flask_appbuilder.forms import DynamicForm
from flask_appbuilder.models.sqla.filters import BaseFilter from flask_appbuilder.models.sqla.filters import BaseFilter
from flask_appbuilder.security.sqla.models import User
from flask_appbuilder.widgets import ListWidget from flask_appbuilder.widgets import ListWidget
from flask_babel import get_locale, gettext as __, lazy_gettext as _ from flask_babel import get_locale, gettext as __, lazy_gettext as _
from flask_wtf.form import FlaskForm from flask_wtf.form import FlaskForm
@ -365,7 +366,7 @@ class CsvResponse(Response): # pylint: disable=too-many-ancestors
charset = conf["CSV_EXPORT"].get("encoding", "utf-8") charset = conf["CSV_EXPORT"].get("encoding", "utf-8")
def check_ownership(obj, raise_if_false=True): def check_ownership(obj: Any, raise_if_false: bool = True) -> bool:
"""Meant to be used in `pre_update` hooks on models to enforce ownership """Meant to be used in `pre_update` hooks on models to enforce ownership
Admin have all access, and other users need to be referenced on either Admin have all access, and other users need to be referenced on either
@ -392,7 +393,7 @@ def check_ownership(obj, raise_if_false=True):
orig_obj = scoped_session.query(obj.__class__).filter_by(id=obj.id).first() orig_obj = scoped_session.query(obj.__class__).filter_by(id=obj.id).first()
# Making a list of owners that works across ORM models # Making a list of owners that works across ORM models
owners = [] owners: List[User] = []
if hasattr(orig_obj, "owners"): if hasattr(orig_obj, "owners"):
owners += orig_obj.owners owners += orig_obj.owners
if hasattr(orig_obj, "owner"): if hasattr(orig_obj, "owner"):