style(mypy): Spit-and-polish pass (#10001)

Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
John Bodley 2020-06-07 08:53:46 -07:00 committed by GitHub
parent 656cdfb867
commit 91517a56a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
56 changed files with 243 additions and 207 deletions

View File

@ -50,10 +50,12 @@ multi_line_output = 3
order_by_type = false
[mypy]
disallow_any_generics = true
ignore_missing_imports = true
no_implicit_optional = true
warn_unused_ignores = true
[mypy-superset,superset.app,superset.bin.*,superset.charts.*,superset.cli,superset.commands.*,superset.common.*,superset.config,superset.connectors.*,superset.constants,superset.dataframe,superset.dao.*,superset.dashboards.*,superset.datasets.*,superset.db_engine_specs.*,superset.db_engines.*,superset.errors,superset.examples.*,superset.exceptions,superset.extensions,superset.forms,superset.jinja_context,superset.legacy,superset.migrations.*,superset.models.*,superset.result_set,superset.queries.*,superset.security.*,superset.sql_lab,superset.sql_parse,superset.sql_validators.*,superset.stats_logger,superset.tasks.*,superset.translations.*,superset.typing,superset.utils.*,,superset.views.*,superset.viz,superset.viz_sip38]
[mypy-superset.*]
check_untyped_defs = true
disallow_untyped_calls = true
disallow_untyped_defs = true

View File

@ -80,7 +80,7 @@ class SupersetAppInitializer:
self.flask_app = app
self.config = app.config
self.manifest: dict = {}
self.manifest: Dict[Any, Any] = {}
def pre_init(self) -> None:
"""
@ -542,7 +542,7 @@ class SupersetAppInitializer:
self.app = app
def __call__(
self, environ: Dict[str, Any], start_response: Callable
self, environ: Dict[str, Any], start_response: Callable[..., Any]
) -> Any:
# Setting wsgi.input_terminated tells werkzeug.wsgi to ignore
# content-length and read the stream till the end.

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional
from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
@ -36,7 +36,7 @@ logger = logging.getLogger(__name__)
class CreateChartCommand(BaseCommand):
def __init__(self, user: User, data: Dict):
def __init__(self, user: User, data: Dict[str, Any]):
self._actor = user
self._properties = data.copy()

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional
from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
@ -42,7 +42,7 @@ logger = logging.getLogger(__name__)
class UpdateChartCommand(BaseCommand):
def __init__(self, user: User, model_id: int, data: Dict):
def __init__(self, user: User, model_id: int, data: Dict[str, Any]):
self._actor = user
self._model_id = model_id
self._properties = data.copy()

View File

@ -26,6 +26,7 @@ from pandas import DataFrame
from superset import app, is_feature_enabled
from superset.exceptions import QueryObjectValidationError
from superset.typing import Metric
from superset.utils import core as utils, pandas_postprocessing
from superset.views.utils import get_time_range_endpoints
@ -67,11 +68,11 @@ class QueryObject:
row_limit: int
filter: List[Dict[str, Any]]
timeseries_limit: int
timeseries_limit_metric: Optional[Dict]
timeseries_limit_metric: Optional[Metric]
order_desc: bool
extras: Dict
extras: Dict[str, Any]
columns: List[str]
orderby: List[List]
orderby: List[List[str]]
post_processing: List[Dict[str, Any]]
def __init__(
@ -85,11 +86,11 @@ class QueryObject:
is_timeseries: bool = False,
timeseries_limit: int = 0,
row_limit: int = app.config["ROW_LIMIT"],
timeseries_limit_metric: Optional[Dict] = None,
timeseries_limit_metric: Optional[Metric] = None,
order_desc: bool = True,
extras: Optional[Dict] = None,
extras: Optional[Dict[str, Any]] = None,
columns: Optional[List[str]] = None,
orderby: Optional[List[List]] = None,
orderby: Optional[List[List[str]]] = None,
post_processing: Optional[List[Dict[str, Any]]] = None,
**kwargs: Any,
):

View File

@ -33,6 +33,7 @@ from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING
from cachelib.base import BaseCache
from celery.schedules import crontab
from dateutil import tz
from flask import Blueprint
from flask_appbuilder.security.manager import AUTH_DB
from superset.jinja_context import ( # pylint: disable=unused-import
@ -421,7 +422,7 @@ DEFAULT_MODULE_DS_MAP = OrderedDict(
]
)
ADDITIONAL_MODULE_DS_MAP: Dict[str, List[str]] = {}
ADDITIONAL_MIDDLEWARE: List[Callable] = []
ADDITIONAL_MIDDLEWARE: List[Callable[..., Any]] = []
# 1) https://docs.python-guide.org/writing/logging/
# 2) https://docs.python.org/2/library/logging.config.html
@ -624,7 +625,7 @@ ALLOWED_USER_CSV_SCHEMA_FUNC: Callable[
# SQL Lab. The existing context gets updated with this dictionary,
# meaning values for existing keys get overwritten by the content of this
# dictionary.
JINJA_CONTEXT_ADDONS: Dict[str, Callable] = {}
JINJA_CONTEXT_ADDONS: Dict[str, Callable[..., Any]] = {}
# A dictionary of macro template processors that gets merged into global
# template processors. The existing template processors get updated with this
@ -684,7 +685,7 @@ PERMISSION_INSTRUCTIONS_LINK = ""
# Integrate external Blueprints to the app by passing them to your
# configuration. These blueprints will get integrated in the app
BLUEPRINTS: List[Callable] = []
BLUEPRINTS: List[Blueprint] = []
# Provide a callable that receives a tracking_url and returns another
# URL. This is used to translate internal Hadoop job tracker URL

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import json
from typing import Any, Dict, Hashable, List, Optional, Type
from typing import Any, Dict, Hashable, List, Optional, Type, Union
from flask_appbuilder.security.sqla.models import User
from sqlalchemy import and_, Boolean, Column, Integer, String, Text
@ -64,12 +64,12 @@ class BaseDatasource(
baselink: Optional[str] = None # url portion pointing to ModelView endpoint
@property
def column_class(self) -> Type:
def column_class(self) -> Type["BaseColumn"]:
# link to derivative of BaseColumn
raise NotImplementedError()
@property
def metric_class(self) -> Type:
def metric_class(self) -> Type["BaseMetric"]:
# link to derivative of BaseMetric
raise NotImplementedError()
@ -368,7 +368,7 @@ class BaseDatasource(
"""
raise NotImplementedError()
def values_for_column(self, column_name: str, limit: int = 10000) -> List:
def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
"""Given a column, returns an iterable of distinct values
This is used to populate the dropdown showing a list of
@ -389,7 +389,10 @@ class BaseDatasource(
@staticmethod
def get_fk_many_from_list(
object_list: List[Any], fkmany: List[Column], fkmany_class: Type, key_attr: str,
object_list: List[Any],
fkmany: List[Column],
fkmany_class: Type[Union["BaseColumn", "BaseMetric"]],
key_attr: str,
) -> List[Column]: # pylint: disable=too-many-locals
"""Update ORM one-to-many list from object list

View File

@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from collections import OrderedDict
from typing import Dict, List, Optional, Set, Type, TYPE_CHECKING
from sqlalchemy import or_
@ -22,6 +21,8 @@ from sqlalchemy.orm import Session, subqueryload
if TYPE_CHECKING:
# pylint: disable=unused-import
from collections import OrderedDict
from superset.models.core import Database
from superset.connectors.base.models import BaseDatasource
@ -32,7 +33,7 @@ class ConnectorRegistry:
sources: Dict[str, Type["BaseDatasource"]] = {}
@classmethod
def register_sources(cls, datasource_config: OrderedDict) -> None:
def register_sources(cls, datasource_config: "OrderedDict[str, List[str]]") -> None:
for module_name, class_names in datasource_config.items():
class_names = [str(s) for s in class_names]
module_obj = __import__(module_name, fromlist=class_names)

View File

@ -24,18 +24,7 @@ from copy import deepcopy
from datetime import datetime, timedelta
from distutils.version import LooseVersion
from multiprocessing.pool import ThreadPool
from typing import (
Any,
Callable,
cast,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
Union,
)
from typing import Any, cast, Dict, Iterable, List, Optional, Set, Tuple, Union
import pandas as pd
import sqlalchemy as sa
@ -173,7 +162,7 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin):
return self.__repr__()
@property
def data(self) -> Dict:
def data(self) -> Dict[str, Any]:
return {"id": self.id, "name": self.cluster_name, "backend": "druid"}
@staticmethod
@ -354,7 +343,7 @@ class DruidColumn(Model, BaseColumn):
return self.dimension_spec_json
@property
def dimension_spec(self) -> Optional[Dict]:
def dimension_spec(self) -> Optional[Dict[str, Any]]:
if self.dimension_spec_json:
return json.loads(self.dimension_spec_json)
return None
@ -438,7 +427,7 @@ class DruidMetric(Model, BaseMetric):
return self.json
@property
def json_obj(self) -> Dict:
def json_obj(self) -> Dict[str, Any]:
try:
obj = json.loads(self.json)
except Exception:
@ -614,7 +603,7 @@ class DruidDatasource(Model, BaseDatasource):
name = escape(self.datasource_name)
return Markup(f'<a href="{url}">{name}</a>')
def get_metric_obj(self, metric_name: str) -> Dict:
def get_metric_obj(self, metric_name: str) -> Dict[str, Any]:
return [m.json_obj for m in self.metrics if m.metric_name == metric_name][0]
@classmethod
@ -705,7 +694,11 @@ class DruidDatasource(Model, BaseDatasource):
@classmethod
def sync_to_db_from_config(
cls, druid_config: Dict, user: User, cluster: DruidCluster, refresh: bool = True
cls,
druid_config: Dict[str, Any],
user: User,
cluster: DruidCluster,
refresh: bool = True,
) -> None:
"""Merges the ds config from druid_config into one stored in the db."""
session = db.session
@ -901,7 +894,7 @@ class DruidDatasource(Model, BaseDatasource):
return postagg_metrics
@staticmethod
def recursive_get_fields(_conf: Dict) -> List[str]:
def recursive_get_fields(_conf: Dict[str, Any]) -> List[str]:
_type = _conf.get("type")
_field = _conf.get("field")
_fields = _conf.get("fields")
@ -957,8 +950,8 @@ class DruidDatasource(Model, BaseDatasource):
@staticmethod
def metrics_and_post_aggs(
metrics: List[Union[Dict, str]], metrics_dict: Dict[str, DruidMetric],
) -> Tuple[OrderedDict, OrderedDict]:
metrics: List[Metric], metrics_dict: Dict[str, DruidMetric],
) -> Tuple["OrderedDict[str, Any]", "OrderedDict[str, Any]"]:
# Separate metrics into those that are aggregations
# and those that are post aggregations
saved_agg_names = set()
@ -987,7 +980,7 @@ class DruidDatasource(Model, BaseDatasource):
)
return aggs, post_aggs
def values_for_column(self, column_name: str, limit: int = 10000) -> List:
def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
"""Retrieve some values for the given column"""
logger.info(
"Getting values for columns [{}] limited to [{}]".format(column_name, limit)
@ -1079,8 +1072,10 @@ class DruidDatasource(Model, BaseDatasource):
@staticmethod
def get_aggregations(
metrics_dict: Dict, saved_metrics: Set[str], adhoc_metrics: List[Dict] = []
) -> OrderedDict:
metrics_dict: Dict[str, Any],
saved_metrics: Set[str],
adhoc_metrics: Optional[List[Dict[str, Any]]] = None,
) -> "OrderedDict[str, Any]":
"""
Returns a dictionary of aggregation metric names to aggregation json objects
@ -1089,7 +1084,9 @@ class DruidDatasource(Model, BaseDatasource):
:param adhoc_metrics: list of adhoc metric names
:raise SupersetException: if one or more metric names are not aggregations
"""
aggregations: OrderedDict = OrderedDict()
if not adhoc_metrics:
adhoc_metrics = []
aggregations = OrderedDict()
invalid_metric_names = []
for metric_name in saved_metrics:
if metric_name in metrics_dict:
@ -1115,7 +1112,7 @@ class DruidDatasource(Model, BaseDatasource):
def get_dimensions(
self, columns: List[str], columns_dict: Dict[str, DruidColumn]
) -> List[Union[str, Dict]]:
) -> List[Union[str, Dict[str, Any]]]:
dimensions = []
columns = [col for col in columns if col in columns_dict]
for column_name in columns:
@ -1433,7 +1430,7 @@ class DruidDatasource(Model, BaseDatasource):
df[columns] = df[columns].fillna(NULL_STRING).astype("unicode")
return df
def query(self, query_obj: Dict) -> QueryResult:
def query(self, query_obj: QueryObjectDict) -> QueryResult:
qry_start_dttm = datetime.now()
client = self.cluster.get_pydruid_client()
query_str = self.get_query_str(client=client, query_obj=query_obj, phase=2)
@ -1583,7 +1580,7 @@ class DruidDatasource(Model, BaseDatasource):
dimension=col, value=eq, extraction_function=extraction_fn
)
elif is_list_target:
eq = cast(list, eq)
eq = cast(List[Any], eq)
fields = []
# ignore the filter if it has no value
if not len(eq):

View File

@ -597,7 +597,7 @@ class SqlaTable(Model, BaseDatasource):
)
@property
def data(self) -> Dict:
def data(self) -> Dict[str, Any]:
d = super().data
if self.type == "table":
grains = self.database.grains() or []
@ -684,7 +684,9 @@ class SqlaTable(Model, BaseDatasource):
return TextAsFrom(sa.text(from_sql), []).alias("expr_qry")
return self.get_sqla_table()
def adhoc_metric_to_sqla(self, metric: Dict, cols: Dict) -> Optional[Column]:
def adhoc_metric_to_sqla(
self, metric: Dict[str, Any], cols: Dict[str, Any]
) -> Optional[Column]:
"""
Turn an adhoc metric into a sqlalchemy column.
@ -804,7 +806,7 @@ class SqlaTable(Model, BaseDatasource):
main_metric_expr = self.make_sqla_column_compatible(main_metric_expr, label)
select_exprs: List[Column] = []
groupby_exprs_sans_timestamp: OrderedDict = OrderedDict()
groupby_exprs_sans_timestamp = OrderedDict()
if (is_sip_38 and metrics and columns) or (not is_sip_38 and groupby):
# dedup columns while preserving order
@ -874,7 +876,7 @@ class SqlaTable(Model, BaseDatasource):
qry = qry.group_by(*groupby_exprs_with_timestamp.values())
where_clause_and = []
having_clause_and: List = []
having_clause_and = []
for flt in filter: # type: ignore
if not all([flt.get(s) for s in ["col", "op"]]):
@ -1082,7 +1084,10 @@ class SqlaTable(Model, BaseDatasource):
return ob
def _get_top_groups(
self, df: pd.DataFrame, dimensions: List, groupby_exprs: OrderedDict
self,
df: pd.DataFrame,
dimensions: List[str],
groupby_exprs: "OrderedDict[str, Any]",
) -> ColumnElement:
groups = []
for unused, row in df.iterrows():

View File

@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional
from flask_appbuilder.models.filters import BaseFilter
from flask_appbuilder.models.sqla import Model
@ -75,7 +75,7 @@ class BaseDAO:
return query.all()
@classmethod
def create(cls, properties: Dict, commit: bool = True) -> Model:
def create(cls, properties: Dict[str, Any], commit: bool = True) -> Model:
"""
Generic for creating models
:raises: DAOCreateFailedError
@ -95,7 +95,9 @@ class BaseDAO:
return model
@classmethod
def update(cls, model: Model, properties: Dict, commit: bool = True) -> Model:
def update(
cls, model: Model, properties: Dict[str, Any], commit: bool = True
) -> Model:
"""
Generic update a model
:raises: DAOCreateFailedError

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional
from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
@ -35,7 +35,7 @@ logger = logging.getLogger(__name__)
class CreateDashboardCommand(BaseCommand):
def __init__(self, user: User, data: Dict):
def __init__(self, user: User, data: Dict[str, Any]):
self._actor = user
self._properties = data.copy()

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional
from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
@ -40,7 +40,7 @@ logger = logging.getLogger(__name__)
class UpdateDashboardCommand(BaseCommand):
def __init__(self, user: User, model_id: int, data: Dict):
def __init__(self, user: User, model_id: int, data: Dict[str, Any]):
self._actor = user
self._model_id = model_id
self._properties = data.copy()

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional
from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
@ -39,7 +39,7 @@ logger = logging.getLogger(__name__)
class CreateDatasetCommand(BaseCommand):
def __init__(self, user: User, data: Dict):
def __init__(self, user: User, data: Dict[str, Any]):
self._actor = user
self._properties = data.copy()

View File

@ -16,7 +16,7 @@
# under the License.
import logging
from collections import Counter
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional
from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
@ -48,7 +48,7 @@ logger = logging.getLogger(__name__)
class UpdateDatasetCommand(BaseCommand):
def __init__(self, user: User, model_id: int, data: Dict):
def __init__(self, user: User, model_id: int, data: Dict[str, Any]):
self._actor = user
self._model_id = model_id
self._properties = data.copy()
@ -111,7 +111,7 @@ class UpdateDatasetCommand(BaseCommand):
raise exception
def _validate_columns(
self, columns: List[Dict], exceptions: List[ValidationError]
self, columns: List[Dict[str, Any]], exceptions: List[ValidationError]
) -> None:
# Validate duplicates on data
if self._get_duplicates(columns, "column_name"):
@ -133,7 +133,7 @@ class UpdateDatasetCommand(BaseCommand):
exceptions.append(DatasetColumnsExistsValidationError())
def _validate_metrics(
self, metrics: List[Dict], exceptions: List[ValidationError]
self, metrics: List[Dict[str, Any]], exceptions: List[ValidationError]
) -> None:
if self._get_duplicates(metrics, "metric_name"):
exceptions.append(DatasetMetricsDuplicateValidationError())
@ -152,7 +152,7 @@ class UpdateDatasetCommand(BaseCommand):
exceptions.append(DatasetMetricsExistsValidationError())
@staticmethod
def _get_duplicates(data: List[Dict], key: str) -> List[str]:
def _get_duplicates(data: List[Dict[str, Any]], key: str) -> List[str]:
duplicates = [
name
for name, count in Counter([item[key] for item in data]).items()

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional
from flask import current_app
from sqlalchemy.exc import SQLAlchemyError
@ -116,7 +116,7 @@ class DatasetDAO(BaseDAO):
@classmethod
def update(
cls, model: SqlaTable, properties: Dict, commit: bool = True
cls, model: SqlaTable, properties: Dict[str, Any], commit: bool = True
) -> Optional[SqlaTable]:
"""
Updates a Dataset model on the metadata DB
@ -151,13 +151,13 @@ class DatasetDAO(BaseDAO):
@classmethod
def update_column(
cls, model: TableColumn, properties: Dict, commit: bool = True
cls, model: TableColumn, properties: Dict[str, Any], commit: bool = True
) -> Optional[TableColumn]:
return DatasetColumnDAO.update(model, properties, commit=commit)
@classmethod
def create_column(
cls, properties: Dict, commit: bool = True
cls, properties: Dict[str, Any], commit: bool = True
) -> Optional[TableColumn]:
"""
Creates a Dataset model on the metadata DB
@ -166,13 +166,13 @@ class DatasetDAO(BaseDAO):
@classmethod
def update_metric(
cls, model: SqlMetric, properties: Dict, commit: bool = True
cls, model: SqlMetric, properties: Dict[str, Any], commit: bool = True
) -> Optional[SqlMetric]:
return DatasetMetricDAO.update(model, properties, commit=commit)
@classmethod
def create_metric(
cls, properties: Dict, commit: bool = True
cls, properties: Dict[str, Any], commit: bool = True
) -> Optional[SqlMetric]:
"""
Creates a Dataset model on the metadata DB

View File

@ -151,7 +151,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
try_remove_schema_from_table_name = True # pylint: disable=invalid-name
# default matching patterns for identifying column types
db_column_types: Dict[utils.DbColumnType, Tuple[Pattern, ...]] = {
db_column_types: Dict[utils.DbColumnType, Tuple[Pattern[Any], ...]] = {
utils.DbColumnType.NUMERIC: (
re.compile(r".*DOUBLE.*", re.IGNORECASE),
re.compile(r".*FLOAT.*", re.IGNORECASE),
@ -296,7 +296,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return select_exprs
@classmethod
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]:
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]:
"""
:param cursor: Cursor instance
@ -311,8 +311,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def expand_data(
cls, columns: List[dict], data: List[dict]
) -> Tuple[List[dict], List[dict], List[dict]]:
cls, columns: List[Dict[Any, Any]], data: List[Dict[Any, Any]]
) -> Tuple[List[Dict[Any, Any]], List[Dict[Any, Any]], List[Dict[Any, Any]]]:
"""
Some engines support expanding nested fields. See implementation in Presto
spec for details.
@ -645,7 +645,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
schema: Optional[str],
database: "Database",
query: Select,
columns: Optional[List] = None,
columns: Optional[List[Dict[str, str]]] = None,
) -> Optional[Select]:
"""
Add a where clause to a query to reference only the most recent partition
@ -925,7 +925,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return []
@staticmethod
def pyodbc_rows_to_tuples(data: List[Any]) -> List[Tuple]:
def pyodbc_rows_to_tuples(data: List[Any]) -> List[Tuple[Any, ...]]:
"""
Convert pyodbc.Row objects from `fetch_data` to tuples.

View File

@ -83,7 +83,7 @@ class BigQueryEngineSpec(BaseEngineSpec):
return None
@classmethod
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]:
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]:
data = super().fetch_data(cursor, limit)
# Support type BigQuery Row, introduced here PR #4071
# google.cloud.bigquery.table.Row

View File

@ -39,7 +39,7 @@ class ExasolEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
}
@classmethod
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]:
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]:
data = super().fetch_data(cursor, limit)
# Lists of `pyodbc.Row` need to be unpacked further
return cls.pyodbc_rows_to_tuples(data)

View File

@ -93,7 +93,7 @@ class HiveEngineSpec(PrestoEngineSpec):
return BaseEngineSpec.get_all_datasource_names(database, datasource_type)
@classmethod
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]:
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]:
import pyhive
from TCLIService import ttypes
@ -304,7 +304,7 @@ class HiveEngineSpec(PrestoEngineSpec):
schema: Optional[str],
database: "Database",
query: Select,
columns: Optional[List] = None,
columns: Optional[List[Dict[str, str]]] = None,
) -> Optional[Select]:
try:
col_names, values = cls.latest_partition(
@ -323,7 +323,7 @@ class HiveEngineSpec(PrestoEngineSpec):
return None
@classmethod
def _get_fields(cls, cols: List[dict]) -> List[ColumnClause]:
def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[ColumnClause]:
return BaseEngineSpec._get_fields(cols) # pylint: disable=protected-access
@classmethod

View File

@ -66,7 +66,7 @@ class MssqlEngineSpec(BaseEngineSpec):
return None
@classmethod
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]:
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]:
data = super().fetch_data(cursor, limit)
# Lists of `pyodbc.Row` need to be unpacked further
return cls.pyodbc_rows_to_tuples(data)

View File

@ -51,7 +51,7 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
}
@classmethod
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]:
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]:
cursor.tzinfo_factory = FixedOffsetTimezone
if not cursor.description:
return []

View File

@ -164,7 +164,7 @@ class PrestoEngineSpec(BaseEngineSpec):
return [row[0] for row in results]
@classmethod
def _create_column_info(cls, name: str, data_type: str) -> dict:
def _create_column_info(cls, name: str, data_type: str) -> Dict[str, Any]:
"""
Create column info object
:param name: column name
@ -213,7 +213,10 @@ class PrestoEngineSpec(BaseEngineSpec):
@classmethod
def _parse_structural_column( # pylint: disable=too-many-locals,too-many-branches
cls, parent_column_name: str, parent_data_type: str, result: List[dict]
cls,
parent_column_name: str,
parent_data_type: str,
result: List[Dict[str, Any]],
) -> None:
"""
Parse a row or array column
@ -322,7 +325,7 @@ class PrestoEngineSpec(BaseEngineSpec):
(i.e. column name and data type)
"""
columns = cls._show_columns(inspector, table_name, schema)
result: List[dict] = []
result: List[Dict[str, Any]] = []
for column in columns:
try:
# parse column if it is a row or array
@ -361,7 +364,7 @@ class PrestoEngineSpec(BaseEngineSpec):
return column_name.startswith('"') and column_name.endswith('"')
@classmethod
def _get_fields(cls, cols: List[dict]) -> List[ColumnClause]:
def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[ColumnClause]:
"""
Format column clauses where names are in quotes and labels are specified
:param cols: columns
@ -561,8 +564,8 @@ class PrestoEngineSpec(BaseEngineSpec):
@classmethod
def expand_data( # pylint: disable=too-many-locals
cls, columns: List[dict], data: List[dict]
) -> Tuple[List[dict], List[dict], List[dict]]:
cls, columns: List[Dict[Any, Any]], data: List[Dict[Any, Any]]
) -> Tuple[List[Dict[Any, Any]], List[Dict[Any, Any]], List[Dict[Any, Any]]]:
"""
We do not immediately display rows and arrays clearly in the data grid. This
method separates out nested fields and data values to help clearly display
@ -590,7 +593,7 @@ class PrestoEngineSpec(BaseEngineSpec):
# process each column, unnesting ARRAY types and
# expanding ROW types into new columns
to_process = deque((column, 0) for column in columns)
all_columns: List[dict] = []
all_columns: List[Dict[str, Any]] = []
expanded_columns = []
current_array_level = None
while to_process:
@ -843,7 +846,7 @@ class PrestoEngineSpec(BaseEngineSpec):
schema: Optional[str],
database: "Database",
query: Select,
columns: Optional[List] = None,
columns: Optional[List[Dict[str, str]]] = None,
) -> Optional[Select]:
try:
col_names, values = cls.latest_partition(

View File

@ -95,7 +95,9 @@ class UIManifestProcessor:
self.parse_manifest_json()
@app.context_processor
def get_manifest() -> Dict[str, Callable]: # pylint: disable=unused-variable
def get_manifest() -> Dict[ # pylint: disable=unused-variable
str, Callable[[str], List[str]]
]:
loaded_chunks = set()
def get_files(bundle: str, asset_type: str = "js") -> List[str]:
@ -131,7 +133,7 @@ appbuilder = AppBuilder(update_perms=False)
cache_manager = CacheManager()
celery_app = celery.Celery()
db = SQLA()
_event_logger: dict = {}
_event_logger: Dict[str, Any] = {}
event_logger = LocalProxy(lambda: _event_logger.get("event_logger"))
feature_flag_manager = FeatureFlagManager()
jinja_context_manager = JinjaContextManager()

View File

@ -341,11 +341,14 @@ class Database(
def get_reserved_words(self) -> Set[str]:
return self.get_dialect().preparer.reserved_words
def get_quoter(self) -> Callable:
def get_quoter(self) -> Callable[[str, Any], str]:
return self.get_dialect().identifier_preparer.quote
def get_df( # pylint: disable=too-many-locals
self, sql: str, schema: Optional[str] = None, mutator: Optional[Callable] = None
self,
sql: str,
schema: Optional[str] = None,
mutator: Optional[Callable[[pd.DataFrame], None]] = None,
) -> pd.DataFrame:
sqls = [str(s).strip(" ;") for s in sqlparse.parse(sql)]
@ -450,7 +453,7 @@ class Database(
@cache_util.memoized_func(
key=lambda *args, **kwargs: "db:{}:schema:None:view_list",
attribute_in_key="id", # type: ignore
attribute_in_key="id",
)
def get_all_view_names_in_database(
self,

View File

@ -240,7 +240,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
self.json_metadata = value
@property
def position(self) -> Dict:
def position(self) -> Dict[str, Any]:
if self.position_json:
return json.loads(self.position_json)
return {}
@ -315,7 +315,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
old_to_new_slc_id_dict: Dict[int, int] = {}
new_timed_refresh_immune_slices = []
new_expanded_slices = {}
new_filter_scopes: Dict[str, Dict] = {}
new_filter_scopes = {}
i_params_dict = dashboard_to_import.params_dict
remote_id_slice_map = {
slc.params_dict["remote_id"]: slc
@ -351,7 +351,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
# are converted to filter_scopes
# but dashboard create from import may still have old dashboard filter metadata
# here we convert them to new filter_scopes metadata first
filter_scopes: Dict = {}
filter_scopes = {}
if (
"filter_immune_slices" in i_params_dict
or "filter_immune_slice_fields" in i_params_dict
@ -415,7 +415,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
@classmethod
def export_dashboards( # pylint: disable=too-many-locals
cls, dashboard_ids: List
cls, dashboard_ids: List[int]
) -> str:
copied_dashboards = []
datasource_ids = set()

View File

@ -81,7 +81,7 @@ class ImportMixin:
for u in cls.__table_args__ # type: ignore
if isinstance(u, UniqueConstraint)
]
unique.extend( # type: ignore
unique.extend(
{c.name} for c in cls.__table__.columns if c.unique # type: ignore
)
return unique

View File

@ -36,7 +36,7 @@ from superset.tasks.thumbnails import cache_chart_thumbnail
from superset.utils import core as utils
if is_feature_enabled("SIP_38_VIZ_REARCHITECTURE"):
from superset.viz_sip38 import BaseViz, viz_types # type: ignore
from superset.viz_sip38 import BaseViz, viz_types
else:
from superset.viz import BaseViz, viz_types # type: ignore

View File

@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Optional, Type
from typing import Any, Dict, List, Optional, Type
from sqlalchemy import types
from sqlalchemy.sql.sqltypes import Integer
@ -29,7 +29,7 @@ class TinyInteger(Integer):
A type for tiny ``int`` integers.
"""
def python_type(self) -> Type:
def python_type(self) -> Type[int]:
return int
@classmethod
@ -42,7 +42,7 @@ class Interval(TypeEngine):
A type for intervals.
"""
def python_type(self) -> Optional[Type]:
def python_type(self) -> Optional[Type[Any]]:
return None
@classmethod
@ -55,7 +55,7 @@ class Array(TypeEngine):
A type for arrays.
"""
def python_type(self) -> Optional[Type]:
def python_type(self) -> Optional[Type[List[Any]]]:
return list
@classmethod
@ -68,7 +68,7 @@ class Map(TypeEngine):
A type for maps.
"""
def python_type(self) -> Optional[Type]:
def python_type(self) -> Optional[Type[Dict[Any, Any]]]:
return dict
@classmethod
@ -81,7 +81,7 @@ class Row(TypeEngine):
A type for rows.
"""
def python_type(self) -> Optional[Type]:
def python_type(self) -> Optional[Type[Any]]:
return None
@classmethod

View File

@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Callable
from typing import Any
from flask import g
from flask_sqlalchemy import BaseQuery
@ -25,7 +25,7 @@ from superset.views.base import BaseFilter
class QueryFilter(BaseFilter): # pylint: disable=too-few-public-methods
def apply(self, query: BaseQuery, value: Callable) -> BaseQuery:
def apply(self, query: BaseQuery, value: Any) -> BaseQuery:
"""
Filter queries to only those owned by current user. If
can_access_all_queries permission is set a user can list all queries

View File

@ -20,7 +20,7 @@
import datetime
import json
import logging
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
from typing import Any, Dict, List, Optional, Tuple, Type
import numpy as np
import pandas as pd
@ -64,7 +64,7 @@ def stringify(obj: Any) -> str:
def stringify_values(array: np.ndarray) -> np.ndarray:
vstringify: Callable = np.vectorize(stringify)
vstringify = np.vectorize(stringify)
return vstringify(array)
@ -172,7 +172,7 @@ class SupersetResultSet:
return table.to_pandas(integer_object_nulls=True)
@staticmethod
def first_nonempty(items: List) -> Any:
def first_nonempty(items: List[Any]) -> Any:
return next((i for i in items if i), None)
def is_temporal(self, db_type_str: Optional[str]) -> bool:

View File

@ -21,11 +21,11 @@ from typing import Any, Callable, List, Optional, Set, Tuple, TYPE_CHECKING, Uni
from flask import current_app, g
from flask_appbuilder import Model
from flask_appbuilder.security.sqla import models as ab_models
from flask_appbuilder.security.sqla.manager import SecurityManager
from flask_appbuilder.security.sqla.models import (
assoc_permissionview_role,
assoc_user_role,
PermissionView,
)
from flask_appbuilder.security.views import (
PermissionModelView,
@ -602,11 +602,8 @@ class SupersetSecurityManager(SecurityManager):
logger.info("Cleaning faulty perms")
sesh = self.get_session
pvms = sesh.query(ab_models.PermissionView).filter(
or_(
ab_models.PermissionView.permission == None,
ab_models.PermissionView.view_menu == None,
)
pvms = sesh.query(PermissionView).filter(
or_(PermissionView.permission == None, PermissionView.view_menu == None,)
)
deleted_count = pvms.delete()
sesh.commit()
@ -640,7 +637,9 @@ class SupersetSecurityManager(SecurityManager):
self.get_session.commit()
self.clean_perms()
def set_role(self, role_name: str, pvm_check: Callable) -> None:
def set_role(
self, role_name: str, pvm_check: Callable[[PermissionView], bool]
) -> None:
"""
Set the FAB permission/views for the role.
@ -650,7 +649,7 @@ class SupersetSecurityManager(SecurityManager):
logger.info("Syncing {} perms".format(role_name))
sesh = self.get_session
pvms = sesh.query(ab_models.PermissionView).all()
pvms = sesh.query(PermissionView).all()
pvms = [p for p in pvms if p.permission and p.view_menu]
role = self.add_role(role_name)
role_pvms = [p for p in pvms if pvm_check(p)]

View File

@ -299,9 +299,10 @@ def _serialize_and_expand_data(
db_engine_spec: BaseEngineSpec,
use_msgpack: Optional[bool] = False,
expand_data: bool = False,
) -> Tuple[Union[bytes, str], list, list, list]:
selected_columns: List[Dict] = result_set.columns
expanded_columns: List[Dict]
) -> Tuple[Union[bytes, str], List[Any], List[Any], List[Any]]:
selected_columns = result_set.columns
all_columns: List[Any]
expanded_columns: List[Any]
if use_msgpack:
with stats_timing(

View File

@ -25,7 +25,7 @@ from superset import create_app
from superset.extensions import celery_app
# Init the Flask app / configure everything
create_app() # type: ignore
create_app()
# Need to import late, as the celery_app will have been setup by "create_app()"
# pylint: disable=wrong-import-position, unused-import

View File

@ -23,7 +23,7 @@ import urllib.request
from collections import namedtuple
from datetime import datetime, timedelta
from email.utils import make_msgid, parseaddr
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from typing import Any, Dict, Iterator, List, Optional, Tuple, TYPE_CHECKING, Union
from urllib.error import URLError # pylint: disable=ungrouped-imports
import croniter
@ -36,7 +36,6 @@ from flask_login import login_user
from retry.api import retry_call
from selenium.common.exceptions import WebDriverException
from selenium.webdriver import chrome, firefox
from werkzeug.datastructures import TypeConversionDict
from werkzeug.http import parse_cookie
# Superset framework imports
@ -53,6 +52,11 @@ from superset.models.schedules import (
)
from superset.utils.core import get_email_address_list, send_email_smtp
if TYPE_CHECKING:
# pylint: disable=unused-import
from werkzeug.datastructures import TypeConversionDict
# Globals
config = app.config
logger = logging.getLogger("tasks.email_reports")
@ -131,7 +135,7 @@ def _generate_mail_content(
return EmailContent(body, data, images)
def _get_auth_cookies() -> List[TypeConversionDict]:
def _get_auth_cookies() -> List["TypeConversionDict[Any, Any]"]:
# Login with the user specified to get the reports
with app.test_request_context():
user = security_manager.find_user(config["EMAIL_REPORTS_USER"])

View File

@ -27,8 +27,9 @@ def view_cache_key(*args: Any, **kwargs: Any) -> str: # pylint: disable=unused-
def memoized_func(
key: Callable = view_cache_key, attribute_in_key: Optional[str] = None
) -> Callable:
key: Callable[..., str] = view_cache_key, # pylint: disable=bad-whitespace
attribute_in_key: Optional[str] = None,
) -> Callable[..., Any]:
"""Use this decorator to cache functions that have predefined first arg.
enable_cache is treated as True by default,
@ -45,7 +46,7 @@ def memoized_func(
returns the caching key.
"""
def wrap(f: Callable) -> Callable:
def wrap(f: Callable[..., Any]) -> Callable[..., Any]:
if cache_manager.tables_cache:
def wrapped_f(self: Any, *args: Any, **kwargs: Any) -> Any:

View File

@ -85,7 +85,7 @@ from superset.exceptions import (
SupersetException,
SupersetTimeoutException,
)
from superset.typing import FormData, Metric
from superset.typing import FlaskResponse, FormData, Metric
from superset.utils.dates import datetime_to_epoch, EPOCH
try:
@ -147,7 +147,9 @@ class _memoized:
should account for instance variable changes.
"""
def __init__(self, func: Callable, watch: Optional[Tuple[str, ...]] = None) -> None:
def __init__(
self, func: Callable[..., Any], watch: Optional[Tuple[str, ...]] = None
) -> None:
self.func = func
self.cache: Dict[Any, Any] = {}
self.is_method = False
@ -173,7 +175,7 @@ class _memoized:
"""Return the function's docstring."""
return self.func.__doc__ or ""
def __get__(self, obj: Any, objtype: Type) -> functools.partial:
def __get__(self, obj: Any, objtype: Type[Any]) -> functools.partial: # type: ignore
if not self.is_method:
self.is_method = True
"""Support instance methods."""
@ -181,13 +183,13 @@ class _memoized:
def memoized(
func: Optional[Callable] = None, watch: Optional[Tuple[str, ...]] = None
) -> Callable:
func: Optional[Callable[..., Any]] = None, watch: Optional[Tuple[str, ...]] = None
) -> Callable[..., Any]:
if func:
return _memoized(func)
else:
def wrapper(f: Callable) -> Callable:
def wrapper(f: Callable[..., Any]) -> Callable[..., Any]:
return _memoized(f, watch)
return wrapper
@ -1241,7 +1243,9 @@ def create_ssl_cert_file(certificate: str) -> str:
return path
def time_function(func: Callable, *args: Any, **kwargs: Any) -> Tuple[float, Any]:
def time_function(
func: Callable[..., FlaskResponse], *args: Any, **kwargs: Any
) -> Tuple[float, Any]:
"""
Measures the amount of time a function takes to execute in ms

View File

@ -29,7 +29,7 @@ def convert_filter_scopes(
) -> Dict[int, Dict[str, Dict[str, Any]]]:
filter_scopes = {}
immuned_by_id: List[int] = json_metadata.get("filter_immune_slices") or []
immuned_by_column: Dict = defaultdict(list)
immuned_by_column: Dict[str, List[int]] = defaultdict(list)
for slice_id, columns in json_metadata.get(
"filter_immune_slice_fields", {}
).items():
@ -52,7 +52,7 @@ def convert_filter_scopes(
logging.info(f"slice [{filter_id}] has invalid field: {filter_field}")
for filter_slice in filters:
filter_fields: Dict = {}
filter_fields: Dict[str, Dict[str, Any]] = {}
filter_id = filter_slice.id
slice_params = json.loads(filter_slice.params or "{}")
configs = slice_params.get("filter_configs") or []
@ -77,9 +77,10 @@ def convert_filter_scopes(
def copy_filter_scopes(
old_to_new_slc_id_dict: Dict[int, int], old_filter_scopes: Dict[str, Dict]
) -> Dict:
new_filter_scopes: Dict[str, Dict] = {}
old_to_new_slc_id_dict: Dict[int, int],
old_filter_scopes: Dict[int, Dict[str, Dict[str, Any]]],
) -> Dict[str, Dict[Any, Any]]:
new_filter_scopes: Dict[str, Dict[Any, Any]] = {}
for (filter_id, scopes) in old_filter_scopes.items():
new_filter_key = old_to_new_slc_id_dict.get(int(filter_id))
if new_filter_key:

View File

@ -46,7 +46,7 @@ def stats_timing(stats_key: str, stats_logger: BaseStatsLogger) -> Iterator[floa
stats_logger.timing(stats_key, now_as_float() - start_ts)
def etag_cache(max_age: int, check_perms: Callable) -> Callable:
def etag_cache(max_age: int, check_perms: Callable[..., Any]) -> Callable[..., Any]:
"""
A decorator for caching views and handling etag conditional requests.
@ -60,7 +60,7 @@ def etag_cache(max_age: int, check_perms: Callable) -> Callable:
"""
def decorator(f: Callable) -> Callable:
def decorator(f: Callable[..., Any]) -> Callable[..., Any]:
@wraps(f)
def wrapper(*args: Any, **kwargs: Any) -> ETagResponseMixin:
# check if the user can access the resource

View File

@ -27,8 +27,8 @@ logger = logging.getLogger(__name__)
def import_datasource(
session: Session,
i_datasource: Model,
lookup_database: Callable,
lookup_datasource: Callable,
lookup_database: Callable[[Model], Model],
lookup_datasource: Callable[[Model], Model],
import_time: Optional[int] = None,
) -> int:
"""Imports the datasource from the object to the database.
@ -82,7 +82,9 @@ def import_datasource(
return datasource.id
def import_simple_obj(session: Session, i_obj: Model, lookup_obj: Callable) -> Model:
def import_simple_obj(
session: Session, i_obj: Model, lookup_obj: Callable[[Model], Model]
) -> Model:
make_transient(i_obj)
i_obj.id = None
i_obj.table = None

View File

@ -35,7 +35,7 @@ class AbstractEventLogger(ABC):
) -> None:
pass
def log_this(self, f: Callable) -> Callable:
def log_this(self, f: Callable[..., Any]) -> Callable[..., Any]:
@functools.wraps(f)
def wrapper(*args: Any, **kwargs: Any) -> Any:
user_id = None
@ -124,7 +124,7 @@ def get_event_logger_from_cfg_value(cfg_value: Any) -> AbstractEventLogger:
)
)
event_logger_type = cast(Type, cfg_value)
event_logger_type = cast(Type[Any], cfg_value)
result = event_logger_type()
# Verify that we have a valid logger impl

View File

@ -58,7 +58,7 @@ class DefaultLoggingConfigurator(LoggingConfigurator):
if app_config["ENABLE_TIME_ROTATE"]:
logging.getLogger().setLevel(app_config["TIME_ROTATE_LOG_LEVEL"])
handler = TimedRotatingFileHandler( # type: ignore
handler = TimedRotatingFileHandler(
app_config["FILENAME"],
when=app_config["ROLLOVER"],
interval=app_config["INTERVAL"],

View File

@ -72,8 +72,8 @@ WHITELIST_CUMULATIVE_FUNCTIONS = (
)
def validate_column_args(*argnames: str) -> Callable:
def wrapper(func: Callable) -> Callable:
def validate_column_args(*argnames: str) -> Callable[..., Any]:
def wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
def wrapped(df: DataFrame, **options: Any) -> Any:
columns = df.columns.tolist()
for name in argnames:
@ -471,7 +471,7 @@ def geodetic_parse(
Parse a string containing a geodetic point and return latitude, longitude
and altitude
"""
point = Point(location) # type: ignore
point = Point(location)
return point[0], point[1], point[2]
try:

View File

@ -51,7 +51,7 @@ SELENIUM_HEADSTART = 3
WindowSize = Tuple[int, int]
def get_auth_cookies(user: "User") -> List[Dict]:
def get_auth_cookies(user: "User") -> List[Dict[Any, Any]]:
# Login with the user specified to get the reports
with current_app.test_request_context("/login"):
login_user(user)
@ -101,14 +101,14 @@ class AuthWebDriverProxy:
self,
driver_type: str,
window: Optional[WindowSize] = None,
auth_func: Optional[Callable] = None,
auth_func: Optional[
Callable[..., Any]
] = None, # pylint: disable=bad-whitespace
):
self._driver_type = driver_type
self._window: WindowSize = window or (800, 600)
config_auth_func: Callable = current_app.config.get(
"WEBDRIVER_AUTH_FUNC", auth_driver
)
self._auth_func: Callable = auth_func or config_auth_func
config_auth_func = current_app.config.get("WEBDRIVER_AUTH_FUNC", auth_driver)
self._auth_func = auth_func or config_auth_func
def create(self) -> WebDriver:
if self._driver_type == "firefox":
@ -123,7 +123,7 @@ class AuthWebDriverProxy:
raise Exception(f"Webdriver name ({self._driver_type}) not supported")
# Prepare args for the webdriver init
options.add_argument("--headless")
kwargs: Dict = dict(options=options)
kwargs: Dict[Any, Any] = dict(options=options)
kwargs.update(current_app.config["WEBDRIVER_CONFIGURATION"])
logger.info("Init selenium driver")
return driver_class(**kwargs)

View File

@ -143,7 +143,7 @@ def generate_download_headers(
return headers
def api(f: Callable) -> Callable:
def api(f: Callable[..., FlaskResponse]) -> Callable[..., FlaskResponse]:
"""
A decorator to label an endpoint as an API. Catches uncaught exceptions and
return the response in the JSON format
@ -383,11 +383,11 @@ class DeleteMixin: # pylint: disable=too-few-public-methods
:param primary_key:
record primary key to delete
"""
item = self.datamodel.get(primary_key, self._base_filters) # type: ignore
item = self.datamodel.get(primary_key, self._base_filters)
if not item:
abort(404)
try:
self.pre_delete(item) # type: ignore
self.pre_delete(item)
except Exception as ex: # pylint: disable=broad-except
flash(str(ex), "danger")
else:
@ -400,8 +400,8 @@ class DeleteMixin: # pylint: disable=too-few-public-methods
.all()
)
if self.datamodel.delete(item): # type: ignore
self.post_delete(item) # type: ignore
if self.datamodel.delete(item):
self.post_delete(item)
for pv in pvs:
security_manager.get_session.delete(pv)
@ -411,8 +411,8 @@ class DeleteMixin: # pylint: disable=too-few-public-methods
security_manager.get_session.commit()
flash(*self.datamodel.message) # type: ignore
self.update_redirect() # type: ignore
flash(*self.datamodel.message)
self.update_redirect()
@action(
"muldelete", __("Delete"), __("Delete all Really?"), "fa-trash", single=False

View File

@ -41,7 +41,7 @@ get_related_schema = {
}
def statsd_metrics(f: Callable) -> Callable:
def statsd_metrics(f: Callable[..., Any]) -> Callable[..., Any]:
"""
Handle sending all statsd metrics from the REST API
"""

View File

@ -88,7 +88,9 @@ class BaseOwnedSchema(BaseSupersetSchema):
owners_field_name = "owners"
@post_load
def make_object(self, data: Dict, discard: Optional[List[str]] = None) -> Model:
def make_object(
self, data: Dict[str, Any], discard: Optional[List[str]] = None
) -> Model:
discard = discard or []
discard.append(self.owners_field_name)
instance = super().make_object(data, discard)

View File

@ -251,7 +251,7 @@ def check_slice_perms(self: "Superset", slice_id: int) -> None:
def _deserialize_results_payload(
payload: Union[bytes, str], query: Query, use_msgpack: Optional[bool] = False
) -> Dict[Any, Any]:
) -> Dict[str, Any]:
logger.debug(f"Deserializing from msgpack: {use_msgpack}")
if use_msgpack:
with stats_timing(
@ -278,7 +278,7 @@ def _deserialize_results_payload(
with stats_timing(
"sqllab.query.results_backend_json_deserialize", stats_logger
):
return json.loads(payload) # type: ignore
return json.loads(payload)
def get_cta_schema_name(
@ -1343,7 +1343,7 @@ class Superset(BaseSupersetView):
if "timed_refresh_immune_slices" not in md:
md["timed_refresh_immune_slices"] = []
new_filter_scopes: Dict[str, Dict] = {}
new_filter_scopes = {}
if "filter_scopes" in data:
# replace filter_id and immune ids from old slice id to new slice id:
# and remove slice ids that are not in dash anymore
@ -2137,7 +2137,7 @@ class Superset(BaseSupersetView):
f"deprecated.{self.__class__.__name__}.select_star.database_not_found"
)
return json_error_response("Not found", 404)
schema = utils.parse_js_uri_path_item(schema, eval_undefined=True) # type: ignore
schema = utils.parse_js_uri_path_item(schema, eval_undefined=True)
table_name = utils.parse_js_uri_path_item(table_name) # type: ignore
# Check that the user can access the datasource
if not self.appbuilder.sm.can_access_datasource(
@ -2245,7 +2245,7 @@ class Superset(BaseSupersetView):
)
payload = utils.zlib_decompress(blob, decode=not results_backend_use_msgpack)
obj: dict = _deserialize_results_payload(
obj = _deserialize_results_payload(
payload, query, cast(bool, results_backend_use_msgpack)
)
@ -2474,9 +2474,7 @@ class Superset(BaseSupersetView):
schema: str = cast(str, query_params.get("schema"))
sql: str = cast(str, query_params.get("sql"))
try:
template_params: dict = json.loads(
query_params.get("templateParams") or "{}"
)
template_params = json.loads(query_params.get("templateParams") or "{}")
except json.JSONDecodeError:
logger.warning(
f"Invalid template parameter {query_params.get('templateParams')}"

View File

@ -61,7 +61,7 @@ def get_col_type(col: Dict[Any, Any]) -> str:
def get_table_metadata(
database: Database, table_name: str, schema_name: Optional[str]
) -> Dict:
) -> Dict[str, Any]:
"""
Get table metadata information, including type, pk, fks.
This function raises SQLAlchemyError when a schema is not found.
@ -72,7 +72,7 @@ def get_table_metadata(
:param schema_name: schema name
:return: Dict table metadata ready for API response
"""
keys: List = []
keys = []
columns = database.get_columns(table_name, schema_name)
primary_key = database.get_pk_constraint(table_name, schema_name)
if primary_key and primary_key.get("constrained_columns"):
@ -82,7 +82,7 @@ def get_table_metadata(
foreign_keys = get_foreign_keys_metadata(database, table_name, schema_name)
indexes = get_indexes_metadata(database, table_name, schema_name)
keys += foreign_keys + indexes
payload_columns: List[Dict] = []
payload_columns: List[Dict[str, Any]] = []
for col in columns:
dtype = get_col_type(col)
payload_columns.append(
@ -90,7 +90,7 @@ def get_table_metadata(
"name": col["name"],
"type": dtype.split("(")[0] if "(" in dtype else dtype,
"longType": dtype,
"keys": [k for k in keys if col["name"] in k.get("column_names")],
"keys": [k for k in keys if col["name"] in k["column_names"]],
}
)
return {
@ -270,7 +270,7 @@ class DatabaseRestApi(DatabaseMixin, BaseSupersetModelRestApi):
"""
self.incr_stats("init", self.table_metadata.__name__)
try:
table_info: Dict = get_table_metadata(database, table_name, schema_name)
table_info = get_table_metadata(database, table_name, schema_name)
except SQLAlchemyError as ex:
self.incr_stats("error", self.table_metadata.__name__)
return self.response_422(error_msg_from_exception(ex))

View File

@ -29,7 +29,7 @@ from superset.views.base_api import BaseSupersetModelRestApi
logger = logging.getLogger(__name__)
def check_datasource_access(f: Callable) -> Callable:
def check_datasource_access(f: Callable[..., Any]) -> Callable[..., Any]:
"""
A Decorator that checks if a user has datasource access
"""

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import enum
from typing import Type
from typing import Type, Union
import simplejson as json
from croniter import croniter
@ -55,7 +55,7 @@ class EmailScheduleView(
raise NotImplementedError()
@property
def schedule_type_model(self) -> Type:
def schedule_type_model(self) -> Type[Union[Dashboard, Slice]]:
raise NotImplementedError()
page_size = 20
@ -154,9 +154,7 @@ class EmailScheduleView(
info[col] = info[col].username
info["user"] = schedule.user.username
info[self.schedule_type] = getattr( # type: ignore
schedule, self.schedule_type
).id
info[self.schedule_type] = getattr(schedule, self.schedule_type).id
schedules.append(info)
return json_success(json.dumps(schedules, default=json_iso_dttm_ser))

View File

@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Callable
from typing import Any
import simplejson as json
from flask import g, redirect, request, Response
@ -40,7 +40,7 @@ from .base import (
class QueryFilter(BaseFilter): # pylint: disable=too-few-public-methods
def apply(self, query: BaseQuery, value: Callable) -> BaseQuery:
def apply(self, query: BaseQuery, value: Any) -> BaseQuery:
"""
Filter queries to only those owned by current user. If
can_access_all_queries permission is set a user can list all queries

View File

@ -35,7 +35,7 @@ from superset.utils.core import QueryStatus, TimeRangeEndpoint
from superset.viz import BaseViz
if is_feature_enabled("SIP_38_VIZ_REARCHITECTURE"):
from superset import viz_sip38 as viz # type: ignore
from superset import viz_sip38 as viz
else:
from superset import viz # type: ignore
@ -318,9 +318,9 @@ def get_dashboard_extra_filters(
def build_extra_filters(
layout: Dict,
filter_scopes: Dict,
default_filters: Dict[str, Dict[str, List]],
layout: Dict[str, Dict[str, Any]],
filter_scopes: Dict[str, Dict[str, Any]],
default_filters: Dict[str, Dict[str, List[Any]]],
slice_id: int,
) -> List[Dict[str, Any]]:
extra_filters = []
@ -343,7 +343,9 @@ def build_extra_filters(
return extra_filters
def is_slice_in_container(layout: Dict, container_id: str, slice_id: int) -> bool:
def is_slice_in_container(
layout: Dict[str, Dict[str, Any]], container_id: str, slice_id: int
) -> bool:
if container_id == "ROOT_ID":
return True

View File

@ -2720,7 +2720,7 @@ class PairedTTestViz(BaseViz):
else:
cols.append(col)
df.columns = cols
data: Dict = {}
data: Dict[str, List[Dict[str, Any]]] = {}
series = df.to_dict("series")
for nameSet in df.columns:
# If no groups are defined, nameSet will be the metric name
@ -2750,7 +2750,7 @@ class RoseViz(NVD3TimeSeriesViz):
return None
data = super().get_data(df)
result: Dict = {}
result: Dict[str, List[Dict[str, str]]] = {}
for datum in data: # type: ignore
key = datum["key"]
for val in datum["values"]:

View File

@ -18,7 +18,7 @@
"""Unit tests for Superset"""
import imp
import json
from typing import Dict, Union, List
from typing import Any, Dict, Union, List
from unittest.mock import Mock, patch
import pandas as pd
@ -397,7 +397,9 @@ class SupersetTestCase(TestCase):
mock_method.assert_called_once_with("error", func_name)
return rv
def post_assert_metric(self, uri: str, data: Dict, func_name: str) -> Response:
def post_assert_metric(
self, uri: str, data: Dict[str, Any], func_name: str
) -> Response:
"""
Simple client post with an extra assertion for statsd metrics
@ -417,7 +419,9 @@ class SupersetTestCase(TestCase):
mock_method.assert_called_once_with("error", func_name)
return rv
def put_assert_metric(self, uri: str, data: Dict, func_name: str) -> Response:
def put_assert_metric(
self, uri: str, data: Dict[str, Any], func_name: str
) -> Response:
"""
Simple client put with an extra assertion for statsd metrics

View File

@ -20,7 +20,7 @@ from copy import copy
from cachelib.redis import RedisCache
from flask import Flask
from superset.config import * # type: ignore
from superset.config import *
AUTH_USER_REGISTRATION_ROLE = "alpha"
SQLALCHEMY_DATABASE_URI = "sqlite:///" + os.path.join(DATA_DIR, "unittests.db")