mirror of https://github.com/apache/superset.git
[mypy] Enforcing typing for charts (#9411)
Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
parent
2e81e27272
commit
ec795a4711
|
@ -53,7 +53,7 @@ order_by_type = false
|
||||||
ignore_missing_imports = true
|
ignore_missing_imports = true
|
||||||
no_implicit_optional = true
|
no_implicit_optional = true
|
||||||
|
|
||||||
[mypy-superset.db_engine_specs.*]
|
[mypy-superset.charts.*,superset.db_engine_specs.*]
|
||||||
check_untyped_defs = true
|
check_untyped_defs = true
|
||||||
disallow_untyped_calls = true
|
disallow_untyped_calls = true
|
||||||
disallow_untyped_defs = true
|
disallow_untyped_defs = true
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from flask import g, request, Response
|
from flask import g, request, Response
|
||||||
from flask_appbuilder.api import expose, protect, rison, safe
|
from flask_appbuilder.api import expose, protect, rison, safe
|
||||||
|
@ -287,7 +288,9 @@ class ChartRestApi(BaseSupersetModelRestApi):
|
||||||
@protect()
|
@protect()
|
||||||
@safe
|
@safe
|
||||||
@rison(get_delete_ids_schema)
|
@rison(get_delete_ids_schema)
|
||||||
def bulk_delete(self, **kwargs) -> Response: # pylint: disable=arguments-differ
|
def bulk_delete(
|
||||||
|
self, **kwargs: Any
|
||||||
|
) -> Response: # pylint: disable=arguments-differ
|
||||||
"""Delete bulk Charts
|
"""Delete bulk Charts
|
||||||
---
|
---
|
||||||
delete:
|
delete:
|
||||||
|
|
|
@ -40,7 +40,7 @@ class BulkDeleteChartCommand(BaseCommand):
|
||||||
self._model_ids = model_ids
|
self._model_ids = model_ids
|
||||||
self._models: Optional[List[Slice]] = None
|
self._models: Optional[List[Slice]] = None
|
||||||
|
|
||||||
def run(self):
|
def run(self) -> None:
|
||||||
self.validate()
|
self.validate()
|
||||||
try:
|
try:
|
||||||
ChartDAO.bulk_delete(self._models)
|
ChartDAO.bulk_delete(self._models)
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from flask_appbuilder.models.sqla import Model
|
||||||
from flask_appbuilder.security.sqla.models import User
|
from flask_appbuilder.security.sqla.models import User
|
||||||
from marshmallow import ValidationError
|
from marshmallow import ValidationError
|
||||||
|
|
||||||
|
@ -39,7 +40,7 @@ class CreateChartCommand(BaseCommand):
|
||||||
self._actor = user
|
self._actor = user
|
||||||
self._properties = data.copy()
|
self._properties = data.copy()
|
||||||
|
|
||||||
def run(self):
|
def run(self) -> Model:
|
||||||
self.validate()
|
self.validate()
|
||||||
try:
|
try:
|
||||||
chart = ChartDAO.create(self._properties)
|
chart = ChartDAO.create(self._properties)
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from flask_appbuilder.models.sqla import Model
|
||||||
from flask_appbuilder.security.sqla.models import User
|
from flask_appbuilder.security.sqla.models import User
|
||||||
|
|
||||||
from superset.charts.commands.exceptions import (
|
from superset.charts.commands.exceptions import (
|
||||||
|
@ -40,7 +41,7 @@ class DeleteChartCommand(BaseCommand):
|
||||||
self._model_id = model_id
|
self._model_id = model_id
|
||||||
self._model: Optional[SqlaTable] = None
|
self._model: Optional[SqlaTable] = None
|
||||||
|
|
||||||
def run(self):
|
def run(self) -> Model:
|
||||||
self.validate()
|
self.validate()
|
||||||
try:
|
try:
|
||||||
chart = ChartDAO.delete(self._model)
|
chart = ChartDAO.delete(self._model)
|
||||||
|
|
|
@ -32,7 +32,7 @@ class DatabaseNotFoundValidationError(ValidationError):
|
||||||
Marshmallow validation error for database does not exist
|
Marshmallow validation error for database does not exist
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__(_("Database does not exist"), field_names=["database"])
|
super().__init__(_("Database does not exist"), field_names=["database"])
|
||||||
|
|
||||||
|
|
||||||
|
@ -41,7 +41,7 @@ class DashboardsNotFoundValidationError(ValidationError):
|
||||||
Marshmallow validation error for dashboards don't exist
|
Marshmallow validation error for dashboards don't exist
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__(_("Dashboards do not exist"), field_names=["dashboards"])
|
super().__init__(_("Dashboards do not exist"), field_names=["dashboards"])
|
||||||
|
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ class DatasourceTypeUpdateRequiredValidationError(ValidationError):
|
||||||
Marshmallow validation error for dashboards don't exist
|
Marshmallow validation error for dashboards don't exist
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
_("Datasource type is required when datasource_id is given"),
|
_("Datasource type is required when datasource_id is given"),
|
||||||
field_names=["datasource_type"],
|
field_names=["datasource_type"],
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from flask_appbuilder.models.sqla import Model
|
||||||
from flask_appbuilder.security.sqla.models import User
|
from flask_appbuilder.security.sqla.models import User
|
||||||
from marshmallow import ValidationError
|
from marshmallow import ValidationError
|
||||||
|
|
||||||
|
@ -47,7 +48,7 @@ class UpdateChartCommand(BaseCommand):
|
||||||
self._properties = data.copy()
|
self._properties = data.copy()
|
||||||
self._model: Optional[SqlaTable] = None
|
self._model: Optional[SqlaTable] = None
|
||||||
|
|
||||||
def run(self):
|
def run(self) -> Model:
|
||||||
self.validate()
|
self.validate()
|
||||||
try:
|
try:
|
||||||
chart = ChartDAO.update(self._model, self._properties)
|
chart = ChartDAO.update(self._model, self._properties)
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
|
||||||
|
@ -32,13 +32,14 @@ class ChartDAO(BaseDAO):
|
||||||
base_filter = ChartFilter
|
base_filter = ChartFilter
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def bulk_delete(models: List[Slice], commit=True):
|
def bulk_delete(models: Optional[List[Slice]], commit: bool = True) -> None:
|
||||||
item_ids = [model.id for model in models]
|
item_ids = [model.id for model in models] if models else []
|
||||||
# bulk delete, first delete related data
|
# bulk delete, first delete related data
|
||||||
for model in models:
|
if models:
|
||||||
model.owners = []
|
for model in models:
|
||||||
model.dashboards = []
|
model.owners = []
|
||||||
db.session.merge(model)
|
model.dashboards = []
|
||||||
|
db.session.merge(model)
|
||||||
# bulk delete itself
|
# bulk delete itself
|
||||||
try:
|
try:
|
||||||
db.session.query(Slice).filter(Slice.id.in_(item_ids)).delete(
|
db.session.query(Slice).filter(Slice.id.in_(item_ids)).delete(
|
||||||
|
|
|
@ -14,14 +14,17 @@
|
||||||
# KIND, either express or implied. See the License for the
|
# KIND, either express or implied. See the License for the
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy import or_
|
from sqlalchemy import or_
|
||||||
|
from sqlalchemy.orm.query import Query
|
||||||
|
|
||||||
from superset import security_manager
|
from superset import security_manager
|
||||||
from superset.views.base import BaseFilter
|
from superset.views.base import BaseFilter
|
||||||
|
|
||||||
|
|
||||||
class ChartFilter(BaseFilter): # pylint: disable=too-few-public-methods
|
class ChartFilter(BaseFilter): # pylint: disable=too-few-public-methods
|
||||||
def apply(self, query, value):
|
def apply(self, query: Query, value: Any) -> Query:
|
||||||
if security_manager.all_datasource_access():
|
if security_manager.all_datasource_access():
|
||||||
return query
|
return query
|
||||||
perms = security_manager.user_view_menu_names("datasource_access")
|
perms = security_manager.user_view_menu_names("datasource_access")
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# KIND, either express or implied. See the License for the
|
# KIND, either express or implied. See the License for the
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
from marshmallow import fields, Schema, ValidationError
|
from marshmallow import fields, Schema, ValidationError
|
||||||
from marshmallow.validate import Length
|
from marshmallow.validate import Length
|
||||||
|
@ -24,7 +25,7 @@ from superset.utils import core as utils
|
||||||
get_delete_ids_schema = {"type": "array", "items": {"type": "integer"}}
|
get_delete_ids_schema = {"type": "array", "items": {"type": "integer"}}
|
||||||
|
|
||||||
|
|
||||||
def validate_json(value):
|
def validate_json(value: Union[bytes, bytearray, str]) -> None:
|
||||||
try:
|
try:
|
||||||
utils.validate_json(value)
|
utils.validate_json(value)
|
||||||
except SupersetException:
|
except SupersetException:
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
# KIND, either express or implied. See the License for the
|
# KIND, either express or implied. See the License for the
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
from typing import List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from flask_babel import lazy_gettext as _
|
from flask_babel import lazy_gettext as _
|
||||||
from marshmallow import ValidationError
|
from marshmallow import ValidationError
|
||||||
|
@ -36,8 +36,8 @@ class CommandInvalidError(CommandException):
|
||||||
|
|
||||||
status = 422
|
status = 422
|
||||||
|
|
||||||
def __init__(self, message=""):
|
def __init__(self, message="") -> None:
|
||||||
self._invalid_exceptions = list()
|
self._invalid_exceptions: List[ValidationError] = []
|
||||||
super().__init__(self.message)
|
super().__init__(self.message)
|
||||||
|
|
||||||
def add(self, exception: ValidationError):
|
def add(self, exception: ValidationError):
|
||||||
|
@ -46,8 +46,8 @@ class CommandInvalidError(CommandException):
|
||||||
def add_list(self, exceptions: List[ValidationError]):
|
def add_list(self, exceptions: List[ValidationError]):
|
||||||
self._invalid_exceptions.extend(exceptions)
|
self._invalid_exceptions.extend(exceptions)
|
||||||
|
|
||||||
def normalized_messages(self):
|
def normalized_messages(self) -> Dict[Any, Any]:
|
||||||
errors = {}
|
errors: Dict[Any, Any] = {}
|
||||||
for exception in self._invalid_exceptions:
|
for exception in self._invalid_exceptions:
|
||||||
errors.update(exception.normalized_messages())
|
errors.update(exception.normalized_messages())
|
||||||
return errors
|
return errors
|
||||||
|
|
|
@ -75,7 +75,7 @@ class BaseDAO:
|
||||||
return query.all()
|
return query.all()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(cls, properties: Dict, commit=True) -> Optional[Model]:
|
def create(cls, properties: Dict, commit: bool = True) -> Model:
|
||||||
"""
|
"""
|
||||||
Generic for creating models
|
Generic for creating models
|
||||||
:raises: DAOCreateFailedError
|
:raises: DAOCreateFailedError
|
||||||
|
@ -95,7 +95,7 @@ class BaseDAO:
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def update(cls, model: Model, properties: Dict, commit=True) -> Optional[Model]:
|
def update(cls, model: Model, properties: Dict, commit: bool = True) -> Model:
|
||||||
"""
|
"""
|
||||||
Generic update a model
|
Generic update a model
|
||||||
:raises: DAOCreateFailedError
|
:raises: DAOCreateFailedError
|
||||||
|
|
|
@ -547,7 +547,7 @@ def get_datasource_full_name(database_name, datasource_name, schema=None):
|
||||||
return "[{}].[{}].[{}]".format(database_name, schema, datasource_name)
|
return "[{}].[{}].[{}]".format(database_name, schema, datasource_name)
|
||||||
|
|
||||||
|
|
||||||
def validate_json(obj):
|
def validate_json(obj: Union[bytes, bytearray, str]) -> None:
|
||||||
if obj:
|
if obj:
|
||||||
try:
|
try:
|
||||||
json.loads(obj)
|
json.loads(obj)
|
||||||
|
|
|
@ -18,7 +18,7 @@ import functools
|
||||||
import logging
|
import logging
|
||||||
import traceback
|
import traceback
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import simplejson as json
|
import simplejson as json
|
||||||
import yaml
|
import yaml
|
||||||
|
@ -27,6 +27,7 @@ from flask_appbuilder import BaseView, ModelView
|
||||||
from flask_appbuilder.actions import action
|
from flask_appbuilder.actions import action
|
||||||
from flask_appbuilder.forms import DynamicForm
|
from flask_appbuilder.forms import DynamicForm
|
||||||
from flask_appbuilder.models.sqla.filters import BaseFilter
|
from flask_appbuilder.models.sqla.filters import BaseFilter
|
||||||
|
from flask_appbuilder.security.sqla.models import User
|
||||||
from flask_appbuilder.widgets import ListWidget
|
from flask_appbuilder.widgets import ListWidget
|
||||||
from flask_babel import get_locale, gettext as __, lazy_gettext as _
|
from flask_babel import get_locale, gettext as __, lazy_gettext as _
|
||||||
from flask_wtf.form import FlaskForm
|
from flask_wtf.form import FlaskForm
|
||||||
|
@ -365,7 +366,7 @@ class CsvResponse(Response): # pylint: disable=too-many-ancestors
|
||||||
charset = conf["CSV_EXPORT"].get("encoding", "utf-8")
|
charset = conf["CSV_EXPORT"].get("encoding", "utf-8")
|
||||||
|
|
||||||
|
|
||||||
def check_ownership(obj, raise_if_false=True):
|
def check_ownership(obj: Any, raise_if_false: bool = True) -> bool:
|
||||||
"""Meant to be used in `pre_update` hooks on models to enforce ownership
|
"""Meant to be used in `pre_update` hooks on models to enforce ownership
|
||||||
|
|
||||||
Admin have all access, and other users need to be referenced on either
|
Admin have all access, and other users need to be referenced on either
|
||||||
|
@ -392,7 +393,7 @@ def check_ownership(obj, raise_if_false=True):
|
||||||
orig_obj = scoped_session.query(obj.__class__).filter_by(id=obj.id).first()
|
orig_obj = scoped_session.query(obj.__class__).filter_by(id=obj.id).first()
|
||||||
|
|
||||||
# Making a list of owners that works across ORM models
|
# Making a list of owners that works across ORM models
|
||||||
owners = []
|
owners: List[User] = []
|
||||||
if hasattr(orig_obj, "owners"):
|
if hasattr(orig_obj, "owners"):
|
||||||
owners += orig_obj.owners
|
owners += orig_obj.owners
|
||||||
if hasattr(orig_obj, "owner"):
|
if hasattr(orig_obj, "owner"):
|
||||||
|
|
Loading…
Reference in New Issue