style(mypy): Enforcing typing for superset.views (#9939)

Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
John Bodley 2020-06-05 08:44:11 -07:00 committed by GitHub
parent 5c4d4f16b3
commit 63e0188f45
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 440 additions and 340 deletions

View File

@ -53,7 +53,7 @@ order_by_type = false
ignore_missing_imports = true
no_implicit_optional = true
[mypy-superset,superset.app,superset.bin.*,superset.charts.*,superset.cli,superset.commands.*,superset.common.*,superset.config,superset.connectors.*,superset.constants,superset.dataframe,superset.dao.*,superset.dashboards.*,superset.datasets.*,superset.db_engine_specs.*,superset.db_engines.*,superset.errors,superset.examples.*,superset.exceptions,superset.extensions,superset.forms,superset.jinja_context,superset.legacy,superset.migrations.*,superset.models.*,superset.result_set,superset.queries.*,superset.security.*,superset.sql_lab,superset.sql_parse,superset.sql_validators.*,superset.stats_logger,superset.tasks.*,superset.translations.*,superset.typing,superset.utils.*,superset.views.chart.*,superset.views.dashboard.*,superset.views.database.*,superset.viz,superset.viz_sip38]
[mypy-superset,superset.app,superset.bin.*,superset.charts.*,superset.cli,superset.commands.*,superset.common.*,superset.config,superset.connectors.*,superset.constants,superset.dataframe,superset.dao.*,superset.dashboards.*,superset.datasets.*,superset.db_engine_specs.*,superset.db_engines.*,superset.errors,superset.examples.*,superset.exceptions,superset.extensions,superset.forms,superset.jinja_context,superset.legacy,superset.migrations.*,superset.models.*,superset.result_set,superset.queries.*,superset.security.*,superset.sql_lab,superset.sql_parse,superset.sql_validators.*,superset.stats_logger,superset.tasks.*,superset.translations.*,superset.typing,superset.utils.*,,superset.views.*,superset.viz,superset.viz_sip38]
check_untyped_defs = true
disallow_untyped_calls = true
disallow_untyped_defs = true

View File

@ -62,14 +62,26 @@ class BaseDatasource(
# ---------------------------------------------------------------
__tablename__: Optional[str] = None # {connector_name}_datasource
baselink: Optional[str] = None # url portion pointing to ModelView endpoint
column_class: Optional[Type] = None # link to derivative of BaseColumn
metric_class: Optional[Type] = None # link to derivative of BaseMetric
@property
def column_class(self) -> Type:
# link to derivative of BaseColumn
raise NotImplementedError()
@property
def metric_class(self) -> Type:
# link to derivative of BaseMetric
raise NotImplementedError()
owner_class: Optional[User] = None
# Used to do code highlighting when displaying the query in the UI
query_language: Optional[str] = None
name = None # can be a Column or a property pointing to one
@property
def name(self) -> str:
# can be a Column or a property pointing to one
raise NotImplementedError()
# ---------------------------------------------------------------

View File

@ -548,7 +548,7 @@ class DruidDatasource(Model, BaseDatasource):
return [c.column_name for c in self.columns if c.is_numeric]
@property
def name(self) -> str: # type: ignore
def name(self) -> str:
return self.datasource_name
@property

View File

@ -531,7 +531,7 @@ class SqlaTable(Model, BaseDatasource):
return ("[{obj.database}].[{obj.table_name}]" "(id:{obj.id})").format(obj=self)
@property
def name(self) -> str: # type: ignore
def name(self) -> str:
if not self.schema:
return self.table_name
return "{}.{}".format(self.schema, self.table_name)

View File

@ -19,6 +19,8 @@
from typing import Any, Dict, List, Optional
from superset.models.core import Database
class SQLValidationAnnotation:
"""Represents a single annotation (error/warning) in an SQL querytext"""
@ -35,7 +37,7 @@ class SQLValidationAnnotation:
self.start_column = start_column
self.end_column = end_column
def to_dict(self) -> Dict:
def to_dict(self) -> Dict[str, Any]:
"""Return a dictionary representation of this annotation"""
return {
"line_number": self.line_number,
@ -53,7 +55,7 @@ class BaseSQLValidator:
@classmethod
def validate(
cls, sql: str, schema: str, database: Any
cls, sql: str, schema: Optional[str], database: Database
) -> List[SQLValidationAnnotation]:
"""Check that the given SQL querystring is valid for the given engine"""
raise NotImplementedError

View File

@ -143,7 +143,7 @@ class PrestoDBSQLValidator(BaseSQLValidator):
@classmethod
def validate(
cls, sql: str, schema: str, database: Any
cls, sql: str, schema: Optional[str], database: Database
) -> List[SQLValidationAnnotation]:
"""
Presto supports query-validation queries by running them with a

View File

@ -225,9 +225,11 @@ def deliver_dashboard(schedule: DashboardEmailSchedule) -> None:
"""
dashboard = schedule.dashboard
dashboard_url = _get_url_path("Superset.dashboard", dashboard_id=dashboard.id)
dashboard_url = _get_url_path(
"Superset.dashboard", dashboard_id_or_slug=dashboard.id
)
dashboard_url_user_friendly = _get_url_path(
"Superset.dashboard", user_friendly=True, dashboard_id=dashboard.id
"Superset.dashboard", user_friendly=True, dashboard_id_or_slug=dashboard.id
)
# Create a driver, fetch the page, wait for the page to render

View File

@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict
from flask_appbuilder import CompactCRUDMixin
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_babel import lazy_gettext as _
@ -30,7 +32,7 @@ class StartEndDttmValidator: # pylint: disable=too-few-public-methods
Validates dttm fields.
"""
def __call__(self, form, field):
def __call__(self, form: Dict[str, Any], field: Any) -> None:
if not form["start_dttm"].data and not form["end_dttm"].data:
raise StopValidation(_("annotation start time or end time is required."))
elif (
@ -82,13 +84,13 @@ class AnnotationModelView(
validators_columns = {"start_dttm": [StartEndDttmValidator()]}
def pre_add(self, item):
def pre_add(self, item: "AnnotationModelView") -> None:
if not item.start_dttm:
item.start_dttm = item.end_dttm
elif not item.end_dttm:
item.end_dttm = item.start_dttm
def pre_update(self, item):
def pre_update(self, item: "AnnotationModelView") -> None:
self.pre_add(item)

View File

@ -24,6 +24,7 @@ from superset import db, event_logger, security_manager
from superset.common.query_context import QueryContext
from superset.legacy import update_time_range
from superset.models.slice import Slice
from superset.typing import FlaskResponse
from superset.utils import core as utils
from superset.views.base import api, BaseSupersetView, handle_api_exception
@ -34,13 +35,13 @@ class Api(BaseSupersetView):
@handle_api_exception
@has_access_api
@expose("/v1/query/", methods=["POST"])
def query(self):
def query(self) -> FlaskResponse:
"""
Takes a query_obj constructed in the client and returns payload data response
for the given query_obj.
params: query_context: json_blob
"""
query_context = QueryContext(**json.loads(request.form.get("query_context")))
query_context = QueryContext(**json.loads(request.form["query_context"]))
security_manager.assert_query_context_permission(query_context)
payload_json = query_context.get_payload()
return json.dumps(
@ -52,7 +53,7 @@ class Api(BaseSupersetView):
@handle_api_exception
@has_access_api
@expose("/v1/form_data/", methods=["GET"])
def query_form_data(self):
def query_form_data(self) -> FlaskResponse:
"""
Get the formdata stored in the database for existing slice.
params: slice_id: integer

View File

@ -18,13 +18,13 @@ import functools
import logging
import traceback
from datetime import datetime
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union
from typing import Any, Callable, cast, Dict, List, Optional, TYPE_CHECKING, Union
import dataclasses
import simplejson as json
import yaml
from flask import abort, flash, g, get_flashed_messages, redirect, Response, session
from flask_appbuilder import BaseView, ModelView
from flask_appbuilder import BaseView, Model, ModelView
from flask_appbuilder.actions import action
from flask_appbuilder.forms import DynamicForm
from flask_appbuilder.models.sqla.filters import BaseFilter
@ -33,7 +33,9 @@ from flask_appbuilder.widgets import ListWidget
from flask_babel import get_locale, gettext as __, lazy_gettext as _
from flask_wtf.form import FlaskForm
from sqlalchemy import or_
from sqlalchemy.orm import Query
from werkzeug.exceptions import HTTPException
from wtforms import Form
from wtforms.fields.core import Field, UnboundField
from superset import (
@ -47,6 +49,7 @@ from superset import (
from superset.connectors.sqla import models
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetException, SupersetSecurityException
from superset.models.helpers import ImportMixin
from superset.translations.utils import get_language_pack
from superset.typing import FlaskResponse
from superset.utils import core as utils
@ -93,7 +96,7 @@ def json_error_response(
status: int = 500,
payload: Optional[Dict[str, Any]] = None,
link: Optional[str] = None,
) -> Response:
) -> FlaskResponse:
if not payload:
payload = {"error": "{}".format(msg)}
if link:
@ -110,7 +113,7 @@ def json_errors_response(
errors: List[SupersetError],
status: int = 500,
payload: Optional[Dict[str, Any]] = None,
) -> Response:
) -> FlaskResponse:
if not payload:
payload = {}
@ -122,11 +125,11 @@ def json_errors_response(
)
def json_success(json_msg: str, status: int = 200) -> Response:
def json_success(json_msg: str, status: int = 200) -> FlaskResponse:
return Response(json_msg, status=status, mimetype="application/json")
def data_payload_response(payload_json: str, has_error: bool = False) -> Response:
def data_payload_response(payload_json: str, has_error: bool = False) -> FlaskResponse:
status = 400 if has_error else 200
return json_success(payload_json, status=status)
@ -140,13 +143,13 @@ def generate_download_headers(
return headers
def api(f):
def api(f: Callable) -> Callable:
"""
A decorator to label an endpoint as an API. Catches uncaught exceptions and
return the response in the JSON format
"""
def wraps(self, *args, **kwargs):
def wraps(self: "BaseSupersetView", *args: Any, **kwargs: Any) -> FlaskResponse:
try:
return f(self, *args, **kwargs)
except Exception as ex: # pylint: disable=broad-except
@ -156,14 +159,16 @@ def api(f):
return functools.update_wrapper(wraps, f)
def handle_api_exception(f):
def handle_api_exception(
f: Callable[..., FlaskResponse]
) -> Callable[..., FlaskResponse]:
"""
A decorator to catch superset exceptions. Use it after the @api decorator above
so superset exception handler is triggered before the handler for generic
exceptions.
"""
def wraps(self, *args, **kwargs):
def wraps(self: "BaseSupersetView", *args: Any, **kwargs: Any) -> FlaskResponse:
try:
return f(self, *args, **kwargs)
except SupersetSecurityException as ex:
@ -179,7 +184,7 @@ def handle_api_exception(f):
except HTTPException as ex:
logger.exception(ex)
return json_error_response(
utils.error_msg_from_exception(ex), status=ex.code
utils.error_msg_from_exception(ex), status=cast(int, ex.code)
)
except Exception as ex: # pylint: disable=broad-except
logger.exception(ex)
@ -233,7 +238,9 @@ def get_user_roles() -> List[Role]:
class BaseSupersetView(BaseView):
@staticmethod
def json_response(obj, status=200) -> Response: # pylint: disable=no-self-use
def json_response(
obj: Any, status: int = 200
) -> FlaskResponse: # pylint: disable=no-self-use
return Response(
json.dumps(obj, default=utils.json_int_dttm_ser, ignore_nan=True),
status=status,
@ -241,7 +248,7 @@ class BaseSupersetView(BaseView):
)
def menu_data():
def menu_data() -> Dict[str, Any]:
menu = appbuilder.menu.get_data()
root_path = "#"
logo_target_path = ""
@ -290,7 +297,7 @@ def menu_data():
}
def common_bootstrap_payload():
def common_bootstrap_payload() -> Dict[str, Any]:
"""Common data always sent to the client"""
messages = get_flashed_messages(with_categories=True)
locale = str(get_locale())
@ -335,7 +342,7 @@ class ListWidgetWithCheckboxes(ListWidget): # pylint: disable=too-few-public-me
template = "superset/fab_overrides/list_with_checkboxes.html"
def validate_json(_form, field):
def validate_json(form: Form, field: Field) -> None: # pylint: disable=unused-argument
try:
json.loads(field.data)
except Exception as ex:
@ -352,24 +359,23 @@ class YamlExportMixin: # pylint: disable=too-few-public-methods
yaml_dict_key: Optional[str] = None
@action("yaml_export", __("Export to YAML"), __("Export to YAML?"), "fa-download")
def yaml_export(self, items):
def yaml_export(
self, items: Union[ImportMixin, List[ImportMixin]]
) -> FlaskResponse:
if not isinstance(items, list):
items = [items]
data = [t.export_to_dict() for t in items]
if self.yaml_dict_key:
data = {self.yaml_dict_key: data}
return Response(
yaml.safe_dump(data),
yaml.safe_dump({self.yaml_dict_key: data} if self.yaml_dict_key else data),
headers=generate_download_headers("yaml"),
mimetype="application/text",
)
class DeleteMixin: # pylint: disable=too-few-public-methods
def _delete(
self: Union[BaseView, "DeleteMixin", "DruidClusterModelView"], primary_key: int,
) -> None:
def _delete(self: BaseView, primary_key: int,) -> None:
"""
Delete function logic, override to implement diferent logic
deletes the record with primary_key = primary_key
@ -411,7 +417,7 @@ class DeleteMixin: # pylint: disable=too-few-public-methods
@action(
"muldelete", __("Delete"), __("Delete all Really?"), "fa-trash", single=False
)
def muldelete(self, items):
def muldelete(self: BaseView, items: List[Model]) -> FlaskResponse:
if not items:
abort(404)
for item in items:
@ -426,7 +432,7 @@ class DeleteMixin: # pylint: disable=too-few-public-methods
class DatasourceFilter(BaseFilter): # pylint: disable=too-few-public-methods
def apply(self, query, value):
def apply(self, query: Query, value: Any) -> Query:
if security_manager.all_datasource_access():
return query
datasource_perms = security_manager.user_view_menu_names("datasource_access")
@ -497,7 +503,7 @@ def check_ownership(obj: Any, raise_if_false: bool = True) -> bool:
def bind_field(
_, form: DynamicForm, unbound_field: UnboundField, options: Dict[Any, Any]
_: Any, form: DynamicForm, unbound_field: UnboundField, options: Dict[Any, Any]
) -> Field:
"""
Customize how fields are bound by stripping all whitespace.

View File

@ -16,17 +16,18 @@
# under the License.
import functools
import logging
from typing import Any, cast, Dict, Optional, Set, Tuple, Type, Union
from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Type, Union
from apispec import APISpec
from flask import Response
from flask_appbuilder import ModelRestApi
from flask import Blueprint, Response
from flask_appbuilder import AppBuilder, Model, ModelRestApi
from flask_appbuilder.api import expose, protect, rison, safe
from flask_appbuilder.models.filters import BaseFilter, Filters
from flask_appbuilder.models.sqla.filters import FilterStartsWith
from marshmallow import Schema
from superset.stats_logger import BaseStatsLogger
from superset.typing import FlaskResponse
from superset.utils.core import time_function
logger = logging.getLogger(__name__)
@ -40,12 +41,12 @@ get_related_schema = {
}
def statsd_metrics(f):
def statsd_metrics(f: Callable) -> Callable:
"""
Handle sending all statsd metrics from the REST API
"""
def wraps(self, *args: Any, **kwargs: Any) -> Response:
def wraps(self: "BaseSupersetModelRestApi", *args: Any, **kwargs: Any) -> Response:
duration, response = time_function(f, self, *args, **kwargs)
self.send_stats_metrics(response, f.__name__, duration)
return response
@ -116,6 +117,11 @@ class BaseSupersetModelRestApi(ModelRestApi):
Add extra schemas to the OpenAPI component schemas section
""" # pylint: disable=pointless-string-statement
add_columns: List[str]
edit_columns: List[str]
list_columns: List[str]
show_columns: List[str]
def __init__(self) -> None:
super().__init__()
self.stats_logger = BaseStatsLogger()
@ -128,11 +134,13 @@ class BaseSupersetModelRestApi(ModelRestApi):
)
super().add_apispec_components(api_spec)
def create_blueprint(self, appbuilder, *args, **kwargs):
def create_blueprint(
self, appbuilder: AppBuilder, *args: Any, **kwargs: Any
) -> Blueprint:
self.stats_logger = self.appbuilder.get_app.config["STATS_LOGGER"]
return super().create_blueprint(appbuilder, *args, **kwargs)
def _init_properties(self):
def _init_properties(self) -> None:
model_id = self.datamodel.get_pk_name()
if self.list_columns is None and not self.list_model_schema:
self.list_columns = [model_id]
@ -144,7 +152,9 @@ class BaseSupersetModelRestApi(ModelRestApi):
self.add_columns = [model_id]
super()._init_properties()
def _get_related_filter(self, datamodel, column_name: str, value: str) -> Filters:
def _get_related_filter(
self, datamodel: Model, column_name: str, value: str
) -> Filters:
filter_field = self.related_field_filters.get(column_name)
if isinstance(filter_field, str):
filter_field = RelatedFieldFilter(cast(str, filter_field), FilterStartsWith)
@ -198,7 +208,7 @@ class BaseSupersetModelRestApi(ModelRestApi):
if time_delta:
self.timing_stats("time", key, time_delta)
def info_headless(self, **kwargs) -> Response:
def info_headless(self, **kwargs: Any) -> Response:
"""
Add statsd metrics to builtin FAB _info endpoint
"""
@ -206,7 +216,7 @@ class BaseSupersetModelRestApi(ModelRestApi):
self.send_stats_metrics(response, self.info.__name__, duration)
return response
def get_headless(self, pk, **kwargs) -> Response:
def get_headless(self, pk: int, **kwargs: Any) -> Response:
"""
Add statsd metrics to builtin FAB GET endpoint
"""
@ -214,7 +224,7 @@ class BaseSupersetModelRestApi(ModelRestApi):
self.send_stats_metrics(response, self.get.__name__, duration)
return response
def get_list_headless(self, **kwargs) -> Response:
def get_list_headless(self, **kwargs: Any) -> Response:
"""
Add statsd metrics to builtin FAB GET list endpoint
"""
@ -227,7 +237,7 @@ class BaseSupersetModelRestApi(ModelRestApi):
@safe
@statsd_metrics
@rison(get_related_schema)
def related(self, column_name: str, **kwargs):
def related(self, column_name: str, **kwargs: Any) -> FlaskResponse:
"""Get related fields data
---
get:

View File

@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Dict, List, Optional
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Set, Union
from flask import current_app, g
from flask_appbuilder import Model
@ -22,7 +22,7 @@ from marshmallow import post_load, pre_load, Schema, ValidationError
from sqlalchemy.orm.exc import NoResultFound
def validate_owner(value):
def validate_owner(value: int) -> None:
try:
(
current_app.appbuilder.get_session.query(
@ -44,18 +44,25 @@ class BaseSupersetSchema(Schema):
__class_model__: Model = None
def __init__(self, **kwargs):
def __init__(self, **kwargs: Any) -> None:
self.instance: Optional[Model] = None
super().__init__(**kwargs)
def load(
self, data, many=None, partial=None, instance: Model = None, **kwargs
): # pylint: disable=arguments-differ
def load( # pylint: disable=arguments-differ
self,
data: Union[Mapping[str, Any], Iterable[Mapping[str, Any]]],
many: Optional[bool] = None,
partial: Optional[Union[bool, Sequence[str], Set[str]]] = None,
instance: Optional[Model] = None,
**kwargs: Any,
) -> Any:
self.instance = instance
return super().load(data, many=many, partial=partial, **kwargs)
@post_load
def make_object(self, data: Dict, discard: Optional[List[str]] = None) -> Model:
def make_object(
self, data: Dict[Any, Any], discard: Optional[List[str]] = None
) -> Model:
"""
Creates a Model object from POST or PUT requests. PUT will use self.instance
previously fetched from the endpoint handler
@ -92,13 +99,13 @@ class BaseOwnedSchema(BaseSupersetSchema):
return instance
@pre_load
def pre_load(self, data: Dict):
def pre_load(self, data: Dict[Any, Any]) -> None:
# if PUT request don't set owners to empty list
if not self.instance:
data[self.owners_field_name] = data.get(self.owners_field_name, [])
@staticmethod
def set_owners(instance: Model, owners: List[int]):
def set_owners(instance: Model, owners: List[int]) -> None:
owner_objs = list()
if g.user.id not in owners:
owners.append(g.user.id)

File diff suppressed because it is too large Load Diff

View File

@ -17,7 +17,7 @@
import json
from collections import Counter
from flask import request, Response
from flask import request
from flask_appbuilder import expose
from flask_appbuilder.security.decorators import has_access_api
from sqlalchemy.orm.exc import NoResultFound
@ -25,6 +25,7 @@ from sqlalchemy.orm.exc import NoResultFound
from superset import db
from superset.connectors.connector_registry import ConnectorRegistry
from superset.models.core import Database
from superset.typing import FlaskResponse
from .base import api, BaseSupersetView, handle_api_exception, json_error_response
@ -36,7 +37,7 @@ class Datasource(BaseSupersetView):
@has_access_api
@api
@handle_api_exception
def save(self) -> Response:
def save(self) -> FlaskResponse:
data = request.form.get("data")
if not isinstance(data, str):
return json_error_response("Request missing data field.", status=500)
@ -78,7 +79,7 @@ class Datasource(BaseSupersetView):
@has_access_api
@api
@handle_api_exception
def get(self, datasource_type: str, datasource_id: int) -> Response:
def get(self, datasource_type: str, datasource_id: int) -> FlaskResponse:
try:
orm_datasource = ConnectorRegistry.get_datasource(
datasource_type, datasource_id, db.session
@ -95,7 +96,9 @@ class Datasource(BaseSupersetView):
@has_access_api
@api
@handle_api_exception
def external_metadata(self, datasource_type: str, datasource_id: int) -> Response:
def external_metadata(
self, datasource_type: str, datasource_id: int
) -> FlaskResponse:
"""Gets column info from the source system"""
if datasource_type == "druid":
datasource = ConnectorRegistry.get_datasource(

View File

@ -14,9 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, cast, Optional
from flask_appbuilder.models.filters import BaseFilter
from flask_babel import lazy_gettext
from sqlalchemy import or_
from sqlalchemy.orm import Query
from superset import security_manager
@ -36,9 +39,9 @@ class FilterRelatedOwners(BaseFilter):
name = lazy_gettext("Owner")
arg_name = "owners"
def apply(self, query, value):
def apply(self, query: Query, value: Optional[Any]) -> Query:
user_model = security_manager.user_model
like_value = "%" + value + "%"
like_value = "%" + cast(str, value) + "%"
return query.filter(
or_(
# could be made to handle spaces between names more gracefully

View File

@ -23,8 +23,8 @@ class LogMixin: # pylint: disable=too-few-public-methods
add_title = _("Add Log")
edit_title = _("Edit Log")
list_columns = ("user", "action", "dttm")
edit_columns = ("user", "action", "dttm", "json")
list_columns = ["user", "action", "dttm"]
edit_columns = ["user", "action", "dttm", "json"]
base_order = ("dttm", "desc")
label_columns = {
"user": _("User"),

View File

@ -28,5 +28,5 @@ class LogRestApi(LogMixin, BaseSupersetModelRestApi):
class_permission_name = "LogModelView"
resource_name = "log"
allow_browser_login = True
list_columns = ("user.username", "action", "dttm")
list_columns = ["user.username", "action", "dttm"]
show_columns = list_columns

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import enum
from typing import Optional, Type
from typing import Type
import simplejson as json
from croniter import croniter
@ -24,7 +24,7 @@ from flask_appbuilder import expose
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_appbuilder.security.decorators import has_access
from flask_babel import lazy_gettext as _
from wtforms import BooleanField, StringField
from wtforms import BooleanField, Form, StringField
from superset import db, security_manager
from superset.constants import RouteMethod
@ -37,6 +37,7 @@ from superset.models.schedules import (
)
from superset.models.slice import Slice
from superset.tasks.schedules import schedule_email_report
from superset.typing import FlaskResponse
from superset.utils.core import get_email_address_list, json_iso_dttm_ser
from superset.views.core import json_success
@ -48,8 +49,14 @@ class EmailScheduleView(
): # pylint: disable=too-many-ancestors
include_route_methods = RouteMethod.CRUD_SET
_extra_data = {"test_email": False, "test_email_recipients": None}
schedule_type: Optional[str] = None
schedule_type_model: Optional[Type] = None
@property
def schedule_type(self) -> str:
raise NotImplementedError()
@property
def schedule_type_model(self) -> Type:
raise NotImplementedError()
page_size = 20
@ -87,7 +94,7 @@ class EmailScheduleView(
edit_form_extra_fields = add_form_extra_fields
def process_form(self, form, is_created):
def process_form(self, form: Form, is_created: bool) -> None:
if form.test_email_recipients.data:
test_email_recipients = form.test_email_recipients.data.strip()
else:
@ -95,7 +102,7 @@ class EmailScheduleView(
self._extra_data["test_email"] = form.test_email.data
self._extra_data["test_email_recipients"] = test_email_recipients
def pre_add(self, item):
def pre_add(self, item: "EmailScheduleView") -> None:
try:
recipients = get_email_address_list(item.recipients)
item.recipients = ", ".join(recipients)
@ -106,10 +113,10 @@ class EmailScheduleView(
if not croniter.is_valid(item.crontab):
raise SupersetException("Invalid crontab format")
def pre_update(self, item):
def pre_update(self, item: "EmailScheduleView") -> None:
self.pre_add(item)
def post_add(self, item):
def post_add(self, item: "EmailScheduleView") -> None:
# Schedule a test mail if the user requested for it.
if self._extra_data["test_email"]:
recipients = self._extra_data["test_email_recipients"] or item.recipients
@ -122,12 +129,12 @@ class EmailScheduleView(
if item.active:
flash("Schedule changes will get applied in one hour", "warning")
def post_update(self, item):
def post_update(self, item: "EmailScheduleView") -> None:
self.post_add(item)
@has_access
@expose("/fetch/<int:item_id>/", methods=["GET"])
def fetch_schedules(self, item_id):
def fetch_schedules(self, item_id: int) -> FlaskResponse:
query = db.session.query(self.datamodel.obj)
query = query.join(self.schedule_type_model).filter(
@ -147,7 +154,9 @@ class EmailScheduleView(
info[col] = info[col].username
info["user"] = schedule.user.username
info[self.schedule_type] = getattr(schedule, self.schedule_type).id
info[self.schedule_type] = getattr( # type: ignore
schedule, self.schedule_type
).id
schedules.append(info)
return json_success(json.dumps(schedules, default=json_iso_dttm_ser))
@ -208,7 +217,7 @@ class DashboardEmailScheduleView(
"delivery_type": _("Delivery Type"),
}
def pre_add(self, item):
def pre_add(self, item: "DashboardEmailScheduleView") -> None:
if item.dashboard is None:
raise SupersetException("Dashboard is mandatory")
super(DashboardEmailScheduleView, self).pre_add(item)
@ -269,7 +278,7 @@ class SliceEmailScheduleView(EmailScheduleView): # pylint: disable=too-many-anc
"email_format": _("Email Format"),
}
def pre_add(self, item):
def pre_add(self, item: "SliceEmailScheduleView") -> None:
if item.slice is None:
raise SupersetException("Slice is mandatory")
super(SliceEmailScheduleView, self).pre_add(item)

View File

@ -27,6 +27,7 @@ from flask_sqlalchemy import BaseQuery
from superset import db, get_feature_flags, security_manager
from superset.constants import RouteMethod
from superset.models.sql_lab import Query, SavedQuery, TableSchema, TabState
from superset.typing import FlaskResponse
from superset.utils import core as utils
from .base import (
@ -120,15 +121,15 @@ class SavedQueryView(
show_template = "superset/models/savedquery/show.html"
def pre_add(self, item):
def pre_add(self, item: "SavedQueryView") -> None:
item.user = g.user
def pre_update(self, item):
def pre_update(self, item: "SavedQueryView") -> None:
self.pre_add(item)
@has_access
@expose("show/<pk>")
def show(self, pk):
def show(self, pk: int) -> FlaskResponse:
pk = self._deserialize_pk_if_composite(pk)
widgets = self._show(pk)
query = self.datamodel.get(pk).to_json()
@ -168,18 +169,18 @@ class SavedQueryViewApi(SavedQueryView): # pylint: disable=too-many-ancestors
@has_access_api
@expose("show/<pk>")
def show(self, pk):
def show(self, pk: int) -> FlaskResponse:
return super().show(pk)
def _get_owner_id(tab_state_id):
def _get_owner_id(tab_state_id: int) -> int:
return db.session.query(TabState.user_id).filter_by(id=tab_state_id).scalar()
class TabStateView(BaseSupersetView):
@has_access_api
@expose("/", methods=["POST"])
def post(self): # pylint: disable=no-self-use
def post(self) -> FlaskResponse: # pylint: disable=no-self-use
query_editor = json.loads(request.form["queryEditor"])
tab_state = TabState(
user_id=g.user.get_id(),
@ -201,7 +202,7 @@ class TabStateView(BaseSupersetView):
@has_access_api
@expose("/<int:tab_state_id>", methods=["DELETE"])
def delete(self, tab_state_id): # pylint: disable=no-self-use
def delete(self, tab_state_id: int) -> FlaskResponse: # pylint: disable=no-self-use
if _get_owner_id(tab_state_id) != int(g.user.get_id()):
return Response(status=403)
@ -216,7 +217,7 @@ class TabStateView(BaseSupersetView):
@has_access_api
@expose("/<int:tab_state_id>", methods=["GET"])
def get(self, tab_state_id): # pylint: disable=no-self-use
def get(self, tab_state_id: int) -> FlaskResponse: # pylint: disable=no-self-use
if _get_owner_id(tab_state_id) != int(g.user.get_id()):
return Response(status=403)
@ -229,7 +230,9 @@ class TabStateView(BaseSupersetView):
@has_access_api
@expose("<int:tab_state_id>/activate", methods=["POST"])
def activate(self, tab_state_id): # pylint: disable=no-self-use
def activate( # pylint: disable=no-self-use
self, tab_state_id: int
) -> FlaskResponse:
owner_id = _get_owner_id(tab_state_id)
if owner_id is None:
return Response(status=404)
@ -246,7 +249,7 @@ class TabStateView(BaseSupersetView):
@has_access_api
@expose("<int:tab_state_id>", methods=["PUT"])
def put(self, tab_state_id): # pylint: disable=no-self-use
def put(self, tab_state_id: int) -> FlaskResponse: # pylint: disable=no-self-use
if _get_owner_id(tab_state_id) != int(g.user.get_id()):
return Response(status=403)
@ -257,7 +260,9 @@ class TabStateView(BaseSupersetView):
@has_access_api
@expose("<int:tab_state_id>/migrate_query", methods=["POST"])
def migrate_query(self, tab_state_id): # pylint: disable=no-self-use
def migrate_query( # pylint: disable=no-self-use
self, tab_state_id: int
) -> FlaskResponse:
if _get_owner_id(tab_state_id) != int(g.user.get_id()):
return Response(status=403)
@ -270,7 +275,9 @@ class TabStateView(BaseSupersetView):
@has_access_api
@expose("<int:tab_state_id>/query/<client_id>", methods=["DELETE"])
def delete_query(self, tab_state_id, client_id): # pylint: disable=no-self-use
def delete_query( # pylint: disable=no-self-use
self, tab_state_id: str, client_id: str
) -> FlaskResponse:
db.session.query(Query).filter_by(
client_id=client_id, user_id=g.user.get_id(), sql_editor_id=tab_state_id
).delete(synchronize_session=False)
@ -281,7 +288,7 @@ class TabStateView(BaseSupersetView):
class TableSchemaView(BaseSupersetView):
@has_access_api
@expose("/", methods=["POST"])
def post(self): # pylint: disable=no-self-use
def post(self) -> FlaskResponse: # pylint: disable=no-self-use
table = json.loads(request.form["table"])
# delete any existing table schema
@ -306,7 +313,9 @@ class TableSchemaView(BaseSupersetView):
@has_access_api
@expose("/<int:table_schema_id>", methods=["DELETE"])
def delete(self, table_schema_id): # pylint: disable=no-self-use
def delete( # pylint: disable=no-self-use
self, table_schema_id: int
) -> FlaskResponse:
db.session.query(TableSchema).filter(TableSchema.id == table_schema_id).delete(
synchronize_session=False
)
@ -315,7 +324,9 @@ class TableSchemaView(BaseSupersetView):
@has_access_api
@expose("/<int:table_schema_id>/expanded", methods=["POST"])
def expanded(self, table_schema_id): # pylint: disable=no-self-use
def expanded( # pylint: disable=no-self-use
self, table_schema_id: int
) -> FlaskResponse:
payload = json.loads(request.form["expanded"])
(
db.session.query(TableSchema)
@ -332,6 +343,6 @@ class SqlLab(BaseSupersetView):
@expose("/my_queries/")
@has_access
def my_queries(self): # pylint: disable=no-self-use
def my_queries(self) -> FlaskResponse: # pylint: disable=no-self-use
"""Assigns a list of found users to the given role."""
return redirect("/savedqueryview/list/?_flt_0_user={}".format(g.user.id))

View File

@ -16,6 +16,8 @@
# under the License.
from __future__ import absolute_import, division, print_function, unicode_literals
from typing import Any, Dict, List
import simplejson as json
from flask import request, Response
from flask_appbuilder import expose
@ -29,11 +31,12 @@ from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.models.sql_lab import SavedQuery
from superset.models.tags import ObjectTypes, Tag, TaggedObject, TagTypes
from superset.typing import FlaskResponse
from .base import BaseSupersetView, json_success
def process_template(content):
def process_template(content: str) -> str:
env = SandboxedEnvironment()
template = env.from_string(content)
context = {
@ -46,7 +49,7 @@ def process_template(content):
class TagView(BaseSupersetView):
@has_access_api
@expose("/tags/suggestions/", methods=["GET"])
def suggestions(self): # pylint: disable=no-self-use
def suggestions(self) -> FlaskResponse: # pylint: disable=no-self-use
query = (
db.session.query(TaggedObject)
.join(Tag)
@ -60,7 +63,9 @@ class TagView(BaseSupersetView):
@has_access_api
@expose("/tags/<object_type:object_type>/<int:object_id>/", methods=["GET"])
def get(self, object_type, object_id): # pylint: disable=no-self-use
def get( # pylint: disable=no-self-use
self, object_type: ObjectTypes, object_id: int
) -> FlaskResponse:
"""List all tags a given object has."""
if object_id == 0:
return json_success(json.dumps([]))
@ -76,7 +81,9 @@ class TagView(BaseSupersetView):
@has_access_api
@expose("/tags/<object_type:object_type>/<int:object_id>/", methods=["POST"])
def post(self, object_type, object_id): # pylint: disable=no-self-use
def post( # pylint: disable=no-self-use
self, object_type: ObjectTypes, object_id: int
) -> FlaskResponse:
"""Add new tags to an object."""
if object_id == 0:
return Response(status=404)
@ -104,7 +111,9 @@ class TagView(BaseSupersetView):
@has_access_api
@expose("/tags/<object_type:object_type>/<int:object_id>/", methods=["DELETE"])
def delete(self, object_type, object_id): # pylint: disable=no-self-use
def delete( # pylint: disable=no-self-use
self, object_type: ObjectTypes, object_id: int
) -> FlaskResponse:
"""Remove tags from an object."""
tag_names = request.get_json(force=True)
if not tag_names:
@ -123,7 +132,7 @@ class TagView(BaseSupersetView):
@has_access_api
@expose("/tagged_objects/", methods=["GET", "POST"])
def tagged_objects(self): # pylint: disable=no-self-use
def tagged_objects(self) -> FlaskResponse: # pylint: disable=no-self-use
tags = [
process_template(tag)
for tag in request.args.get("tags", "").split(",")
@ -135,7 +144,7 @@ class TagView(BaseSupersetView):
# filter types
types = [type_ for type_ in request.args.get("types", "").split(",") if type_]
results = []
results: List[Dict[str, Any]] = []
# dashboards
if not types or "dashboard" in types:

View File

@ -16,11 +16,12 @@
# under the License.
from collections import defaultdict
from datetime import date
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, DefaultDict, Dict, List, Optional, Set, Tuple
from urllib import parse
import simplejson as json
from flask import g, request
from flask_appbuilder.security.sqla.models import User
import superset.models.core as models
from superset import app, db, is_feature_enabled
@ -29,7 +30,9 @@ from superset.exceptions import SupersetException
from superset.legacy import update_time_range
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.typing import FormData
from superset.utils.core import QueryStatus, TimeRangeEndpoint
from superset.viz import BaseViz
if is_feature_enabled("SIP_38_VIZ_REARCHITECTURE"):
from superset import viz_sip38 as viz # type: ignore
@ -42,7 +45,7 @@ if not app.config["ENABLE_JAVASCRIPT_CONTROLS"]:
FORM_DATA_KEY_BLACKLIST = ["js_tooltip", "js_onclick_href", "js_data_mutator"]
def bootstrap_user_data(user, include_perms=False):
def bootstrap_user_data(user: User, include_perms: bool = False) -> Dict[str, Any]:
if user.is_anonymous:
return {}
payload = {
@ -63,7 +66,9 @@ def bootstrap_user_data(user, include_perms=False):
return payload
def get_permissions(user):
def get_permissions(
user: User,
) -> Tuple[Dict[str, List[List[str]]], DefaultDict[str, Set[str]]]:
if not user.roles:
raise AttributeError("User object does not have roles")
@ -86,11 +91,8 @@ def get_permissions(user):
def get_viz(
form_data: Dict[str, Any],
datasource_type: str,
datasource_id: int,
force: bool = False,
):
form_data: FormData, datasource_type: str, datasource_id: int, force: bool = False,
) -> BaseViz:
viz_type = form_data.get("viz_type", "table")
datasource = ConnectorRegistry.get_datasource(
datasource_type, datasource_id, db.session
@ -158,9 +160,7 @@ def get_form_data(
def get_datasource_info(
datasource_id: Optional[int],
datasource_type: Optional[str],
form_data: Dict[str, Any],
datasource_id: Optional[int], datasource_type: Optional[str], form_data: FormData,
) -> Tuple[int, Optional[str]]:
"""
Compatibility layer for handling of datasource info
@ -222,9 +222,7 @@ def apply_display_max_row_limit(
def get_time_range_endpoints(
form_data: Dict[str, Any],
slc: Optional[Slice] = None,
slice_id: Optional[int] = None,
form_data: FormData, slc: Optional[Slice] = None, slice_id: Optional[int] = None,
) -> Optional[Tuple[TimeRangeEndpoint, TimeRangeEndpoint]]:
"""
Get the slice aware time range endpoints from the form-data falling back to the SQL

View File

@ -525,7 +525,7 @@ class BaseViz:
has_error = (
payload.get("status") == utils.QueryStatus.FAILED
or payload.get("error") is not None
or len(payload.get("errors") or []) > 0
or bool(payload.get("errors"))
)
return self.json_dumps(payload), has_error

View File

@ -56,7 +56,7 @@ from superset.exceptions import (
SpatialException,
)
from superset.models.helpers import QueryResult
from superset.typing import VizData
from superset.typing import QueryObjectDict, VizData, VizPayload
from superset.utils import core as utils
from superset.utils.core import (
DTTM_ALIAS,
@ -251,7 +251,7 @@ class BaseViz:
df = df[min_periods:]
return df
def get_samples(self):
def get_samples(self) -> List[Dict[str, Any]]:
query_obj = self.query_obj()
query_obj.update(
{
@ -452,7 +452,7 @@ class BaseViz:
json_data = self.json_dumps(cache_dict, sort_keys=True)
return hashlib.md5(json_data.encode("utf-8")).hexdigest()
def get_payload(self, query_obj=None):
def get_payload(self, query_obj: Optional[QueryObjectDict] = None) -> VizPayload:
"""Returns a payload of metadata and data"""
self.run_extra_queries()
payload = self.get_df_payload(query_obj)
@ -464,7 +464,9 @@ class BaseViz:
del payload["df"]
return payload
def get_df_payload(self, query_obj=None, **kwargs):
def get_df_payload(
self, query_obj: Optional[QueryObjectDict] = None, **kwargs: Any
) -> Dict[str, Any]:
"""Handles caching around the df payload retrieval"""
if not query_obj:
query_obj = self.query_obj()
@ -559,11 +561,11 @@ class BaseViz:
obj, default=utils.json_int_dttm_ser, ignore_nan=True, sort_keys=sort_keys
)
def payload_json_and_has_error(self, payload):
def payload_json_and_has_error(self, payload: VizPayload) -> Tuple[str, bool]:
has_error = (
payload.get("status") == utils.QueryStatus.FAILED
or payload.get("error") is not None
or len(payload.get("errors")) > 0
or len(payload.get("errors", [])) > 0
)
return self.json_dumps(payload), has_error
@ -578,7 +580,7 @@ class BaseViz:
}
return content
def get_csv(self):
def get_csv(self) -> Optional[str]:
df = self.get_df()
include_index = not isinstance(df.index, pd.RangeIndex)
return df.to_csv(index=include_index, **config["CSV_EXPORT"])