mirror of https://github.com/apache/superset.git
style(mypy): Spit-and-polish pass (#10001)
Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
parent
656cdfb867
commit
91517a56a3
|
@ -50,10 +50,12 @@ multi_line_output = 3
|
|||
order_by_type = false
|
||||
|
||||
[mypy]
|
||||
disallow_any_generics = true
|
||||
ignore_missing_imports = true
|
||||
no_implicit_optional = true
|
||||
warn_unused_ignores = true
|
||||
|
||||
[mypy-superset,superset.app,superset.bin.*,superset.charts.*,superset.cli,superset.commands.*,superset.common.*,superset.config,superset.connectors.*,superset.constants,superset.dataframe,superset.dao.*,superset.dashboards.*,superset.datasets.*,superset.db_engine_specs.*,superset.db_engines.*,superset.errors,superset.examples.*,superset.exceptions,superset.extensions,superset.forms,superset.jinja_context,superset.legacy,superset.migrations.*,superset.models.*,superset.result_set,superset.queries.*,superset.security.*,superset.sql_lab,superset.sql_parse,superset.sql_validators.*,superset.stats_logger,superset.tasks.*,superset.translations.*,superset.typing,superset.utils.*,,superset.views.*,superset.viz,superset.viz_sip38]
|
||||
[mypy-superset.*]
|
||||
check_untyped_defs = true
|
||||
disallow_untyped_calls = true
|
||||
disallow_untyped_defs = true
|
||||
|
|
|
@ -80,7 +80,7 @@ class SupersetAppInitializer:
|
|||
|
||||
self.flask_app = app
|
||||
self.config = app.config
|
||||
self.manifest: dict = {}
|
||||
self.manifest: Dict[Any, Any] = {}
|
||||
|
||||
def pre_init(self) -> None:
|
||||
"""
|
||||
|
@ -542,7 +542,7 @@ class SupersetAppInitializer:
|
|||
self.app = app
|
||||
|
||||
def __call__(
|
||||
self, environ: Dict[str, Any], start_response: Callable
|
||||
self, environ: Dict[str, Any], start_response: Callable[..., Any]
|
||||
) -> Any:
|
||||
# Setting wsgi.input_terminated tells werkzeug.wsgi to ignore
|
||||
# content-length and read the stream till the end.
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
|
@ -36,7 +36,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class CreateChartCommand(BaseCommand):
|
||||
def __init__(self, user: User, data: Dict):
|
||||
def __init__(self, user: User, data: Dict[str, Any]):
|
||||
self._actor = user
|
||||
self._properties = data.copy()
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
|
@ -42,7 +42,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class UpdateChartCommand(BaseCommand):
|
||||
def __init__(self, user: User, model_id: int, data: Dict):
|
||||
def __init__(self, user: User, model_id: int, data: Dict[str, Any]):
|
||||
self._actor = user
|
||||
self._model_id = model_id
|
||||
self._properties = data.copy()
|
||||
|
|
|
@ -26,6 +26,7 @@ from pandas import DataFrame
|
|||
|
||||
from superset import app, is_feature_enabled
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.typing import Metric
|
||||
from superset.utils import core as utils, pandas_postprocessing
|
||||
from superset.views.utils import get_time_range_endpoints
|
||||
|
||||
|
@ -67,11 +68,11 @@ class QueryObject:
|
|||
row_limit: int
|
||||
filter: List[Dict[str, Any]]
|
||||
timeseries_limit: int
|
||||
timeseries_limit_metric: Optional[Dict]
|
||||
timeseries_limit_metric: Optional[Metric]
|
||||
order_desc: bool
|
||||
extras: Dict
|
||||
extras: Dict[str, Any]
|
||||
columns: List[str]
|
||||
orderby: List[List]
|
||||
orderby: List[List[str]]
|
||||
post_processing: List[Dict[str, Any]]
|
||||
|
||||
def __init__(
|
||||
|
@ -85,11 +86,11 @@ class QueryObject:
|
|||
is_timeseries: bool = False,
|
||||
timeseries_limit: int = 0,
|
||||
row_limit: int = app.config["ROW_LIMIT"],
|
||||
timeseries_limit_metric: Optional[Dict] = None,
|
||||
timeseries_limit_metric: Optional[Metric] = None,
|
||||
order_desc: bool = True,
|
||||
extras: Optional[Dict] = None,
|
||||
extras: Optional[Dict[str, Any]] = None,
|
||||
columns: Optional[List[str]] = None,
|
||||
orderby: Optional[List[List]] = None,
|
||||
orderby: Optional[List[List[str]]] = None,
|
||||
post_processing: Optional[List[Dict[str, Any]]] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
|
|
|
@ -33,6 +33,7 @@ from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING
|
|||
from cachelib.base import BaseCache
|
||||
from celery.schedules import crontab
|
||||
from dateutil import tz
|
||||
from flask import Blueprint
|
||||
from flask_appbuilder.security.manager import AUTH_DB
|
||||
|
||||
from superset.jinja_context import ( # pylint: disable=unused-import
|
||||
|
@ -421,7 +422,7 @@ DEFAULT_MODULE_DS_MAP = OrderedDict(
|
|||
]
|
||||
)
|
||||
ADDITIONAL_MODULE_DS_MAP: Dict[str, List[str]] = {}
|
||||
ADDITIONAL_MIDDLEWARE: List[Callable] = []
|
||||
ADDITIONAL_MIDDLEWARE: List[Callable[..., Any]] = []
|
||||
|
||||
# 1) https://docs.python-guide.org/writing/logging/
|
||||
# 2) https://docs.python.org/2/library/logging.config.html
|
||||
|
@ -624,7 +625,7 @@ ALLOWED_USER_CSV_SCHEMA_FUNC: Callable[
|
|||
# SQL Lab. The existing context gets updated with this dictionary,
|
||||
# meaning values for existing keys get overwritten by the content of this
|
||||
# dictionary.
|
||||
JINJA_CONTEXT_ADDONS: Dict[str, Callable] = {}
|
||||
JINJA_CONTEXT_ADDONS: Dict[str, Callable[..., Any]] = {}
|
||||
|
||||
# A dictionary of macro template processors that gets merged into global
|
||||
# template processors. The existing template processors get updated with this
|
||||
|
@ -684,7 +685,7 @@ PERMISSION_INSTRUCTIONS_LINK = ""
|
|||
|
||||
# Integrate external Blueprints to the app by passing them to your
|
||||
# configuration. These blueprints will get integrated in the app
|
||||
BLUEPRINTS: List[Callable] = []
|
||||
BLUEPRINTS: List[Blueprint] = []
|
||||
|
||||
# Provide a callable that receives a tracking_url and returns another
|
||||
# URL. This is used to translate internal Hadoop job tracker URL
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import json
|
||||
from typing import Any, Dict, Hashable, List, Optional, Type
|
||||
from typing import Any, Dict, Hashable, List, Optional, Type, Union
|
||||
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
from sqlalchemy import and_, Boolean, Column, Integer, String, Text
|
||||
|
@ -64,12 +64,12 @@ class BaseDatasource(
|
|||
baselink: Optional[str] = None # url portion pointing to ModelView endpoint
|
||||
|
||||
@property
|
||||
def column_class(self) -> Type:
|
||||
def column_class(self) -> Type["BaseColumn"]:
|
||||
# link to derivative of BaseColumn
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def metric_class(self) -> Type:
|
||||
def metric_class(self) -> Type["BaseMetric"]:
|
||||
# link to derivative of BaseMetric
|
||||
raise NotImplementedError()
|
||||
|
||||
|
@ -368,7 +368,7 @@ class BaseDatasource(
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def values_for_column(self, column_name: str, limit: int = 10000) -> List:
|
||||
def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
|
||||
"""Given a column, returns an iterable of distinct values
|
||||
|
||||
This is used to populate the dropdown showing a list of
|
||||
|
@ -389,7 +389,10 @@ class BaseDatasource(
|
|||
|
||||
@staticmethod
|
||||
def get_fk_many_from_list(
|
||||
object_list: List[Any], fkmany: List[Column], fkmany_class: Type, key_attr: str,
|
||||
object_list: List[Any],
|
||||
fkmany: List[Column],
|
||||
fkmany_class: Type[Union["BaseColumn", "BaseMetric"]],
|
||||
key_attr: str,
|
||||
) -> List[Column]: # pylint: disable=too-many-locals
|
||||
"""Update ORM one-to-many list from object list
|
||||
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, List, Optional, Set, Type, TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import or_
|
||||
|
@ -22,6 +21,8 @@ from sqlalchemy.orm import Session, subqueryload
|
|||
|
||||
if TYPE_CHECKING:
|
||||
# pylint: disable=unused-import
|
||||
from collections import OrderedDict
|
||||
|
||||
from superset.models.core import Database
|
||||
from superset.connectors.base.models import BaseDatasource
|
||||
|
||||
|
@ -32,7 +33,7 @@ class ConnectorRegistry:
|
|||
sources: Dict[str, Type["BaseDatasource"]] = {}
|
||||
|
||||
@classmethod
|
||||
def register_sources(cls, datasource_config: OrderedDict) -> None:
|
||||
def register_sources(cls, datasource_config: "OrderedDict[str, List[str]]") -> None:
|
||||
for module_name, class_names in datasource_config.items():
|
||||
class_names = [str(s) for s in class_names]
|
||||
module_obj = __import__(module_name, fromlist=class_names)
|
||||
|
|
|
@ -24,18 +24,7 @@ from copy import deepcopy
|
|||
from datetime import datetime, timedelta
|
||||
from distutils.version import LooseVersion
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, cast, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import pandas as pd
|
||||
import sqlalchemy as sa
|
||||
|
@ -173,7 +162,7 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin):
|
|||
return self.__repr__()
|
||||
|
||||
@property
|
||||
def data(self) -> Dict:
|
||||
def data(self) -> Dict[str, Any]:
|
||||
return {"id": self.id, "name": self.cluster_name, "backend": "druid"}
|
||||
|
||||
@staticmethod
|
||||
|
@ -354,7 +343,7 @@ class DruidColumn(Model, BaseColumn):
|
|||
return self.dimension_spec_json
|
||||
|
||||
@property
|
||||
def dimension_spec(self) -> Optional[Dict]:
|
||||
def dimension_spec(self) -> Optional[Dict[str, Any]]:
|
||||
if self.dimension_spec_json:
|
||||
return json.loads(self.dimension_spec_json)
|
||||
return None
|
||||
|
@ -438,7 +427,7 @@ class DruidMetric(Model, BaseMetric):
|
|||
return self.json
|
||||
|
||||
@property
|
||||
def json_obj(self) -> Dict:
|
||||
def json_obj(self) -> Dict[str, Any]:
|
||||
try:
|
||||
obj = json.loads(self.json)
|
||||
except Exception:
|
||||
|
@ -614,7 +603,7 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
name = escape(self.datasource_name)
|
||||
return Markup(f'<a href="{url}">{name}</a>')
|
||||
|
||||
def get_metric_obj(self, metric_name: str) -> Dict:
|
||||
def get_metric_obj(self, metric_name: str) -> Dict[str, Any]:
|
||||
return [m.json_obj for m in self.metrics if m.metric_name == metric_name][0]
|
||||
|
||||
@classmethod
|
||||
|
@ -705,7 +694,11 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
|
||||
@classmethod
|
||||
def sync_to_db_from_config(
|
||||
cls, druid_config: Dict, user: User, cluster: DruidCluster, refresh: bool = True
|
||||
cls,
|
||||
druid_config: Dict[str, Any],
|
||||
user: User,
|
||||
cluster: DruidCluster,
|
||||
refresh: bool = True,
|
||||
) -> None:
|
||||
"""Merges the ds config from druid_config into one stored in the db."""
|
||||
session = db.session
|
||||
|
@ -901,7 +894,7 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
return postagg_metrics
|
||||
|
||||
@staticmethod
|
||||
def recursive_get_fields(_conf: Dict) -> List[str]:
|
||||
def recursive_get_fields(_conf: Dict[str, Any]) -> List[str]:
|
||||
_type = _conf.get("type")
|
||||
_field = _conf.get("field")
|
||||
_fields = _conf.get("fields")
|
||||
|
@ -957,8 +950,8 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
|
||||
@staticmethod
|
||||
def metrics_and_post_aggs(
|
||||
metrics: List[Union[Dict, str]], metrics_dict: Dict[str, DruidMetric],
|
||||
) -> Tuple[OrderedDict, OrderedDict]:
|
||||
metrics: List[Metric], metrics_dict: Dict[str, DruidMetric],
|
||||
) -> Tuple["OrderedDict[str, Any]", "OrderedDict[str, Any]"]:
|
||||
# Separate metrics into those that are aggregations
|
||||
# and those that are post aggregations
|
||||
saved_agg_names = set()
|
||||
|
@ -987,7 +980,7 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
)
|
||||
return aggs, post_aggs
|
||||
|
||||
def values_for_column(self, column_name: str, limit: int = 10000) -> List:
|
||||
def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
|
||||
"""Retrieve some values for the given column"""
|
||||
logger.info(
|
||||
"Getting values for columns [{}] limited to [{}]".format(column_name, limit)
|
||||
|
@ -1079,8 +1072,10 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
|
||||
@staticmethod
|
||||
def get_aggregations(
|
||||
metrics_dict: Dict, saved_metrics: Set[str], adhoc_metrics: List[Dict] = []
|
||||
) -> OrderedDict:
|
||||
metrics_dict: Dict[str, Any],
|
||||
saved_metrics: Set[str],
|
||||
adhoc_metrics: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> "OrderedDict[str, Any]":
|
||||
"""
|
||||
Returns a dictionary of aggregation metric names to aggregation json objects
|
||||
|
||||
|
@ -1089,7 +1084,9 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
:param adhoc_metrics: list of adhoc metric names
|
||||
:raise SupersetException: if one or more metric names are not aggregations
|
||||
"""
|
||||
aggregations: OrderedDict = OrderedDict()
|
||||
if not adhoc_metrics:
|
||||
adhoc_metrics = []
|
||||
aggregations = OrderedDict()
|
||||
invalid_metric_names = []
|
||||
for metric_name in saved_metrics:
|
||||
if metric_name in metrics_dict:
|
||||
|
@ -1115,7 +1112,7 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
|
||||
def get_dimensions(
|
||||
self, columns: List[str], columns_dict: Dict[str, DruidColumn]
|
||||
) -> List[Union[str, Dict]]:
|
||||
) -> List[Union[str, Dict[str, Any]]]:
|
||||
dimensions = []
|
||||
columns = [col for col in columns if col in columns_dict]
|
||||
for column_name in columns:
|
||||
|
@ -1433,7 +1430,7 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
df[columns] = df[columns].fillna(NULL_STRING).astype("unicode")
|
||||
return df
|
||||
|
||||
def query(self, query_obj: Dict) -> QueryResult:
|
||||
def query(self, query_obj: QueryObjectDict) -> QueryResult:
|
||||
qry_start_dttm = datetime.now()
|
||||
client = self.cluster.get_pydruid_client()
|
||||
query_str = self.get_query_str(client=client, query_obj=query_obj, phase=2)
|
||||
|
@ -1583,7 +1580,7 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
dimension=col, value=eq, extraction_function=extraction_fn
|
||||
)
|
||||
elif is_list_target:
|
||||
eq = cast(list, eq)
|
||||
eq = cast(List[Any], eq)
|
||||
fields = []
|
||||
# ignore the filter if it has no value
|
||||
if not len(eq):
|
||||
|
|
|
@ -597,7 +597,7 @@ class SqlaTable(Model, BaseDatasource):
|
|||
)
|
||||
|
||||
@property
|
||||
def data(self) -> Dict:
|
||||
def data(self) -> Dict[str, Any]:
|
||||
d = super().data
|
||||
if self.type == "table":
|
||||
grains = self.database.grains() or []
|
||||
|
@ -684,7 +684,9 @@ class SqlaTable(Model, BaseDatasource):
|
|||
return TextAsFrom(sa.text(from_sql), []).alias("expr_qry")
|
||||
return self.get_sqla_table()
|
||||
|
||||
def adhoc_metric_to_sqla(self, metric: Dict, cols: Dict) -> Optional[Column]:
|
||||
def adhoc_metric_to_sqla(
|
||||
self, metric: Dict[str, Any], cols: Dict[str, Any]
|
||||
) -> Optional[Column]:
|
||||
"""
|
||||
Turn an adhoc metric into a sqlalchemy column.
|
||||
|
||||
|
@ -804,7 +806,7 @@ class SqlaTable(Model, BaseDatasource):
|
|||
main_metric_expr = self.make_sqla_column_compatible(main_metric_expr, label)
|
||||
|
||||
select_exprs: List[Column] = []
|
||||
groupby_exprs_sans_timestamp: OrderedDict = OrderedDict()
|
||||
groupby_exprs_sans_timestamp = OrderedDict()
|
||||
|
||||
if (is_sip_38 and metrics and columns) or (not is_sip_38 and groupby):
|
||||
# dedup columns while preserving order
|
||||
|
@ -874,7 +876,7 @@ class SqlaTable(Model, BaseDatasource):
|
|||
qry = qry.group_by(*groupby_exprs_with_timestamp.values())
|
||||
|
||||
where_clause_and = []
|
||||
having_clause_and: List = []
|
||||
having_clause_and = []
|
||||
|
||||
for flt in filter: # type: ignore
|
||||
if not all([flt.get(s) for s in ["col", "op"]]):
|
||||
|
@ -1082,7 +1084,10 @@ class SqlaTable(Model, BaseDatasource):
|
|||
return ob
|
||||
|
||||
def _get_top_groups(
|
||||
self, df: pd.DataFrame, dimensions: List, groupby_exprs: OrderedDict
|
||||
self,
|
||||
df: pd.DataFrame,
|
||||
dimensions: List[str],
|
||||
groupby_exprs: "OrderedDict[str, Any]",
|
||||
) -> ColumnElement:
|
||||
groups = []
|
||||
for unused, row in df.iterrows():
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from flask_appbuilder.models.filters import BaseFilter
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
|
@ -75,7 +75,7 @@ class BaseDAO:
|
|||
return query.all()
|
||||
|
||||
@classmethod
|
||||
def create(cls, properties: Dict, commit: bool = True) -> Model:
|
||||
def create(cls, properties: Dict[str, Any], commit: bool = True) -> Model:
|
||||
"""
|
||||
Generic for creating models
|
||||
:raises: DAOCreateFailedError
|
||||
|
@ -95,7 +95,9 @@ class BaseDAO:
|
|||
return model
|
||||
|
||||
@classmethod
|
||||
def update(cls, model: Model, properties: Dict, commit: bool = True) -> Model:
|
||||
def update(
|
||||
cls, model: Model, properties: Dict[str, Any], commit: bool = True
|
||||
) -> Model:
|
||||
"""
|
||||
Generic update a model
|
||||
:raises: DAOCreateFailedError
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
|
@ -35,7 +35,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class CreateDashboardCommand(BaseCommand):
|
||||
def __init__(self, user: User, data: Dict):
|
||||
def __init__(self, user: User, data: Dict[str, Any]):
|
||||
self._actor = user
|
||||
self._properties = data.copy()
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
|
@ -40,7 +40,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class UpdateDashboardCommand(BaseCommand):
|
||||
def __init__(self, user: User, model_id: int, data: Dict):
|
||||
def __init__(self, user: User, model_id: int, data: Dict[str, Any]):
|
||||
self._actor = user
|
||||
self._model_id = model_id
|
||||
self._properties = data.copy()
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
|
@ -39,7 +39,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class CreateDatasetCommand(BaseCommand):
|
||||
def __init__(self, user: User, data: Dict):
|
||||
def __init__(self, user: User, data: Dict[str, Any]):
|
||||
self._actor = user
|
||||
self._properties = data.copy()
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
# under the License.
|
||||
import logging
|
||||
from collections import Counter
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
|
@ -48,7 +48,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class UpdateDatasetCommand(BaseCommand):
|
||||
def __init__(self, user: User, model_id: int, data: Dict):
|
||||
def __init__(self, user: User, model_id: int, data: Dict[str, Any]):
|
||||
self._actor = user
|
||||
self._model_id = model_id
|
||||
self._properties = data.copy()
|
||||
|
@ -111,7 +111,7 @@ class UpdateDatasetCommand(BaseCommand):
|
|||
raise exception
|
||||
|
||||
def _validate_columns(
|
||||
self, columns: List[Dict], exceptions: List[ValidationError]
|
||||
self, columns: List[Dict[str, Any]], exceptions: List[ValidationError]
|
||||
) -> None:
|
||||
# Validate duplicates on data
|
||||
if self._get_duplicates(columns, "column_name"):
|
||||
|
@ -133,7 +133,7 @@ class UpdateDatasetCommand(BaseCommand):
|
|||
exceptions.append(DatasetColumnsExistsValidationError())
|
||||
|
||||
def _validate_metrics(
|
||||
self, metrics: List[Dict], exceptions: List[ValidationError]
|
||||
self, metrics: List[Dict[str, Any]], exceptions: List[ValidationError]
|
||||
) -> None:
|
||||
if self._get_duplicates(metrics, "metric_name"):
|
||||
exceptions.append(DatasetMetricsDuplicateValidationError())
|
||||
|
@ -152,7 +152,7 @@ class UpdateDatasetCommand(BaseCommand):
|
|||
exceptions.append(DatasetMetricsExistsValidationError())
|
||||
|
||||
@staticmethod
|
||||
def _get_duplicates(data: List[Dict], key: str) -> List[str]:
|
||||
def _get_duplicates(data: List[Dict[str, Any]], key: str) -> List[str]:
|
||||
duplicates = [
|
||||
name
|
||||
for name, count in Counter([item[key] for item in data]).items()
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from flask import current_app
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
@ -116,7 +116,7 @@ class DatasetDAO(BaseDAO):
|
|||
|
||||
@classmethod
|
||||
def update(
|
||||
cls, model: SqlaTable, properties: Dict, commit: bool = True
|
||||
cls, model: SqlaTable, properties: Dict[str, Any], commit: bool = True
|
||||
) -> Optional[SqlaTable]:
|
||||
"""
|
||||
Updates a Dataset model on the metadata DB
|
||||
|
@ -151,13 +151,13 @@ class DatasetDAO(BaseDAO):
|
|||
|
||||
@classmethod
|
||||
def update_column(
|
||||
cls, model: TableColumn, properties: Dict, commit: bool = True
|
||||
cls, model: TableColumn, properties: Dict[str, Any], commit: bool = True
|
||||
) -> Optional[TableColumn]:
|
||||
return DatasetColumnDAO.update(model, properties, commit=commit)
|
||||
|
||||
@classmethod
|
||||
def create_column(
|
||||
cls, properties: Dict, commit: bool = True
|
||||
cls, properties: Dict[str, Any], commit: bool = True
|
||||
) -> Optional[TableColumn]:
|
||||
"""
|
||||
Creates a Dataset model on the metadata DB
|
||||
|
@ -166,13 +166,13 @@ class DatasetDAO(BaseDAO):
|
|||
|
||||
@classmethod
|
||||
def update_metric(
|
||||
cls, model: SqlMetric, properties: Dict, commit: bool = True
|
||||
cls, model: SqlMetric, properties: Dict[str, Any], commit: bool = True
|
||||
) -> Optional[SqlMetric]:
|
||||
return DatasetMetricDAO.update(model, properties, commit=commit)
|
||||
|
||||
@classmethod
|
||||
def create_metric(
|
||||
cls, properties: Dict, commit: bool = True
|
||||
cls, properties: Dict[str, Any], commit: bool = True
|
||||
) -> Optional[SqlMetric]:
|
||||
"""
|
||||
Creates a Dataset model on the metadata DB
|
||||
|
|
|
@ -151,7 +151,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
try_remove_schema_from_table_name = True # pylint: disable=invalid-name
|
||||
|
||||
# default matching patterns for identifying column types
|
||||
db_column_types: Dict[utils.DbColumnType, Tuple[Pattern, ...]] = {
|
||||
db_column_types: Dict[utils.DbColumnType, Tuple[Pattern[Any], ...]] = {
|
||||
utils.DbColumnType.NUMERIC: (
|
||||
re.compile(r".*DOUBLE.*", re.IGNORECASE),
|
||||
re.compile(r".*FLOAT.*", re.IGNORECASE),
|
||||
|
@ -296,7 +296,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
return select_exprs
|
||||
|
||||
@classmethod
|
||||
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]:
|
||||
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]:
|
||||
"""
|
||||
|
||||
:param cursor: Cursor instance
|
||||
|
@ -311,8 +311,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
|
||||
@classmethod
|
||||
def expand_data(
|
||||
cls, columns: List[dict], data: List[dict]
|
||||
) -> Tuple[List[dict], List[dict], List[dict]]:
|
||||
cls, columns: List[Dict[Any, Any]], data: List[Dict[Any, Any]]
|
||||
) -> Tuple[List[Dict[Any, Any]], List[Dict[Any, Any]], List[Dict[Any, Any]]]:
|
||||
"""
|
||||
Some engines support expanding nested fields. See implementation in Presto
|
||||
spec for details.
|
||||
|
@ -645,7 +645,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
schema: Optional[str],
|
||||
database: "Database",
|
||||
query: Select,
|
||||
columns: Optional[List] = None,
|
||||
columns: Optional[List[Dict[str, str]]] = None,
|
||||
) -> Optional[Select]:
|
||||
"""
|
||||
Add a where clause to a query to reference only the most recent partition
|
||||
|
@ -925,7 +925,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
return []
|
||||
|
||||
@staticmethod
|
||||
def pyodbc_rows_to_tuples(data: List[Any]) -> List[Tuple]:
|
||||
def pyodbc_rows_to_tuples(data: List[Any]) -> List[Tuple[Any, ...]]:
|
||||
"""
|
||||
Convert pyodbc.Row objects from `fetch_data` to tuples.
|
||||
|
||||
|
|
|
@ -83,7 +83,7 @@ class BigQueryEngineSpec(BaseEngineSpec):
|
|||
return None
|
||||
|
||||
@classmethod
|
||||
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]:
|
||||
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]:
|
||||
data = super().fetch_data(cursor, limit)
|
||||
# Support type BigQuery Row, introduced here PR #4071
|
||||
# google.cloud.bigquery.table.Row
|
||||
|
|
|
@ -39,7 +39,7 @@ class ExasolEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
|
|||
}
|
||||
|
||||
@classmethod
|
||||
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]:
|
||||
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]:
|
||||
data = super().fetch_data(cursor, limit)
|
||||
# Lists of `pyodbc.Row` need to be unpacked further
|
||||
return cls.pyodbc_rows_to_tuples(data)
|
||||
|
|
|
@ -93,7 +93,7 @@ class HiveEngineSpec(PrestoEngineSpec):
|
|||
return BaseEngineSpec.get_all_datasource_names(database, datasource_type)
|
||||
|
||||
@classmethod
|
||||
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]:
|
||||
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]:
|
||||
import pyhive
|
||||
from TCLIService import ttypes
|
||||
|
||||
|
@ -304,7 +304,7 @@ class HiveEngineSpec(PrestoEngineSpec):
|
|||
schema: Optional[str],
|
||||
database: "Database",
|
||||
query: Select,
|
||||
columns: Optional[List] = None,
|
||||
columns: Optional[List[Dict[str, str]]] = None,
|
||||
) -> Optional[Select]:
|
||||
try:
|
||||
col_names, values = cls.latest_partition(
|
||||
|
@ -323,7 +323,7 @@ class HiveEngineSpec(PrestoEngineSpec):
|
|||
return None
|
||||
|
||||
@classmethod
|
||||
def _get_fields(cls, cols: List[dict]) -> List[ColumnClause]:
|
||||
def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[ColumnClause]:
|
||||
return BaseEngineSpec._get_fields(cols) # pylint: disable=protected-access
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -66,7 +66,7 @@ class MssqlEngineSpec(BaseEngineSpec):
|
|||
return None
|
||||
|
||||
@classmethod
|
||||
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]:
|
||||
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]:
|
||||
data = super().fetch_data(cursor, limit)
|
||||
# Lists of `pyodbc.Row` need to be unpacked further
|
||||
return cls.pyodbc_rows_to_tuples(data)
|
||||
|
|
|
@ -51,7 +51,7 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
|
|||
}
|
||||
|
||||
@classmethod
|
||||
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]:
|
||||
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]:
|
||||
cursor.tzinfo_factory = FixedOffsetTimezone
|
||||
if not cursor.description:
|
||||
return []
|
||||
|
|
|
@ -164,7 +164,7 @@ class PrestoEngineSpec(BaseEngineSpec):
|
|||
return [row[0] for row in results]
|
||||
|
||||
@classmethod
|
||||
def _create_column_info(cls, name: str, data_type: str) -> dict:
|
||||
def _create_column_info(cls, name: str, data_type: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Create column info object
|
||||
:param name: column name
|
||||
|
@ -213,7 +213,10 @@ class PrestoEngineSpec(BaseEngineSpec):
|
|||
|
||||
@classmethod
|
||||
def _parse_structural_column( # pylint: disable=too-many-locals,too-many-branches
|
||||
cls, parent_column_name: str, parent_data_type: str, result: List[dict]
|
||||
cls,
|
||||
parent_column_name: str,
|
||||
parent_data_type: str,
|
||||
result: List[Dict[str, Any]],
|
||||
) -> None:
|
||||
"""
|
||||
Parse a row or array column
|
||||
|
@ -322,7 +325,7 @@ class PrestoEngineSpec(BaseEngineSpec):
|
|||
(i.e. column name and data type)
|
||||
"""
|
||||
columns = cls._show_columns(inspector, table_name, schema)
|
||||
result: List[dict] = []
|
||||
result: List[Dict[str, Any]] = []
|
||||
for column in columns:
|
||||
try:
|
||||
# parse column if it is a row or array
|
||||
|
@ -361,7 +364,7 @@ class PrestoEngineSpec(BaseEngineSpec):
|
|||
return column_name.startswith('"') and column_name.endswith('"')
|
||||
|
||||
@classmethod
|
||||
def _get_fields(cls, cols: List[dict]) -> List[ColumnClause]:
|
||||
def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[ColumnClause]:
|
||||
"""
|
||||
Format column clauses where names are in quotes and labels are specified
|
||||
:param cols: columns
|
||||
|
@ -561,8 +564,8 @@ class PrestoEngineSpec(BaseEngineSpec):
|
|||
|
||||
@classmethod
|
||||
def expand_data( # pylint: disable=too-many-locals
|
||||
cls, columns: List[dict], data: List[dict]
|
||||
) -> Tuple[List[dict], List[dict], List[dict]]:
|
||||
cls, columns: List[Dict[Any, Any]], data: List[Dict[Any, Any]]
|
||||
) -> Tuple[List[Dict[Any, Any]], List[Dict[Any, Any]], List[Dict[Any, Any]]]:
|
||||
"""
|
||||
We do not immediately display rows and arrays clearly in the data grid. This
|
||||
method separates out nested fields and data values to help clearly display
|
||||
|
@ -590,7 +593,7 @@ class PrestoEngineSpec(BaseEngineSpec):
|
|||
# process each column, unnesting ARRAY types and
|
||||
# expanding ROW types into new columns
|
||||
to_process = deque((column, 0) for column in columns)
|
||||
all_columns: List[dict] = []
|
||||
all_columns: List[Dict[str, Any]] = []
|
||||
expanded_columns = []
|
||||
current_array_level = None
|
||||
while to_process:
|
||||
|
@ -843,7 +846,7 @@ class PrestoEngineSpec(BaseEngineSpec):
|
|||
schema: Optional[str],
|
||||
database: "Database",
|
||||
query: Select,
|
||||
columns: Optional[List] = None,
|
||||
columns: Optional[List[Dict[str, str]]] = None,
|
||||
) -> Optional[Select]:
|
||||
try:
|
||||
col_names, values = cls.latest_partition(
|
||||
|
|
|
@ -95,7 +95,9 @@ class UIManifestProcessor:
|
|||
self.parse_manifest_json()
|
||||
|
||||
@app.context_processor
|
||||
def get_manifest() -> Dict[str, Callable]: # pylint: disable=unused-variable
|
||||
def get_manifest() -> Dict[ # pylint: disable=unused-variable
|
||||
str, Callable[[str], List[str]]
|
||||
]:
|
||||
loaded_chunks = set()
|
||||
|
||||
def get_files(bundle: str, asset_type: str = "js") -> List[str]:
|
||||
|
@ -131,7 +133,7 @@ appbuilder = AppBuilder(update_perms=False)
|
|||
cache_manager = CacheManager()
|
||||
celery_app = celery.Celery()
|
||||
db = SQLA()
|
||||
_event_logger: dict = {}
|
||||
_event_logger: Dict[str, Any] = {}
|
||||
event_logger = LocalProxy(lambda: _event_logger.get("event_logger"))
|
||||
feature_flag_manager = FeatureFlagManager()
|
||||
jinja_context_manager = JinjaContextManager()
|
||||
|
|
|
@ -341,11 +341,14 @@ class Database(
|
|||
def get_reserved_words(self) -> Set[str]:
|
||||
return self.get_dialect().preparer.reserved_words
|
||||
|
||||
def get_quoter(self) -> Callable:
|
||||
def get_quoter(self) -> Callable[[str, Any], str]:
|
||||
return self.get_dialect().identifier_preparer.quote
|
||||
|
||||
def get_df( # pylint: disable=too-many-locals
|
||||
self, sql: str, schema: Optional[str] = None, mutator: Optional[Callable] = None
|
||||
self,
|
||||
sql: str,
|
||||
schema: Optional[str] = None,
|
||||
mutator: Optional[Callable[[pd.DataFrame], None]] = None,
|
||||
) -> pd.DataFrame:
|
||||
sqls = [str(s).strip(" ;") for s in sqlparse.parse(sql)]
|
||||
|
||||
|
@ -450,7 +453,7 @@ class Database(
|
|||
|
||||
@cache_util.memoized_func(
|
||||
key=lambda *args, **kwargs: "db:{}:schema:None:view_list",
|
||||
attribute_in_key="id", # type: ignore
|
||||
attribute_in_key="id",
|
||||
)
|
||||
def get_all_view_names_in_database(
|
||||
self,
|
||||
|
|
|
@ -240,7 +240,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
|
|||
self.json_metadata = value
|
||||
|
||||
@property
|
||||
def position(self) -> Dict:
|
||||
def position(self) -> Dict[str, Any]:
|
||||
if self.position_json:
|
||||
return json.loads(self.position_json)
|
||||
return {}
|
||||
|
@ -315,7 +315,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
|
|||
old_to_new_slc_id_dict: Dict[int, int] = {}
|
||||
new_timed_refresh_immune_slices = []
|
||||
new_expanded_slices = {}
|
||||
new_filter_scopes: Dict[str, Dict] = {}
|
||||
new_filter_scopes = {}
|
||||
i_params_dict = dashboard_to_import.params_dict
|
||||
remote_id_slice_map = {
|
||||
slc.params_dict["remote_id"]: slc
|
||||
|
@ -351,7 +351,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
|
|||
# are converted to filter_scopes
|
||||
# but dashboard create from import may still have old dashboard filter metadata
|
||||
# here we convert them to new filter_scopes metadata first
|
||||
filter_scopes: Dict = {}
|
||||
filter_scopes = {}
|
||||
if (
|
||||
"filter_immune_slices" in i_params_dict
|
||||
or "filter_immune_slice_fields" in i_params_dict
|
||||
|
@ -415,7 +415,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
|
|||
|
||||
@classmethod
|
||||
def export_dashboards( # pylint: disable=too-many-locals
|
||||
cls, dashboard_ids: List
|
||||
cls, dashboard_ids: List[int]
|
||||
) -> str:
|
||||
copied_dashboards = []
|
||||
datasource_ids = set()
|
||||
|
|
|
@ -81,7 +81,7 @@ class ImportMixin:
|
|||
for u in cls.__table_args__ # type: ignore
|
||||
if isinstance(u, UniqueConstraint)
|
||||
]
|
||||
unique.extend( # type: ignore
|
||||
unique.extend(
|
||||
{c.name} for c in cls.__table__.columns if c.unique # type: ignore
|
||||
)
|
||||
return unique
|
||||
|
|
|
@ -36,7 +36,7 @@ from superset.tasks.thumbnails import cache_chart_thumbnail
|
|||
from superset.utils import core as utils
|
||||
|
||||
if is_feature_enabled("SIP_38_VIZ_REARCHITECTURE"):
|
||||
from superset.viz_sip38 import BaseViz, viz_types # type: ignore
|
||||
from superset.viz_sip38 import BaseViz, viz_types
|
||||
else:
|
||||
from superset.viz import BaseViz, viz_types # type: ignore
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import Any, Optional, Type
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
from sqlalchemy import types
|
||||
from sqlalchemy.sql.sqltypes import Integer
|
||||
|
@ -29,7 +29,7 @@ class TinyInteger(Integer):
|
|||
A type for tiny ``int`` integers.
|
||||
"""
|
||||
|
||||
def python_type(self) -> Type:
|
||||
def python_type(self) -> Type[int]:
|
||||
return int
|
||||
|
||||
@classmethod
|
||||
|
@ -42,7 +42,7 @@ class Interval(TypeEngine):
|
|||
A type for intervals.
|
||||
"""
|
||||
|
||||
def python_type(self) -> Optional[Type]:
|
||||
def python_type(self) -> Optional[Type[Any]]:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
|
@ -55,7 +55,7 @@ class Array(TypeEngine):
|
|||
A type for arrays.
|
||||
"""
|
||||
|
||||
def python_type(self) -> Optional[Type]:
|
||||
def python_type(self) -> Optional[Type[List[Any]]]:
|
||||
return list
|
||||
|
||||
@classmethod
|
||||
|
@ -68,7 +68,7 @@ class Map(TypeEngine):
|
|||
A type for maps.
|
||||
"""
|
||||
|
||||
def python_type(self) -> Optional[Type]:
|
||||
def python_type(self) -> Optional[Type[Dict[Any, Any]]]:
|
||||
return dict
|
||||
|
||||
@classmethod
|
||||
|
@ -81,7 +81,7 @@ class Row(TypeEngine):
|
|||
A type for rows.
|
||||
"""
|
||||
|
||||
def python_type(self) -> Optional[Type]:
|
||||
def python_type(self) -> Optional[Type[Any]]:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import Callable
|
||||
from typing import Any
|
||||
|
||||
from flask import g
|
||||
from flask_sqlalchemy import BaseQuery
|
||||
|
@ -25,7 +25,7 @@ from superset.views.base import BaseFilter
|
|||
|
||||
|
||||
class QueryFilter(BaseFilter): # pylint: disable=too-few-public-methods
|
||||
def apply(self, query: BaseQuery, value: Callable) -> BaseQuery:
|
||||
def apply(self, query: BaseQuery, value: Any) -> BaseQuery:
|
||||
"""
|
||||
Filter queries to only those owned by current user. If
|
||||
can_access_all_queries permission is set a user can list all queries
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
import datetime
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
@ -64,7 +64,7 @@ def stringify(obj: Any) -> str:
|
|||
|
||||
|
||||
def stringify_values(array: np.ndarray) -> np.ndarray:
|
||||
vstringify: Callable = np.vectorize(stringify)
|
||||
vstringify = np.vectorize(stringify)
|
||||
return vstringify(array)
|
||||
|
||||
|
||||
|
@ -172,7 +172,7 @@ class SupersetResultSet:
|
|||
return table.to_pandas(integer_object_nulls=True)
|
||||
|
||||
@staticmethod
|
||||
def first_nonempty(items: List) -> Any:
|
||||
def first_nonempty(items: List[Any]) -> Any:
|
||||
return next((i for i in items if i), None)
|
||||
|
||||
def is_temporal(self, db_type_str: Optional[str]) -> bool:
|
||||
|
|
|
@ -21,11 +21,11 @@ from typing import Any, Callable, List, Optional, Set, Tuple, TYPE_CHECKING, Uni
|
|||
|
||||
from flask import current_app, g
|
||||
from flask_appbuilder import Model
|
||||
from flask_appbuilder.security.sqla import models as ab_models
|
||||
from flask_appbuilder.security.sqla.manager import SecurityManager
|
||||
from flask_appbuilder.security.sqla.models import (
|
||||
assoc_permissionview_role,
|
||||
assoc_user_role,
|
||||
PermissionView,
|
||||
)
|
||||
from flask_appbuilder.security.views import (
|
||||
PermissionModelView,
|
||||
|
@ -602,11 +602,8 @@ class SupersetSecurityManager(SecurityManager):
|
|||
|
||||
logger.info("Cleaning faulty perms")
|
||||
sesh = self.get_session
|
||||
pvms = sesh.query(ab_models.PermissionView).filter(
|
||||
or_(
|
||||
ab_models.PermissionView.permission == None,
|
||||
ab_models.PermissionView.view_menu == None,
|
||||
)
|
||||
pvms = sesh.query(PermissionView).filter(
|
||||
or_(PermissionView.permission == None, PermissionView.view_menu == None,)
|
||||
)
|
||||
deleted_count = pvms.delete()
|
||||
sesh.commit()
|
||||
|
@ -640,7 +637,9 @@ class SupersetSecurityManager(SecurityManager):
|
|||
self.get_session.commit()
|
||||
self.clean_perms()
|
||||
|
||||
def set_role(self, role_name: str, pvm_check: Callable) -> None:
|
||||
def set_role(
|
||||
self, role_name: str, pvm_check: Callable[[PermissionView], bool]
|
||||
) -> None:
|
||||
"""
|
||||
Set the FAB permission/views for the role.
|
||||
|
||||
|
@ -650,7 +649,7 @@ class SupersetSecurityManager(SecurityManager):
|
|||
|
||||
logger.info("Syncing {} perms".format(role_name))
|
||||
sesh = self.get_session
|
||||
pvms = sesh.query(ab_models.PermissionView).all()
|
||||
pvms = sesh.query(PermissionView).all()
|
||||
pvms = [p for p in pvms if p.permission and p.view_menu]
|
||||
role = self.add_role(role_name)
|
||||
role_pvms = [p for p in pvms if pvm_check(p)]
|
||||
|
|
|
@ -299,9 +299,10 @@ def _serialize_and_expand_data(
|
|||
db_engine_spec: BaseEngineSpec,
|
||||
use_msgpack: Optional[bool] = False,
|
||||
expand_data: bool = False,
|
||||
) -> Tuple[Union[bytes, str], list, list, list]:
|
||||
selected_columns: List[Dict] = result_set.columns
|
||||
expanded_columns: List[Dict]
|
||||
) -> Tuple[Union[bytes, str], List[Any], List[Any], List[Any]]:
|
||||
selected_columns = result_set.columns
|
||||
all_columns: List[Any]
|
||||
expanded_columns: List[Any]
|
||||
|
||||
if use_msgpack:
|
||||
with stats_timing(
|
||||
|
|
|
@ -25,7 +25,7 @@ from superset import create_app
|
|||
from superset.extensions import celery_app
|
||||
|
||||
# Init the Flask app / configure everything
|
||||
create_app() # type: ignore
|
||||
create_app()
|
||||
|
||||
# Need to import late, as the celery_app will have been setup by "create_app()"
|
||||
# pylint: disable=wrong-import-position, unused-import
|
||||
|
|
|
@ -23,7 +23,7 @@ import urllib.request
|
|||
from collections import namedtuple
|
||||
from datetime import datetime, timedelta
|
||||
from email.utils import make_msgid, parseaddr
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
from urllib.error import URLError # pylint: disable=ungrouped-imports
|
||||
|
||||
import croniter
|
||||
|
@ -36,7 +36,6 @@ from flask_login import login_user
|
|||
from retry.api import retry_call
|
||||
from selenium.common.exceptions import WebDriverException
|
||||
from selenium.webdriver import chrome, firefox
|
||||
from werkzeug.datastructures import TypeConversionDict
|
||||
from werkzeug.http import parse_cookie
|
||||
|
||||
# Superset framework imports
|
||||
|
@ -53,6 +52,11 @@ from superset.models.schedules import (
|
|||
)
|
||||
from superset.utils.core import get_email_address_list, send_email_smtp
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# pylint: disable=unused-import
|
||||
from werkzeug.datastructures import TypeConversionDict
|
||||
|
||||
|
||||
# Globals
|
||||
config = app.config
|
||||
logger = logging.getLogger("tasks.email_reports")
|
||||
|
@ -131,7 +135,7 @@ def _generate_mail_content(
|
|||
return EmailContent(body, data, images)
|
||||
|
||||
|
||||
def _get_auth_cookies() -> List[TypeConversionDict]:
|
||||
def _get_auth_cookies() -> List["TypeConversionDict[Any, Any]"]:
|
||||
# Login with the user specified to get the reports
|
||||
with app.test_request_context():
|
||||
user = security_manager.find_user(config["EMAIL_REPORTS_USER"])
|
||||
|
|
|
@ -27,8 +27,9 @@ def view_cache_key(*args: Any, **kwargs: Any) -> str: # pylint: disable=unused-
|
|||
|
||||
|
||||
def memoized_func(
|
||||
key: Callable = view_cache_key, attribute_in_key: Optional[str] = None
|
||||
) -> Callable:
|
||||
key: Callable[..., str] = view_cache_key, # pylint: disable=bad-whitespace
|
||||
attribute_in_key: Optional[str] = None,
|
||||
) -> Callable[..., Any]:
|
||||
"""Use this decorator to cache functions that have predefined first arg.
|
||||
|
||||
enable_cache is treated as True by default,
|
||||
|
@ -45,7 +46,7 @@ def memoized_func(
|
|||
returns the caching key.
|
||||
"""
|
||||
|
||||
def wrap(f: Callable) -> Callable:
|
||||
def wrap(f: Callable[..., Any]) -> Callable[..., Any]:
|
||||
if cache_manager.tables_cache:
|
||||
|
||||
def wrapped_f(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
|
|
|
@ -85,7 +85,7 @@ from superset.exceptions import (
|
|||
SupersetException,
|
||||
SupersetTimeoutException,
|
||||
)
|
||||
from superset.typing import FormData, Metric
|
||||
from superset.typing import FlaskResponse, FormData, Metric
|
||||
from superset.utils.dates import datetime_to_epoch, EPOCH
|
||||
|
||||
try:
|
||||
|
@ -147,7 +147,9 @@ class _memoized:
|
|||
should account for instance variable changes.
|
||||
"""
|
||||
|
||||
def __init__(self, func: Callable, watch: Optional[Tuple[str, ...]] = None) -> None:
|
||||
def __init__(
|
||||
self, func: Callable[..., Any], watch: Optional[Tuple[str, ...]] = None
|
||||
) -> None:
|
||||
self.func = func
|
||||
self.cache: Dict[Any, Any] = {}
|
||||
self.is_method = False
|
||||
|
@ -173,7 +175,7 @@ class _memoized:
|
|||
"""Return the function's docstring."""
|
||||
return self.func.__doc__ or ""
|
||||
|
||||
def __get__(self, obj: Any, objtype: Type) -> functools.partial:
|
||||
def __get__(self, obj: Any, objtype: Type[Any]) -> functools.partial: # type: ignore
|
||||
if not self.is_method:
|
||||
self.is_method = True
|
||||
"""Support instance methods."""
|
||||
|
@ -181,13 +183,13 @@ class _memoized:
|
|||
|
||||
|
||||
def memoized(
|
||||
func: Optional[Callable] = None, watch: Optional[Tuple[str, ...]] = None
|
||||
) -> Callable:
|
||||
func: Optional[Callable[..., Any]] = None, watch: Optional[Tuple[str, ...]] = None
|
||||
) -> Callable[..., Any]:
|
||||
if func:
|
||||
return _memoized(func)
|
||||
else:
|
||||
|
||||
def wrapper(f: Callable) -> Callable:
|
||||
def wrapper(f: Callable[..., Any]) -> Callable[..., Any]:
|
||||
return _memoized(f, watch)
|
||||
|
||||
return wrapper
|
||||
|
@ -1241,7 +1243,9 @@ def create_ssl_cert_file(certificate: str) -> str:
|
|||
return path
|
||||
|
||||
|
||||
def time_function(func: Callable, *args: Any, **kwargs: Any) -> Tuple[float, Any]:
|
||||
def time_function(
|
||||
func: Callable[..., FlaskResponse], *args: Any, **kwargs: Any
|
||||
) -> Tuple[float, Any]:
|
||||
"""
|
||||
Measures the amount of time a function takes to execute in ms
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ def convert_filter_scopes(
|
|||
) -> Dict[int, Dict[str, Dict[str, Any]]]:
|
||||
filter_scopes = {}
|
||||
immuned_by_id: List[int] = json_metadata.get("filter_immune_slices") or []
|
||||
immuned_by_column: Dict = defaultdict(list)
|
||||
immuned_by_column: Dict[str, List[int]] = defaultdict(list)
|
||||
for slice_id, columns in json_metadata.get(
|
||||
"filter_immune_slice_fields", {}
|
||||
).items():
|
||||
|
@ -52,7 +52,7 @@ def convert_filter_scopes(
|
|||
logging.info(f"slice [{filter_id}] has invalid field: {filter_field}")
|
||||
|
||||
for filter_slice in filters:
|
||||
filter_fields: Dict = {}
|
||||
filter_fields: Dict[str, Dict[str, Any]] = {}
|
||||
filter_id = filter_slice.id
|
||||
slice_params = json.loads(filter_slice.params or "{}")
|
||||
configs = slice_params.get("filter_configs") or []
|
||||
|
@ -77,9 +77,10 @@ def convert_filter_scopes(
|
|||
|
||||
|
||||
def copy_filter_scopes(
|
||||
old_to_new_slc_id_dict: Dict[int, int], old_filter_scopes: Dict[str, Dict]
|
||||
) -> Dict:
|
||||
new_filter_scopes: Dict[str, Dict] = {}
|
||||
old_to_new_slc_id_dict: Dict[int, int],
|
||||
old_filter_scopes: Dict[int, Dict[str, Dict[str, Any]]],
|
||||
) -> Dict[str, Dict[Any, Any]]:
|
||||
new_filter_scopes: Dict[str, Dict[Any, Any]] = {}
|
||||
for (filter_id, scopes) in old_filter_scopes.items():
|
||||
new_filter_key = old_to_new_slc_id_dict.get(int(filter_id))
|
||||
if new_filter_key:
|
||||
|
|
|
@ -46,7 +46,7 @@ def stats_timing(stats_key: str, stats_logger: BaseStatsLogger) -> Iterator[floa
|
|||
stats_logger.timing(stats_key, now_as_float() - start_ts)
|
||||
|
||||
|
||||
def etag_cache(max_age: int, check_perms: Callable) -> Callable:
|
||||
def etag_cache(max_age: int, check_perms: Callable[..., Any]) -> Callable[..., Any]:
|
||||
"""
|
||||
A decorator for caching views and handling etag conditional requests.
|
||||
|
||||
|
@ -60,7 +60,7 @@ def etag_cache(max_age: int, check_perms: Callable) -> Callable:
|
|||
|
||||
"""
|
||||
|
||||
def decorator(f: Callable) -> Callable:
|
||||
def decorator(f: Callable[..., Any]) -> Callable[..., Any]:
|
||||
@wraps(f)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> ETagResponseMixin:
|
||||
# check if the user can access the resource
|
||||
|
|
|
@ -27,8 +27,8 @@ logger = logging.getLogger(__name__)
|
|||
def import_datasource(
|
||||
session: Session,
|
||||
i_datasource: Model,
|
||||
lookup_database: Callable,
|
||||
lookup_datasource: Callable,
|
||||
lookup_database: Callable[[Model], Model],
|
||||
lookup_datasource: Callable[[Model], Model],
|
||||
import_time: Optional[int] = None,
|
||||
) -> int:
|
||||
"""Imports the datasource from the object to the database.
|
||||
|
@ -82,7 +82,9 @@ def import_datasource(
|
|||
return datasource.id
|
||||
|
||||
|
||||
def import_simple_obj(session: Session, i_obj: Model, lookup_obj: Callable) -> Model:
|
||||
def import_simple_obj(
|
||||
session: Session, i_obj: Model, lookup_obj: Callable[[Model], Model]
|
||||
) -> Model:
|
||||
make_transient(i_obj)
|
||||
i_obj.id = None
|
||||
i_obj.table = None
|
||||
|
|
|
@ -35,7 +35,7 @@ class AbstractEventLogger(ABC):
|
|||
) -> None:
|
||||
pass
|
||||
|
||||
def log_this(self, f: Callable) -> Callable:
|
||||
def log_this(self, f: Callable[..., Any]) -> Callable[..., Any]:
|
||||
@functools.wraps(f)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
user_id = None
|
||||
|
@ -124,7 +124,7 @@ def get_event_logger_from_cfg_value(cfg_value: Any) -> AbstractEventLogger:
|
|||
)
|
||||
)
|
||||
|
||||
event_logger_type = cast(Type, cfg_value)
|
||||
event_logger_type = cast(Type[Any], cfg_value)
|
||||
result = event_logger_type()
|
||||
|
||||
# Verify that we have a valid logger impl
|
||||
|
|
|
@ -58,7 +58,7 @@ class DefaultLoggingConfigurator(LoggingConfigurator):
|
|||
|
||||
if app_config["ENABLE_TIME_ROTATE"]:
|
||||
logging.getLogger().setLevel(app_config["TIME_ROTATE_LOG_LEVEL"])
|
||||
handler = TimedRotatingFileHandler( # type: ignore
|
||||
handler = TimedRotatingFileHandler(
|
||||
app_config["FILENAME"],
|
||||
when=app_config["ROLLOVER"],
|
||||
interval=app_config["INTERVAL"],
|
||||
|
|
|
@ -72,8 +72,8 @@ WHITELIST_CUMULATIVE_FUNCTIONS = (
|
|||
)
|
||||
|
||||
|
||||
def validate_column_args(*argnames: str) -> Callable:
|
||||
def wrapper(func: Callable) -> Callable:
|
||||
def validate_column_args(*argnames: str) -> Callable[..., Any]:
|
||||
def wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
def wrapped(df: DataFrame, **options: Any) -> Any:
|
||||
columns = df.columns.tolist()
|
||||
for name in argnames:
|
||||
|
@ -471,7 +471,7 @@ def geodetic_parse(
|
|||
Parse a string containing a geodetic point and return latitude, longitude
|
||||
and altitude
|
||||
"""
|
||||
point = Point(location) # type: ignore
|
||||
point = Point(location)
|
||||
return point[0], point[1], point[2]
|
||||
|
||||
try:
|
||||
|
|
|
@ -51,7 +51,7 @@ SELENIUM_HEADSTART = 3
|
|||
WindowSize = Tuple[int, int]
|
||||
|
||||
|
||||
def get_auth_cookies(user: "User") -> List[Dict]:
|
||||
def get_auth_cookies(user: "User") -> List[Dict[Any, Any]]:
|
||||
# Login with the user specified to get the reports
|
||||
with current_app.test_request_context("/login"):
|
||||
login_user(user)
|
||||
|
@ -101,14 +101,14 @@ class AuthWebDriverProxy:
|
|||
self,
|
||||
driver_type: str,
|
||||
window: Optional[WindowSize] = None,
|
||||
auth_func: Optional[Callable] = None,
|
||||
auth_func: Optional[
|
||||
Callable[..., Any]
|
||||
] = None, # pylint: disable=bad-whitespace
|
||||
):
|
||||
self._driver_type = driver_type
|
||||
self._window: WindowSize = window or (800, 600)
|
||||
config_auth_func: Callable = current_app.config.get(
|
||||
"WEBDRIVER_AUTH_FUNC", auth_driver
|
||||
)
|
||||
self._auth_func: Callable = auth_func or config_auth_func
|
||||
config_auth_func = current_app.config.get("WEBDRIVER_AUTH_FUNC", auth_driver)
|
||||
self._auth_func = auth_func or config_auth_func
|
||||
|
||||
def create(self) -> WebDriver:
|
||||
if self._driver_type == "firefox":
|
||||
|
@ -123,7 +123,7 @@ class AuthWebDriverProxy:
|
|||
raise Exception(f"Webdriver name ({self._driver_type}) not supported")
|
||||
# Prepare args for the webdriver init
|
||||
options.add_argument("--headless")
|
||||
kwargs: Dict = dict(options=options)
|
||||
kwargs: Dict[Any, Any] = dict(options=options)
|
||||
kwargs.update(current_app.config["WEBDRIVER_CONFIGURATION"])
|
||||
logger.info("Init selenium driver")
|
||||
return driver_class(**kwargs)
|
||||
|
|
|
@ -143,7 +143,7 @@ def generate_download_headers(
|
|||
return headers
|
||||
|
||||
|
||||
def api(f: Callable) -> Callable:
|
||||
def api(f: Callable[..., FlaskResponse]) -> Callable[..., FlaskResponse]:
|
||||
"""
|
||||
A decorator to label an endpoint as an API. Catches uncaught exceptions and
|
||||
return the response in the JSON format
|
||||
|
@ -383,11 +383,11 @@ class DeleteMixin: # pylint: disable=too-few-public-methods
|
|||
:param primary_key:
|
||||
record primary key to delete
|
||||
"""
|
||||
item = self.datamodel.get(primary_key, self._base_filters) # type: ignore
|
||||
item = self.datamodel.get(primary_key, self._base_filters)
|
||||
if not item:
|
||||
abort(404)
|
||||
try:
|
||||
self.pre_delete(item) # type: ignore
|
||||
self.pre_delete(item)
|
||||
except Exception as ex: # pylint: disable=broad-except
|
||||
flash(str(ex), "danger")
|
||||
else:
|
||||
|
@ -400,8 +400,8 @@ class DeleteMixin: # pylint: disable=too-few-public-methods
|
|||
.all()
|
||||
)
|
||||
|
||||
if self.datamodel.delete(item): # type: ignore
|
||||
self.post_delete(item) # type: ignore
|
||||
if self.datamodel.delete(item):
|
||||
self.post_delete(item)
|
||||
|
||||
for pv in pvs:
|
||||
security_manager.get_session.delete(pv)
|
||||
|
@ -411,8 +411,8 @@ class DeleteMixin: # pylint: disable=too-few-public-methods
|
|||
|
||||
security_manager.get_session.commit()
|
||||
|
||||
flash(*self.datamodel.message) # type: ignore
|
||||
self.update_redirect() # type: ignore
|
||||
flash(*self.datamodel.message)
|
||||
self.update_redirect()
|
||||
|
||||
@action(
|
||||
"muldelete", __("Delete"), __("Delete all Really?"), "fa-trash", single=False
|
||||
|
|
|
@ -41,7 +41,7 @@ get_related_schema = {
|
|||
}
|
||||
|
||||
|
||||
def statsd_metrics(f: Callable) -> Callable:
|
||||
def statsd_metrics(f: Callable[..., Any]) -> Callable[..., Any]:
|
||||
"""
|
||||
Handle sending all statsd metrics from the REST API
|
||||
"""
|
||||
|
|
|
@ -88,7 +88,9 @@ class BaseOwnedSchema(BaseSupersetSchema):
|
|||
owners_field_name = "owners"
|
||||
|
||||
@post_load
|
||||
def make_object(self, data: Dict, discard: Optional[List[str]] = None) -> Model:
|
||||
def make_object(
|
||||
self, data: Dict[str, Any], discard: Optional[List[str]] = None
|
||||
) -> Model:
|
||||
discard = discard or []
|
||||
discard.append(self.owners_field_name)
|
||||
instance = super().make_object(data, discard)
|
||||
|
|
|
@ -251,7 +251,7 @@ def check_slice_perms(self: "Superset", slice_id: int) -> None:
|
|||
|
||||
def _deserialize_results_payload(
|
||||
payload: Union[bytes, str], query: Query, use_msgpack: Optional[bool] = False
|
||||
) -> Dict[Any, Any]:
|
||||
) -> Dict[str, Any]:
|
||||
logger.debug(f"Deserializing from msgpack: {use_msgpack}")
|
||||
if use_msgpack:
|
||||
with stats_timing(
|
||||
|
@ -278,7 +278,7 @@ def _deserialize_results_payload(
|
|||
with stats_timing(
|
||||
"sqllab.query.results_backend_json_deserialize", stats_logger
|
||||
):
|
||||
return json.loads(payload) # type: ignore
|
||||
return json.loads(payload)
|
||||
|
||||
|
||||
def get_cta_schema_name(
|
||||
|
@ -1343,7 +1343,7 @@ class Superset(BaseSupersetView):
|
|||
|
||||
if "timed_refresh_immune_slices" not in md:
|
||||
md["timed_refresh_immune_slices"] = []
|
||||
new_filter_scopes: Dict[str, Dict] = {}
|
||||
new_filter_scopes = {}
|
||||
if "filter_scopes" in data:
|
||||
# replace filter_id and immune ids from old slice id to new slice id:
|
||||
# and remove slice ids that are not in dash anymore
|
||||
|
@ -2137,7 +2137,7 @@ class Superset(BaseSupersetView):
|
|||
f"deprecated.{self.__class__.__name__}.select_star.database_not_found"
|
||||
)
|
||||
return json_error_response("Not found", 404)
|
||||
schema = utils.parse_js_uri_path_item(schema, eval_undefined=True) # type: ignore
|
||||
schema = utils.parse_js_uri_path_item(schema, eval_undefined=True)
|
||||
table_name = utils.parse_js_uri_path_item(table_name) # type: ignore
|
||||
# Check that the user can access the datasource
|
||||
if not self.appbuilder.sm.can_access_datasource(
|
||||
|
@ -2245,7 +2245,7 @@ class Superset(BaseSupersetView):
|
|||
)
|
||||
|
||||
payload = utils.zlib_decompress(blob, decode=not results_backend_use_msgpack)
|
||||
obj: dict = _deserialize_results_payload(
|
||||
obj = _deserialize_results_payload(
|
||||
payload, query, cast(bool, results_backend_use_msgpack)
|
||||
)
|
||||
|
||||
|
@ -2474,9 +2474,7 @@ class Superset(BaseSupersetView):
|
|||
schema: str = cast(str, query_params.get("schema"))
|
||||
sql: str = cast(str, query_params.get("sql"))
|
||||
try:
|
||||
template_params: dict = json.loads(
|
||||
query_params.get("templateParams") or "{}"
|
||||
)
|
||||
template_params = json.loads(query_params.get("templateParams") or "{}")
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Invalid template parameter {query_params.get('templateParams')}"
|
||||
|
|
|
@ -61,7 +61,7 @@ def get_col_type(col: Dict[Any, Any]) -> str:
|
|||
|
||||
def get_table_metadata(
|
||||
database: Database, table_name: str, schema_name: Optional[str]
|
||||
) -> Dict:
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get table metadata information, including type, pk, fks.
|
||||
This function raises SQLAlchemyError when a schema is not found.
|
||||
|
@ -72,7 +72,7 @@ def get_table_metadata(
|
|||
:param schema_name: schema name
|
||||
:return: Dict table metadata ready for API response
|
||||
"""
|
||||
keys: List = []
|
||||
keys = []
|
||||
columns = database.get_columns(table_name, schema_name)
|
||||
primary_key = database.get_pk_constraint(table_name, schema_name)
|
||||
if primary_key and primary_key.get("constrained_columns"):
|
||||
|
@ -82,7 +82,7 @@ def get_table_metadata(
|
|||
foreign_keys = get_foreign_keys_metadata(database, table_name, schema_name)
|
||||
indexes = get_indexes_metadata(database, table_name, schema_name)
|
||||
keys += foreign_keys + indexes
|
||||
payload_columns: List[Dict] = []
|
||||
payload_columns: List[Dict[str, Any]] = []
|
||||
for col in columns:
|
||||
dtype = get_col_type(col)
|
||||
payload_columns.append(
|
||||
|
@ -90,7 +90,7 @@ def get_table_metadata(
|
|||
"name": col["name"],
|
||||
"type": dtype.split("(")[0] if "(" in dtype else dtype,
|
||||
"longType": dtype,
|
||||
"keys": [k for k in keys if col["name"] in k.get("column_names")],
|
||||
"keys": [k for k in keys if col["name"] in k["column_names"]],
|
||||
}
|
||||
)
|
||||
return {
|
||||
|
@ -270,7 +270,7 @@ class DatabaseRestApi(DatabaseMixin, BaseSupersetModelRestApi):
|
|||
"""
|
||||
self.incr_stats("init", self.table_metadata.__name__)
|
||||
try:
|
||||
table_info: Dict = get_table_metadata(database, table_name, schema_name)
|
||||
table_info = get_table_metadata(database, table_name, schema_name)
|
||||
except SQLAlchemyError as ex:
|
||||
self.incr_stats("error", self.table_metadata.__name__)
|
||||
return self.response_422(error_msg_from_exception(ex))
|
||||
|
|
|
@ -29,7 +29,7 @@ from superset.views.base_api import BaseSupersetModelRestApi
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_datasource_access(f: Callable) -> Callable:
|
||||
def check_datasource_access(f: Callable[..., Any]) -> Callable[..., Any]:
|
||||
"""
|
||||
A Decorator that checks if a user has datasource access
|
||||
"""
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import enum
|
||||
from typing import Type
|
||||
from typing import Type, Union
|
||||
|
||||
import simplejson as json
|
||||
from croniter import croniter
|
||||
|
@ -55,7 +55,7 @@ class EmailScheduleView(
|
|||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def schedule_type_model(self) -> Type:
|
||||
def schedule_type_model(self) -> Type[Union[Dashboard, Slice]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
page_size = 20
|
||||
|
@ -154,9 +154,7 @@ class EmailScheduleView(
|
|||
info[col] = info[col].username
|
||||
|
||||
info["user"] = schedule.user.username
|
||||
info[self.schedule_type] = getattr( # type: ignore
|
||||
schedule, self.schedule_type
|
||||
).id
|
||||
info[self.schedule_type] = getattr(schedule, self.schedule_type).id
|
||||
schedules.append(info)
|
||||
|
||||
return json_success(json.dumps(schedules, default=json_iso_dttm_ser))
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import Callable
|
||||
from typing import Any
|
||||
|
||||
import simplejson as json
|
||||
from flask import g, redirect, request, Response
|
||||
|
@ -40,7 +40,7 @@ from .base import (
|
|||
|
||||
|
||||
class QueryFilter(BaseFilter): # pylint: disable=too-few-public-methods
|
||||
def apply(self, query: BaseQuery, value: Callable) -> BaseQuery:
|
||||
def apply(self, query: BaseQuery, value: Any) -> BaseQuery:
|
||||
"""
|
||||
Filter queries to only those owned by current user. If
|
||||
can_access_all_queries permission is set a user can list all queries
|
||||
|
|
|
@ -35,7 +35,7 @@ from superset.utils.core import QueryStatus, TimeRangeEndpoint
|
|||
from superset.viz import BaseViz
|
||||
|
||||
if is_feature_enabled("SIP_38_VIZ_REARCHITECTURE"):
|
||||
from superset import viz_sip38 as viz # type: ignore
|
||||
from superset import viz_sip38 as viz
|
||||
else:
|
||||
from superset import viz # type: ignore
|
||||
|
||||
|
@ -318,9 +318,9 @@ def get_dashboard_extra_filters(
|
|||
|
||||
|
||||
def build_extra_filters(
|
||||
layout: Dict,
|
||||
filter_scopes: Dict,
|
||||
default_filters: Dict[str, Dict[str, List]],
|
||||
layout: Dict[str, Dict[str, Any]],
|
||||
filter_scopes: Dict[str, Dict[str, Any]],
|
||||
default_filters: Dict[str, Dict[str, List[Any]]],
|
||||
slice_id: int,
|
||||
) -> List[Dict[str, Any]]:
|
||||
extra_filters = []
|
||||
|
@ -343,7 +343,9 @@ def build_extra_filters(
|
|||
return extra_filters
|
||||
|
||||
|
||||
def is_slice_in_container(layout: Dict, container_id: str, slice_id: int) -> bool:
|
||||
def is_slice_in_container(
|
||||
layout: Dict[str, Dict[str, Any]], container_id: str, slice_id: int
|
||||
) -> bool:
|
||||
if container_id == "ROOT_ID":
|
||||
return True
|
||||
|
||||
|
|
|
@ -2720,7 +2720,7 @@ class PairedTTestViz(BaseViz):
|
|||
else:
|
||||
cols.append(col)
|
||||
df.columns = cols
|
||||
data: Dict = {}
|
||||
data: Dict[str, List[Dict[str, Any]]] = {}
|
||||
series = df.to_dict("series")
|
||||
for nameSet in df.columns:
|
||||
# If no groups are defined, nameSet will be the metric name
|
||||
|
@ -2750,7 +2750,7 @@ class RoseViz(NVD3TimeSeriesViz):
|
|||
return None
|
||||
|
||||
data = super().get_data(df)
|
||||
result: Dict = {}
|
||||
result: Dict[str, List[Dict[str, str]]] = {}
|
||||
for datum in data: # type: ignore
|
||||
key = datum["key"]
|
||||
for val in datum["values"]:
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
"""Unit tests for Superset"""
|
||||
import imp
|
||||
import json
|
||||
from typing import Dict, Union, List
|
||||
from typing import Any, Dict, Union, List
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pandas as pd
|
||||
|
@ -397,7 +397,9 @@ class SupersetTestCase(TestCase):
|
|||
mock_method.assert_called_once_with("error", func_name)
|
||||
return rv
|
||||
|
||||
def post_assert_metric(self, uri: str, data: Dict, func_name: str) -> Response:
|
||||
def post_assert_metric(
|
||||
self, uri: str, data: Dict[str, Any], func_name: str
|
||||
) -> Response:
|
||||
"""
|
||||
Simple client post with an extra assertion for statsd metrics
|
||||
|
||||
|
@ -417,7 +419,9 @@ class SupersetTestCase(TestCase):
|
|||
mock_method.assert_called_once_with("error", func_name)
|
||||
return rv
|
||||
|
||||
def put_assert_metric(self, uri: str, data: Dict, func_name: str) -> Response:
|
||||
def put_assert_metric(
|
||||
self, uri: str, data: Dict[str, Any], func_name: str
|
||||
) -> Response:
|
||||
"""
|
||||
Simple client put with an extra assertion for statsd metrics
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ from copy import copy
|
|||
from cachelib.redis import RedisCache
|
||||
from flask import Flask
|
||||
|
||||
from superset.config import * # type: ignore
|
||||
from superset.config import *
|
||||
|
||||
AUTH_USER_REGISTRATION_ROLE = "alpha"
|
||||
SQLALCHEMY_DATABASE_URI = "sqlite:///" + os.path.join(DATA_DIR, "unittests.db")
|
||||
|
|
Loading…
Reference in New Issue