From 63e0188f45134c25267d183f5d7391577f9a6d63 Mon Sep 17 00:00:00 2001 From: John Bodley <4567245+john-bodley@users.noreply.github.com> Date: Fri, 5 Jun 2020 08:44:11 -0700 Subject: [PATCH] style(mypy): Enforcing typing for superset.views (#9939) Co-authored-by: John Bodley --- setup.cfg | 2 +- superset/connectors/base/models.py | 18 +- superset/connectors/druid/models.py | 2 +- superset/connectors/sqla/models.py | 2 +- superset/sql_validators/base.py | 6 +- superset/sql_validators/presto_db.py | 2 +- superset/tasks/schedules.py | 6 +- superset/views/annotations.py | 8 +- superset/views/api.py | 7 +- superset/views/base.py | 56 ++-- superset/views/base_api.py | 34 +- superset/views/base_schemas.py | 25 +- superset/views/core.py | 443 ++++++++++++++------------- superset/views/datasource.py | 11 +- superset/views/filters.py | 7 +- superset/views/log/__init__.py | 4 +- superset/views/log/api.py | 2 +- superset/views/schedules.py | 35 ++- superset/views/sql_lab.py | 43 ++- superset/views/tags.py | 23 +- superset/views/utils.py | 26 +- superset/viz.py | 2 +- superset/viz_sip38.py | 16 +- 23 files changed, 440 insertions(+), 340 deletions(-) diff --git a/setup.cfg b/setup.cfg index fc94a24a44..81c7ed27a7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -53,7 +53,7 @@ order_by_type = false ignore_missing_imports = true no_implicit_optional = true -[mypy-superset,superset.app,superset.bin.*,superset.charts.*,superset.cli,superset.commands.*,superset.common.*,superset.config,superset.connectors.*,superset.constants,superset.dataframe,superset.dao.*,superset.dashboards.*,superset.datasets.*,superset.db_engine_specs.*,superset.db_engines.*,superset.errors,superset.examples.*,superset.exceptions,superset.extensions,superset.forms,superset.jinja_context,superset.legacy,superset.migrations.*,superset.models.*,superset.result_set,superset.queries.*,superset.security.*,superset.sql_lab,superset.sql_parse,superset.sql_validators.*,superset.stats_logger,superset.tasks.*,superset.translations.*,superset.typing,superset.utils.*,superset.views.chart.*,superset.views.dashboard.*,superset.views.database.*,superset.viz,superset.viz_sip38] +[mypy-superset,superset.app,superset.bin.*,superset.charts.*,superset.cli,superset.commands.*,superset.common.*,superset.config,superset.connectors.*,superset.constants,superset.dataframe,superset.dao.*,superset.dashboards.*,superset.datasets.*,superset.db_engine_specs.*,superset.db_engines.*,superset.errors,superset.examples.*,superset.exceptions,superset.extensions,superset.forms,superset.jinja_context,superset.legacy,superset.migrations.*,superset.models.*,superset.result_set,superset.queries.*,superset.security.*,superset.sql_lab,superset.sql_parse,superset.sql_validators.*,superset.stats_logger,superset.tasks.*,superset.translations.*,superset.typing,superset.utils.*,,superset.views.*,superset.viz,superset.viz_sip38] check_untyped_defs = true disallow_untyped_calls = true disallow_untyped_defs = true diff --git a/superset/connectors/base/models.py b/superset/connectors/base/models.py index 8ead67018c..0533aa1340 100644 --- a/superset/connectors/base/models.py +++ b/superset/connectors/base/models.py @@ -62,14 +62,26 @@ class BaseDatasource( # --------------------------------------------------------------- __tablename__: Optional[str] = None # {connector_name}_datasource baselink: Optional[str] = None # url portion pointing to ModelView endpoint - column_class: Optional[Type] = None # link to derivative of BaseColumn - metric_class: Optional[Type] = None # link to derivative of BaseMetric + + @property + def column_class(self) -> Type: + # link to derivative of BaseColumn + raise NotImplementedError() + + @property + def metric_class(self) -> Type: + # link to derivative of BaseMetric + raise NotImplementedError() + owner_class: Optional[User] = None # Used to do code highlighting when displaying the query in the UI query_language: Optional[str] = None - name = None # can be a Column or a property pointing to one + @property + def name(self) -> str: + # can be a Column or a property pointing to one + raise NotImplementedError() # --------------------------------------------------------------- diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index b0a333274e..50f163781f 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -548,7 +548,7 @@ class DruidDatasource(Model, BaseDatasource): return [c.column_name for c in self.columns if c.is_numeric] @property - def name(self) -> str: # type: ignore + def name(self) -> str: return self.datasource_name @property diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index b413ebd124..0e91bd2cc7 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -531,7 +531,7 @@ class SqlaTable(Model, BaseDatasource): return ("[{obj.database}].[{obj.table_name}]" "(id:{obj.id})").format(obj=self) @property - def name(self) -> str: # type: ignore + def name(self) -> str: if not self.schema: return self.table_name return "{}.{}".format(self.schema, self.table_name) diff --git a/superset/sql_validators/base.py b/superset/sql_validators/base.py index beed47c7ac..c477568b63 100644 --- a/superset/sql_validators/base.py +++ b/superset/sql_validators/base.py @@ -19,6 +19,8 @@ from typing import Any, Dict, List, Optional +from superset.models.core import Database + class SQLValidationAnnotation: """Represents a single annotation (error/warning) in an SQL querytext""" @@ -35,7 +37,7 @@ class SQLValidationAnnotation: self.start_column = start_column self.end_column = end_column - def to_dict(self) -> Dict: + def to_dict(self) -> Dict[str, Any]: """Return a dictionary representation of this annotation""" return { "line_number": self.line_number, @@ -53,7 +55,7 @@ class BaseSQLValidator: @classmethod def validate( - cls, sql: str, schema: str, database: Any + cls, sql: str, schema: Optional[str], database: Database ) -> List[SQLValidationAnnotation]: """Check that the given SQL querystring is valid for the given engine""" raise NotImplementedError diff --git a/superset/sql_validators/presto_db.py b/superset/sql_validators/presto_db.py index 42e7cffaa1..6c5bb309bb 100644 --- a/superset/sql_validators/presto_db.py +++ b/superset/sql_validators/presto_db.py @@ -143,7 +143,7 @@ class PrestoDBSQLValidator(BaseSQLValidator): @classmethod def validate( - cls, sql: str, schema: str, database: Any + cls, sql: str, schema: Optional[str], database: Database ) -> List[SQLValidationAnnotation]: """ Presto supports query-validation queries by running them with a diff --git a/superset/tasks/schedules.py b/superset/tasks/schedules.py index 0a356ea503..2a5733ed60 100644 --- a/superset/tasks/schedules.py +++ b/superset/tasks/schedules.py @@ -225,9 +225,11 @@ def deliver_dashboard(schedule: DashboardEmailSchedule) -> None: """ dashboard = schedule.dashboard - dashboard_url = _get_url_path("Superset.dashboard", dashboard_id=dashboard.id) + dashboard_url = _get_url_path( + "Superset.dashboard", dashboard_id_or_slug=dashboard.id + ) dashboard_url_user_friendly = _get_url_path( - "Superset.dashboard", user_friendly=True, dashboard_id=dashboard.id + "Superset.dashboard", user_friendly=True, dashboard_id_or_slug=dashboard.id ) # Create a driver, fetch the page, wait for the page to render diff --git a/superset/views/annotations.py b/superset/views/annotations.py index e29883dec1..84428770d3 100644 --- a/superset/views/annotations.py +++ b/superset/views/annotations.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Any, Dict + from flask_appbuilder import CompactCRUDMixin from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_babel import lazy_gettext as _ @@ -30,7 +32,7 @@ class StartEndDttmValidator: # pylint: disable=too-few-public-methods Validates dttm fields. """ - def __call__(self, form, field): + def __call__(self, form: Dict[str, Any], field: Any) -> None: if not form["start_dttm"].data and not form["end_dttm"].data: raise StopValidation(_("annotation start time or end time is required.")) elif ( @@ -82,13 +84,13 @@ class AnnotationModelView( validators_columns = {"start_dttm": [StartEndDttmValidator()]} - def pre_add(self, item): + def pre_add(self, item: "AnnotationModelView") -> None: if not item.start_dttm: item.start_dttm = item.end_dttm elif not item.end_dttm: item.end_dttm = item.start_dttm - def pre_update(self, item): + def pre_update(self, item: "AnnotationModelView") -> None: self.pre_add(item) diff --git a/superset/views/api.py b/superset/views/api.py index e82aa86dbc..d37059876e 100644 --- a/superset/views/api.py +++ b/superset/views/api.py @@ -24,6 +24,7 @@ from superset import db, event_logger, security_manager from superset.common.query_context import QueryContext from superset.legacy import update_time_range from superset.models.slice import Slice +from superset.typing import FlaskResponse from superset.utils import core as utils from superset.views.base import api, BaseSupersetView, handle_api_exception @@ -34,13 +35,13 @@ class Api(BaseSupersetView): @handle_api_exception @has_access_api @expose("/v1/query/", methods=["POST"]) - def query(self): + def query(self) -> FlaskResponse: """ Takes a query_obj constructed in the client and returns payload data response for the given query_obj. params: query_context: json_blob """ - query_context = QueryContext(**json.loads(request.form.get("query_context"))) + query_context = QueryContext(**json.loads(request.form["query_context"])) security_manager.assert_query_context_permission(query_context) payload_json = query_context.get_payload() return json.dumps( @@ -52,7 +53,7 @@ class Api(BaseSupersetView): @handle_api_exception @has_access_api @expose("/v1/form_data/", methods=["GET"]) - def query_form_data(self): + def query_form_data(self) -> FlaskResponse: """ Get the formdata stored in the database for existing slice. params: slice_id: integer diff --git a/superset/views/base.py b/superset/views/base.py index 5821619c1b..12382185b5 100644 --- a/superset/views/base.py +++ b/superset/views/base.py @@ -18,13 +18,13 @@ import functools import logging import traceback from datetime import datetime -from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union +from typing import Any, Callable, cast, Dict, List, Optional, TYPE_CHECKING, Union import dataclasses import simplejson as json import yaml from flask import abort, flash, g, get_flashed_messages, redirect, Response, session -from flask_appbuilder import BaseView, ModelView +from flask_appbuilder import BaseView, Model, ModelView from flask_appbuilder.actions import action from flask_appbuilder.forms import DynamicForm from flask_appbuilder.models.sqla.filters import BaseFilter @@ -33,7 +33,9 @@ from flask_appbuilder.widgets import ListWidget from flask_babel import get_locale, gettext as __, lazy_gettext as _ from flask_wtf.form import FlaskForm from sqlalchemy import or_ +from sqlalchemy.orm import Query from werkzeug.exceptions import HTTPException +from wtforms import Form from wtforms.fields.core import Field, UnboundField from superset import ( @@ -47,6 +49,7 @@ from superset import ( from superset.connectors.sqla import models from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import SupersetException, SupersetSecurityException +from superset.models.helpers import ImportMixin from superset.translations.utils import get_language_pack from superset.typing import FlaskResponse from superset.utils import core as utils @@ -93,7 +96,7 @@ def json_error_response( status: int = 500, payload: Optional[Dict[str, Any]] = None, link: Optional[str] = None, -) -> Response: +) -> FlaskResponse: if not payload: payload = {"error": "{}".format(msg)} if link: @@ -110,7 +113,7 @@ def json_errors_response( errors: List[SupersetError], status: int = 500, payload: Optional[Dict[str, Any]] = None, -) -> Response: +) -> FlaskResponse: if not payload: payload = {} @@ -122,11 +125,11 @@ def json_errors_response( ) -def json_success(json_msg: str, status: int = 200) -> Response: +def json_success(json_msg: str, status: int = 200) -> FlaskResponse: return Response(json_msg, status=status, mimetype="application/json") -def data_payload_response(payload_json: str, has_error: bool = False) -> Response: +def data_payload_response(payload_json: str, has_error: bool = False) -> FlaskResponse: status = 400 if has_error else 200 return json_success(payload_json, status=status) @@ -140,13 +143,13 @@ def generate_download_headers( return headers -def api(f): +def api(f: Callable) -> Callable: """ A decorator to label an endpoint as an API. Catches uncaught exceptions and return the response in the JSON format """ - def wraps(self, *args, **kwargs): + def wraps(self: "BaseSupersetView", *args: Any, **kwargs: Any) -> FlaskResponse: try: return f(self, *args, **kwargs) except Exception as ex: # pylint: disable=broad-except @@ -156,14 +159,16 @@ def api(f): return functools.update_wrapper(wraps, f) -def handle_api_exception(f): +def handle_api_exception( + f: Callable[..., FlaskResponse] +) -> Callable[..., FlaskResponse]: """ A decorator to catch superset exceptions. Use it after the @api decorator above so superset exception handler is triggered before the handler for generic exceptions. """ - def wraps(self, *args, **kwargs): + def wraps(self: "BaseSupersetView", *args: Any, **kwargs: Any) -> FlaskResponse: try: return f(self, *args, **kwargs) except SupersetSecurityException as ex: @@ -179,7 +184,7 @@ def handle_api_exception(f): except HTTPException as ex: logger.exception(ex) return json_error_response( - utils.error_msg_from_exception(ex), status=ex.code + utils.error_msg_from_exception(ex), status=cast(int, ex.code) ) except Exception as ex: # pylint: disable=broad-except logger.exception(ex) @@ -233,7 +238,9 @@ def get_user_roles() -> List[Role]: class BaseSupersetView(BaseView): @staticmethod - def json_response(obj, status=200) -> Response: # pylint: disable=no-self-use + def json_response( + obj: Any, status: int = 200 + ) -> FlaskResponse: # pylint: disable=no-self-use return Response( json.dumps(obj, default=utils.json_int_dttm_ser, ignore_nan=True), status=status, @@ -241,7 +248,7 @@ class BaseSupersetView(BaseView): ) -def menu_data(): +def menu_data() -> Dict[str, Any]: menu = appbuilder.menu.get_data() root_path = "#" logo_target_path = "" @@ -290,7 +297,7 @@ def menu_data(): } -def common_bootstrap_payload(): +def common_bootstrap_payload() -> Dict[str, Any]: """Common data always sent to the client""" messages = get_flashed_messages(with_categories=True) locale = str(get_locale()) @@ -335,7 +342,7 @@ class ListWidgetWithCheckboxes(ListWidget): # pylint: disable=too-few-public-me template = "superset/fab_overrides/list_with_checkboxes.html" -def validate_json(_form, field): +def validate_json(form: Form, field: Field) -> None: # pylint: disable=unused-argument try: json.loads(field.data) except Exception as ex: @@ -352,24 +359,23 @@ class YamlExportMixin: # pylint: disable=too-few-public-methods yaml_dict_key: Optional[str] = None @action("yaml_export", __("Export to YAML"), __("Export to YAML?"), "fa-download") - def yaml_export(self, items): + def yaml_export( + self, items: Union[ImportMixin, List[ImportMixin]] + ) -> FlaskResponse: if not isinstance(items, list): items = [items] data = [t.export_to_dict() for t in items] - if self.yaml_dict_key: - data = {self.yaml_dict_key: data} + return Response( - yaml.safe_dump(data), + yaml.safe_dump({self.yaml_dict_key: data} if self.yaml_dict_key else data), headers=generate_download_headers("yaml"), mimetype="application/text", ) class DeleteMixin: # pylint: disable=too-few-public-methods - def _delete( - self: Union[BaseView, "DeleteMixin", "DruidClusterModelView"], primary_key: int, - ) -> None: + def _delete(self: BaseView, primary_key: int,) -> None: """ Delete function logic, override to implement diferent logic deletes the record with primary_key = primary_key @@ -411,7 +417,7 @@ class DeleteMixin: # pylint: disable=too-few-public-methods @action( "muldelete", __("Delete"), __("Delete all Really?"), "fa-trash", single=False ) - def muldelete(self, items): + def muldelete(self: BaseView, items: List[Model]) -> FlaskResponse: if not items: abort(404) for item in items: @@ -426,7 +432,7 @@ class DeleteMixin: # pylint: disable=too-few-public-methods class DatasourceFilter(BaseFilter): # pylint: disable=too-few-public-methods - def apply(self, query, value): + def apply(self, query: Query, value: Any) -> Query: if security_manager.all_datasource_access(): return query datasource_perms = security_manager.user_view_menu_names("datasource_access") @@ -497,7 +503,7 @@ def check_ownership(obj: Any, raise_if_false: bool = True) -> bool: def bind_field( - _, form: DynamicForm, unbound_field: UnboundField, options: Dict[Any, Any] + _: Any, form: DynamicForm, unbound_field: UnboundField, options: Dict[Any, Any] ) -> Field: """ Customize how fields are bound by stripping all whitespace. diff --git a/superset/views/base_api.py b/superset/views/base_api.py index 5675506dc5..3d40c338d7 100644 --- a/superset/views/base_api.py +++ b/superset/views/base_api.py @@ -16,17 +16,18 @@ # under the License. import functools import logging -from typing import Any, cast, Dict, Optional, Set, Tuple, Type, Union +from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Type, Union from apispec import APISpec -from flask import Response -from flask_appbuilder import ModelRestApi +from flask import Blueprint, Response +from flask_appbuilder import AppBuilder, Model, ModelRestApi from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.models.filters import BaseFilter, Filters from flask_appbuilder.models.sqla.filters import FilterStartsWith from marshmallow import Schema from superset.stats_logger import BaseStatsLogger +from superset.typing import FlaskResponse from superset.utils.core import time_function logger = logging.getLogger(__name__) @@ -40,12 +41,12 @@ get_related_schema = { } -def statsd_metrics(f): +def statsd_metrics(f: Callable) -> Callable: """ Handle sending all statsd metrics from the REST API """ - def wraps(self, *args: Any, **kwargs: Any) -> Response: + def wraps(self: "BaseSupersetModelRestApi", *args: Any, **kwargs: Any) -> Response: duration, response = time_function(f, self, *args, **kwargs) self.send_stats_metrics(response, f.__name__, duration) return response @@ -116,6 +117,11 @@ class BaseSupersetModelRestApi(ModelRestApi): Add extra schemas to the OpenAPI component schemas section """ # pylint: disable=pointless-string-statement + add_columns: List[str] + edit_columns: List[str] + list_columns: List[str] + show_columns: List[str] + def __init__(self) -> None: super().__init__() self.stats_logger = BaseStatsLogger() @@ -128,11 +134,13 @@ class BaseSupersetModelRestApi(ModelRestApi): ) super().add_apispec_components(api_spec) - def create_blueprint(self, appbuilder, *args, **kwargs): + def create_blueprint( + self, appbuilder: AppBuilder, *args: Any, **kwargs: Any + ) -> Blueprint: self.stats_logger = self.appbuilder.get_app.config["STATS_LOGGER"] return super().create_blueprint(appbuilder, *args, **kwargs) - def _init_properties(self): + def _init_properties(self) -> None: model_id = self.datamodel.get_pk_name() if self.list_columns is None and not self.list_model_schema: self.list_columns = [model_id] @@ -144,7 +152,9 @@ class BaseSupersetModelRestApi(ModelRestApi): self.add_columns = [model_id] super()._init_properties() - def _get_related_filter(self, datamodel, column_name: str, value: str) -> Filters: + def _get_related_filter( + self, datamodel: Model, column_name: str, value: str + ) -> Filters: filter_field = self.related_field_filters.get(column_name) if isinstance(filter_field, str): filter_field = RelatedFieldFilter(cast(str, filter_field), FilterStartsWith) @@ -198,7 +208,7 @@ class BaseSupersetModelRestApi(ModelRestApi): if time_delta: self.timing_stats("time", key, time_delta) - def info_headless(self, **kwargs) -> Response: + def info_headless(self, **kwargs: Any) -> Response: """ Add statsd metrics to builtin FAB _info endpoint """ @@ -206,7 +216,7 @@ class BaseSupersetModelRestApi(ModelRestApi): self.send_stats_metrics(response, self.info.__name__, duration) return response - def get_headless(self, pk, **kwargs) -> Response: + def get_headless(self, pk: int, **kwargs: Any) -> Response: """ Add statsd metrics to builtin FAB GET endpoint """ @@ -214,7 +224,7 @@ class BaseSupersetModelRestApi(ModelRestApi): self.send_stats_metrics(response, self.get.__name__, duration) return response - def get_list_headless(self, **kwargs) -> Response: + def get_list_headless(self, **kwargs: Any) -> Response: """ Add statsd metrics to builtin FAB GET list endpoint """ @@ -227,7 +237,7 @@ class BaseSupersetModelRestApi(ModelRestApi): @safe @statsd_metrics @rison(get_related_schema) - def related(self, column_name: str, **kwargs): + def related(self, column_name: str, **kwargs: Any) -> FlaskResponse: """Get related fields data --- get: diff --git a/superset/views/base_schemas.py b/superset/views/base_schemas.py index e4795c53c0..a4436ddaa6 100644 --- a/superset/views/base_schemas.py +++ b/superset/views/base_schemas.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Dict, List, Optional +from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Set, Union from flask import current_app, g from flask_appbuilder import Model @@ -22,7 +22,7 @@ from marshmallow import post_load, pre_load, Schema, ValidationError from sqlalchemy.orm.exc import NoResultFound -def validate_owner(value): +def validate_owner(value: int) -> None: try: ( current_app.appbuilder.get_session.query( @@ -44,18 +44,25 @@ class BaseSupersetSchema(Schema): __class_model__: Model = None - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: self.instance: Optional[Model] = None super().__init__(**kwargs) - def load( - self, data, many=None, partial=None, instance: Model = None, **kwargs - ): # pylint: disable=arguments-differ + def load( # pylint: disable=arguments-differ + self, + data: Union[Mapping[str, Any], Iterable[Mapping[str, Any]]], + many: Optional[bool] = None, + partial: Optional[Union[bool, Sequence[str], Set[str]]] = None, + instance: Optional[Model] = None, + **kwargs: Any, + ) -> Any: self.instance = instance return super().load(data, many=many, partial=partial, **kwargs) @post_load - def make_object(self, data: Dict, discard: Optional[List[str]] = None) -> Model: + def make_object( + self, data: Dict[Any, Any], discard: Optional[List[str]] = None + ) -> Model: """ Creates a Model object from POST or PUT requests. PUT will use self.instance previously fetched from the endpoint handler @@ -92,13 +99,13 @@ class BaseOwnedSchema(BaseSupersetSchema): return instance @pre_load - def pre_load(self, data: Dict): + def pre_load(self, data: Dict[Any, Any]) -> None: # if PUT request don't set owners to empty list if not self.instance: data[self.owners_field_name] = data.get(self.owners_field_name, []) @staticmethod - def set_owners(instance: Model, owners: List[int]): + def set_owners(instance: Model, owners: List[int]) -> None: owner_objs = list() if g.user.id not in owners: owners.append(g.user.id) diff --git a/superset/views/core.py b/superset/views/core.py index 3d893341c9..c561222c51 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -33,6 +33,7 @@ from flask_appbuilder import expose from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_appbuilder.security.decorators import has_access, has_access_api from flask_appbuilder.security.sqla import models as ab_models +from flask_appbuilder.security.sqla.models import User from flask_babel import gettext as __, lazy_gettext as _ from sqlalchemy import and_, Integer, or_, select from sqlalchemy.engine.url import make_url @@ -64,7 +65,12 @@ from superset import ( viz, ) from superset.connectors.connector_registry import ConnectorRegistry -from superset.connectors.sqla.models import AnnotationDatasource +from superset.connectors.sqla.models import ( + AnnotationDatasource, + SqlaTable, + SqlMetric, + TableColumn, +) from superset.constants import RouteMethod from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import ( @@ -88,12 +94,14 @@ from superset.security.analytics_db_safety import ( ) from superset.sql_parse import ParsedQuery, Table from superset.sql_validators import get_validator_by_name +from superset.typing import FlaskResponse from superset.utils import core as utils, dashboard_import_export from superset.utils.dashboard_filter_scopes_converter import copy_filter_scopes from superset.utils.dates import now_as_float from superset.utils.decorators import etag_cache, stats_timing from superset.views.database.filters import DatabaseFilter from superset.views.utils import get_dashboard_extra_filters +from superset.viz import BaseViz from .base import ( api, @@ -157,7 +165,7 @@ if not config["ENABLE_JAVASCRIPT_CONTROLS"]: FORM_DATA_KEY_BLACKLIST = ["js_tooltip", "js_onclick_href", "js_data_mutator"] -def get_database_access_error_msg(database_name): +def get_database_access_error_msg(database_name: str) -> str: return __( "This view requires the database %(name)s or " "`all_datasource_access` permission", @@ -165,13 +173,15 @@ def get_database_access_error_msg(database_name): ) -def is_owner(obj, user): +def is_owner(obj: Union[Dashboard, Slice], user: User) -> bool: """ Check if user is owner of the slice """ return obj and user in obj.owners def check_datasource_perms( - self, datasource_type: Optional[str] = None, datasource_id: Optional[int] = None + self: "Superset", + datasource_type: Optional[str] = None, + datasource_id: Optional[int] = None, ) -> None: """ Check if user can access a cached response from explore_json. @@ -218,7 +228,7 @@ def check_datasource_perms( security_manager.assert_viz_permission(viz_obj) -def check_slice_perms(self, slice_id): +def check_slice_perms(self: "Superset", slice_id: int) -> None: """ Check if user can access a cached response from slice_json. @@ -228,19 +238,20 @@ def check_slice_perms(self, slice_id): form_data, slc = get_form_data(slice_id, use_slice_data=True) - viz_obj = get_viz( - datasource_type=slc.datasource.type, - datasource_id=slc.datasource.id, - form_data=form_data, - force=False, - ) + if slc: + viz_obj = get_viz( + datasource_type=slc.datasource.type, + datasource_id=slc.datasource.id, + form_data=form_data, + force=False, + ) - security_manager.assert_viz_permission(viz_obj) + security_manager.assert_viz_permission(viz_obj) def _deserialize_results_payload( - payload: Union[bytes, str], query, use_msgpack: Optional[bool] = False -) -> dict: + payload: Union[bytes, str], query: Query, use_msgpack: Optional[bool] = False +) -> Dict[Any, Any]: logger.debug(f"Deserializing from msgpack: {use_msgpack}") if use_msgpack: with stats_timing( @@ -305,19 +316,19 @@ class AccessRequestsModelView(SupersetModelView, DeleteMixin): @talisman(force_https=False) @app.route("/health") -def health(): +def health() -> FlaskResponse: return "OK" @talisman(force_https=False) @app.route("/healthcheck") -def healthcheck(): +def healthcheck() -> FlaskResponse: return "OK" @talisman(force_https=False) @app.route("/ping") -def ping(): +def ping() -> FlaskResponse: return "OK" @@ -328,26 +339,26 @@ class KV(BaseSupersetView): @event_logger.log_this @has_access_api @expose("/store/", methods=["POST"]) - def store(self): + def store(self) -> FlaskResponse: try: value = request.form.get("data") obj = models.KeyValue(value=value) db.session.add(obj) db.session.commit() except Exception as ex: - return json_error_response(ex) + return json_error_response(utils.error_msg_from_exception(ex)) return Response(json.dumps({"id": obj.id}), status=200) @event_logger.log_this @has_access_api - @expose("//", methods=["GET"]) - def get_value(self, key_id): + @expose("//", methods=["GET"]) + def get_value(self, key_id: int) -> FlaskResponse: try: kv = db.session.query(models.KeyValue).filter_by(id=key_id).scalar() if not kv: return Response(status=404, content_type="text/plain") except Exception as ex: - return json_error_response(ex) + return json_error_response(utils.error_msg_from_exception(ex)) return Response(kv.value, status=200, content_type="text/plain") @@ -356,8 +367,8 @@ class R(BaseSupersetView): """used for short urls""" @event_logger.log_this - @expose("/") - def index(self, url_id): + @expose("/") + def index(self, url_id: int) -> FlaskResponse: url = db.session.query(models.Url).get(url_id) if url and url.url: explore_url = "//superset/explore/?" @@ -373,7 +384,7 @@ class R(BaseSupersetView): @event_logger.log_this @has_access_api @expose("/shortner/", methods=["POST"]) - def shortner(self): + def shortner(self) -> FlaskResponse: url = request.form.get("data") obj = models.Url(url=url) db.session.add(obj) @@ -393,15 +404,21 @@ class Superset(BaseSupersetView): @has_access_api @expose("/datasources/") - def datasources(self): - datasources = ConnectorRegistry.get_all_datasources(db.session) - datasources = [o.short_data for o in datasources if o.short_data.get("name")] - datasources = sorted(datasources, key=lambda o: o["name"]) - return self.json_response(datasources) + def datasources(self) -> FlaskResponse: + return self.json_response( + sorted( + [ + datasource.short_data + for datasource in ConnectorRegistry.get_all_datasources(db.session) + if datasource.short_data.get("name") + ], + key=lambda datasource: datasource["name"], + ) + ) @has_access_api @expose("/override_role_permissions/", methods=["POST"]) - def override_role_permissions(self): + def override_role_permissions(self) -> FlaskResponse: """Updates the role with the give datasource permissions. Permissions not in the request will be revoked. This endpoint should @@ -454,7 +471,7 @@ class Superset(BaseSupersetView): @event_logger.log_this @has_access @expose("/request_access/") - def request_access(self): + def request_access(self) -> FlaskResponse: datasources = set() dashboard_id = request.args.get("dashboard_id") if dashboard_id: @@ -462,7 +479,7 @@ class Superset(BaseSupersetView): datasources |= dash.datasources datasource_id = request.args.get("datasource_id") datasource_type = request.args.get("datasource_type") - if datasource_id: + if datasource_id and datasource_type: ds_class = ConnectorRegistry.sources.get(datasource_type) datasource = ( db.session.query(ds_class).filter_by(id=int(datasource_id)).one() @@ -497,8 +514,8 @@ class Superset(BaseSupersetView): @event_logger.log_this @has_access @expose("/approve") - def approve(self): - def clean_fulfilled_requests(session): + def approve(self) -> FlaskResponse: + def clean_fulfilled_requests(session: Session) -> None: for r in session.query(DAR).all(): datasource = ConnectorRegistry.get_datasource( r.datasource_type, r.datasource_id, session @@ -508,8 +525,8 @@ class Superset(BaseSupersetView): session.delete(r) session.commit() - datasource_type = request.args.get("datasource_type") - datasource_id = request.args.get("datasource_id") + datasource_type = request.args["datasource_type"] + datasource_id = request.args["datasource_id"] created_by_username = request.args.get("created_by") role_to_grant = request.args.get("role_to_grant") role_to_extend = request.args.get("role_to_extend") @@ -598,8 +615,8 @@ class Superset(BaseSupersetView): return redirect("/accessrequestsmodelview/list/") @has_access - @expose("/slice//") - def slice(self, slice_id): + @expose("/slice//") + def slice(self, slice_id: int) -> FlaskResponse: form_data, slc = get_form_data(slice_id, use_slice_data=True) if not slc: abort(404) @@ -611,15 +628,16 @@ class Superset(BaseSupersetView): endpoint += f"&{param}=true" return redirect(endpoint) - def get_query_string_response(self, viz_obj): + def get_query_string_response(self, viz_obj: BaseViz) -> FlaskResponse: query = None try: query_obj = viz_obj.query_obj() if query_obj: query = viz_obj.datasource.get_query_str(query_obj) except Exception as ex: - logger.exception(ex) - return json_error_response(ex) + err_msg = utils.error_msg_from_exception(ex) + logger.exception(err_msg) + return json_error_response(err_msg) if not query: query = "No query." @@ -628,15 +646,17 @@ class Superset(BaseSupersetView): {"query": query, "language": viz_obj.datasource.query_language} ) - def get_raw_results(self, viz_obj): + def get_raw_results(self, viz_obj: BaseViz) -> FlaskResponse: return self.json_response( {"data": viz_obj.get_df_payload()["df"].to_dict("records")} ) - def get_samples(self, viz_obj): + def get_samples(self, viz_obj: BaseViz) -> FlaskResponse: return self.json_response({"data": viz_obj.get_samples()}) - def generate_json(self, viz_obj, response_type: Optional[str] = None) -> Response: + def generate_json( + self, viz_obj: BaseViz, response_type: Optional[str] = None + ) -> FlaskResponse: if response_type == utils.ChartDataResultFormat.CSV: return CsvResponse( viz_obj.get_csv(), @@ -660,16 +680,16 @@ class Superset(BaseSupersetView): @event_logger.log_this @api @has_access_api - @expose("/slice_json/") + @expose("/slice_json/") @etag_cache(CACHE_DEFAULT_TIMEOUT, check_perms=check_slice_perms) - def slice_json(self, slice_id): + def slice_json(self, slice_id: int) -> FlaskResponse: form_data, slc = get_form_data(slice_id, use_slice_data=True) - datasource_type = slc.datasource.type - datasource_id = slc.datasource.id + if not slc: + return json_error_response("The slice does not exist") try: viz_obj = get_viz( - datasource_type=datasource_type, - datasource_id=datasource_id, + datasource_type=slc.datasource.type, + datasource_id=slc.datasource.id, form_data=form_data, force=False, ) @@ -680,8 +700,8 @@ class Superset(BaseSupersetView): @event_logger.log_this @api @has_access_api - @expose("/annotation_json/") - def annotation_json(self, layer_id): + @expose("/annotation_json/") + def annotation_json(self, layer_id: int) -> FlaskResponse: form_data = get_form_data()[0] form_data["layer_id"] = layer_id form_data["filters"] = [{"col": "layer_id", "op": "==", "val": layer_id}] @@ -714,11 +734,14 @@ class Superset(BaseSupersetView): @has_access_api @handle_api_exception @expose( - "/explore_json///", methods=EXPLORE_JSON_METHODS + "/explore_json///", + methods=EXPLORE_JSON_METHODS, ) @expose("/explore_json/", methods=EXPLORE_JSON_METHODS) @etag_cache(CACHE_DEFAULT_TIMEOUT, check_perms=check_datasource_perms) - def explore_json(self, datasource_type=None, datasource_id=None): + def explore_json( + self, datasource_type: Optional[str] = None, datasource_id: Optional[int] = None + ) -> FlaskResponse: """Serves all request that GET or POST form_data This endpoint evolved to be the entry point of many different @@ -729,7 +752,9 @@ class Superset(BaseSupersetView): TODO: break into one endpoint for each return shape""" response_type = utils.ChartDataResultFormat.JSON.value - responses = [resp_format for resp_format in utils.ChartDataResultFormat] + responses: List[ + Union[utils.ChartDataResultFormat, utils.ChartDataResultType] + ] = [resp_format for resp_format in utils.ChartDataResultFormat] responses.extend([resp_type for resp_type in utils.ChartDataResultType]) for response_option in responses: if request.args.get(response_option) == "true": @@ -744,7 +769,7 @@ class Superset(BaseSupersetView): ) viz_obj = get_viz( - datasource_type=datasource_type, + datasource_type=cast(str, datasource_type), datasource_id=datasource_id, form_data=form_data, force=request.args.get("force") == "true", @@ -757,7 +782,7 @@ class Superset(BaseSupersetView): @event_logger.log_this @has_access @expose("/import_dashboards", methods=["GET", "POST"]) - def import_dashboards(self): + def import_dashboards(self) -> FlaskResponse: """Overrides the dashboards using json instances from the file.""" f = request.files.get("file") if request.method == "POST" and f: @@ -788,9 +813,11 @@ class Superset(BaseSupersetView): @event_logger.log_this @has_access - @expose("/explore///", methods=["GET", "POST"]) + @expose("/explore///", methods=["GET", "POST"]) @expose("/explore/", methods=["GET", "POST"]) - def explore(self, datasource_type=None, datasource_id=None): + def explore( + self, datasource_type: Optional[str] = None, datasource_id: Optional[int] = None + ) -> FlaskResponse: user_id = g.user.get_id() if g.user else None form_data, slc = get_form_data(use_slice_data=True) @@ -834,7 +861,7 @@ class Superset(BaseSupersetView): return redirect(error_redirect) datasource = ConnectorRegistry.get_datasource( - datasource_type, datasource_id, db.session + cast(str, datasource_type), datasource_id, db.session ) if not datasource: flash(DATASOURCE_MISSING_ERR, "danger") @@ -859,12 +886,12 @@ class Superset(BaseSupersetView): # slc perms slice_add_perm = security_manager.can_access("can_add", "SliceModelView") - slice_overwrite_perm = is_owner(slc, g.user) + slice_overwrite_perm = is_owner(slc, g.user) if slc else False slice_download_perm = security_manager.can_access( "can_download", "SliceModelView" ) - form_data["datasource"] = str(datasource_id) + "__" + datasource_type + form_data["datasource"] = str(datasource_id) + "__" + cast(str, datasource_type) # On explore, merge legacy and extra filters into the form data utils.convert_legacy_filters_into_adhoc(form_data) @@ -890,14 +917,16 @@ class Superset(BaseSupersetView): ) if action in ("saveas", "overwrite"): + if not slc: + return json_error_response("The slice does not exist") + return self.save_or_overwrite_slice( - request.args, slc, slice_add_perm, slice_overwrite_perm, slice_download_perm, datasource_id, - datasource_type, + cast(str, datasource_type), datasource.name, ) @@ -940,8 +969,10 @@ class Superset(BaseSupersetView): @api @handle_api_exception @has_access_api - @expose("/filter////") - def filter(self, datasource_type, datasource_id, column): + @expose("/filter////") + def filter( + self, datasource_type: str, datasource_id: int, column: str + ) -> FlaskResponse: """ Endpoint to retrieve values for specified column. @@ -965,28 +996,27 @@ class Superset(BaseSupersetView): return json_success(payload) @staticmethod - def remove_extra_filters(filters): + def remove_extra_filters(filters: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """Extra filters are ones inherited from the dashboard's temporary context Those should not be saved when saving the chart""" return [f for f in filters if not f.get("isExtra")] def save_or_overwrite_slice( self, - args, - slc, - slice_add_perm, - slice_overwrite_perm, - slice_download_perm, - datasource_id, - datasource_type, - datasource_name, - ): + slc: Slice, + slice_add_perm: bool, + slice_overwrite_perm: bool, + slice_download_perm: bool, + datasource_id: int, + datasource_type: str, + datasource_name: str, + ) -> FlaskResponse: """Save or overwrite a slice""" - slice_name = args.get("slice_name") - action = args.get("action") + slice_name = request.args.get("slice_name") + action = request.args.get("action") form_data = get_form_data()[0] - if action in ("saveas"): + if action == "saveas": if "slice_id" in form_data: form_data.pop("slice_id") # don't save old slice_id slc = Slice(owners=[g.user] if g.user else []) @@ -1002,18 +1032,20 @@ class Superset(BaseSupersetView): slc.datasource_id = datasource_id slc.slice_name = slice_name - if action in ("saveas") and slice_add_perm: + if action == "saveas" and slice_add_perm: self.save_slice(slc) elif action == "overwrite" and slice_overwrite_perm: self.overwrite_slice(slc) # Adding slice to a dashboard if requested - dash = None + dash: Optional[Dashboard] = None + if request.args.get("add_to_dash") == "existing": - dash = ( + dash = cast( + Dashboard, db.session.query(Dashboard) - .filter_by(id=int(request.args.get("save_to_dashboard_id"))) - .one() + .filter_by(id=int(request.args["save_to_dashboard_id"])) + .one(), ) # check edit dashboard permissions dash_overwrite_perm = check_ownership(dash, raise_if_false=False) @@ -1066,19 +1098,19 @@ class Superset(BaseSupersetView): "dashboard_id": dash.id if dash else None, } - if request.args.get("goto_dash") == "true": + if dash and request.args.get("goto_dash") == "true": response.update({"dashboard": dash.url}) return json_success(json.dumps(response)) - def save_slice(self, slc): + def save_slice(self, slc: Slice) -> None: session = db.session() msg = _("Chart [{}] has been saved").format(slc.slice_name) session.add(slc) session.commit() flash(msg, "info") - def overwrite_slice(self, slc): + def overwrite_slice(self, slc: Slice) -> None: session = db.session() session.merge(slc) session.commit() @@ -1087,17 +1119,16 @@ class Superset(BaseSupersetView): @api @has_access_api - @expose("/schemas//") - @expose("/schemas///") - def schemas(self, db_id, force_refresh="false"): + @expose("/schemas//") + @expose("/schemas///") + def schemas(self, db_id: int, force_refresh: str = "false") -> FlaskResponse: db_id = int(db_id) - force_refresh = force_refresh.lower() == "true" database = db.session.query(models.Database).get(db_id) if database: schemas = database.get_all_schema_names( cache=database.schema_cache_enabled, cache_timeout=database.schema_cache_timeout, - force=force_refresh, + force=force_refresh.lower() == "true", ) schemas = security_manager.schemas_accessible_by_user(database, schemas) else: @@ -1111,7 +1142,7 @@ class Superset(BaseSupersetView): @expose("/tables/////") def tables( self, db_id: int, schema: str, substr: str, force_refresh: str = "false" - ): + ) -> FlaskResponse: """Endpoint to fetch the list of tables for given database""" # Guarantees database filtering by security access query = db.session.query(models.Database) @@ -1211,11 +1242,11 @@ class Superset(BaseSupersetView): @api @has_access_api - @expose("/copy_dash//", methods=["GET", "POST"]) - def copy_dash(self, dashboard_id): + @expose("/copy_dash//", methods=["GET", "POST"]) + def copy_dash(self, dashboard_id: int) -> FlaskResponse: """Copy dashboard""" session = db.session() - data = json.loads(request.form.get("data")) + data = json.loads(request.form["data"]) dash = models.Dashboard() original_dash = session.query(Dashboard).get(dashboard_id) @@ -1235,12 +1266,8 @@ class Superset(BaseSupersetView): # update chartId of layout entities for value in data["positions"].values(): - if ( - isinstance(value, dict) - and value.get("meta") - and value.get("meta").get("chartId") - ): - old_id = value.get("meta").get("chartId") + if isinstance(value, dict) and value.get("meta", {}).get("chartId"): + old_id = value["meta"]["chartId"] new_id = old_to_new_slice_ids[old_id] value["meta"]["chartId"] = new_id else: @@ -1257,13 +1284,13 @@ class Superset(BaseSupersetView): @api @has_access_api - @expose("/save_dash//", methods=["GET", "POST"]) - def save_dash(self, dashboard_id): + @expose("/save_dash//", methods=["GET", "POST"]) + def save_dash(self, dashboard_id: int) -> FlaskResponse: """Save a dashboard's metadata""" session = db.session() dash = session.query(Dashboard).get(dashboard_id) check_ownership(dash, raise_if_false=True) - data = json.loads(request.form.get("data")) + data = json.loads(request.form["data"]) self._set_dash_metadata(dash, data) session.merge(dash) session.commit() @@ -1272,8 +1299,10 @@ class Superset(BaseSupersetView): @staticmethod def _set_dash_metadata( - dashboard, data, old_to_new_slice_ids: Optional[Dict[int, int]] = None - ): + dashboard: Dashboard, + data: Dict[Any, Any], + old_to_new_slice_ids: Optional[Dict[int, int]] = None, + ) -> None: positions = data["positions"] # find slices in the position data slice_ids = [] @@ -1352,10 +1381,10 @@ class Superset(BaseSupersetView): @api @has_access_api - @expose("/add_slices//", methods=["POST"]) - def add_slices(self, dashboard_id): + @expose("/add_slices//", methods=["POST"]) + def add_slices(self, dashboard_id: int) -> FlaskResponse: """Add and save slices to a dashboard""" - data = json.loads(request.form.get("data")) + data = json.loads(request.form["data"]) session = db.session() dash = session.query(Dashboard).get(dashboard_id) check_ownership(dash, raise_if_false=True) @@ -1369,7 +1398,7 @@ class Superset(BaseSupersetView): @api @has_access_api @expose("/testconn", methods=["POST", "GET"]) - def testconn(self): + def testconn(self) -> FlaskResponse: """Tests a sqla connection""" db_name = request.json.get("name") uri = request.json.get("uri") @@ -1443,13 +1472,13 @@ class Superset(BaseSupersetView): @api @has_access_api - @expose("/recent_activity//", methods=["GET"]) - def recent_activity(self, user_id): + @expose("/recent_activity//", methods=["GET"]) + def recent_activity(self, user_id: int) -> FlaskResponse: """Recent activity (actions) for a given user""" M = models if request.args.get("limit"): - limit = int(request.args.get("limit")) + limit = int(request.args["limit"]) else: limit = 1000 @@ -1490,7 +1519,7 @@ class Superset(BaseSupersetView): @api @has_access_api @expose("/csrf_token/", methods=["GET"]) - def csrf_token(self): + def csrf_token(self) -> FlaskResponse: return Response( self.render_template("superset/csrf_token.json"), mimetype="text/json" ) @@ -1498,7 +1527,7 @@ class Superset(BaseSupersetView): @api @has_access_api @expose("/available_domains/", methods=["GET"]) - def available_domains(self): + def available_domains(self) -> FlaskResponse: """ Returns the list of available Superset Webserver domains (if any) defined in config. This enables charts embedded in other apps to @@ -1511,15 +1540,15 @@ class Superset(BaseSupersetView): @api @has_access_api @expose("/fave_dashboards_by_username//", methods=["GET"]) - def fave_dashboards_by_username(self, username): + def fave_dashboards_by_username(self, username: str) -> FlaskResponse: """This lets us use a user's username to pull favourite dashboards""" user = security_manager.find_user(username=username) return self.fave_dashboards(user.get_id()) @api @has_access_api - @expose("/fave_dashboards//", methods=["GET"]) - def fave_dashboards(self, user_id): + @expose("/fave_dashboards//", methods=["GET"]) + def fave_dashboards(self, user_id: int) -> FlaskResponse: qry = ( db.session.query(Dashboard, models.FavStar.dttm) .join( @@ -1550,8 +1579,8 @@ class Superset(BaseSupersetView): @api @has_access_api - @expose("/created_dashboards//", methods=["GET"]) - def created_dashboards(self, user_id): + @expose("/created_dashboards//", methods=["GET"]) + def created_dashboards(self, user_id: int) -> FlaskResponse: Dash = Dashboard qry = ( db.session.query(Dash) @@ -1573,8 +1602,8 @@ class Superset(BaseSupersetView): @api @has_access_api @expose("/user_slices", methods=["GET"]) - @expose("/user_slices//", methods=["GET"]) - def user_slices(self, user_id=None): + @expose("/user_slices//", methods=["GET"]) + def user_slices(self, user_id: Optional[int] = None) -> FlaskResponse: """List of slices a user created, or faved""" if not user_id: user_id = g.user.id @@ -1584,7 +1613,7 @@ class Superset(BaseSupersetView): .join( models.FavStar, and_( - models.FavStar.user_id == int(user_id), + models.FavStar.user_id == user_id, models.FavStar.class_name == "slice", Slice.id == models.FavStar.obj_id, ), @@ -1615,8 +1644,8 @@ class Superset(BaseSupersetView): @api @has_access_api @expose("/created_slices", methods=["GET"]) - @expose("/created_slices//", methods=["GET"]) - def created_slices(self, user_id=None): + @expose("/created_slices//", methods=["GET"]) + def created_slices(self, user_id: Optional[int] = None) -> FlaskResponse: """List of slices created by this user""" if not user_id: user_id = g.user.id @@ -1640,8 +1669,8 @@ class Superset(BaseSupersetView): @api @has_access_api @expose("/fave_slices", methods=["GET"]) - @expose("/fave_slices//", methods=["GET"]) - def fave_slices(self, user_id=None): + @expose("/fave_slices//", methods=["GET"]) + def fave_slices(self, user_id: Optional[int] = None) -> FlaskResponse: """Favorite slices for a user""" if not user_id: user_id = g.user.id @@ -1650,7 +1679,7 @@ class Superset(BaseSupersetView): .join( models.FavStar, and_( - models.FavStar.user_id == int(user_id), + models.FavStar.user_id == user_id, models.FavStar.class_name == "slice", Slice.id == models.FavStar.obj_id, ), @@ -1677,12 +1706,11 @@ class Superset(BaseSupersetView): @api @has_access_api @expose("/warm_up_cache/", methods=["GET"]) - def warm_up_cache(self): + def warm_up_cache(self) -> FlaskResponse: """Warms up the cache for the slice or table. Note for slices a force refresh occurs. """ - slices = None session = db.session() slice_id = request.args.get("slice_id") dashboard_id = request.args.get("dashboard_id") @@ -1704,7 +1732,6 @@ class Superset(BaseSupersetView): __("Chart %(id)s not found", id=slice_id), status=404 ) elif table_name and db_name: - SqlaTable = ConnectorRegistry.sources["table"] table = ( session.query(SqlaTable) .join(models.Database) @@ -1761,8 +1788,8 @@ class Superset(BaseSupersetView): return json_success(json.dumps(result)) @has_access_api - @expose("/favstar////") - def favstar(self, class_name, obj_id, action): + @expose("/favstar////") + def favstar(self, class_name: str, obj_id: int, action: str) -> FlaskResponse: """Toggle favorite stars on Slices and Dashboard""" session = db.session() FavStar = models.FavStar @@ -1793,8 +1820,8 @@ class Superset(BaseSupersetView): @api @has_access_api - @expose("/dashboard//published/", methods=("GET", "POST")) - def publish(self, dashboard_id): + @expose("/dashboard//published/", methods=("GET", "POST")) + def publish(self, dashboard_id: int) -> FlaskResponse: """Gets and toggles published status on dashboards""" logger.warning( "This API endpoint is deprecated and will be removed in version 1.0.0" @@ -1827,15 +1854,15 @@ class Superset(BaseSupersetView): return json_success(json.dumps({"published": dash.published})) @has_access - @expose("/dashboard//") - def dashboard(self, dashboard_id): + @expose("/dashboard//") + def dashboard(self, dashboard_id_or_slug: str) -> FlaskResponse: """Server side rendering for a dashboard""" session = db.session() qry = session.query(Dashboard) - if dashboard_id.isdigit(): - qry = qry.filter_by(id=int(dashboard_id)) + if dashboard_id_or_slug.isdigit(): + qry = qry.filter_by(id=int(dashboard_id_or_slug)) else: - qry = qry.filter_by(slug=dashboard_id) + qry = qry.filter_by(slug=dashboard_id_or_slug) dash = qry.one_or_none() if not dash: @@ -1885,7 +1912,7 @@ class Superset(BaseSupersetView): # Hack to log the dashboard_id properly, even when getting a slug @event_logger.log_this - def dashboard(**kwargs): + def dashboard(**kwargs: Any) -> None: pass dashboard( @@ -1939,13 +1966,13 @@ class Superset(BaseSupersetView): @api @event_logger.log_this @expose("/log/", methods=["POST"]) - def log(self): + def log(self) -> FlaskResponse: return Response(status=200) @has_access @expose("/sync_druid/", methods=["POST"]) @event_logger.log_this - def sync_druid_source(self): + def sync_druid_source(self) -> FlaskResponse: """Syncs the druid datasource in main db with the provided config. The endpoint takes 3 arguments: @@ -1996,14 +2023,15 @@ class Superset(BaseSupersetView): try: DruidDatasource.sync_to_db_from_config(druid_config, user, cluster) except Exception as ex: - logger.exception(utils.error_msg_from_exception(ex)) - return json_error_response(utils.error_msg_from_exception(ex)) + err_msg = utils.error_msg_from_exception(ex) + logger.exception(err_msg) + return json_error_response(err_msg) return Response(status=201) @has_access @expose("/get_or_create_table/", methods=["POST"]) @event_logger.log_this - def sqllab_table_viz(self): + def sqllab_table_viz(self) -> FlaskResponse: """ Gets or creates a table object with attributes passed to the API. It expects the json with params: @@ -2013,10 +2041,9 @@ class Superset(BaseSupersetView): * templateParams - params for the Jinja templating syntax, optional :return: Response """ - SqlaTable = ConnectorRegistry.sources["table"] - data = json.loads(request.form.get("data")) - table_name = data.get("datasourceName") - database_id = data.get("dbId") + data = json.loads(request.form["data"]) + table_name = data["datasourceName"] + database_id = data["dbId"] table = ( db.session.query(SqlaTable) .filter_by(database_id=database_id, table_name=table_name) @@ -2045,11 +2072,10 @@ class Superset(BaseSupersetView): @has_access @expose("/sqllab_viz/", methods=["POST"]) @event_logger.log_this - def sqllab_viz(self): - SqlaTable = ConnectorRegistry.sources["table"] - data = json.loads(request.form.get("data")) - table_name = data.get("datasourceName") - database_id = data.get("dbId") + def sqllab_viz(self) -> FlaskResponse: + data = json.loads(request.form["data"]) + table_name = data["datasourceName"] + database_id = data["dbId"] table = ( db.session.query(SqlaTable) .filter_by(database_id=database_id, table_name=table_name) @@ -2067,9 +2093,6 @@ class Superset(BaseSupersetView): cols = [] for config in data.get("columns"): column_name = config.get("name") - SqlaTable = ConnectorRegistry.sources["table"] - TableColumn = SqlaTable.column_class - SqlMetric = SqlaTable.metric_class col = TableColumn( column_name=column_name, filterable=True, @@ -2085,20 +2108,24 @@ class Superset(BaseSupersetView): return json_success(json.dumps({"table_id": table.id})) @has_access - @expose("/extra_table_metadata////") + @expose("/extra_table_metadata////") @event_logger.log_this - def extra_table_metadata(self, database_id, table_name, schema): - schema = utils.parse_js_uri_path_item(schema, eval_undefined=True) - table_name = utils.parse_js_uri_path_item(table_name) + def extra_table_metadata( + self, database_id: int, table_name: str, schema: str + ) -> FlaskResponse: + schema = utils.parse_js_uri_path_item(schema, eval_undefined=True) # type: ignore + table_name = utils.parse_js_uri_path_item(table_name) # type: ignore mydb = db.session.query(models.Database).filter_by(id=database_id).one() payload = mydb.db_engine_spec.extra_table_metadata(mydb, table_name, schema) return json_success(json.dumps(payload)) @has_access - @expose("/select_star//") - @expose("/select_star///") + @expose("/select_star//") + @expose("/select_star///") @event_logger.log_this - def select_star(self, database_id, table_name, schema=None): + def select_star( + self, database_id: int, table_name: str, schema: Optional[str] = None + ) -> FlaskResponse: logging.warning( f"{self.__class__.__name__}.select_star " "This API endpoint is deprecated and will be removed in version 1.0.0" @@ -2110,8 +2137,8 @@ class Superset(BaseSupersetView): f"deprecated.{self.__class__.__name__}.select_star.database_not_found" ) return json_error_response("Not found", 404) - schema = utils.parse_js_uri_path_item(schema, eval_undefined=True) - table_name = utils.parse_js_uri_path_item(table_name) + schema = utils.parse_js_uri_path_item(schema, eval_undefined=True) # type: ignore + table_name = utils.parse_js_uri_path_item(table_name) # type: ignore # Check that the user can access the datasource if not self.appbuilder.sm.can_access_datasource( database, Table(table_name, schema), schema @@ -2132,12 +2159,12 @@ class Superset(BaseSupersetView): ) @has_access_api - @expose("/estimate_query_cost//", methods=["POST"]) - @expose("/estimate_query_cost///", methods=["POST"]) + @expose("/estimate_query_cost//", methods=["POST"]) + @expose("/estimate_query_cost///", methods=["POST"]) @event_logger.log_this def estimate_query_cost( self, database_id: int, schema: Optional[str] = None - ) -> Response: + ) -> FlaskResponse: mydb = db.session.query(models.Database).get(database_id) sql = json.loads(request.form.get("sql", '""')) @@ -2157,7 +2184,7 @@ class Superset(BaseSupersetView): logger.exception(ex) return json_error_response(timeout_msg) except Exception as ex: - return json_error_response(str(ex)) + return json_error_response(utils.error_msg_from_exception(ex)) spec = mydb.db_engine_spec query_cost_formatters: Dict[str, Any] = get_feature_flags().get( @@ -2171,16 +2198,16 @@ class Superset(BaseSupersetView): return json_success(json.dumps(cost)) @expose("/theme/") - def theme(self): + def theme(self) -> FlaskResponse: return self.render_template("superset/theme.html") @has_access_api @expose("/results//") @event_logger.log_this - def results(self, key): + def results(self, key: str) -> FlaskResponse: return self.results_exec(key) - def results_exec(self, key: str): + def results_exec(self, key: str) -> FlaskResponse: """Serves a key off of the results backend It is possible to pass the `rows` query argument to limit the number @@ -2244,7 +2271,7 @@ class Superset(BaseSupersetView): on_giveup=lambda details: db.session.rollback(), max_tries=5, ) - def stop_query(self): + def stop_query(self) -> FlaskResponse: client_id = request.form.get("client_id") query = db.session.query(Query).filter_by(client_id=client_id).one() @@ -2265,12 +2292,12 @@ class Superset(BaseSupersetView): @has_access_api @expose("/validate_sql_json/", methods=["POST", "GET"]) @event_logger.log_this - def validate_sql_json(self): + def validate_sql_json(self) -> FlaskResponse: """Validates that arbitrary sql is acceptable for the given database. Returns a list of error/warning annotations as json. """ - sql = request.form.get("sql") - database_id = request.form.get("database_id") + sql = request.form["sql"] + database_id = request.form["database_id"] schema = request.form.get("schema") or None template_params = json.loads(request.form.get("templateParams") or "{}") @@ -2338,9 +2365,9 @@ class Superset(BaseSupersetView): query: Query, expand_data: bool, log_params: Optional[Dict[str, Any]] = None, - ) -> Response: + ) -> FlaskResponse: """ - Send SQL JSON query to celery workers + Send SQL JSON query to celery workers. :param session: SQLAlchemy session object :param rendered_query: the rendered query to perform by workers @@ -2389,9 +2416,9 @@ class Superset(BaseSupersetView): query: Query, expand_data: bool, log_params: Optional[Dict[str, Any]] = None, - ) -> Response: + ) -> FlaskResponse: """ - Execute SQL query (sql json) + Execute SQL query (sql json). :param rendered_query: The rendered query (included templates) :param query: The query SQL (SQLAlchemy) object @@ -2424,7 +2451,7 @@ class Superset(BaseSupersetView): ) except Exception as ex: logger.exception(f"Query {query.id}: {ex}") - return json_error_response(f"{{e}}") + return json_error_response(utils.error_msg_from_exception(ex)) if data.get("status") == QueryStatus.FAILED: return json_error_response(payload=data) return json_success(payload) @@ -2432,15 +2459,15 @@ class Superset(BaseSupersetView): @has_access_api @expose("/sql_json/", methods=["POST"]) @event_logger.log_this - def sql_json(self): + def sql_json(self) -> FlaskResponse: log_params = { "user_agent": cast(Optional[str], request.headers.get("USER_AGENT")) } return self.sql_json_exec(request.json, log_params) def sql_json_exec( - self, query_params: dict, log_params: Optional[Dict[str, Any]] = None - ): + self, query_params: Dict[str, Any], log_params: Optional[Dict[str, Any]] = None + ) -> FlaskResponse: """Runs arbitrary sql and returns data as json""" # Collect Values database_id: int = cast(int, query_params.get("database_id")) @@ -2564,7 +2591,7 @@ class Superset(BaseSupersetView): @has_access @expose("/csv/") @event_logger.log_this - def csv(self, client_id): + def csv(self, client_id: str) -> FlaskResponse: """Download the query results as csv.""" logger.info("Exporting CSV file [{}]".format(client_id)) query = db.session.query(Query).filter_by(client_id=client_id).one() @@ -2587,7 +2614,7 @@ class Superset(BaseSupersetView): blob, decode=not results_backend_use_msgpack ) obj = _deserialize_results_payload( - payload, query, results_backend_use_msgpack + payload, query, cast(bool, results_backend_use_msgpack) ) columns = [c["name"] for c in obj["columns"]] df = pd.DataFrame.from_records(obj["data"], columns=columns) @@ -2622,8 +2649,8 @@ class Superset(BaseSupersetView): @has_access @expose("/fetch_datasource_metadata") @event_logger.log_this - def fetch_datasource_metadata(self): - datasource_id, datasource_type = request.args.get("datasourceKey").split("__") + def fetch_datasource_metadata(self) -> FlaskResponse: + datasource_id, datasource_type = request.args["datasourceKey"].split("__") datasource = ConnectorRegistry.get_datasource( datasource_type, datasource_id, db.session ) @@ -2636,17 +2663,16 @@ class Superset(BaseSupersetView): return json_success(json.dumps(datasource.data)) @has_access_api - @expose("/queries/") - def queries(self, last_updated_ms): + @expose("/queries/") + def queries(self, last_updated_ms: int) -> FlaskResponse: """ Get the updated queries. :param last_updated_ms: unix time, milliseconds """ - last_updated_ms_int = int(float(last_updated_ms)) if last_updated_ms else 0 - return self.queries_exec(last_updated_ms_int) + return self.queries_exec(last_updated_ms) - def queries_exec(self, last_updated_ms_int: int): + def queries_exec(self, last_updated_ms: int) -> FlaskResponse: stats_logger.incr("queries") if not g.user.get_id(): return json_error_response( @@ -2654,7 +2680,7 @@ class Superset(BaseSupersetView): ) # UTC date time, same that is stored in the DB. - last_updated_dt = utils.EPOCH + timedelta(seconds=last_updated_ms_int / 1000) + last_updated_dt = utils.EPOCH + timedelta(seconds=last_updated_ms / 1000) sql_queries = ( db.session.query(Query) @@ -2669,7 +2695,7 @@ class Superset(BaseSupersetView): @has_access @expose("/search_queries") @event_logger.log_this - def search_queries(self) -> Response: + def search_queries(self) -> FlaskResponse: """ Search for previously run sqllab queries. Used for Sqllab Query Search page /superset/sqllab#search. @@ -2730,14 +2756,14 @@ class Superset(BaseSupersetView): ) @app.errorhandler(500) - def show_traceback(self): + def show_traceback(self) -> FlaskResponse: return ( render_template("superset/traceback.html", error_msg=get_error_msg()), 500, ) @expose("/welcome") - def welcome(self): + def welcome(self) -> FlaskResponse: """Personalized welcome page""" if not g.user or not g.user.get_id(): return redirect(appbuilder.get_url_for_login) @@ -2765,11 +2791,8 @@ class Superset(BaseSupersetView): @has_access @expose("/profile//") - def profile(self, username): + def profile(self, username: str) -> FlaskResponse: """User profile page""" - if not username and g.user: - username = g.user.username - user = ( db.session.query(ab_models.User).filter_by(username=username).one_or_none() ) @@ -2839,7 +2862,7 @@ class Superset(BaseSupersetView): @has_access @expose("/sqllab", methods=["GET", "POST"]) - def sqllab(self): + def sqllab(self) -> FlaskResponse: """SQL Editor""" payload = { "defaultDbId": config["SQLLAB_DEFAULT_DBID"], @@ -2864,7 +2887,7 @@ class Superset(BaseSupersetView): @api @has_access_api @expose("/schemas_access_for_csv_upload") - def schemas_access_for_csv_upload(self): + def schemas_access_for_csv_upload(self) -> FlaskResponse: """ This method exposes an API endpoint to get the schema access control settings for csv upload in this database @@ -2872,7 +2895,7 @@ class Superset(BaseSupersetView): if not request.args.get("db_id"): return json_error_response("No database is allowed for your csv upload") - db_id = int(request.args.get("db_id")) + db_id = int(request.args["db_id"]) database = db.session.query(models.Database).filter_by(id=db_id).one() try: schemas_allowed = database.get_schema_access_for_csv_upload() @@ -2919,11 +2942,11 @@ class CssTemplateAsyncModelView(CssTemplateModelView): @app.after_request -def apply_http_headers(response: Response): +def apply_http_headers(response: Response) -> Response: """Applies the configuration's http headers to all responses""" # HTTP_HEADERS is deprecated, this provides backwards compatibility - response.headers.extend( + response.headers.extend( # type: ignore {**config["OVERRIDE_HTTP_HEADERS"], **config["HTTP_HEADERS"]} ) diff --git a/superset/views/datasource.py b/superset/views/datasource.py index b641ee1d7f..2ce11027ce 100644 --- a/superset/views/datasource.py +++ b/superset/views/datasource.py @@ -17,7 +17,7 @@ import json from collections import Counter -from flask import request, Response +from flask import request from flask_appbuilder import expose from flask_appbuilder.security.decorators import has_access_api from sqlalchemy.orm.exc import NoResultFound @@ -25,6 +25,7 @@ from sqlalchemy.orm.exc import NoResultFound from superset import db from superset.connectors.connector_registry import ConnectorRegistry from superset.models.core import Database +from superset.typing import FlaskResponse from .base import api, BaseSupersetView, handle_api_exception, json_error_response @@ -36,7 +37,7 @@ class Datasource(BaseSupersetView): @has_access_api @api @handle_api_exception - def save(self) -> Response: + def save(self) -> FlaskResponse: data = request.form.get("data") if not isinstance(data, str): return json_error_response("Request missing data field.", status=500) @@ -78,7 +79,7 @@ class Datasource(BaseSupersetView): @has_access_api @api @handle_api_exception - def get(self, datasource_type: str, datasource_id: int) -> Response: + def get(self, datasource_type: str, datasource_id: int) -> FlaskResponse: try: orm_datasource = ConnectorRegistry.get_datasource( datasource_type, datasource_id, db.session @@ -95,7 +96,9 @@ class Datasource(BaseSupersetView): @has_access_api @api @handle_api_exception - def external_metadata(self, datasource_type: str, datasource_id: int) -> Response: + def external_metadata( + self, datasource_type: str, datasource_id: int + ) -> FlaskResponse: """Gets column info from the source system""" if datasource_type == "druid": datasource = ConnectorRegistry.get_datasource( diff --git a/superset/views/filters.py b/superset/views/filters.py index 3e4d85a555..3594d21075 100644 --- a/superset/views/filters.py +++ b/superset/views/filters.py @@ -14,9 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Any, cast, Optional + from flask_appbuilder.models.filters import BaseFilter from flask_babel import lazy_gettext from sqlalchemy import or_ +from sqlalchemy.orm import Query from superset import security_manager @@ -36,9 +39,9 @@ class FilterRelatedOwners(BaseFilter): name = lazy_gettext("Owner") arg_name = "owners" - def apply(self, query, value): + def apply(self, query: Query, value: Optional[Any]) -> Query: user_model = security_manager.user_model - like_value = "%" + value + "%" + like_value = "%" + cast(str, value) + "%" return query.filter( or_( # could be made to handle spaces between names more gracefully diff --git a/superset/views/log/__init__.py b/superset/views/log/__init__.py index b39d6023c8..103632b275 100644 --- a/superset/views/log/__init__.py +++ b/superset/views/log/__init__.py @@ -23,8 +23,8 @@ class LogMixin: # pylint: disable=too-few-public-methods add_title = _("Add Log") edit_title = _("Edit Log") - list_columns = ("user", "action", "dttm") - edit_columns = ("user", "action", "dttm", "json") + list_columns = ["user", "action", "dttm"] + edit_columns = ["user", "action", "dttm", "json"] base_order = ("dttm", "desc") label_columns = { "user": _("User"), diff --git a/superset/views/log/api.py b/superset/views/log/api.py index d579eb0fa7..f132c349a6 100644 --- a/superset/views/log/api.py +++ b/superset/views/log/api.py @@ -28,5 +28,5 @@ class LogRestApi(LogMixin, BaseSupersetModelRestApi): class_permission_name = "LogModelView" resource_name = "log" allow_browser_login = True - list_columns = ("user.username", "action", "dttm") + list_columns = ["user.username", "action", "dttm"] show_columns = list_columns diff --git a/superset/views/schedules.py b/superset/views/schedules.py index 6f35a7a095..68ae6ffdac 100644 --- a/superset/views/schedules.py +++ b/superset/views/schedules.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import enum -from typing import Optional, Type +from typing import Type import simplejson as json from croniter import croniter @@ -24,7 +24,7 @@ from flask_appbuilder import expose from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_appbuilder.security.decorators import has_access from flask_babel import lazy_gettext as _ -from wtforms import BooleanField, StringField +from wtforms import BooleanField, Form, StringField from superset import db, security_manager from superset.constants import RouteMethod @@ -37,6 +37,7 @@ from superset.models.schedules import ( ) from superset.models.slice import Slice from superset.tasks.schedules import schedule_email_report +from superset.typing import FlaskResponse from superset.utils.core import get_email_address_list, json_iso_dttm_ser from superset.views.core import json_success @@ -48,8 +49,14 @@ class EmailScheduleView( ): # pylint: disable=too-many-ancestors include_route_methods = RouteMethod.CRUD_SET _extra_data = {"test_email": False, "test_email_recipients": None} - schedule_type: Optional[str] = None - schedule_type_model: Optional[Type] = None + + @property + def schedule_type(self) -> str: + raise NotImplementedError() + + @property + def schedule_type_model(self) -> Type: + raise NotImplementedError() page_size = 20 @@ -87,7 +94,7 @@ class EmailScheduleView( edit_form_extra_fields = add_form_extra_fields - def process_form(self, form, is_created): + def process_form(self, form: Form, is_created: bool) -> None: if form.test_email_recipients.data: test_email_recipients = form.test_email_recipients.data.strip() else: @@ -95,7 +102,7 @@ class EmailScheduleView( self._extra_data["test_email"] = form.test_email.data self._extra_data["test_email_recipients"] = test_email_recipients - def pre_add(self, item): + def pre_add(self, item: "EmailScheduleView") -> None: try: recipients = get_email_address_list(item.recipients) item.recipients = ", ".join(recipients) @@ -106,10 +113,10 @@ class EmailScheduleView( if not croniter.is_valid(item.crontab): raise SupersetException("Invalid crontab format") - def pre_update(self, item): + def pre_update(self, item: "EmailScheduleView") -> None: self.pre_add(item) - def post_add(self, item): + def post_add(self, item: "EmailScheduleView") -> None: # Schedule a test mail if the user requested for it. if self._extra_data["test_email"]: recipients = self._extra_data["test_email_recipients"] or item.recipients @@ -122,12 +129,12 @@ class EmailScheduleView( if item.active: flash("Schedule changes will get applied in one hour", "warning") - def post_update(self, item): + def post_update(self, item: "EmailScheduleView") -> None: self.post_add(item) @has_access @expose("/fetch//", methods=["GET"]) - def fetch_schedules(self, item_id): + def fetch_schedules(self, item_id: int) -> FlaskResponse: query = db.session.query(self.datamodel.obj) query = query.join(self.schedule_type_model).filter( @@ -147,7 +154,9 @@ class EmailScheduleView( info[col] = info[col].username info["user"] = schedule.user.username - info[self.schedule_type] = getattr(schedule, self.schedule_type).id + info[self.schedule_type] = getattr( # type: ignore + schedule, self.schedule_type + ).id schedules.append(info) return json_success(json.dumps(schedules, default=json_iso_dttm_ser)) @@ -208,7 +217,7 @@ class DashboardEmailScheduleView( "delivery_type": _("Delivery Type"), } - def pre_add(self, item): + def pre_add(self, item: "DashboardEmailScheduleView") -> None: if item.dashboard is None: raise SupersetException("Dashboard is mandatory") super(DashboardEmailScheduleView, self).pre_add(item) @@ -269,7 +278,7 @@ class SliceEmailScheduleView(EmailScheduleView): # pylint: disable=too-many-anc "email_format": _("Email Format"), } - def pre_add(self, item): + def pre_add(self, item: "SliceEmailScheduleView") -> None: if item.slice is None: raise SupersetException("Slice is mandatory") super(SliceEmailScheduleView, self).pre_add(item) diff --git a/superset/views/sql_lab.py b/superset/views/sql_lab.py index 36469657e1..3476bb3fd0 100644 --- a/superset/views/sql_lab.py +++ b/superset/views/sql_lab.py @@ -27,6 +27,7 @@ from flask_sqlalchemy import BaseQuery from superset import db, get_feature_flags, security_manager from superset.constants import RouteMethod from superset.models.sql_lab import Query, SavedQuery, TableSchema, TabState +from superset.typing import FlaskResponse from superset.utils import core as utils from .base import ( @@ -120,15 +121,15 @@ class SavedQueryView( show_template = "superset/models/savedquery/show.html" - def pre_add(self, item): + def pre_add(self, item: "SavedQueryView") -> None: item.user = g.user - def pre_update(self, item): + def pre_update(self, item: "SavedQueryView") -> None: self.pre_add(item) @has_access @expose("show/") - def show(self, pk): + def show(self, pk: int) -> FlaskResponse: pk = self._deserialize_pk_if_composite(pk) widgets = self._show(pk) query = self.datamodel.get(pk).to_json() @@ -168,18 +169,18 @@ class SavedQueryViewApi(SavedQueryView): # pylint: disable=too-many-ancestors @has_access_api @expose("show/") - def show(self, pk): + def show(self, pk: int) -> FlaskResponse: return super().show(pk) -def _get_owner_id(tab_state_id): +def _get_owner_id(tab_state_id: int) -> int: return db.session.query(TabState.user_id).filter_by(id=tab_state_id).scalar() class TabStateView(BaseSupersetView): @has_access_api @expose("/", methods=["POST"]) - def post(self): # pylint: disable=no-self-use + def post(self) -> FlaskResponse: # pylint: disable=no-self-use query_editor = json.loads(request.form["queryEditor"]) tab_state = TabState( user_id=g.user.get_id(), @@ -201,7 +202,7 @@ class TabStateView(BaseSupersetView): @has_access_api @expose("/", methods=["DELETE"]) - def delete(self, tab_state_id): # pylint: disable=no-self-use + def delete(self, tab_state_id: int) -> FlaskResponse: # pylint: disable=no-self-use if _get_owner_id(tab_state_id) != int(g.user.get_id()): return Response(status=403) @@ -216,7 +217,7 @@ class TabStateView(BaseSupersetView): @has_access_api @expose("/", methods=["GET"]) - def get(self, tab_state_id): # pylint: disable=no-self-use + def get(self, tab_state_id: int) -> FlaskResponse: # pylint: disable=no-self-use if _get_owner_id(tab_state_id) != int(g.user.get_id()): return Response(status=403) @@ -229,7 +230,9 @@ class TabStateView(BaseSupersetView): @has_access_api @expose("/activate", methods=["POST"]) - def activate(self, tab_state_id): # pylint: disable=no-self-use + def activate( # pylint: disable=no-self-use + self, tab_state_id: int + ) -> FlaskResponse: owner_id = _get_owner_id(tab_state_id) if owner_id is None: return Response(status=404) @@ -246,7 +249,7 @@ class TabStateView(BaseSupersetView): @has_access_api @expose("", methods=["PUT"]) - def put(self, tab_state_id): # pylint: disable=no-self-use + def put(self, tab_state_id: int) -> FlaskResponse: # pylint: disable=no-self-use if _get_owner_id(tab_state_id) != int(g.user.get_id()): return Response(status=403) @@ -257,7 +260,9 @@ class TabStateView(BaseSupersetView): @has_access_api @expose("/migrate_query", methods=["POST"]) - def migrate_query(self, tab_state_id): # pylint: disable=no-self-use + def migrate_query( # pylint: disable=no-self-use + self, tab_state_id: int + ) -> FlaskResponse: if _get_owner_id(tab_state_id) != int(g.user.get_id()): return Response(status=403) @@ -270,7 +275,9 @@ class TabStateView(BaseSupersetView): @has_access_api @expose("/query/", methods=["DELETE"]) - def delete_query(self, tab_state_id, client_id): # pylint: disable=no-self-use + def delete_query( # pylint: disable=no-self-use + self, tab_state_id: str, client_id: str + ) -> FlaskResponse: db.session.query(Query).filter_by( client_id=client_id, user_id=g.user.get_id(), sql_editor_id=tab_state_id ).delete(synchronize_session=False) @@ -281,7 +288,7 @@ class TabStateView(BaseSupersetView): class TableSchemaView(BaseSupersetView): @has_access_api @expose("/", methods=["POST"]) - def post(self): # pylint: disable=no-self-use + def post(self) -> FlaskResponse: # pylint: disable=no-self-use table = json.loads(request.form["table"]) # delete any existing table schema @@ -306,7 +313,9 @@ class TableSchemaView(BaseSupersetView): @has_access_api @expose("/", methods=["DELETE"]) - def delete(self, table_schema_id): # pylint: disable=no-self-use + def delete( # pylint: disable=no-self-use + self, table_schema_id: int + ) -> FlaskResponse: db.session.query(TableSchema).filter(TableSchema.id == table_schema_id).delete( synchronize_session=False ) @@ -315,7 +324,9 @@ class TableSchemaView(BaseSupersetView): @has_access_api @expose("//expanded", methods=["POST"]) - def expanded(self, table_schema_id): # pylint: disable=no-self-use + def expanded( # pylint: disable=no-self-use + self, table_schema_id: int + ) -> FlaskResponse: payload = json.loads(request.form["expanded"]) ( db.session.query(TableSchema) @@ -332,6 +343,6 @@ class SqlLab(BaseSupersetView): @expose("/my_queries/") @has_access - def my_queries(self): # pylint: disable=no-self-use + def my_queries(self) -> FlaskResponse: # pylint: disable=no-self-use """Assigns a list of found users to the given role.""" return redirect("/savedqueryview/list/?_flt_0_user={}".format(g.user.id)) diff --git a/superset/views/tags.py b/superset/views/tags.py index e12df2a7c9..2bcc0c7c99 100644 --- a/superset/views/tags.py +++ b/superset/views/tags.py @@ -16,6 +16,8 @@ # under the License. from __future__ import absolute_import, division, print_function, unicode_literals +from typing import Any, Dict, List + import simplejson as json from flask import request, Response from flask_appbuilder import expose @@ -29,11 +31,12 @@ from superset.models.dashboard import Dashboard from superset.models.slice import Slice from superset.models.sql_lab import SavedQuery from superset.models.tags import ObjectTypes, Tag, TaggedObject, TagTypes +from superset.typing import FlaskResponse from .base import BaseSupersetView, json_success -def process_template(content): +def process_template(content: str) -> str: env = SandboxedEnvironment() template = env.from_string(content) context = { @@ -46,7 +49,7 @@ def process_template(content): class TagView(BaseSupersetView): @has_access_api @expose("/tags/suggestions/", methods=["GET"]) - def suggestions(self): # pylint: disable=no-self-use + def suggestions(self) -> FlaskResponse: # pylint: disable=no-self-use query = ( db.session.query(TaggedObject) .join(Tag) @@ -60,7 +63,9 @@ class TagView(BaseSupersetView): @has_access_api @expose("/tags///", methods=["GET"]) - def get(self, object_type, object_id): # pylint: disable=no-self-use + def get( # pylint: disable=no-self-use + self, object_type: ObjectTypes, object_id: int + ) -> FlaskResponse: """List all tags a given object has.""" if object_id == 0: return json_success(json.dumps([])) @@ -76,7 +81,9 @@ class TagView(BaseSupersetView): @has_access_api @expose("/tags///", methods=["POST"]) - def post(self, object_type, object_id): # pylint: disable=no-self-use + def post( # pylint: disable=no-self-use + self, object_type: ObjectTypes, object_id: int + ) -> FlaskResponse: """Add new tags to an object.""" if object_id == 0: return Response(status=404) @@ -104,7 +111,9 @@ class TagView(BaseSupersetView): @has_access_api @expose("/tags///", methods=["DELETE"]) - def delete(self, object_type, object_id): # pylint: disable=no-self-use + def delete( # pylint: disable=no-self-use + self, object_type: ObjectTypes, object_id: int + ) -> FlaskResponse: """Remove tags from an object.""" tag_names = request.get_json(force=True) if not tag_names: @@ -123,7 +132,7 @@ class TagView(BaseSupersetView): @has_access_api @expose("/tagged_objects/", methods=["GET", "POST"]) - def tagged_objects(self): # pylint: disable=no-self-use + def tagged_objects(self) -> FlaskResponse: # pylint: disable=no-self-use tags = [ process_template(tag) for tag in request.args.get("tags", "").split(",") @@ -135,7 +144,7 @@ class TagView(BaseSupersetView): # filter types types = [type_ for type_ in request.args.get("types", "").split(",") if type_] - results = [] + results: List[Dict[str, Any]] = [] # dashboards if not types or "dashboard" in types: diff --git a/superset/views/utils.py b/superset/views/utils.py index 5ed7e48383..4edd2e73ac 100644 --- a/superset/views/utils.py +++ b/superset/views/utils.py @@ -16,11 +16,12 @@ # under the License. from collections import defaultdict from datetime import date -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, DefaultDict, Dict, List, Optional, Set, Tuple from urllib import parse import simplejson as json from flask import g, request +from flask_appbuilder.security.sqla.models import User import superset.models.core as models from superset import app, db, is_feature_enabled @@ -29,7 +30,9 @@ from superset.exceptions import SupersetException from superset.legacy import update_time_range from superset.models.dashboard import Dashboard from superset.models.slice import Slice +from superset.typing import FormData from superset.utils.core import QueryStatus, TimeRangeEndpoint +from superset.viz import BaseViz if is_feature_enabled("SIP_38_VIZ_REARCHITECTURE"): from superset import viz_sip38 as viz # type: ignore @@ -42,7 +45,7 @@ if not app.config["ENABLE_JAVASCRIPT_CONTROLS"]: FORM_DATA_KEY_BLACKLIST = ["js_tooltip", "js_onclick_href", "js_data_mutator"] -def bootstrap_user_data(user, include_perms=False): +def bootstrap_user_data(user: User, include_perms: bool = False) -> Dict[str, Any]: if user.is_anonymous: return {} payload = { @@ -63,7 +66,9 @@ def bootstrap_user_data(user, include_perms=False): return payload -def get_permissions(user): +def get_permissions( + user: User, +) -> Tuple[Dict[str, List[List[str]]], DefaultDict[str, Set[str]]]: if not user.roles: raise AttributeError("User object does not have roles") @@ -86,11 +91,8 @@ def get_permissions(user): def get_viz( - form_data: Dict[str, Any], - datasource_type: str, - datasource_id: int, - force: bool = False, -): + form_data: FormData, datasource_type: str, datasource_id: int, force: bool = False, +) -> BaseViz: viz_type = form_data.get("viz_type", "table") datasource = ConnectorRegistry.get_datasource( datasource_type, datasource_id, db.session @@ -158,9 +160,7 @@ def get_form_data( def get_datasource_info( - datasource_id: Optional[int], - datasource_type: Optional[str], - form_data: Dict[str, Any], + datasource_id: Optional[int], datasource_type: Optional[str], form_data: FormData, ) -> Tuple[int, Optional[str]]: """ Compatibility layer for handling of datasource info @@ -222,9 +222,7 @@ def apply_display_max_row_limit( def get_time_range_endpoints( - form_data: Dict[str, Any], - slc: Optional[Slice] = None, - slice_id: Optional[int] = None, + form_data: FormData, slc: Optional[Slice] = None, slice_id: Optional[int] = None, ) -> Optional[Tuple[TimeRangeEndpoint, TimeRangeEndpoint]]: """ Get the slice aware time range endpoints from the form-data falling back to the SQL diff --git a/superset/viz.py b/superset/viz.py index 3fdd81ef08..d53dcf2313 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -525,7 +525,7 @@ class BaseViz: has_error = ( payload.get("status") == utils.QueryStatus.FAILED or payload.get("error") is not None - or len(payload.get("errors") or []) > 0 + or bool(payload.get("errors")) ) return self.json_dumps(payload), has_error diff --git a/superset/viz_sip38.py b/superset/viz_sip38.py index 78d15d9efb..32df001173 100644 --- a/superset/viz_sip38.py +++ b/superset/viz_sip38.py @@ -56,7 +56,7 @@ from superset.exceptions import ( SpatialException, ) from superset.models.helpers import QueryResult -from superset.typing import VizData +from superset.typing import QueryObjectDict, VizData, VizPayload from superset.utils import core as utils from superset.utils.core import ( DTTM_ALIAS, @@ -251,7 +251,7 @@ class BaseViz: df = df[min_periods:] return df - def get_samples(self): + def get_samples(self) -> List[Dict[str, Any]]: query_obj = self.query_obj() query_obj.update( { @@ -452,7 +452,7 @@ class BaseViz: json_data = self.json_dumps(cache_dict, sort_keys=True) return hashlib.md5(json_data.encode("utf-8")).hexdigest() - def get_payload(self, query_obj=None): + def get_payload(self, query_obj: Optional[QueryObjectDict] = None) -> VizPayload: """Returns a payload of metadata and data""" self.run_extra_queries() payload = self.get_df_payload(query_obj) @@ -464,7 +464,9 @@ class BaseViz: del payload["df"] return payload - def get_df_payload(self, query_obj=None, **kwargs): + def get_df_payload( + self, query_obj: Optional[QueryObjectDict] = None, **kwargs: Any + ) -> Dict[str, Any]: """Handles caching around the df payload retrieval""" if not query_obj: query_obj = self.query_obj() @@ -559,11 +561,11 @@ class BaseViz: obj, default=utils.json_int_dttm_ser, ignore_nan=True, sort_keys=sort_keys ) - def payload_json_and_has_error(self, payload): + def payload_json_and_has_error(self, payload: VizPayload) -> Tuple[str, bool]: has_error = ( payload.get("status") == utils.QueryStatus.FAILED or payload.get("error") is not None - or len(payload.get("errors")) > 0 + or len(payload.get("errors", [])) > 0 ) return self.json_dumps(payload), has_error @@ -578,7 +580,7 @@ class BaseViz: } return content - def get_csv(self): + def get_csv(self) -> Optional[str]: df = self.get_df() include_index = not isinstance(df.index, pd.RangeIndex) return df.to_csv(index=include_index, **config["CSV_EXPORT"])