[typing] add typing for superset/connectors and superset/common (#8138)

This commit is contained in:
serenajiang 2019-09-19 16:51:01 -07:00 committed by Erik Ritter
parent 8bc5cd7dc0
commit dfb3bf69a0
7 changed files with 329 additions and 242 deletions

View File

@ -18,20 +18,22 @@
from datetime import datetime, timedelta
import logging
import pickle as pkl
from typing import Dict, List
from typing import Any, Dict, List, Optional
import numpy as np
import pandas as pd
from superset import app, cache
from superset import db
from superset.connectors.base.models import BaseDatasource
from superset.connectors.connector_registry import ConnectorRegistry
from superset.stats_logger import BaseStatsLogger
from superset.utils import core as utils
from superset.utils.core import DTTM_ALIAS
from .query_object import QueryObject
config = app.config
stats_logger = config.get("STATS_LOGGER")
stats_logger: BaseStatsLogger = config["STATS_LOGGER"]
class QueryContext:
@ -40,8 +42,13 @@ class QueryContext:
to retrieve the data payload for a given viz.
"""
cache_type = "df"
enforce_numerical_metrics = True
cache_type: str = "df"
enforce_numerical_metrics: bool = True
datasource: BaseDatasource
queries: List[QueryObject]
force: bool
custom_cache_timeout: Optional[int]
# TODO: Type datasource and query_object dictionary with TypedDict when it becomes
# a vanilla python type https://github.com/python/mypy/issues/5288
@ -50,8 +57,8 @@ class QueryContext:
datasource: Dict,
queries: List[Dict],
force: bool = False,
custom_cache_timeout: int = None,
):
custom_cache_timeout: Optional[int] = None,
) -> None:
self.datasource = ConnectorRegistry.get_datasource(
datasource.get("type"), int(datasource.get("id")), db.session # noqa: T400
)
@ -61,9 +68,7 @@ class QueryContext:
self.custom_cache_timeout = custom_cache_timeout
self.enforce_numerical_metrics = True
def get_query_result(self, query_object):
def get_query_result(self, query_object: QueryObject) -> Dict[str, Any]:
"""Returns a pandas dataframe based on the query object"""
# Here, we assume that all the queries will use the same datasource, which is
@ -109,23 +114,23 @@ class QueryContext:
"df": df,
}
def df_metrics_to_num(self, df, query_object):
def df_metrics_to_num(self, df: pd.DataFrame, query_object: QueryObject) -> None:
"""Converting metrics to numeric when pandas.read_sql cannot"""
metrics = [metric for metric in query_object.metrics]
for col, dtype in df.dtypes.items():
if dtype.type == np.object_ and col in metrics:
df[col] = pd.to_numeric(df[col], errors="coerce")
def get_data(self, df):
def get_data(self, df: pd.DataFrame) -> List[Dict]:
return df.to_dict(orient="records")
def get_single_payload(self, query_obj: QueryObject):
def get_single_payload(self, query_obj: QueryObject) -> Dict[str, Any]:
"""Returns a payload of metadata and data"""
payload = self.get_df_payload(query_obj)
df = payload.get("df")
status = payload.get("status")
if status != utils.QueryStatus.FAILED:
if df is not None and df.empty:
if df is None or df.empty:
payload["error"] = "No data"
else:
payload["data"] = self.get_data(df)
@ -133,12 +138,12 @@ class QueryContext:
del payload["df"]
return payload
def get_payload(self):
def get_payload(self) -> List[Dict[str, Any]]:
"""Get all the payloads from the arrays"""
return [self.get_single_payload(query_object) for query_object in self.queries]
@property
def cache_timeout(self):
def cache_timeout(self) -> int:
if self.custom_cache_timeout is not None:
return self.custom_cache_timeout
if self.datasource.cache_timeout is not None:
@ -148,10 +153,10 @@ class QueryContext:
and self.datasource.database.cache_timeout
) is not None:
return self.datasource.database.cache_timeout
return config.get("CACHE_DEFAULT_TIMEOUT")
return config["CACHE_DEFAULT_TIMEOUT"]
def get_df_payload(self, query_obj: QueryObject, **kwargs):
"""Handles caching around the df paylod retrieval"""
def get_df_payload(self, query_obj: QueryObject, **kwargs) -> Dict[str, Any]:
"""Handles caching around the df payload retrieval"""
extra_cache_keys = self.datasource.get_extra_cache_keys(query_obj.to_dict())
cache_key = (
query_obj.cache_key(
@ -207,9 +212,7 @@ class QueryContext:
if is_loaded and cache_key and cache and status != utils.QueryStatus.FAILED:
try:
cache_value = dict(
dttm=cached_dttm, df=df if df is not None else None, query=query
)
cache_value = dict(dttm=cached_dttm, df=df, query=query)
cache_binary = pkl.dumps(cache_value, protocol=pkl.HIGHEST_PROTOCOL)
logging.info(

View File

@ -15,8 +15,9 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=R
from datetime import datetime, timedelta
import hashlib
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union
import simplejson as json
@ -34,12 +35,28 @@ class QueryObject:
and druid. The query objects are constructed on the client.
"""
granularity: str
from_dttm: datetime
to_dttm: datetime
is_timeseries: bool
time_shift: timedelta
groupby: List[str]
metrics: List[Union[Dict, str]]
row_limit: int
filter: List[str]
timeseries_limit: int
timeseries_limit_metric: Optional[Dict]
order_desc: bool
extras: Dict
columns: List[str]
orderby: List[List]
def __init__(
self,
granularity: str,
metrics: List[Union[Dict, str]],
groupby: List[str] = None,
filters: List[str] = None,
groupby: Optional[List[str]] = None,
filters: Optional[List[str]] = None,
time_range: Optional[str] = None,
time_shift: Optional[str] = None,
is_timeseries: bool = False,
@ -48,8 +65,8 @@ class QueryObject:
timeseries_limit_metric: Optional[Dict] = None,
order_desc: bool = True,
extras: Optional[Dict] = None,
columns: List[str] = None,
orderby: List[List] = None,
columns: Optional[List[str]] = None,
orderby: Optional[List[List]] = None,
relative_start: str = app.config.get("DEFAULT_RELATIVE_START_TIME", "today"),
relative_end: str = app.config.get("DEFAULT_RELATIVE_END_TIME", "today"),
):
@ -63,7 +80,7 @@ class QueryObject:
self.is_timeseries = is_timeseries
self.time_range = time_range
self.time_shift = utils.parse_human_timedelta(time_shift)
self.groupby = groupby if groupby is not None else []
self.groupby = groupby or []
# Temporal solution for backward compatability issue
# due the new format of non-ad-hoc metric.
@ -72,15 +89,15 @@ class QueryObject:
for metric in metrics
]
self.row_limit = row_limit
self.filter = filters if filters is not None else []
self.filter = filters or []
self.timeseries_limit = timeseries_limit
self.timeseries_limit_metric = timeseries_limit_metric
self.order_desc = order_desc
self.extras = extras if extras is not None else {}
self.columns = columns if columns is not None else []
self.orderby = orderby if orderby is not None else []
self.extras = extras or {}
self.columns = columns or []
self.orderby = orderby or []
def to_dict(self):
def to_dict(self) -> Dict[str, Any]:
query_object_dict = {
"granularity": self.granularity,
"from_dttm": self.from_dttm,
@ -99,7 +116,7 @@ class QueryObject:
}
return query_object_dict
def cache_key(self, **extra):
def cache_key(self, **extra) -> str:
"""
The cache key is made out of the key/values from to_dict(), plus any
other key/values in `extra`
@ -117,7 +134,7 @@ class QueryObject:
json_data = self.json_dumps(cache_dict, sort_keys=True)
return hashlib.md5(json_data.encode("utf-8")).hexdigest()
def json_dumps(self, obj, sort_keys=False):
def json_dumps(self, obj: Any, sort_keys: bool = False) -> str:
return json.dumps(
obj, default=utils.json_int_dttm_ser, ignore_nan=True, sort_keys=sort_keys
)

View File

@ -16,14 +16,15 @@
# under the License.
# pylint: disable=C,R,W
import json
from typing import Any, List
from typing import Any, Dict, List, Optional
from flask_appbuilder.security.sqla.models import User
from sqlalchemy import and_, Boolean, Column, Integer, String, Text
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import foreign, relationship
from sqlalchemy.orm import foreign, Query, relationship
from superset.models.core import Slice
from superset.models.helpers import AuditMixinNullable, ImportMixin
from superset.models.helpers import AuditMixinNullable, ImportMixin, QueryResult
from superset.utils import core as utils
@ -59,9 +60,9 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
params = Column(String(1000))
perm = Column(String(1000))
sql = None
owners = None
update_from_object_fields = None
sql: Optional[str] = None
owners: List[User]
update_from_object_fields: List[str]
@declared_attr
def slices(self):
@ -79,20 +80,20 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
metrics: List[Any] = []
@property
def uid(self):
def uid(self) -> str:
"""Unique id across datasource types"""
return f"{self.id}__{self.type}"
@property
def column_names(self):
def column_names(self) -> List[str]:
return sorted([c.column_name for c in self.columns], key=lambda x: x or "")
@property
def columns_types(self):
def columns_types(self) -> Dict:
return {c.column_name: c.type for c in self.columns}
@property
def main_dttm_col(self):
def main_dttm_col(self) -> str:
return "timestamp"
@property
@ -100,47 +101,47 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
raise NotImplementedError()
@property
def connection(self):
def connection(self) -> Optional[str]:
"""String representing the context of the Datasource"""
return None
@property
def schema(self):
def schema(self) -> Optional[str]:
"""String representing the schema of the Datasource (if it applies)"""
return None
@property
def filterable_column_names(self):
def filterable_column_names(self) -> List[str]:
return sorted([c.column_name for c in self.columns if c.filterable])
@property
def dttm_cols(self):
def dttm_cols(self) -> List:
return []
@property
def url(self):
def url(self) -> str:
return "/{}/edit/{}".format(self.baselink, self.id)
@property
def explore_url(self):
def explore_url(self) -> str:
if self.default_endpoint:
return self.default_endpoint
else:
return "/superset/explore/{obj.type}/{obj.id}/".format(obj=self)
@property
def column_formats(self):
def column_formats(self) -> Dict[str, Optional[str]]:
return {m.metric_name: m.d3format for m in self.metrics if m.d3format}
def add_missing_metrics(self, metrics):
exisiting_metrics = {m.metric_name for m in self.metrics}
def add_missing_metrics(self, metrics: List["BaseMetric"]) -> None:
existing_metrics = {m.metric_name for m in self.metrics}
for metric in metrics:
if metric.metric_name not in exisiting_metrics:
if metric.metric_name not in existing_metrics:
metric.table_id = self.id
self.metrics += [metric]
self.metrics.append(metric)
@property
def short_data(self):
def short_data(self) -> Dict[str, Any]:
"""Data representation of the datasource sent to the frontend"""
return {
"edit_url": self.url,
@ -158,7 +159,7 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
pass
@property
def data(self):
def data(self) -> Dict[str, Any]:
"""Data representation of the datasource sent to the frontend"""
order_by_choices = []
# self.column_names return sorted column_names
@ -239,14 +240,14 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
"""Returns column information from the external system"""
raise NotImplementedError()
def get_query_str(self, query_obj):
def get_query_str(self, query_obj) -> str:
"""Returns a query as a string
This is used to be displayed to the user so that she/he can
understand what is taking place behind the scene"""
raise NotImplementedError()
def query(self, query_obj):
def query(self, query_obj) -> QueryResult:
"""Executes the query and returns a dataframe
query_obj is a dictionary representing Superset's query interface.
@ -254,7 +255,7 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
"""
raise NotImplementedError()
def values_for_column(self, column_name, limit=10000):
def values_for_column(self, column_name: str, limit: int = 10000) -> List:
"""Given a column, returns an iterable of distinct values
This is used to populate the dropdown showing a list of
@ -262,13 +263,14 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
raise NotImplementedError()
@staticmethod
def default_query(qry):
def default_query(qry) -> Query:
return qry
def get_column(self, column_name):
def get_column(self, column_name: str) -> Optional["BaseColumn"]:
for col in self.columns:
if col.column_name == column_name:
return col
return None
def get_fk_many_from_list(self, object_list, fkmany, fkmany_class, key_attr):
"""Update ORM one-to-many list from object list
@ -276,10 +278,9 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
Used for syncing metrics and columns using the same code"""
object_dict = {o.get(key_attr): o for o in object_list}
object_keys = [o.get(key_attr) for o in object_list]
# delete fks that have been removed
fkmany = [o for o in fkmany if getattr(o, key_attr) in object_keys]
fkmany = [o for o in fkmany if getattr(o, key_attr) in object_dict]
# sync existing fks
for fk in fkmany:
@ -303,7 +304,7 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
fkmany += new_fks
return fkmany
def update_from_object(self, obj):
def update_from_object(self, obj) -> None:
"""Update datasource from a data structure
The UI's table editor crafts a complex data structure that
@ -330,7 +331,7 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
obj.get("columns"), self.columns, self.column_class, "column_name"
)
def get_extra_cache_keys(self, query_obj) -> List[Any]:
def get_extra_cache_keys(self, query_obj: Dict) -> List[Any]:
""" If a datasource needs to provide additional keys for calculation of
cache keys, those can be provided via this method
"""
@ -374,23 +375,23 @@ class BaseColumn(AuditMixinNullable, ImportMixin):
str_types = ("VARCHAR", "STRING", "CHAR")
@property
def is_num(self):
return self.type and any([t in self.type.upper() for t in self.num_types])
def is_num(self) -> bool:
return self.type and any(map(lambda t: t in self.type.upper(), self.num_types))
@property
def is_time(self):
return self.type and any([t in self.type.upper() for t in self.date_types])
def is_time(self) -> bool:
return self.type and any(map(lambda t: t in self.type.upper(), self.date_types))
@property
def is_string(self):
return self.type and any([t in self.type.upper() for t in self.str_types])
def is_string(self) -> bool:
return self.type and any(map(lambda t: t in self.type.upper(), self.str_types))
@property
def expression(self):
raise NotImplementedError()
@property
def data(self):
def data(self) -> Dict[str, Any]:
attrs = (
"id",
"column_name",
@ -443,7 +444,7 @@ class BaseMetric(AuditMixinNullable, ImportMixin):
raise NotImplementedError()
@property
def data(self):
def data(self) -> Dict[str, Any]:
attrs = (
"id",
"metric_name",

View File

@ -15,16 +15,23 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=C,R,W
from sqlalchemy.orm import subqueryload
from collections import OrderedDict
from typing import Dict, List, Optional, Set, Type, TYPE_CHECKING
from sqlalchemy.orm import Session, subqueryload
if TYPE_CHECKING:
from superset.models.core import Database
from superset.connectors.base.models import BaseDatasource
class ConnectorRegistry(object):
""" Central Registry for all available datasource engines"""
sources = {}
sources: Dict[str, Type["BaseDatasource"]] = {}
@classmethod
def register_sources(cls, datasource_config):
def register_sources(cls, datasource_config: OrderedDict) -> None:
for module_name, class_names in datasource_config.items():
class_names = [str(s) for s in class_names]
module_obj = __import__(module_name, fromlist=class_names)
@ -33,7 +40,9 @@ class ConnectorRegistry(object):
cls.sources[source_class.type] = source_class
@classmethod
def get_datasource(cls, datasource_type, datasource_id, session):
def get_datasource(
cls, datasource_type: str, datasource_id: int, session: Session
) -> Optional["BaseDatasource"]:
return (
session.query(cls.sources[datasource_type])
.filter_by(id=datasource_id)
@ -41,8 +50,8 @@ class ConnectorRegistry(object):
)
@classmethod
def get_all_datasources(cls, session):
datasources = []
def get_all_datasources(cls, session: Session) -> List["BaseDatasource"]:
datasources: List["BaseDatasource"] = []
for source_type in ConnectorRegistry.sources:
source_class = ConnectorRegistry.sources[source_type]
qry = session.query(source_class)
@ -52,15 +61,22 @@ class ConnectorRegistry(object):
@classmethod
def get_datasource_by_name(
cls, session, datasource_type, datasource_name, schema, database_name
):
cls,
session: Session,
datasource_type: str,
datasource_name: str,
schema: str,
database_name: str,
) -> Optional["BaseDatasource"]:
datasource_class = ConnectorRegistry.sources[datasource_type]
return datasource_class.get_datasource_by_name(
session, datasource_name, schema, database_name
)
@classmethod
def query_datasources_by_permissions(cls, session, database, permissions):
def query_datasources_by_permissions(
cls, session: Session, database: "Database", permissions: Set[str]
) -> List["BaseDatasource"]:
datasource_class = ConnectorRegistry.sources[database.type]
return (
session.query(datasource_class)
@ -70,7 +86,9 @@ class ConnectorRegistry(object):
)
@classmethod
def get_eager_datasource(cls, session, datasource_type, datasource_id):
def get_eager_datasource(
cls, session: Session, datasource_type: str, datasource_id: int
) -> "BaseDatasource":
"""Returns datasource with columns and metrics."""
datasource_class = ConnectorRegistry.sources[datasource_type]
return (
@ -84,7 +102,13 @@ class ConnectorRegistry(object):
)
@classmethod
def query_datasources_by_name(cls, session, database, datasource_name, schema=None):
def query_datasources_by_name(
cls,
session: Session,
database: "Database",
datasource_name: str,
schema: Optional[str] = None,
) -> List["BaseDatasource"]:
datasource_class = ConnectorRegistry.sources[database.type]
return datasource_class.query_datasources_by_name(
session, database, datasource_name, schema=None

View File

@ -24,13 +24,15 @@ import json
import logging
from multiprocessing.pool import ThreadPool
import re
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
from dateutil.parser import parse as dparse
from flask import escape, Markup
from flask_appbuilder import Model
from flask_appbuilder.models.decorators import renders
from flask_appbuilder.security.sqla.models import User
from flask_babel import lazy_gettext as _
import pandas
import pandas as pd
try:
from pydruid.client import PyDruid
@ -41,7 +43,7 @@ try:
RegisteredLookupExtraction,
)
from pydruid.utils.filters import Dimension, Filter
from pydruid.utils.having import Aggregation
from pydruid.utils.having import Aggregation, Having
from pydruid.utils.postaggregator import (
Const,
Field,
@ -65,12 +67,13 @@ from sqlalchemy import (
Text,
UniqueConstraint,
)
from sqlalchemy.orm import backref, relationship
from sqlalchemy.orm import backref, relationship, RelationshipProperty, Session
from sqlalchemy_utils import EncryptedType
from superset import conf, db, security_manager
from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric
from superset.exceptions import SupersetException
from superset.models.core import Database
from superset.models.helpers import AuditMixinNullable, ImportMixin, QueryResult
from superset.utils import core as utils, import_datasource
@ -78,6 +81,8 @@ try:
from superset.utils.core import DimSelector, DTTM_ALIAS, flasher
except ImportError:
pass
DRUID_TZ = conf.get("DRUID_TZ")
POST_AGG_TYPE = "postagg"
metadata = Model.metadata # pylint: disable=no-member
@ -150,22 +155,22 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin):
return self.__repr__()
@property
def data(self):
def data(self) -> Dict:
return {"id": self.id, "name": self.cluster_name, "backend": "druid"}
@staticmethod
def get_base_url(host, port):
def get_base_url(host, port) -> str:
if not re.match("http(s)?://", host):
host = "http://" + host
url = "{0}:{1}".format(host, port) if port else host
return url
def get_base_broker_url(self):
def get_base_broker_url(self) -> str:
base_url = self.get_base_url(self.broker_host, self.broker_port)
return f"{base_url}/{self.broker_endpoint}"
def get_pydruid_client(self):
def get_pydruid_client(self) -> PyDruid:
cli = PyDruid(
self.get_base_url(self.broker_host, self.broker_port), self.broker_endpoint
)
@ -173,39 +178,44 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin):
cli.set_basic_auth_credentials(self.broker_user, self.broker_pass)
return cli
def get_datasources(self):
def get_datasources(self) -> List[str]:
endpoint = self.get_base_broker_url() + "/datasources"
auth = requests.auth.HTTPBasicAuth(self.broker_user, self.broker_pass)
return json.loads(requests.get(endpoint, auth=auth).text)
def get_druid_version(self):
def get_druid_version(self) -> str:
endpoint = self.get_base_url(self.broker_host, self.broker_port) + "/status"
auth = requests.auth.HTTPBasicAuth(self.broker_user, self.broker_pass)
return json.loads(requests.get(endpoint, auth=auth).text)["version"]
@property
@property # noqa: T484
@utils.memoized
def druid_version(self):
def druid_version(self) -> str:
return self.get_druid_version()
def refresh_datasources(
self, datasource_name=None, merge_flag=True, refreshAll=True
):
self,
datasource_name: Optional[str] = None,
merge_flag: bool = True,
refresh_all: bool = True,
) -> None:
"""Refresh metadata of all datasources in the cluster
If ``datasource_name`` is specified, only that datasource is updated
"""
ds_list = self.get_datasources()
blacklist = conf.get("DRUID_DATA_SOURCE_BLACKLIST", [])
ds_refresh = []
ds_refresh: List[str] = []
if not datasource_name:
ds_refresh = list(filter(lambda ds: ds not in blacklist, ds_list))
elif datasource_name not in blacklist and datasource_name in ds_list:
ds_refresh.append(datasource_name)
else:
return
self.refresh(ds_refresh, merge_flag, refreshAll)
self.refresh(ds_refresh, merge_flag, refresh_all)
def refresh(self, datasource_names, merge_flag, refreshAll):
def refresh(
self, datasource_names: List[str], merge_flag: bool, refresh_all: bool
) -> None:
"""
Fetches metadata for the specified datasources and
merges to the Superset database
@ -225,7 +235,7 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin):
session.add(datasource)
flasher(_("Adding new datasource [{}]").format(ds_name), "success")
ds_map[ds_name] = datasource
elif refreshAll:
elif refresh_all:
flasher(_("Refreshing datasource [{}]").format(ds_name), "info")
else:
del ds_map[ds_name]
@ -270,19 +280,19 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin):
session.commit()
@property
def perm(self):
def perm(self) -> str:
return "[{obj.cluster_name}].(id:{obj.id})".format(obj=self)
def get_perm(self):
def get_perm(self) -> str:
return self.perm
@property
def name(self):
return self.verbose_name if self.verbose_name else self.cluster_name
def name(self) -> str:
return self.verbose_name or self.cluster_name
@property
def unique_name(self):
return self.verbose_name if self.verbose_name else self.cluster_name
def unique_name(self) -> str:
return self.verbose_name or self.cluster_name
class DruidColumn(Model, BaseColumn):
@ -318,25 +328,26 @@ class DruidColumn(Model, BaseColumn):
return self.column_name or str(self.id)
@property
def expression(self):
def expression(self) -> str:
return self.dimension_spec_json
@property
def dimension_spec(self):
def dimension_spec(self) -> Optional[Dict]: # noqa: T484
if self.dimension_spec_json:
return json.loads(self.dimension_spec_json)
def get_metrics(self):
metrics = {}
metrics["count"] = DruidMetric(
metric_name="count",
verbose_name="COUNT(*)",
metric_type="count",
json=json.dumps({"type": "count", "name": "count"}),
)
def get_metrics(self) -> Dict[str, "DruidMetric"]:
metrics = {
"count": DruidMetric(
metric_name="count",
verbose_name="COUNT(*)",
metric_type="count",
json=json.dumps({"type": "count", "name": "count"}),
)
}
return metrics
def refresh_metrics(self):
def refresh_metrics(self) -> None:
"""Refresh metrics based on the column metadata"""
metrics = self.get_metrics()
dbmetrics = (
@ -356,8 +367,8 @@ class DruidColumn(Model, BaseColumn):
db.session.add(metric)
@classmethod
def import_obj(cls, i_column):
def lookup_obj(lookup_column):
def import_obj(cls, i_column: "DruidColumn") -> "DruidColumn":
def lookup_obj(lookup_column: "DruidColumn") -> Optional["DruidColumn"]:
return (
db.session.query(DruidColumn)
.filter(
@ -404,7 +415,7 @@ class DruidMetric(Model, BaseMetric):
return self.json
@property
def json_obj(self):
def json_obj(self) -> Dict:
try:
obj = json.loads(self.json)
except Exception:
@ -412,7 +423,7 @@ class DruidMetric(Model, BaseMetric):
return obj
@property
def perm(self):
def perm(self) -> Optional[str]:
return (
("{parent_name}.[{obj.metric_name}](id:{obj.id})").format(
obj=self, parent_name=self.datasource.full_name
@ -421,12 +432,12 @@ class DruidMetric(Model, BaseMetric):
else None
)
def get_perm(self):
def get_perm(self) -> Optional[str]:
return self.perm
@classmethod
def import_obj(cls, i_metric):
def lookup_obj(lookup_metric):
def import_obj(cls, i_metric: "DruidMetric") -> "DruidMetric":
def lookup_obj(lookup_metric: DruidMetric) -> Optional[DruidMetric]:
return (
db.session.query(DruidMetric)
.filter(
@ -494,23 +505,23 @@ class DruidDatasource(Model, BaseDatasource):
export_children = ["columns", "metrics"]
@property
def database(self):
def database(self) -> RelationshipProperty:
return self.cluster
@property
def connection(self):
def connection(self) -> str:
return str(self.database)
@property
def num_cols(self):
def num_cols(self) -> List[str]:
return [c.column_name for c in self.columns if c.is_num]
@property
def name(self):
def name(self) -> str:
return self.datasource_name
@property
def schema(self):
def schema(self) -> Optional[str]:
ds_name = self.datasource_name or ""
name_pieces = ds_name.split(".")
if len(name_pieces) > 1:
@ -519,11 +530,11 @@ class DruidDatasource(Model, BaseDatasource):
return None
@property
def schema_perm(self):
def schema_perm(self) -> Optional[str]:
"""Returns schema permission if present, cluster one otherwise."""
return security_manager.get_schema_perm(self.cluster, self.schema)
def get_perm(self):
def get_perm(self) -> str:
return ("[{obj.cluster_name}].[{obj.datasource_name}]" "(id:{obj.id})").format(
obj=self
)
@ -532,16 +543,16 @@ class DruidDatasource(Model, BaseDatasource):
return NotImplementedError()
@property
def link(self):
def link(self) -> Markup:
name = escape(self.datasource_name)
return Markup(f'<a href="{self.url}">{name}</a>')
@property
def full_name(self):
def full_name(self) -> str:
return utils.get_datasource_full_name(self.cluster_name, self.datasource_name)
@property
def time_column_grains(self):
def time_column_grains(self) -> Dict[str, List[str]]:
return {
"time_columns": [
"all",
@ -568,16 +579,18 @@ class DruidDatasource(Model, BaseDatasource):
return self.datasource_name
@renders("datasource_name")
def datasource_link(self):
def datasource_link(self) -> str:
url = f"/superset/explore/{self.type}/{self.id}/"
name = escape(self.datasource_name)
return Markup(f'<a href="{url}">{name}</a>')
def get_metric_obj(self, metric_name):
def get_metric_obj(self, metric_name: str) -> Dict:
return [m.json_obj for m in self.metrics if m.metric_name == metric_name][0]
@classmethod
def import_obj(cls, i_datasource, import_time=None):
def import_obj(
cls, i_datasource: "DruidDatasource", import_time: Optional[int] = None
) -> int:
"""Imports the datasource from the object to the database.
Metrics and columns and datasource will be overridden if exists.
@ -585,7 +598,7 @@ class DruidDatasource(Model, BaseDatasource):
superset instances. Audit metadata isn't copies over.
"""
def lookup_datasource(d):
def lookup_datasource(d: DruidDatasource) -> Optional[DruidDatasource]:
return (
db.session.query(DruidDatasource)
.filter(
@ -595,7 +608,7 @@ class DruidDatasource(Model, BaseDatasource):
.first()
)
def lookup_cluster(d):
def lookup_cluster(d: DruidDatasource) -> Optional[DruidCluster]:
return (
db.session.query(DruidCluster)
.filter_by(cluster_name=d.cluster_name)
@ -659,12 +672,14 @@ class DruidDatasource(Model, BaseDatasource):
if segment_metadata:
return segment_metadata[-1]["columns"]
def refresh_metrics(self):
def refresh_metrics(self) -> None:
for col in self.columns:
col.refresh_metrics()
@classmethod
def sync_to_db_from_config(cls, druid_config, user, cluster, refresh=True):
def sync_to_db_from_config(
cls, druid_config: Dict, user: User, cluster: DruidCluster, refresh: bool = True
) -> None:
"""Merges the ds config from druid_config into one stored in the db."""
session = db.session
datasource = (
@ -742,13 +757,15 @@ class DruidDatasource(Model, BaseDatasource):
session.commit()
@staticmethod
def time_offset(granularity):
def time_offset(granularity: Union[str, Dict]) -> int:
if granularity == "week_ending_saturday":
return 6 * 24 * 3600 * 1000 # 6 days
return 0
@classmethod
def get_datasource_by_name(cls, session, datasource_name, schema, database_name):
def get_datasource_by_name(
cls, session: Session, datasource_name: str, schema: str, database_name: str
) -> Optional["DruidDatasource"]:
query = (
session.query(cls)
.join(DruidCluster)
@ -761,7 +778,9 @@ class DruidDatasource(Model, BaseDatasource):
# http://druid.io/docs/0.8.0/querying/granularities.html
# TODO: pass origin from the UI
@staticmethod
def granularity(period_name, timezone=None, origin=None):
def granularity(
period_name: str, timezone: Optional[str] = None, origin: Optional[str] = None
) -> Union[str, Dict]:
if not period_name or period_name == "all":
return "all"
iso_8601_dict = {
@ -810,7 +829,7 @@ class DruidDatasource(Model, BaseDatasource):
return granularity
@staticmethod
def get_post_agg(mconf):
def get_post_agg(mconf: Dict) -> Postaggregator:
"""
For a metric specified as `postagg` returns the
kind of post aggregation for pydruid.
@ -839,7 +858,9 @@ class DruidDatasource(Model, BaseDatasource):
return CustomPostAggregator(mconf.get("name", ""), mconf)
@staticmethod
def find_postaggs_for(postagg_names, metrics_dict):
def find_postaggs_for(
postagg_names: Set[str], metrics_dict: Dict[str, DruidMetric]
) -> List[DruidMetric]:
"""Return a list of metrics that are post aggregations"""
postagg_metrics = [
metrics_dict[name]
@ -852,7 +873,7 @@ class DruidDatasource(Model, BaseDatasource):
return postagg_metrics
@staticmethod
def recursive_get_fields(_conf):
def recursive_get_fields(_conf: Dict) -> List[str]:
_type = _conf.get("type")
_field = _conf.get("field")
_fields = _conf.get("fields")
@ -875,11 +896,9 @@ class DruidDatasource(Model, BaseDatasource):
# Check if the fields are already in aggs
# or is a previous postagg
required_fields = set(
[
field
for field in required_fields
if field not in visited_postaggs and field not in agg_names
]
field
for field in required_fields
if field not in visited_postaggs and field not in agg_names
)
# First try to find postaggs that match
if len(required_fields) > 0:
@ -903,7 +922,11 @@ class DruidDatasource(Model, BaseDatasource):
post_aggs[postagg.metric_name] = DruidDatasource.get_post_agg(postagg.json_obj)
@staticmethod
def metrics_and_post_aggs(metrics, metrics_dict, druid_version=None):
def metrics_and_post_aggs(
metrics: List[Union[Dict, str]],
metrics_dict: Dict[str, DruidMetric],
druid_version=None,
) -> Tuple[OrderedDict, OrderedDict]: # noqa: T484
# Separate metrics into those that are aggregations
# and those that are post aggregations
saved_agg_names = set()
@ -912,26 +935,26 @@ class DruidDatasource(Model, BaseDatasource):
for metric in metrics:
if utils.is_adhoc_metric(metric):
adhoc_agg_configs.append(metric)
elif metrics_dict[metric].metric_type != POST_AGG_TYPE:
elif metrics_dict[metric].metric_type != POST_AGG_TYPE: # noqa: T484
saved_agg_names.add(metric)
else:
postagg_names.append(metric)
# Create the post aggregations, maintain order since postaggs
# may depend on previous ones
post_aggs = OrderedDict()
post_aggs = OrderedDict() # noqa: T484
visited_postaggs = set()
for postagg_name in postagg_names:
postagg = metrics_dict[postagg_name]
postagg = metrics_dict[postagg_name] # noqa: T484
visited_postaggs.add(postagg_name)
DruidDatasource.resolve_postagg(
postagg, post_aggs, saved_agg_names, visited_postaggs, metrics_dict
)
aggs = DruidDatasource.get_aggregations(
aggs = DruidDatasource.get_aggregations( # noqa: T484
metrics_dict, saved_agg_names, adhoc_agg_configs
)
return aggs, post_aggs
def values_for_column(self, column_name, limit=10000):
def values_for_column(self, column_name: str, limit: int = 10000) -> List:
"""Retrieve some values for the given column"""
logging.info(
"Getting values for columns [{}] limited to [{}]".format(column_name, limit)
@ -955,12 +978,14 @@ class DruidDatasource(Model, BaseDatasource):
client = self.cluster.get_pydruid_client()
client.topn(**qry)
df = client.export_pandas()
return [row[column_name] for row in df.to_records(index=False)]
return df[column_name].to_list()
def get_query_str(self, query_obj, phase=1, client=None):
return self.run_query(client=client, phase=phase, **query_obj)
def _add_filter_from_pre_query_data(self, df, dimensions, dim_filter):
def _add_filter_from_pre_query_data(
self, df: Optional[pd.DataFrame], dimensions, dim_filter
):
ret = dim_filter
if df is not None and not df.empty:
new_filters = []
@ -1002,7 +1027,7 @@ class DruidDatasource(Model, BaseDatasource):
return ret
@staticmethod
def druid_type_from_adhoc_metric(adhoc_metric):
def druid_type_from_adhoc_metric(adhoc_metric: Dict) -> str:
column_type = adhoc_metric["column"]["type"].lower()
aggregate = adhoc_metric["aggregate"].lower()
@ -1014,7 +1039,9 @@ class DruidDatasource(Model, BaseDatasource):
return column_type + aggregate.capitalize()
@staticmethod
def get_aggregations(metrics_dict, saved_metrics, adhoc_metrics=[]):
def get_aggregations(
metrics_dict: Dict, saved_metrics: Iterable[str], adhoc_metrics: List[Dict] = []
) -> OrderedDict:
"""
Returns a dictionary of aggregation metric names to aggregation json objects
@ -1023,7 +1050,7 @@ class DruidDatasource(Model, BaseDatasource):
:param adhoc_metrics: list of adhoc metric names
:raise SupersetException: if one or more metric names are not aggregations
"""
aggregations = OrderedDict()
aggregations: OrderedDict = OrderedDict()
invalid_metric_names = []
for metric_name in saved_metrics:
if metric_name in metrics_dict:
@ -1047,19 +1074,18 @@ class DruidDatasource(Model, BaseDatasource):
}
return aggregations
def get_dimensions(self, groupby, columns_dict):
def get_dimensions(
self, groupby: List[str], columns_dict: Dict[str, DruidColumn]
) -> List[Union[str, Dict]]:
dimensions = []
groupby = [gb for gb in groupby if gb in columns_dict]
for column_name in groupby:
col = columns_dict.get(column_name)
dim_spec = col.dimension_spec if col else None
if dim_spec:
dimensions.append(dim_spec)
else:
dimensions.append(column_name)
dimensions.append(dim_spec or column_name)
return dimensions
def intervals_from_dttms(self, from_dttm, to_dttm):
def intervals_from_dttms(self, from_dttm: datetime, to_dttm: datetime) -> str:
# Couldn't find a way to just not filter on time...
from_dttm = from_dttm or datetime(1901, 1, 1)
to_dttm = to_dttm or datetime(2101, 1, 1)
@ -1091,7 +1117,7 @@ class DruidDatasource(Model, BaseDatasource):
return values
@staticmethod
def sanitize_metric_object(metric):
def sanitize_metric_object(metric: Dict) -> None:
"""
Update a metric with the correct type if necessary.
:param dict metric: The metric to sanitize
@ -1122,7 +1148,7 @@ class DruidDatasource(Model, BaseDatasource):
phase=2,
client=None,
order_desc=True,
):
) -> str:
"""Runs a query against Druid and returns a dataframe.
"""
# TODO refactor into using a TBD Query object
@ -1193,7 +1219,7 @@ class DruidDatasource(Model, BaseDatasource):
del qry["dimensions"]
client.timeseries(**qry)
elif not having_filters and len(groupby) == 1 and order_desc:
dim = list(qry.get("dimensions"))[0]
dim = list(qry.get("dimensions"))[0] # noqa: T484
logging.info("Running two-phase topn query for dimension [{}]".format(dim))
pre_qry = deepcopy(qry)
if timeseries_limit_metric:
@ -1324,7 +1350,7 @@ class DruidDatasource(Model, BaseDatasource):
return query_str
@staticmethod
def homogenize_types(df, groupby_cols):
def homogenize_types(df: pd.DataFrame, groupby_cols: Iterable[str]) -> pd.DataFrame:
"""Converting all GROUPBY columns to strings
When grouping by a numeric (say FLOAT) column, pydruid returns
@ -1334,11 +1360,10 @@ class DruidDatasource(Model, BaseDatasource):
Here we replace None with <NULL> and make the whole series a
str instead of an object.
"""
for col in groupby_cols:
df[col] = df[col].fillna("<NULL>").astype("unicode")
df[groupby_cols] = df[groupby_cols].fillna("<NULL>").astype("unicode")
return df
def query(self, query_obj):
def query(self, query_obj: Dict) -> QueryResult:
qry_start_dttm = datetime.now()
client = self.cluster.get_pydruid_client()
query_str = self.get_query_str(client=client, query_obj=query_obj, phase=2)
@ -1346,7 +1371,7 @@ class DruidDatasource(Model, BaseDatasource):
if df is None or df.size == 0:
return QueryResult(
df=pandas.DataFrame([]),
df=pd.DataFrame([]),
query=query_str,
duration=datetime.now() - qry_start_dttm,
)
@ -1363,7 +1388,7 @@ class DruidDatasource(Model, BaseDatasource):
del df[DTTM_ALIAS]
# Reordering columns
cols = []
cols: List[str] = []
if DTTM_ALIAS in df.columns:
cols += [DTTM_ALIAS]
cols += query_obj.get("groupby") or []
@ -1413,7 +1438,7 @@ class DruidDatasource(Model, BaseDatasource):
return (col, extraction_fn)
@classmethod
def get_filters(cls, raw_filters, num_cols, columns_dict): # noqa
def get_filters(cls, raw_filters, num_cols, columns_dict) -> Filter: # noqa: T484
"""Given Superset filter data structure, returns pydruid Filter(s)"""
filters = None
for flt in raw_filters:
@ -1542,7 +1567,7 @@ class DruidDatasource(Model, BaseDatasource):
return filters
def _get_having_obj(self, col, op, eq):
def _get_having_obj(self, col: str, op: str, eq: str) -> Having:
cond = None
if op == "==":
if col in self.column_names:
@ -1556,7 +1581,7 @@ class DruidDatasource(Model, BaseDatasource):
return cond
def get_having_filters(self, raw_filters):
def get_having_filters(self, raw_filters: List[Dict]) -> Having:
filters = None
reversed_op_map = {"!=": "==", ">=": "<", "<=": ">"}
@ -1579,7 +1604,9 @@ class DruidDatasource(Model, BaseDatasource):
return filters
@classmethod
def query_datasources_by_name(cls, session, database, datasource_name, schema=None):
def query_datasources_by_name(
cls, session: Session, database: Database, datasource_name: str, schema=None
) -> List["DruidDatasource"]:
return (
session.query(cls)
.filter_by(cluster_name=database.id)
@ -1587,7 +1614,7 @@ class DruidDatasource(Model, BaseDatasource):
.all()
)
def external_metadata(self):
def external_metadata(self) -> List[Dict]:
self.merge_flag = True
return [
{"name": k, "type": v.get("type")}

View File

@ -395,7 +395,7 @@ class Druid(BaseSupersetView):
@has_access
@expose("/refresh_datasources/")
def refresh_datasources(self, refreshAll=True):
def refresh_datasources(self, refresh_all=True):
"""endpoint that refreshes druid datasources metadata"""
session = db.session()
DruidCluster = ConnectorRegistry.sources["druid"].cluster_class
@ -403,7 +403,7 @@ class Druid(BaseSupersetView):
cluster_name = cluster.cluster_name
valid_cluster = True
try:
cluster.refresh_datasources(refreshAll=refreshAll)
cluster.refresh_datasources(refresh_all=refresh_all)
except Exception as e:
valid_cluster = False
flash(
@ -432,7 +432,7 @@ class Druid(BaseSupersetView):
Calling this endpoint will cause a scan for new
datasources only and add them.
"""
return self.refresh_datasources(refreshAll=False)
return self.refresh_datasources(refresh_all=False)
appbuilder.add_view_no_menu(Druid)

View File

@ -42,10 +42,10 @@ from sqlalchemy import (
Text,
)
from sqlalchemy.exc import CompileError
from sqlalchemy.orm import backref, relationship
from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty, Session
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.sql import column, literal_column, table, text
from sqlalchemy.sql import column, ColumnElement, literal_column, table, text
from sqlalchemy.sql.expression import Label, Select, TextAsFrom
import sqlparse
@ -83,7 +83,7 @@ class AnnotationDatasource(BaseDatasource):
cache_timeout = 0
def query(self, query_obj):
def query(self, query_obj: Dict) -> QueryResult:
df = None
error_message = None
qry = db.session.query(Annotation)
@ -143,7 +143,7 @@ class TableColumn(Model, BaseColumn):
update_from_object_fields = [s for s in export_fields if s not in ("table_id",)]
export_parent = "table"
def get_sqla_col(self, label=None):
def get_sqla_col(self, label: Optional[str] = None) -> Column:
label = label or self.column_name
if not self.expression:
db_engine_spec = self.table.database.db_engine_spec
@ -155,10 +155,12 @@ class TableColumn(Model, BaseColumn):
return col
@property
def datasource(self):
def datasource(self) -> RelationshipProperty:
return self.table
def get_time_filter(self, start_dttm, end_dttm):
def get_time_filter(
self, start_dttm: DateTime, end_dttm: DateTime
) -> ColumnElement:
col = self.get_sqla_col(label="__time")
l = [] # noqa: E741
if start_dttm:
@ -205,7 +207,7 @@ class TableColumn(Model, BaseColumn):
return import_datasource.import_simple_obj(db.session, i_column, lookup_obj)
def dttm_sql_literal(self, dttm):
def dttm_sql_literal(self, dttm: DateTime) -> str:
"""Convert datetime object to a SQL expression string"""
tf = self.python_date_format
if tf:
@ -249,13 +251,13 @@ class SqlMetric(Model, BaseMetric):
)
export_parent = "table"
def get_sqla_col(self, label=None):
def get_sqla_col(self, label: Optional[str] = None) -> Column:
label = label or self.metric_name
sqla_col = literal_column(self.expression)
return self.table.make_sqla_column_compatible(sqla_col, label)
@property
def perm(self):
def perm(self) -> Optional[str]:
return (
("{parent_name}.[{obj.metric_name}](id:{obj.id})").format(
obj=self, parent_name=self.table.full_name
@ -264,7 +266,7 @@ class SqlMetric(Model, BaseMetric):
else None
)
def get_perm(self):
def get_perm(self) -> Optional[str]:
return self.perm
@classmethod
@ -351,7 +353,9 @@ class SqlaTable(Model, BaseDatasource):
"MAX": sa.func.MAX,
}
def make_sqla_column_compatible(self, sqla_col, label=None):
def make_sqla_column_compatible(
self, sqla_col: Column, label: Optional[str] = None
) -> Column:
"""Takes a sql alchemy column object and adds label info if supported by engine.
:param sqla_col: sql alchemy column instance
:param label: alias/label that column is expected to have
@ -369,23 +373,29 @@ class SqlaTable(Model, BaseDatasource):
return self.name
@property
def connection(self):
def connection(self) -> str:
return str(self.database)
@property
def description_markeddown(self):
def description_markeddown(self) -> str:
return utils.markdown(self.description)
@property
def datasource_name(self):
def datasource_name(self) -> str:
return self.table_name
@property
def database_name(self):
def database_name(self) -> str:
return self.database.name
@classmethod
def get_datasource_by_name(cls, session, datasource_name, schema, database_name):
def get_datasource_by_name(
cls,
session: Session,
datasource_name: str,
schema: Optional[str],
database_name: str,
) -> Optional["SqlaTable"]:
schema = schema or None
query = (
session.query(cls)
@ -398,52 +408,52 @@ class SqlaTable(Model, BaseDatasource):
for tbl in query.all():
if schema == (tbl.schema or None):
return tbl
return None
@property
def link(self):
def link(self) -> Markup:
name = escape(self.name)
anchor = f'<a target="_blank" href="{self.explore_url}">{name}</a>'
return Markup(anchor)
@property
def schema_perm(self):
def schema_perm(self) -> Optional[str]:
"""Returns schema permission if present, database one otherwise."""
return security_manager.get_schema_perm(self.database, self.schema)
def get_perm(self):
def get_perm(self) -> str:
return ("[{obj.database}].[{obj.table_name}]" "(id:{obj.id})").format(obj=self)
@property
def name(self):
def name(self) -> str:
if not self.schema:
return self.table_name
return "{}.{}".format(self.schema, self.table_name)
@property
def full_name(self):
def full_name(self) -> str:
return utils.get_datasource_full_name(
self.database, self.table_name, schema=self.schema
)
@property
def dttm_cols(self):
def dttm_cols(self) -> List:
l = [c.column_name for c in self.columns if c.is_dttm] # noqa: E741
if self.main_dttm_col and self.main_dttm_col not in l:
l.append(self.main_dttm_col)
return l
@property
def num_cols(self):
def num_cols(self) -> List:
return [c.column_name for c in self.columns if c.is_num]
@property
def any_dttm_col(self):
def any_dttm_col(self) -> Optional[str]:
cols = self.dttm_cols
if cols:
return cols[0]
return cols[0] if cols else None
@property
def html(self):
def html(self) -> str:
t = ((c.column_name, c.type) for c in self.columns)
df = pd.DataFrame(t)
df.columns = ["field", "type"]
@ -453,7 +463,7 @@ class SqlaTable(Model, BaseDatasource):
)
@property
def sql_url(self):
def sql_url(self) -> str:
return self.database.sql_url + "?table_name=" + str(self.table_name)
def external_metadata(self):
@ -466,28 +476,29 @@ class SqlaTable(Model, BaseDatasource):
return cols
@property
def time_column_grains(self):
def time_column_grains(self) -> Dict[str, Any]:
return {
"time_columns": self.dttm_cols,
"time_grains": [grain.name for grain in self.database.grains()],
}
@property
def select_star(self):
def select_star(self) -> str:
# show_cols and latest_partition set to false to avoid
# the expensive cost of inspecting the DB
return self.database.select_star(
self.table_name, schema=self.schema, show_cols=False, latest_partition=False
)
def get_col(self, col_name):
def get_col(self, col_name: str) -> Optional[Column]:
columns = self.columns
for col in columns:
if col_name == col.column_name:
return col
return None
@property
def data(self):
def data(self) -> Dict:
d = super(SqlaTable, self).data
if self.type == "table":
grains = self.database.grains() or []
@ -500,7 +511,7 @@ class SqlaTable(Model, BaseDatasource):
d["template_params"] = self.template_params
return d
def values_for_column(self, column_name, limit=10000):
def values_for_column(self, column_name: str, limit: int = 10000) -> List:
"""Runs query against sqla to retrieve some
sample values for the given column.
"""
@ -525,9 +536,9 @@ class SqlaTable(Model, BaseDatasource):
sql = self.mutate_query_from_config(sql)
df = pd.read_sql_query(sql=sql, con=engine)
return [row[0] for row in df.to_records(index=False)]
return df[column_name].to_list()
def mutate_query_from_config(self, sql):
def mutate_query_from_config(self, sql: str) -> str:
"""Apply config's SQL_QUERY_MUTATOR
Typically adds comments to the query with context"""
@ -540,7 +551,7 @@ class SqlaTable(Model, BaseDatasource):
def get_template_processor(self, **kwargs):
return get_template_processor(table=self, database=self.database, **kwargs)
def get_query_str_extended(self, query_obj) -> QueryStringExtended:
def get_query_str_extended(self, query_obj: Dict) -> QueryStringExtended:
sqlaq = self.get_sqla_query(**query_obj)
sql = self.database.compile_sqla_query(sqlaq.sqla_query)
logging.info(sql)
@ -550,7 +561,7 @@ class SqlaTable(Model, BaseDatasource):
labels_expected=sqlaq.labels_expected, sql=sql, prequeries=sqlaq.prequeries
)
def get_query_str(self, query_obj):
def get_query_str(self, query_obj: Dict) -> str:
query_str_ext = self.get_query_str_extended(query_obj)
all_queries = query_str_ext.prequeries + [query_str_ext.sql]
return ";\n\n".join(all_queries) + ";"
@ -571,7 +582,7 @@ class SqlaTable(Model, BaseDatasource):
return TextAsFrom(sa.text(from_sql), []).alias("expr_qry")
return self.get_sqla_table()
def adhoc_metric_to_sqla(self, metric, cols):
def adhoc_metric_to_sqla(self, metric: Dict, cols: Dict) -> Optional[Column]:
"""
Turn an adhoc metric into a sqlalchemy column.
@ -584,13 +595,13 @@ class SqlaTable(Model, BaseDatasource):
label = utils.get_metric_name(metric)
if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES["SIMPLE"]:
column_name = metric.get("column").get("column_name")
column_name = metric["column"].get("column_name")
table_column = cols.get(column_name)
if table_column:
sqla_column = table_column.get_sqla_col()
else:
sqla_column = column(column_name)
sqla_metric = self.sqla_aggregations[metric.get("aggregate")](sqla_column)
sqla_metric = self.sqla_aggregations[metric["aggregate"]](sqla_column)
elif expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES["SQL"]:
sqla_metric = literal_column(metric.get("sqlExpression"))
else:
@ -616,7 +627,7 @@ class SqlaTable(Model, BaseDatasource):
extras=None,
columns=None,
order_desc=True,
):
) -> SqlaQuery:
"""Querying any sqla table from this common interface"""
template_kwargs = {
"from_dttm": from_dttm,
@ -643,8 +654,8 @@ class SqlaTable(Model, BaseDatasource):
# Database spec supports join-free timeslot grouping
time_groupby_inline = db_engine_spec.time_groupby_inline
cols = {col.column_name: col for col in self.columns}
metrics_dict = {m.metric_name: m for m in self.metrics}
cols: Dict[str, Column] = {col.column_name: col for col in self.columns}
metrics_dict: Dict[str, SqlMetric] = {m.metric_name: m for m in self.metrics}
if not granularity and is_timeseries:
raise Exception(
@ -660,7 +671,7 @@ class SqlaTable(Model, BaseDatasource):
if utils.is_adhoc_metric(m):
metrics_exprs.append(self.adhoc_metric_to_sqla(m, cols))
elif m in metrics_dict:
metrics_exprs.append(metrics_dict.get(m).get_sqla_col())
metrics_exprs.append(metrics_dict[m].get_sqla_col())
else:
raise Exception(_("Metric '%(metric)s' does not exist", metric=m))
if metrics_exprs:
@ -669,8 +680,8 @@ class SqlaTable(Model, BaseDatasource):
main_metric_expr, label = literal_column("COUNT(*)"), "ccount"
main_metric_expr = self.make_sqla_column_compatible(main_metric_expr, label)
select_exprs = []
groupby_exprs_sans_timestamp = OrderedDict()
select_exprs: List[Column] = []
groupby_exprs_sans_timestamp: OrderedDict = OrderedDict()
if groupby:
select_exprs = []
@ -729,7 +740,7 @@ class SqlaTable(Model, BaseDatasource):
qry = qry.group_by(*groupby_exprs_with_timestamp.values())
where_clause_and = []
having_clause_and = []
having_clause_and: List = []
for flt in filter:
if not all([flt.get(s) for s in ["col", "op"]]):
continue
@ -899,7 +910,9 @@ class SqlaTable(Model, BaseDatasource):
return ob
def _get_top_groups(self, df, dimensions, groupby_exprs):
def _get_top_groups(
self, df: pd.DataFrame, dimensions: List, groupby_exprs: OrderedDict
) -> ColumnElement:
groups = []
for unused, row in df.iterrows():
group = []
@ -909,7 +922,7 @@ class SqlaTable(Model, BaseDatasource):
return or_(*groups)
def query(self, query_obj):
def query(self, query_obj: Dict) -> QueryResult:
qry_start_dttm = datetime.now()
query_str_ext = self.get_query_str_extended(query_obj)
sql = query_str_ext.sql
@ -945,10 +958,10 @@ class SqlaTable(Model, BaseDatasource):
error_message=error_message,
)
def get_sqla_table_object(self):
def get_sqla_table_object(self) -> Table:
return self.database.get_table(self.table_name, schema=self.schema)
def fetch_metadata(self):
def fetch_metadata(self) -> None:
"""Fetches the metadata for the table and merges it in"""
try:
table = self.get_sqla_table_object()
@ -1012,7 +1025,7 @@ class SqlaTable(Model, BaseDatasource):
db.session.commit()
@classmethod
def import_obj(cls, i_datasource, import_time=None):
def import_obj(cls, i_datasource, import_time=None) -> int:
"""Imports the datasource from the object to the database.
Metrics and columns and datasource will be overrided if exists.
@ -1052,7 +1065,9 @@ class SqlaTable(Model, BaseDatasource):
)
@classmethod
def query_datasources_by_name(cls, session, database, datasource_name, schema=None):
def query_datasources_by_name(
cls, session: Session, database: Database, datasource_name: str, schema=None
) -> List["SqlaTable"]:
query = (
session.query(cls)
.filter_by(database_id=database.id)
@ -1063,7 +1078,7 @@ class SqlaTable(Model, BaseDatasource):
return query.all()
@staticmethod
def default_query(qry):
def default_query(qry) -> Query:
return qry.filter_by(is_sqllab_view=False)
def has_extra_cache_keys(self, query_obj: Dict) -> bool: