mirror of https://github.com/apache/superset.git
[typing] add typing for superset/connectors and superset/common (#8138)
This commit is contained in:
parent
8bc5cd7dc0
commit
dfb3bf69a0
|
@ -18,20 +18,22 @@
|
|||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
import pickle as pkl
|
||||
from typing import Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from superset import app, cache
|
||||
from superset import db
|
||||
from superset.connectors.base.models import BaseDatasource
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from superset.stats_logger import BaseStatsLogger
|
||||
from superset.utils import core as utils
|
||||
from superset.utils.core import DTTM_ALIAS
|
||||
from .query_object import QueryObject
|
||||
|
||||
config = app.config
|
||||
stats_logger = config.get("STATS_LOGGER")
|
||||
stats_logger: BaseStatsLogger = config["STATS_LOGGER"]
|
||||
|
||||
|
||||
class QueryContext:
|
||||
|
@ -40,8 +42,13 @@ class QueryContext:
|
|||
to retrieve the data payload for a given viz.
|
||||
"""
|
||||
|
||||
cache_type = "df"
|
||||
enforce_numerical_metrics = True
|
||||
cache_type: str = "df"
|
||||
enforce_numerical_metrics: bool = True
|
||||
|
||||
datasource: BaseDatasource
|
||||
queries: List[QueryObject]
|
||||
force: bool
|
||||
custom_cache_timeout: Optional[int]
|
||||
|
||||
# TODO: Type datasource and query_object dictionary with TypedDict when it becomes
|
||||
# a vanilla python type https://github.com/python/mypy/issues/5288
|
||||
|
@ -50,8 +57,8 @@ class QueryContext:
|
|||
datasource: Dict,
|
||||
queries: List[Dict],
|
||||
force: bool = False,
|
||||
custom_cache_timeout: int = None,
|
||||
):
|
||||
custom_cache_timeout: Optional[int] = None,
|
||||
) -> None:
|
||||
self.datasource = ConnectorRegistry.get_datasource(
|
||||
datasource.get("type"), int(datasource.get("id")), db.session # noqa: T400
|
||||
)
|
||||
|
@ -61,9 +68,7 @@ class QueryContext:
|
|||
|
||||
self.custom_cache_timeout = custom_cache_timeout
|
||||
|
||||
self.enforce_numerical_metrics = True
|
||||
|
||||
def get_query_result(self, query_object):
|
||||
def get_query_result(self, query_object: QueryObject) -> Dict[str, Any]:
|
||||
"""Returns a pandas dataframe based on the query object"""
|
||||
|
||||
# Here, we assume that all the queries will use the same datasource, which is
|
||||
|
@ -109,23 +114,23 @@ class QueryContext:
|
|||
"df": df,
|
||||
}
|
||||
|
||||
def df_metrics_to_num(self, df, query_object):
|
||||
def df_metrics_to_num(self, df: pd.DataFrame, query_object: QueryObject) -> None:
|
||||
"""Converting metrics to numeric when pandas.read_sql cannot"""
|
||||
metrics = [metric for metric in query_object.metrics]
|
||||
for col, dtype in df.dtypes.items():
|
||||
if dtype.type == np.object_ and col in metrics:
|
||||
df[col] = pd.to_numeric(df[col], errors="coerce")
|
||||
|
||||
def get_data(self, df):
|
||||
def get_data(self, df: pd.DataFrame) -> List[Dict]:
|
||||
return df.to_dict(orient="records")
|
||||
|
||||
def get_single_payload(self, query_obj: QueryObject):
|
||||
def get_single_payload(self, query_obj: QueryObject) -> Dict[str, Any]:
|
||||
"""Returns a payload of metadata and data"""
|
||||
payload = self.get_df_payload(query_obj)
|
||||
df = payload.get("df")
|
||||
status = payload.get("status")
|
||||
if status != utils.QueryStatus.FAILED:
|
||||
if df is not None and df.empty:
|
||||
if df is None or df.empty:
|
||||
payload["error"] = "No data"
|
||||
else:
|
||||
payload["data"] = self.get_data(df)
|
||||
|
@ -133,12 +138,12 @@ class QueryContext:
|
|||
del payload["df"]
|
||||
return payload
|
||||
|
||||
def get_payload(self):
|
||||
def get_payload(self) -> List[Dict[str, Any]]:
|
||||
"""Get all the payloads from the arrays"""
|
||||
return [self.get_single_payload(query_object) for query_object in self.queries]
|
||||
|
||||
@property
|
||||
def cache_timeout(self):
|
||||
def cache_timeout(self) -> int:
|
||||
if self.custom_cache_timeout is not None:
|
||||
return self.custom_cache_timeout
|
||||
if self.datasource.cache_timeout is not None:
|
||||
|
@ -148,10 +153,10 @@ class QueryContext:
|
|||
and self.datasource.database.cache_timeout
|
||||
) is not None:
|
||||
return self.datasource.database.cache_timeout
|
||||
return config.get("CACHE_DEFAULT_TIMEOUT")
|
||||
return config["CACHE_DEFAULT_TIMEOUT"]
|
||||
|
||||
def get_df_payload(self, query_obj: QueryObject, **kwargs):
|
||||
"""Handles caching around the df paylod retrieval"""
|
||||
def get_df_payload(self, query_obj: QueryObject, **kwargs) -> Dict[str, Any]:
|
||||
"""Handles caching around the df payload retrieval"""
|
||||
extra_cache_keys = self.datasource.get_extra_cache_keys(query_obj.to_dict())
|
||||
cache_key = (
|
||||
query_obj.cache_key(
|
||||
|
@ -207,9 +212,7 @@ class QueryContext:
|
|||
|
||||
if is_loaded and cache_key and cache and status != utils.QueryStatus.FAILED:
|
||||
try:
|
||||
cache_value = dict(
|
||||
dttm=cached_dttm, df=df if df is not None else None, query=query
|
||||
)
|
||||
cache_value = dict(dttm=cached_dttm, df=df, query=query)
|
||||
cache_binary = pkl.dumps(cache_value, protocol=pkl.HIGHEST_PROTOCOL)
|
||||
|
||||
logging.info(
|
||||
|
|
|
@ -15,8 +15,9 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=R
|
||||
from datetime import datetime, timedelta
|
||||
import hashlib
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import simplejson as json
|
||||
|
||||
|
@ -34,12 +35,28 @@ class QueryObject:
|
|||
and druid. The query objects are constructed on the client.
|
||||
"""
|
||||
|
||||
granularity: str
|
||||
from_dttm: datetime
|
||||
to_dttm: datetime
|
||||
is_timeseries: bool
|
||||
time_shift: timedelta
|
||||
groupby: List[str]
|
||||
metrics: List[Union[Dict, str]]
|
||||
row_limit: int
|
||||
filter: List[str]
|
||||
timeseries_limit: int
|
||||
timeseries_limit_metric: Optional[Dict]
|
||||
order_desc: bool
|
||||
extras: Dict
|
||||
columns: List[str]
|
||||
orderby: List[List]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
granularity: str,
|
||||
metrics: List[Union[Dict, str]],
|
||||
groupby: List[str] = None,
|
||||
filters: List[str] = None,
|
||||
groupby: Optional[List[str]] = None,
|
||||
filters: Optional[List[str]] = None,
|
||||
time_range: Optional[str] = None,
|
||||
time_shift: Optional[str] = None,
|
||||
is_timeseries: bool = False,
|
||||
|
@ -48,8 +65,8 @@ class QueryObject:
|
|||
timeseries_limit_metric: Optional[Dict] = None,
|
||||
order_desc: bool = True,
|
||||
extras: Optional[Dict] = None,
|
||||
columns: List[str] = None,
|
||||
orderby: List[List] = None,
|
||||
columns: Optional[List[str]] = None,
|
||||
orderby: Optional[List[List]] = None,
|
||||
relative_start: str = app.config.get("DEFAULT_RELATIVE_START_TIME", "today"),
|
||||
relative_end: str = app.config.get("DEFAULT_RELATIVE_END_TIME", "today"),
|
||||
):
|
||||
|
@ -63,7 +80,7 @@ class QueryObject:
|
|||
self.is_timeseries = is_timeseries
|
||||
self.time_range = time_range
|
||||
self.time_shift = utils.parse_human_timedelta(time_shift)
|
||||
self.groupby = groupby if groupby is not None else []
|
||||
self.groupby = groupby or []
|
||||
|
||||
# Temporal solution for backward compatability issue
|
||||
# due the new format of non-ad-hoc metric.
|
||||
|
@ -72,15 +89,15 @@ class QueryObject:
|
|||
for metric in metrics
|
||||
]
|
||||
self.row_limit = row_limit
|
||||
self.filter = filters if filters is not None else []
|
||||
self.filter = filters or []
|
||||
self.timeseries_limit = timeseries_limit
|
||||
self.timeseries_limit_metric = timeseries_limit_metric
|
||||
self.order_desc = order_desc
|
||||
self.extras = extras if extras is not None else {}
|
||||
self.columns = columns if columns is not None else []
|
||||
self.orderby = orderby if orderby is not None else []
|
||||
self.extras = extras or {}
|
||||
self.columns = columns or []
|
||||
self.orderby = orderby or []
|
||||
|
||||
def to_dict(self):
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
query_object_dict = {
|
||||
"granularity": self.granularity,
|
||||
"from_dttm": self.from_dttm,
|
||||
|
@ -99,7 +116,7 @@ class QueryObject:
|
|||
}
|
||||
return query_object_dict
|
||||
|
||||
def cache_key(self, **extra):
|
||||
def cache_key(self, **extra) -> str:
|
||||
"""
|
||||
The cache key is made out of the key/values from to_dict(), plus any
|
||||
other key/values in `extra`
|
||||
|
@ -117,7 +134,7 @@ class QueryObject:
|
|||
json_data = self.json_dumps(cache_dict, sort_keys=True)
|
||||
return hashlib.md5(json_data.encode("utf-8")).hexdigest()
|
||||
|
||||
def json_dumps(self, obj, sort_keys=False):
|
||||
def json_dumps(self, obj: Any, sort_keys: bool = False) -> str:
|
||||
return json.dumps(
|
||||
obj, default=utils.json_int_dttm_ser, ignore_nan=True, sort_keys=sort_keys
|
||||
)
|
||||
|
|
|
@ -16,14 +16,15 @@
|
|||
# under the License.
|
||||
# pylint: disable=C,R,W
|
||||
import json
|
||||
from typing import Any, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
from sqlalchemy import and_, Boolean, Column, Integer, String, Text
|
||||
from sqlalchemy.ext.declarative import declared_attr
|
||||
from sqlalchemy.orm import foreign, relationship
|
||||
from sqlalchemy.orm import foreign, Query, relationship
|
||||
|
||||
from superset.models.core import Slice
|
||||
from superset.models.helpers import AuditMixinNullable, ImportMixin
|
||||
from superset.models.helpers import AuditMixinNullable, ImportMixin, QueryResult
|
||||
from superset.utils import core as utils
|
||||
|
||||
|
||||
|
@ -59,9 +60,9 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
|
|||
params = Column(String(1000))
|
||||
perm = Column(String(1000))
|
||||
|
||||
sql = None
|
||||
owners = None
|
||||
update_from_object_fields = None
|
||||
sql: Optional[str] = None
|
||||
owners: List[User]
|
||||
update_from_object_fields: List[str]
|
||||
|
||||
@declared_attr
|
||||
def slices(self):
|
||||
|
@ -79,20 +80,20 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
|
|||
metrics: List[Any] = []
|
||||
|
||||
@property
|
||||
def uid(self):
|
||||
def uid(self) -> str:
|
||||
"""Unique id across datasource types"""
|
||||
return f"{self.id}__{self.type}"
|
||||
|
||||
@property
|
||||
def column_names(self):
|
||||
def column_names(self) -> List[str]:
|
||||
return sorted([c.column_name for c in self.columns], key=lambda x: x or "")
|
||||
|
||||
@property
|
||||
def columns_types(self):
|
||||
def columns_types(self) -> Dict:
|
||||
return {c.column_name: c.type for c in self.columns}
|
||||
|
||||
@property
|
||||
def main_dttm_col(self):
|
||||
def main_dttm_col(self) -> str:
|
||||
return "timestamp"
|
||||
|
||||
@property
|
||||
|
@ -100,47 +101,47 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
|
|||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def connection(self):
|
||||
def connection(self) -> Optional[str]:
|
||||
"""String representing the context of the Datasource"""
|
||||
return None
|
||||
|
||||
@property
|
||||
def schema(self):
|
||||
def schema(self) -> Optional[str]:
|
||||
"""String representing the schema of the Datasource (if it applies)"""
|
||||
return None
|
||||
|
||||
@property
|
||||
def filterable_column_names(self):
|
||||
def filterable_column_names(self) -> List[str]:
|
||||
return sorted([c.column_name for c in self.columns if c.filterable])
|
||||
|
||||
@property
|
||||
def dttm_cols(self):
|
||||
def dttm_cols(self) -> List:
|
||||
return []
|
||||
|
||||
@property
|
||||
def url(self):
|
||||
def url(self) -> str:
|
||||
return "/{}/edit/{}".format(self.baselink, self.id)
|
||||
|
||||
@property
|
||||
def explore_url(self):
|
||||
def explore_url(self) -> str:
|
||||
if self.default_endpoint:
|
||||
return self.default_endpoint
|
||||
else:
|
||||
return "/superset/explore/{obj.type}/{obj.id}/".format(obj=self)
|
||||
|
||||
@property
|
||||
def column_formats(self):
|
||||
def column_formats(self) -> Dict[str, Optional[str]]:
|
||||
return {m.metric_name: m.d3format for m in self.metrics if m.d3format}
|
||||
|
||||
def add_missing_metrics(self, metrics):
|
||||
exisiting_metrics = {m.metric_name for m in self.metrics}
|
||||
def add_missing_metrics(self, metrics: List["BaseMetric"]) -> None:
|
||||
existing_metrics = {m.metric_name for m in self.metrics}
|
||||
for metric in metrics:
|
||||
if metric.metric_name not in exisiting_metrics:
|
||||
if metric.metric_name not in existing_metrics:
|
||||
metric.table_id = self.id
|
||||
self.metrics += [metric]
|
||||
self.metrics.append(metric)
|
||||
|
||||
@property
|
||||
def short_data(self):
|
||||
def short_data(self) -> Dict[str, Any]:
|
||||
"""Data representation of the datasource sent to the frontend"""
|
||||
return {
|
||||
"edit_url": self.url,
|
||||
|
@ -158,7 +159,7 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
|
|||
pass
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
def data(self) -> Dict[str, Any]:
|
||||
"""Data representation of the datasource sent to the frontend"""
|
||||
order_by_choices = []
|
||||
# self.column_names return sorted column_names
|
||||
|
@ -239,14 +240,14 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
|
|||
"""Returns column information from the external system"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_query_str(self, query_obj):
|
||||
def get_query_str(self, query_obj) -> str:
|
||||
"""Returns a query as a string
|
||||
|
||||
This is used to be displayed to the user so that she/he can
|
||||
understand what is taking place behind the scene"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def query(self, query_obj):
|
||||
def query(self, query_obj) -> QueryResult:
|
||||
"""Executes the query and returns a dataframe
|
||||
|
||||
query_obj is a dictionary representing Superset's query interface.
|
||||
|
@ -254,7 +255,7 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def values_for_column(self, column_name, limit=10000):
|
||||
def values_for_column(self, column_name: str, limit: int = 10000) -> List:
|
||||
"""Given a column, returns an iterable of distinct values
|
||||
|
||||
This is used to populate the dropdown showing a list of
|
||||
|
@ -262,13 +263,14 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
|
|||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def default_query(qry):
|
||||
def default_query(qry) -> Query:
|
||||
return qry
|
||||
|
||||
def get_column(self, column_name):
|
||||
def get_column(self, column_name: str) -> Optional["BaseColumn"]:
|
||||
for col in self.columns:
|
||||
if col.column_name == column_name:
|
||||
return col
|
||||
return None
|
||||
|
||||
def get_fk_many_from_list(self, object_list, fkmany, fkmany_class, key_attr):
|
||||
"""Update ORM one-to-many list from object list
|
||||
|
@ -276,10 +278,9 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
|
|||
Used for syncing metrics and columns using the same code"""
|
||||
|
||||
object_dict = {o.get(key_attr): o for o in object_list}
|
||||
object_keys = [o.get(key_attr) for o in object_list]
|
||||
|
||||
# delete fks that have been removed
|
||||
fkmany = [o for o in fkmany if getattr(o, key_attr) in object_keys]
|
||||
fkmany = [o for o in fkmany if getattr(o, key_attr) in object_dict]
|
||||
|
||||
# sync existing fks
|
||||
for fk in fkmany:
|
||||
|
@ -303,7 +304,7 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
|
|||
fkmany += new_fks
|
||||
return fkmany
|
||||
|
||||
def update_from_object(self, obj):
|
||||
def update_from_object(self, obj) -> None:
|
||||
"""Update datasource from a data structure
|
||||
|
||||
The UI's table editor crafts a complex data structure that
|
||||
|
@ -330,7 +331,7 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
|
|||
obj.get("columns"), self.columns, self.column_class, "column_name"
|
||||
)
|
||||
|
||||
def get_extra_cache_keys(self, query_obj) -> List[Any]:
|
||||
def get_extra_cache_keys(self, query_obj: Dict) -> List[Any]:
|
||||
""" If a datasource needs to provide additional keys for calculation of
|
||||
cache keys, those can be provided via this method
|
||||
"""
|
||||
|
@ -374,23 +375,23 @@ class BaseColumn(AuditMixinNullable, ImportMixin):
|
|||
str_types = ("VARCHAR", "STRING", "CHAR")
|
||||
|
||||
@property
|
||||
def is_num(self):
|
||||
return self.type and any([t in self.type.upper() for t in self.num_types])
|
||||
def is_num(self) -> bool:
|
||||
return self.type and any(map(lambda t: t in self.type.upper(), self.num_types))
|
||||
|
||||
@property
|
||||
def is_time(self):
|
||||
return self.type and any([t in self.type.upper() for t in self.date_types])
|
||||
def is_time(self) -> bool:
|
||||
return self.type and any(map(lambda t: t in self.type.upper(), self.date_types))
|
||||
|
||||
@property
|
||||
def is_string(self):
|
||||
return self.type and any([t in self.type.upper() for t in self.str_types])
|
||||
def is_string(self) -> bool:
|
||||
return self.type and any(map(lambda t: t in self.type.upper(), self.str_types))
|
||||
|
||||
@property
|
||||
def expression(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
def data(self) -> Dict[str, Any]:
|
||||
attrs = (
|
||||
"id",
|
||||
"column_name",
|
||||
|
@ -443,7 +444,7 @@ class BaseMetric(AuditMixinNullable, ImportMixin):
|
|||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
def data(self) -> Dict[str, Any]:
|
||||
attrs = (
|
||||
"id",
|
||||
"metric_name",
|
||||
|
|
|
@ -15,16 +15,23 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=C,R,W
|
||||
from sqlalchemy.orm import subqueryload
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, List, Optional, Set, Type, TYPE_CHECKING
|
||||
|
||||
from sqlalchemy.orm import Session, subqueryload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.models.core import Database
|
||||
from superset.connectors.base.models import BaseDatasource
|
||||
|
||||
|
||||
class ConnectorRegistry(object):
|
||||
""" Central Registry for all available datasource engines"""
|
||||
|
||||
sources = {}
|
||||
sources: Dict[str, Type["BaseDatasource"]] = {}
|
||||
|
||||
@classmethod
|
||||
def register_sources(cls, datasource_config):
|
||||
def register_sources(cls, datasource_config: OrderedDict) -> None:
|
||||
for module_name, class_names in datasource_config.items():
|
||||
class_names = [str(s) for s in class_names]
|
||||
module_obj = __import__(module_name, fromlist=class_names)
|
||||
|
@ -33,7 +40,9 @@ class ConnectorRegistry(object):
|
|||
cls.sources[source_class.type] = source_class
|
||||
|
||||
@classmethod
|
||||
def get_datasource(cls, datasource_type, datasource_id, session):
|
||||
def get_datasource(
|
||||
cls, datasource_type: str, datasource_id: int, session: Session
|
||||
) -> Optional["BaseDatasource"]:
|
||||
return (
|
||||
session.query(cls.sources[datasource_type])
|
||||
.filter_by(id=datasource_id)
|
||||
|
@ -41,8 +50,8 @@ class ConnectorRegistry(object):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def get_all_datasources(cls, session):
|
||||
datasources = []
|
||||
def get_all_datasources(cls, session: Session) -> List["BaseDatasource"]:
|
||||
datasources: List["BaseDatasource"] = []
|
||||
for source_type in ConnectorRegistry.sources:
|
||||
source_class = ConnectorRegistry.sources[source_type]
|
||||
qry = session.query(source_class)
|
||||
|
@ -52,15 +61,22 @@ class ConnectorRegistry(object):
|
|||
|
||||
@classmethod
|
||||
def get_datasource_by_name(
|
||||
cls, session, datasource_type, datasource_name, schema, database_name
|
||||
):
|
||||
cls,
|
||||
session: Session,
|
||||
datasource_type: str,
|
||||
datasource_name: str,
|
||||
schema: str,
|
||||
database_name: str,
|
||||
) -> Optional["BaseDatasource"]:
|
||||
datasource_class = ConnectorRegistry.sources[datasource_type]
|
||||
return datasource_class.get_datasource_by_name(
|
||||
session, datasource_name, schema, database_name
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def query_datasources_by_permissions(cls, session, database, permissions):
|
||||
def query_datasources_by_permissions(
|
||||
cls, session: Session, database: "Database", permissions: Set[str]
|
||||
) -> List["BaseDatasource"]:
|
||||
datasource_class = ConnectorRegistry.sources[database.type]
|
||||
return (
|
||||
session.query(datasource_class)
|
||||
|
@ -70,7 +86,9 @@ class ConnectorRegistry(object):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def get_eager_datasource(cls, session, datasource_type, datasource_id):
|
||||
def get_eager_datasource(
|
||||
cls, session: Session, datasource_type: str, datasource_id: int
|
||||
) -> "BaseDatasource":
|
||||
"""Returns datasource with columns and metrics."""
|
||||
datasource_class = ConnectorRegistry.sources[datasource_type]
|
||||
return (
|
||||
|
@ -84,7 +102,13 @@ class ConnectorRegistry(object):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def query_datasources_by_name(cls, session, database, datasource_name, schema=None):
|
||||
def query_datasources_by_name(
|
||||
cls,
|
||||
session: Session,
|
||||
database: "Database",
|
||||
datasource_name: str,
|
||||
schema: Optional[str] = None,
|
||||
) -> List["BaseDatasource"]:
|
||||
datasource_class = ConnectorRegistry.sources[database.type]
|
||||
return datasource_class.query_datasources_by_name(
|
||||
session, database, datasource_name, schema=None
|
||||
|
|
|
@ -24,13 +24,15 @@ import json
|
|||
import logging
|
||||
from multiprocessing.pool import ThreadPool
|
||||
import re
|
||||
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
from dateutil.parser import parse as dparse
|
||||
from flask import escape, Markup
|
||||
from flask_appbuilder import Model
|
||||
from flask_appbuilder.models.decorators import renders
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
from flask_babel import lazy_gettext as _
|
||||
import pandas
|
||||
import pandas as pd
|
||||
|
||||
try:
|
||||
from pydruid.client import PyDruid
|
||||
|
@ -41,7 +43,7 @@ try:
|
|||
RegisteredLookupExtraction,
|
||||
)
|
||||
from pydruid.utils.filters import Dimension, Filter
|
||||
from pydruid.utils.having import Aggregation
|
||||
from pydruid.utils.having import Aggregation, Having
|
||||
from pydruid.utils.postaggregator import (
|
||||
Const,
|
||||
Field,
|
||||
|
@ -65,12 +67,13 @@ from sqlalchemy import (
|
|||
Text,
|
||||
UniqueConstraint,
|
||||
)
|
||||
from sqlalchemy.orm import backref, relationship
|
||||
from sqlalchemy.orm import backref, relationship, RelationshipProperty, Session
|
||||
from sqlalchemy_utils import EncryptedType
|
||||
|
||||
from superset import conf, db, security_manager
|
||||
from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric
|
||||
from superset.exceptions import SupersetException
|
||||
from superset.models.core import Database
|
||||
from superset.models.helpers import AuditMixinNullable, ImportMixin, QueryResult
|
||||
from superset.utils import core as utils, import_datasource
|
||||
|
||||
|
@ -78,6 +81,8 @@ try:
|
|||
from superset.utils.core import DimSelector, DTTM_ALIAS, flasher
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
DRUID_TZ = conf.get("DRUID_TZ")
|
||||
POST_AGG_TYPE = "postagg"
|
||||
metadata = Model.metadata # pylint: disable=no-member
|
||||
|
@ -150,22 +155,22 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin):
|
|||
return self.__repr__()
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
def data(self) -> Dict:
|
||||
return {"id": self.id, "name": self.cluster_name, "backend": "druid"}
|
||||
|
||||
@staticmethod
|
||||
def get_base_url(host, port):
|
||||
def get_base_url(host, port) -> str:
|
||||
if not re.match("http(s)?://", host):
|
||||
host = "http://" + host
|
||||
|
||||
url = "{0}:{1}".format(host, port) if port else host
|
||||
return url
|
||||
|
||||
def get_base_broker_url(self):
|
||||
def get_base_broker_url(self) -> str:
|
||||
base_url = self.get_base_url(self.broker_host, self.broker_port)
|
||||
return f"{base_url}/{self.broker_endpoint}"
|
||||
|
||||
def get_pydruid_client(self):
|
||||
def get_pydruid_client(self) -> PyDruid:
|
||||
cli = PyDruid(
|
||||
self.get_base_url(self.broker_host, self.broker_port), self.broker_endpoint
|
||||
)
|
||||
|
@ -173,39 +178,44 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin):
|
|||
cli.set_basic_auth_credentials(self.broker_user, self.broker_pass)
|
||||
return cli
|
||||
|
||||
def get_datasources(self):
|
||||
def get_datasources(self) -> List[str]:
|
||||
endpoint = self.get_base_broker_url() + "/datasources"
|
||||
auth = requests.auth.HTTPBasicAuth(self.broker_user, self.broker_pass)
|
||||
return json.loads(requests.get(endpoint, auth=auth).text)
|
||||
|
||||
def get_druid_version(self):
|
||||
def get_druid_version(self) -> str:
|
||||
endpoint = self.get_base_url(self.broker_host, self.broker_port) + "/status"
|
||||
auth = requests.auth.HTTPBasicAuth(self.broker_user, self.broker_pass)
|
||||
return json.loads(requests.get(endpoint, auth=auth).text)["version"]
|
||||
|
||||
@property
|
||||
@property # noqa: T484
|
||||
@utils.memoized
|
||||
def druid_version(self):
|
||||
def druid_version(self) -> str:
|
||||
return self.get_druid_version()
|
||||
|
||||
def refresh_datasources(
|
||||
self, datasource_name=None, merge_flag=True, refreshAll=True
|
||||
):
|
||||
self,
|
||||
datasource_name: Optional[str] = None,
|
||||
merge_flag: bool = True,
|
||||
refresh_all: bool = True,
|
||||
) -> None:
|
||||
"""Refresh metadata of all datasources in the cluster
|
||||
If ``datasource_name`` is specified, only that datasource is updated
|
||||
"""
|
||||
ds_list = self.get_datasources()
|
||||
blacklist = conf.get("DRUID_DATA_SOURCE_BLACKLIST", [])
|
||||
ds_refresh = []
|
||||
ds_refresh: List[str] = []
|
||||
if not datasource_name:
|
||||
ds_refresh = list(filter(lambda ds: ds not in blacklist, ds_list))
|
||||
elif datasource_name not in blacklist and datasource_name in ds_list:
|
||||
ds_refresh.append(datasource_name)
|
||||
else:
|
||||
return
|
||||
self.refresh(ds_refresh, merge_flag, refreshAll)
|
||||
self.refresh(ds_refresh, merge_flag, refresh_all)
|
||||
|
||||
def refresh(self, datasource_names, merge_flag, refreshAll):
|
||||
def refresh(
|
||||
self, datasource_names: List[str], merge_flag: bool, refresh_all: bool
|
||||
) -> None:
|
||||
"""
|
||||
Fetches metadata for the specified datasources and
|
||||
merges to the Superset database
|
||||
|
@ -225,7 +235,7 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin):
|
|||
session.add(datasource)
|
||||
flasher(_("Adding new datasource [{}]").format(ds_name), "success")
|
||||
ds_map[ds_name] = datasource
|
||||
elif refreshAll:
|
||||
elif refresh_all:
|
||||
flasher(_("Refreshing datasource [{}]").format(ds_name), "info")
|
||||
else:
|
||||
del ds_map[ds_name]
|
||||
|
@ -270,19 +280,19 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin):
|
|||
session.commit()
|
||||
|
||||
@property
|
||||
def perm(self):
|
||||
def perm(self) -> str:
|
||||
return "[{obj.cluster_name}].(id:{obj.id})".format(obj=self)
|
||||
|
||||
def get_perm(self):
|
||||
def get_perm(self) -> str:
|
||||
return self.perm
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.verbose_name if self.verbose_name else self.cluster_name
|
||||
def name(self) -> str:
|
||||
return self.verbose_name or self.cluster_name
|
||||
|
||||
@property
|
||||
def unique_name(self):
|
||||
return self.verbose_name if self.verbose_name else self.cluster_name
|
||||
def unique_name(self) -> str:
|
||||
return self.verbose_name or self.cluster_name
|
||||
|
||||
|
||||
class DruidColumn(Model, BaseColumn):
|
||||
|
@ -318,25 +328,26 @@ class DruidColumn(Model, BaseColumn):
|
|||
return self.column_name or str(self.id)
|
||||
|
||||
@property
|
||||
def expression(self):
|
||||
def expression(self) -> str:
|
||||
return self.dimension_spec_json
|
||||
|
||||
@property
|
||||
def dimension_spec(self):
|
||||
def dimension_spec(self) -> Optional[Dict]: # noqa: T484
|
||||
if self.dimension_spec_json:
|
||||
return json.loads(self.dimension_spec_json)
|
||||
|
||||
def get_metrics(self):
|
||||
metrics = {}
|
||||
metrics["count"] = DruidMetric(
|
||||
def get_metrics(self) -> Dict[str, "DruidMetric"]:
|
||||
metrics = {
|
||||
"count": DruidMetric(
|
||||
metric_name="count",
|
||||
verbose_name="COUNT(*)",
|
||||
metric_type="count",
|
||||
json=json.dumps({"type": "count", "name": "count"}),
|
||||
)
|
||||
}
|
||||
return metrics
|
||||
|
||||
def refresh_metrics(self):
|
||||
def refresh_metrics(self) -> None:
|
||||
"""Refresh metrics based on the column metadata"""
|
||||
metrics = self.get_metrics()
|
||||
dbmetrics = (
|
||||
|
@ -356,8 +367,8 @@ class DruidColumn(Model, BaseColumn):
|
|||
db.session.add(metric)
|
||||
|
||||
@classmethod
|
||||
def import_obj(cls, i_column):
|
||||
def lookup_obj(lookup_column):
|
||||
def import_obj(cls, i_column: "DruidColumn") -> "DruidColumn":
|
||||
def lookup_obj(lookup_column: "DruidColumn") -> Optional["DruidColumn"]:
|
||||
return (
|
||||
db.session.query(DruidColumn)
|
||||
.filter(
|
||||
|
@ -404,7 +415,7 @@ class DruidMetric(Model, BaseMetric):
|
|||
return self.json
|
||||
|
||||
@property
|
||||
def json_obj(self):
|
||||
def json_obj(self) -> Dict:
|
||||
try:
|
||||
obj = json.loads(self.json)
|
||||
except Exception:
|
||||
|
@ -412,7 +423,7 @@ class DruidMetric(Model, BaseMetric):
|
|||
return obj
|
||||
|
||||
@property
|
||||
def perm(self):
|
||||
def perm(self) -> Optional[str]:
|
||||
return (
|
||||
("{parent_name}.[{obj.metric_name}](id:{obj.id})").format(
|
||||
obj=self, parent_name=self.datasource.full_name
|
||||
|
@ -421,12 +432,12 @@ class DruidMetric(Model, BaseMetric):
|
|||
else None
|
||||
)
|
||||
|
||||
def get_perm(self):
|
||||
def get_perm(self) -> Optional[str]:
|
||||
return self.perm
|
||||
|
||||
@classmethod
|
||||
def import_obj(cls, i_metric):
|
||||
def lookup_obj(lookup_metric):
|
||||
def import_obj(cls, i_metric: "DruidMetric") -> "DruidMetric":
|
||||
def lookup_obj(lookup_metric: DruidMetric) -> Optional[DruidMetric]:
|
||||
return (
|
||||
db.session.query(DruidMetric)
|
||||
.filter(
|
||||
|
@ -494,23 +505,23 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
export_children = ["columns", "metrics"]
|
||||
|
||||
@property
|
||||
def database(self):
|
||||
def database(self) -> RelationshipProperty:
|
||||
return self.cluster
|
||||
|
||||
@property
|
||||
def connection(self):
|
||||
def connection(self) -> str:
|
||||
return str(self.database)
|
||||
|
||||
@property
|
||||
def num_cols(self):
|
||||
def num_cols(self) -> List[str]:
|
||||
return [c.column_name for c in self.columns if c.is_num]
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
return self.datasource_name
|
||||
|
||||
@property
|
||||
def schema(self):
|
||||
def schema(self) -> Optional[str]:
|
||||
ds_name = self.datasource_name or ""
|
||||
name_pieces = ds_name.split(".")
|
||||
if len(name_pieces) > 1:
|
||||
|
@ -519,11 +530,11 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
return None
|
||||
|
||||
@property
|
||||
def schema_perm(self):
|
||||
def schema_perm(self) -> Optional[str]:
|
||||
"""Returns schema permission if present, cluster one otherwise."""
|
||||
return security_manager.get_schema_perm(self.cluster, self.schema)
|
||||
|
||||
def get_perm(self):
|
||||
def get_perm(self) -> str:
|
||||
return ("[{obj.cluster_name}].[{obj.datasource_name}]" "(id:{obj.id})").format(
|
||||
obj=self
|
||||
)
|
||||
|
@ -532,16 +543,16 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
return NotImplementedError()
|
||||
|
||||
@property
|
||||
def link(self):
|
||||
def link(self) -> Markup:
|
||||
name = escape(self.datasource_name)
|
||||
return Markup(f'<a href="{self.url}">{name}</a>')
|
||||
|
||||
@property
|
||||
def full_name(self):
|
||||
def full_name(self) -> str:
|
||||
return utils.get_datasource_full_name(self.cluster_name, self.datasource_name)
|
||||
|
||||
@property
|
||||
def time_column_grains(self):
|
||||
def time_column_grains(self) -> Dict[str, List[str]]:
|
||||
return {
|
||||
"time_columns": [
|
||||
"all",
|
||||
|
@ -568,16 +579,18 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
return self.datasource_name
|
||||
|
||||
@renders("datasource_name")
|
||||
def datasource_link(self):
|
||||
def datasource_link(self) -> str:
|
||||
url = f"/superset/explore/{self.type}/{self.id}/"
|
||||
name = escape(self.datasource_name)
|
||||
return Markup(f'<a href="{url}">{name}</a>')
|
||||
|
||||
def get_metric_obj(self, metric_name):
|
||||
def get_metric_obj(self, metric_name: str) -> Dict:
|
||||
return [m.json_obj for m in self.metrics if m.metric_name == metric_name][0]
|
||||
|
||||
@classmethod
|
||||
def import_obj(cls, i_datasource, import_time=None):
|
||||
def import_obj(
|
||||
cls, i_datasource: "DruidDatasource", import_time: Optional[int] = None
|
||||
) -> int:
|
||||
"""Imports the datasource from the object to the database.
|
||||
|
||||
Metrics and columns and datasource will be overridden if exists.
|
||||
|
@ -585,7 +598,7 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
superset instances. Audit metadata isn't copies over.
|
||||
"""
|
||||
|
||||
def lookup_datasource(d):
|
||||
def lookup_datasource(d: DruidDatasource) -> Optional[DruidDatasource]:
|
||||
return (
|
||||
db.session.query(DruidDatasource)
|
||||
.filter(
|
||||
|
@ -595,7 +608,7 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
.first()
|
||||
)
|
||||
|
||||
def lookup_cluster(d):
|
||||
def lookup_cluster(d: DruidDatasource) -> Optional[DruidCluster]:
|
||||
return (
|
||||
db.session.query(DruidCluster)
|
||||
.filter_by(cluster_name=d.cluster_name)
|
||||
|
@ -659,12 +672,14 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
if segment_metadata:
|
||||
return segment_metadata[-1]["columns"]
|
||||
|
||||
def refresh_metrics(self):
|
||||
def refresh_metrics(self) -> None:
|
||||
for col in self.columns:
|
||||
col.refresh_metrics()
|
||||
|
||||
@classmethod
|
||||
def sync_to_db_from_config(cls, druid_config, user, cluster, refresh=True):
|
||||
def sync_to_db_from_config(
|
||||
cls, druid_config: Dict, user: User, cluster: DruidCluster, refresh: bool = True
|
||||
) -> None:
|
||||
"""Merges the ds config from druid_config into one stored in the db."""
|
||||
session = db.session
|
||||
datasource = (
|
||||
|
@ -742,13 +757,15 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def time_offset(granularity):
|
||||
def time_offset(granularity: Union[str, Dict]) -> int:
|
||||
if granularity == "week_ending_saturday":
|
||||
return 6 * 24 * 3600 * 1000 # 6 days
|
||||
return 0
|
||||
|
||||
@classmethod
|
||||
def get_datasource_by_name(cls, session, datasource_name, schema, database_name):
|
||||
def get_datasource_by_name(
|
||||
cls, session: Session, datasource_name: str, schema: str, database_name: str
|
||||
) -> Optional["DruidDatasource"]:
|
||||
query = (
|
||||
session.query(cls)
|
||||
.join(DruidCluster)
|
||||
|
@ -761,7 +778,9 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
# http://druid.io/docs/0.8.0/querying/granularities.html
|
||||
# TODO: pass origin from the UI
|
||||
@staticmethod
|
||||
def granularity(period_name, timezone=None, origin=None):
|
||||
def granularity(
|
||||
period_name: str, timezone: Optional[str] = None, origin: Optional[str] = None
|
||||
) -> Union[str, Dict]:
|
||||
if not period_name or period_name == "all":
|
||||
return "all"
|
||||
iso_8601_dict = {
|
||||
|
@ -810,7 +829,7 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
return granularity
|
||||
|
||||
@staticmethod
|
||||
def get_post_agg(mconf):
|
||||
def get_post_agg(mconf: Dict) -> Postaggregator:
|
||||
"""
|
||||
For a metric specified as `postagg` returns the
|
||||
kind of post aggregation for pydruid.
|
||||
|
@ -839,7 +858,9 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
return CustomPostAggregator(mconf.get("name", ""), mconf)
|
||||
|
||||
@staticmethod
|
||||
def find_postaggs_for(postagg_names, metrics_dict):
|
||||
def find_postaggs_for(
|
||||
postagg_names: Set[str], metrics_dict: Dict[str, DruidMetric]
|
||||
) -> List[DruidMetric]:
|
||||
"""Return a list of metrics that are post aggregations"""
|
||||
postagg_metrics = [
|
||||
metrics_dict[name]
|
||||
|
@ -852,7 +873,7 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
return postagg_metrics
|
||||
|
||||
@staticmethod
|
||||
def recursive_get_fields(_conf):
|
||||
def recursive_get_fields(_conf: Dict) -> List[str]:
|
||||
_type = _conf.get("type")
|
||||
_field = _conf.get("field")
|
||||
_fields = _conf.get("fields")
|
||||
|
@ -875,11 +896,9 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
# Check if the fields are already in aggs
|
||||
# or is a previous postagg
|
||||
required_fields = set(
|
||||
[
|
||||
field
|
||||
for field in required_fields
|
||||
if field not in visited_postaggs and field not in agg_names
|
||||
]
|
||||
)
|
||||
# First try to find postaggs that match
|
||||
if len(required_fields) > 0:
|
||||
|
@ -903,7 +922,11 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
post_aggs[postagg.metric_name] = DruidDatasource.get_post_agg(postagg.json_obj)
|
||||
|
||||
@staticmethod
|
||||
def metrics_and_post_aggs(metrics, metrics_dict, druid_version=None):
|
||||
def metrics_and_post_aggs(
|
||||
metrics: List[Union[Dict, str]],
|
||||
metrics_dict: Dict[str, DruidMetric],
|
||||
druid_version=None,
|
||||
) -> Tuple[OrderedDict, OrderedDict]: # noqa: T484
|
||||
# Separate metrics into those that are aggregations
|
||||
# and those that are post aggregations
|
||||
saved_agg_names = set()
|
||||
|
@ -912,26 +935,26 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
for metric in metrics:
|
||||
if utils.is_adhoc_metric(metric):
|
||||
adhoc_agg_configs.append(metric)
|
||||
elif metrics_dict[metric].metric_type != POST_AGG_TYPE:
|
||||
elif metrics_dict[metric].metric_type != POST_AGG_TYPE: # noqa: T484
|
||||
saved_agg_names.add(metric)
|
||||
else:
|
||||
postagg_names.append(metric)
|
||||
# Create the post aggregations, maintain order since postaggs
|
||||
# may depend on previous ones
|
||||
post_aggs = OrderedDict()
|
||||
post_aggs = OrderedDict() # noqa: T484
|
||||
visited_postaggs = set()
|
||||
for postagg_name in postagg_names:
|
||||
postagg = metrics_dict[postagg_name]
|
||||
postagg = metrics_dict[postagg_name] # noqa: T484
|
||||
visited_postaggs.add(postagg_name)
|
||||
DruidDatasource.resolve_postagg(
|
||||
postagg, post_aggs, saved_agg_names, visited_postaggs, metrics_dict
|
||||
)
|
||||
aggs = DruidDatasource.get_aggregations(
|
||||
aggs = DruidDatasource.get_aggregations( # noqa: T484
|
||||
metrics_dict, saved_agg_names, adhoc_agg_configs
|
||||
)
|
||||
return aggs, post_aggs
|
||||
|
||||
def values_for_column(self, column_name, limit=10000):
|
||||
def values_for_column(self, column_name: str, limit: int = 10000) -> List:
|
||||
"""Retrieve some values for the given column"""
|
||||
logging.info(
|
||||
"Getting values for columns [{}] limited to [{}]".format(column_name, limit)
|
||||
|
@ -955,12 +978,14 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
client = self.cluster.get_pydruid_client()
|
||||
client.topn(**qry)
|
||||
df = client.export_pandas()
|
||||
return [row[column_name] for row in df.to_records(index=False)]
|
||||
return df[column_name].to_list()
|
||||
|
||||
def get_query_str(self, query_obj, phase=1, client=None):
|
||||
return self.run_query(client=client, phase=phase, **query_obj)
|
||||
|
||||
def _add_filter_from_pre_query_data(self, df, dimensions, dim_filter):
|
||||
def _add_filter_from_pre_query_data(
|
||||
self, df: Optional[pd.DataFrame], dimensions, dim_filter
|
||||
):
|
||||
ret = dim_filter
|
||||
if df is not None and not df.empty:
|
||||
new_filters = []
|
||||
|
@ -1002,7 +1027,7 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
return ret
|
||||
|
||||
@staticmethod
|
||||
def druid_type_from_adhoc_metric(adhoc_metric):
|
||||
def druid_type_from_adhoc_metric(adhoc_metric: Dict) -> str:
|
||||
column_type = adhoc_metric["column"]["type"].lower()
|
||||
aggregate = adhoc_metric["aggregate"].lower()
|
||||
|
||||
|
@ -1014,7 +1039,9 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
return column_type + aggregate.capitalize()
|
||||
|
||||
@staticmethod
|
||||
def get_aggregations(metrics_dict, saved_metrics, adhoc_metrics=[]):
|
||||
def get_aggregations(
|
||||
metrics_dict: Dict, saved_metrics: Iterable[str], adhoc_metrics: List[Dict] = []
|
||||
) -> OrderedDict:
|
||||
"""
|
||||
Returns a dictionary of aggregation metric names to aggregation json objects
|
||||
|
||||
|
@ -1023,7 +1050,7 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
:param adhoc_metrics: list of adhoc metric names
|
||||
:raise SupersetException: if one or more metric names are not aggregations
|
||||
"""
|
||||
aggregations = OrderedDict()
|
||||
aggregations: OrderedDict = OrderedDict()
|
||||
invalid_metric_names = []
|
||||
for metric_name in saved_metrics:
|
||||
if metric_name in metrics_dict:
|
||||
|
@ -1047,19 +1074,18 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
}
|
||||
return aggregations
|
||||
|
||||
def get_dimensions(self, groupby, columns_dict):
|
||||
def get_dimensions(
|
||||
self, groupby: List[str], columns_dict: Dict[str, DruidColumn]
|
||||
) -> List[Union[str, Dict]]:
|
||||
dimensions = []
|
||||
groupby = [gb for gb in groupby if gb in columns_dict]
|
||||
for column_name in groupby:
|
||||
col = columns_dict.get(column_name)
|
||||
dim_spec = col.dimension_spec if col else None
|
||||
if dim_spec:
|
||||
dimensions.append(dim_spec)
|
||||
else:
|
||||
dimensions.append(column_name)
|
||||
dimensions.append(dim_spec or column_name)
|
||||
return dimensions
|
||||
|
||||
def intervals_from_dttms(self, from_dttm, to_dttm):
|
||||
def intervals_from_dttms(self, from_dttm: datetime, to_dttm: datetime) -> str:
|
||||
# Couldn't find a way to just not filter on time...
|
||||
from_dttm = from_dttm or datetime(1901, 1, 1)
|
||||
to_dttm = to_dttm or datetime(2101, 1, 1)
|
||||
|
@ -1091,7 +1117,7 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
return values
|
||||
|
||||
@staticmethod
|
||||
def sanitize_metric_object(metric):
|
||||
def sanitize_metric_object(metric: Dict) -> None:
|
||||
"""
|
||||
Update a metric with the correct type if necessary.
|
||||
:param dict metric: The metric to sanitize
|
||||
|
@ -1122,7 +1148,7 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
phase=2,
|
||||
client=None,
|
||||
order_desc=True,
|
||||
):
|
||||
) -> str:
|
||||
"""Runs a query against Druid and returns a dataframe.
|
||||
"""
|
||||
# TODO refactor into using a TBD Query object
|
||||
|
@ -1193,7 +1219,7 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
del qry["dimensions"]
|
||||
client.timeseries(**qry)
|
||||
elif not having_filters and len(groupby) == 1 and order_desc:
|
||||
dim = list(qry.get("dimensions"))[0]
|
||||
dim = list(qry.get("dimensions"))[0] # noqa: T484
|
||||
logging.info("Running two-phase topn query for dimension [{}]".format(dim))
|
||||
pre_qry = deepcopy(qry)
|
||||
if timeseries_limit_metric:
|
||||
|
@ -1324,7 +1350,7 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
return query_str
|
||||
|
||||
@staticmethod
|
||||
def homogenize_types(df, groupby_cols):
|
||||
def homogenize_types(df: pd.DataFrame, groupby_cols: Iterable[str]) -> pd.DataFrame:
|
||||
"""Converting all GROUPBY columns to strings
|
||||
|
||||
When grouping by a numeric (say FLOAT) column, pydruid returns
|
||||
|
@ -1334,11 +1360,10 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
Here we replace None with <NULL> and make the whole series a
|
||||
str instead of an object.
|
||||
"""
|
||||
for col in groupby_cols:
|
||||
df[col] = df[col].fillna("<NULL>").astype("unicode")
|
||||
df[groupby_cols] = df[groupby_cols].fillna("<NULL>").astype("unicode")
|
||||
return df
|
||||
|
||||
def query(self, query_obj):
|
||||
def query(self, query_obj: Dict) -> QueryResult:
|
||||
qry_start_dttm = datetime.now()
|
||||
client = self.cluster.get_pydruid_client()
|
||||
query_str = self.get_query_str(client=client, query_obj=query_obj, phase=2)
|
||||
|
@ -1346,7 +1371,7 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
|
||||
if df is None or df.size == 0:
|
||||
return QueryResult(
|
||||
df=pandas.DataFrame([]),
|
||||
df=pd.DataFrame([]),
|
||||
query=query_str,
|
||||
duration=datetime.now() - qry_start_dttm,
|
||||
)
|
||||
|
@ -1363,7 +1388,7 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
del df[DTTM_ALIAS]
|
||||
|
||||
# Reordering columns
|
||||
cols = []
|
||||
cols: List[str] = []
|
||||
if DTTM_ALIAS in df.columns:
|
||||
cols += [DTTM_ALIAS]
|
||||
cols += query_obj.get("groupby") or []
|
||||
|
@ -1413,7 +1438,7 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
return (col, extraction_fn)
|
||||
|
||||
@classmethod
|
||||
def get_filters(cls, raw_filters, num_cols, columns_dict): # noqa
|
||||
def get_filters(cls, raw_filters, num_cols, columns_dict) -> Filter: # noqa: T484
|
||||
"""Given Superset filter data structure, returns pydruid Filter(s)"""
|
||||
filters = None
|
||||
for flt in raw_filters:
|
||||
|
@ -1542,7 +1567,7 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
|
||||
return filters
|
||||
|
||||
def _get_having_obj(self, col, op, eq):
|
||||
def _get_having_obj(self, col: str, op: str, eq: str) -> Having:
|
||||
cond = None
|
||||
if op == "==":
|
||||
if col in self.column_names:
|
||||
|
@ -1556,7 +1581,7 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
|
||||
return cond
|
||||
|
||||
def get_having_filters(self, raw_filters):
|
||||
def get_having_filters(self, raw_filters: List[Dict]) -> Having:
|
||||
filters = None
|
||||
reversed_op_map = {"!=": "==", ">=": "<", "<=": ">"}
|
||||
|
||||
|
@ -1579,7 +1604,9 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
return filters
|
||||
|
||||
@classmethod
|
||||
def query_datasources_by_name(cls, session, database, datasource_name, schema=None):
|
||||
def query_datasources_by_name(
|
||||
cls, session: Session, database: Database, datasource_name: str, schema=None
|
||||
) -> List["DruidDatasource"]:
|
||||
return (
|
||||
session.query(cls)
|
||||
.filter_by(cluster_name=database.id)
|
||||
|
@ -1587,7 +1614,7 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
.all()
|
||||
)
|
||||
|
||||
def external_metadata(self):
|
||||
def external_metadata(self) -> List[Dict]:
|
||||
self.merge_flag = True
|
||||
return [
|
||||
{"name": k, "type": v.get("type")}
|
||||
|
|
|
@ -395,7 +395,7 @@ class Druid(BaseSupersetView):
|
|||
|
||||
@has_access
|
||||
@expose("/refresh_datasources/")
|
||||
def refresh_datasources(self, refreshAll=True):
|
||||
def refresh_datasources(self, refresh_all=True):
|
||||
"""endpoint that refreshes druid datasources metadata"""
|
||||
session = db.session()
|
||||
DruidCluster = ConnectorRegistry.sources["druid"].cluster_class
|
||||
|
@ -403,7 +403,7 @@ class Druid(BaseSupersetView):
|
|||
cluster_name = cluster.cluster_name
|
||||
valid_cluster = True
|
||||
try:
|
||||
cluster.refresh_datasources(refreshAll=refreshAll)
|
||||
cluster.refresh_datasources(refresh_all=refresh_all)
|
||||
except Exception as e:
|
||||
valid_cluster = False
|
||||
flash(
|
||||
|
@ -432,7 +432,7 @@ class Druid(BaseSupersetView):
|
|||
Calling this endpoint will cause a scan for new
|
||||
datasources only and add them.
|
||||
"""
|
||||
return self.refresh_datasources(refreshAll=False)
|
||||
return self.refresh_datasources(refresh_all=False)
|
||||
|
||||
|
||||
appbuilder.add_view_no_menu(Druid)
|
||||
|
|
|
@ -42,10 +42,10 @@ from sqlalchemy import (
|
|||
Text,
|
||||
)
|
||||
from sqlalchemy.exc import CompileError
|
||||
from sqlalchemy.orm import backref, relationship
|
||||
from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty, Session
|
||||
from sqlalchemy.orm.exc import NoResultFound
|
||||
from sqlalchemy.schema import UniqueConstraint
|
||||
from sqlalchemy.sql import column, literal_column, table, text
|
||||
from sqlalchemy.sql import column, ColumnElement, literal_column, table, text
|
||||
from sqlalchemy.sql.expression import Label, Select, TextAsFrom
|
||||
import sqlparse
|
||||
|
||||
|
@ -83,7 +83,7 @@ class AnnotationDatasource(BaseDatasource):
|
|||
|
||||
cache_timeout = 0
|
||||
|
||||
def query(self, query_obj):
|
||||
def query(self, query_obj: Dict) -> QueryResult:
|
||||
df = None
|
||||
error_message = None
|
||||
qry = db.session.query(Annotation)
|
||||
|
@ -143,7 +143,7 @@ class TableColumn(Model, BaseColumn):
|
|||
update_from_object_fields = [s for s in export_fields if s not in ("table_id",)]
|
||||
export_parent = "table"
|
||||
|
||||
def get_sqla_col(self, label=None):
|
||||
def get_sqla_col(self, label: Optional[str] = None) -> Column:
|
||||
label = label or self.column_name
|
||||
if not self.expression:
|
||||
db_engine_spec = self.table.database.db_engine_spec
|
||||
|
@ -155,10 +155,12 @@ class TableColumn(Model, BaseColumn):
|
|||
return col
|
||||
|
||||
@property
|
||||
def datasource(self):
|
||||
def datasource(self) -> RelationshipProperty:
|
||||
return self.table
|
||||
|
||||
def get_time_filter(self, start_dttm, end_dttm):
|
||||
def get_time_filter(
|
||||
self, start_dttm: DateTime, end_dttm: DateTime
|
||||
) -> ColumnElement:
|
||||
col = self.get_sqla_col(label="__time")
|
||||
l = [] # noqa: E741
|
||||
if start_dttm:
|
||||
|
@ -205,7 +207,7 @@ class TableColumn(Model, BaseColumn):
|
|||
|
||||
return import_datasource.import_simple_obj(db.session, i_column, lookup_obj)
|
||||
|
||||
def dttm_sql_literal(self, dttm):
|
||||
def dttm_sql_literal(self, dttm: DateTime) -> str:
|
||||
"""Convert datetime object to a SQL expression string"""
|
||||
tf = self.python_date_format
|
||||
if tf:
|
||||
|
@ -249,13 +251,13 @@ class SqlMetric(Model, BaseMetric):
|
|||
)
|
||||
export_parent = "table"
|
||||
|
||||
def get_sqla_col(self, label=None):
|
||||
def get_sqla_col(self, label: Optional[str] = None) -> Column:
|
||||
label = label or self.metric_name
|
||||
sqla_col = literal_column(self.expression)
|
||||
return self.table.make_sqla_column_compatible(sqla_col, label)
|
||||
|
||||
@property
|
||||
def perm(self):
|
||||
def perm(self) -> Optional[str]:
|
||||
return (
|
||||
("{parent_name}.[{obj.metric_name}](id:{obj.id})").format(
|
||||
obj=self, parent_name=self.table.full_name
|
||||
|
@ -264,7 +266,7 @@ class SqlMetric(Model, BaseMetric):
|
|||
else None
|
||||
)
|
||||
|
||||
def get_perm(self):
|
||||
def get_perm(self) -> Optional[str]:
|
||||
return self.perm
|
||||
|
||||
@classmethod
|
||||
|
@ -351,7 +353,9 @@ class SqlaTable(Model, BaseDatasource):
|
|||
"MAX": sa.func.MAX,
|
||||
}
|
||||
|
||||
def make_sqla_column_compatible(self, sqla_col, label=None):
|
||||
def make_sqla_column_compatible(
|
||||
self, sqla_col: Column, label: Optional[str] = None
|
||||
) -> Column:
|
||||
"""Takes a sql alchemy column object and adds label info if supported by engine.
|
||||
:param sqla_col: sql alchemy column instance
|
||||
:param label: alias/label that column is expected to have
|
||||
|
@ -369,23 +373,29 @@ class SqlaTable(Model, BaseDatasource):
|
|||
return self.name
|
||||
|
||||
@property
|
||||
def connection(self):
|
||||
def connection(self) -> str:
|
||||
return str(self.database)
|
||||
|
||||
@property
|
||||
def description_markeddown(self):
|
||||
def description_markeddown(self) -> str:
|
||||
return utils.markdown(self.description)
|
||||
|
||||
@property
|
||||
def datasource_name(self):
|
||||
def datasource_name(self) -> str:
|
||||
return self.table_name
|
||||
|
||||
@property
|
||||
def database_name(self):
|
||||
def database_name(self) -> str:
|
||||
return self.database.name
|
||||
|
||||
@classmethod
|
||||
def get_datasource_by_name(cls, session, datasource_name, schema, database_name):
|
||||
def get_datasource_by_name(
|
||||
cls,
|
||||
session: Session,
|
||||
datasource_name: str,
|
||||
schema: Optional[str],
|
||||
database_name: str,
|
||||
) -> Optional["SqlaTable"]:
|
||||
schema = schema or None
|
||||
query = (
|
||||
session.query(cls)
|
||||
|
@ -398,52 +408,52 @@ class SqlaTable(Model, BaseDatasource):
|
|||
for tbl in query.all():
|
||||
if schema == (tbl.schema or None):
|
||||
return tbl
|
||||
return None
|
||||
|
||||
@property
|
||||
def link(self):
|
||||
def link(self) -> Markup:
|
||||
name = escape(self.name)
|
||||
anchor = f'<a target="_blank" href="{self.explore_url}">{name}</a>'
|
||||
return Markup(anchor)
|
||||
|
||||
@property
|
||||
def schema_perm(self):
|
||||
def schema_perm(self) -> Optional[str]:
|
||||
"""Returns schema permission if present, database one otherwise."""
|
||||
return security_manager.get_schema_perm(self.database, self.schema)
|
||||
|
||||
def get_perm(self):
|
||||
def get_perm(self) -> str:
|
||||
return ("[{obj.database}].[{obj.table_name}]" "(id:{obj.id})").format(obj=self)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
if not self.schema:
|
||||
return self.table_name
|
||||
return "{}.{}".format(self.schema, self.table_name)
|
||||
|
||||
@property
|
||||
def full_name(self):
|
||||
def full_name(self) -> str:
|
||||
return utils.get_datasource_full_name(
|
||||
self.database, self.table_name, schema=self.schema
|
||||
)
|
||||
|
||||
@property
|
||||
def dttm_cols(self):
|
||||
def dttm_cols(self) -> List:
|
||||
l = [c.column_name for c in self.columns if c.is_dttm] # noqa: E741
|
||||
if self.main_dttm_col and self.main_dttm_col not in l:
|
||||
l.append(self.main_dttm_col)
|
||||
return l
|
||||
|
||||
@property
|
||||
def num_cols(self):
|
||||
def num_cols(self) -> List:
|
||||
return [c.column_name for c in self.columns if c.is_num]
|
||||
|
||||
@property
|
||||
def any_dttm_col(self):
|
||||
def any_dttm_col(self) -> Optional[str]:
|
||||
cols = self.dttm_cols
|
||||
if cols:
|
||||
return cols[0]
|
||||
return cols[0] if cols else None
|
||||
|
||||
@property
|
||||
def html(self):
|
||||
def html(self) -> str:
|
||||
t = ((c.column_name, c.type) for c in self.columns)
|
||||
df = pd.DataFrame(t)
|
||||
df.columns = ["field", "type"]
|
||||
|
@ -453,7 +463,7 @@ class SqlaTable(Model, BaseDatasource):
|
|||
)
|
||||
|
||||
@property
|
||||
def sql_url(self):
|
||||
def sql_url(self) -> str:
|
||||
return self.database.sql_url + "?table_name=" + str(self.table_name)
|
||||
|
||||
def external_metadata(self):
|
||||
|
@ -466,28 +476,29 @@ class SqlaTable(Model, BaseDatasource):
|
|||
return cols
|
||||
|
||||
@property
|
||||
def time_column_grains(self):
|
||||
def time_column_grains(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"time_columns": self.dttm_cols,
|
||||
"time_grains": [grain.name for grain in self.database.grains()],
|
||||
}
|
||||
|
||||
@property
|
||||
def select_star(self):
|
||||
def select_star(self) -> str:
|
||||
# show_cols and latest_partition set to false to avoid
|
||||
# the expensive cost of inspecting the DB
|
||||
return self.database.select_star(
|
||||
self.table_name, schema=self.schema, show_cols=False, latest_partition=False
|
||||
)
|
||||
|
||||
def get_col(self, col_name):
|
||||
def get_col(self, col_name: str) -> Optional[Column]:
|
||||
columns = self.columns
|
||||
for col in columns:
|
||||
if col_name == col.column_name:
|
||||
return col
|
||||
return None
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
def data(self) -> Dict:
|
||||
d = super(SqlaTable, self).data
|
||||
if self.type == "table":
|
||||
grains = self.database.grains() or []
|
||||
|
@ -500,7 +511,7 @@ class SqlaTable(Model, BaseDatasource):
|
|||
d["template_params"] = self.template_params
|
||||
return d
|
||||
|
||||
def values_for_column(self, column_name, limit=10000):
|
||||
def values_for_column(self, column_name: str, limit: int = 10000) -> List:
|
||||
"""Runs query against sqla to retrieve some
|
||||
sample values for the given column.
|
||||
"""
|
||||
|
@ -525,9 +536,9 @@ class SqlaTable(Model, BaseDatasource):
|
|||
sql = self.mutate_query_from_config(sql)
|
||||
|
||||
df = pd.read_sql_query(sql=sql, con=engine)
|
||||
return [row[0] for row in df.to_records(index=False)]
|
||||
return df[column_name].to_list()
|
||||
|
||||
def mutate_query_from_config(self, sql):
|
||||
def mutate_query_from_config(self, sql: str) -> str:
|
||||
"""Apply config's SQL_QUERY_MUTATOR
|
||||
|
||||
Typically adds comments to the query with context"""
|
||||
|
@ -540,7 +551,7 @@ class SqlaTable(Model, BaseDatasource):
|
|||
def get_template_processor(self, **kwargs):
|
||||
return get_template_processor(table=self, database=self.database, **kwargs)
|
||||
|
||||
def get_query_str_extended(self, query_obj) -> QueryStringExtended:
|
||||
def get_query_str_extended(self, query_obj: Dict) -> QueryStringExtended:
|
||||
sqlaq = self.get_sqla_query(**query_obj)
|
||||
sql = self.database.compile_sqla_query(sqlaq.sqla_query)
|
||||
logging.info(sql)
|
||||
|
@ -550,7 +561,7 @@ class SqlaTable(Model, BaseDatasource):
|
|||
labels_expected=sqlaq.labels_expected, sql=sql, prequeries=sqlaq.prequeries
|
||||
)
|
||||
|
||||
def get_query_str(self, query_obj):
|
||||
def get_query_str(self, query_obj: Dict) -> str:
|
||||
query_str_ext = self.get_query_str_extended(query_obj)
|
||||
all_queries = query_str_ext.prequeries + [query_str_ext.sql]
|
||||
return ";\n\n".join(all_queries) + ";"
|
||||
|
@ -571,7 +582,7 @@ class SqlaTable(Model, BaseDatasource):
|
|||
return TextAsFrom(sa.text(from_sql), []).alias("expr_qry")
|
||||
return self.get_sqla_table()
|
||||
|
||||
def adhoc_metric_to_sqla(self, metric, cols):
|
||||
def adhoc_metric_to_sqla(self, metric: Dict, cols: Dict) -> Optional[Column]:
|
||||
"""
|
||||
Turn an adhoc metric into a sqlalchemy column.
|
||||
|
||||
|
@ -584,13 +595,13 @@ class SqlaTable(Model, BaseDatasource):
|
|||
label = utils.get_metric_name(metric)
|
||||
|
||||
if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES["SIMPLE"]:
|
||||
column_name = metric.get("column").get("column_name")
|
||||
column_name = metric["column"].get("column_name")
|
||||
table_column = cols.get(column_name)
|
||||
if table_column:
|
||||
sqla_column = table_column.get_sqla_col()
|
||||
else:
|
||||
sqla_column = column(column_name)
|
||||
sqla_metric = self.sqla_aggregations[metric.get("aggregate")](sqla_column)
|
||||
sqla_metric = self.sqla_aggregations[metric["aggregate"]](sqla_column)
|
||||
elif expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES["SQL"]:
|
||||
sqla_metric = literal_column(metric.get("sqlExpression"))
|
||||
else:
|
||||
|
@ -616,7 +627,7 @@ class SqlaTable(Model, BaseDatasource):
|
|||
extras=None,
|
||||
columns=None,
|
||||
order_desc=True,
|
||||
):
|
||||
) -> SqlaQuery:
|
||||
"""Querying any sqla table from this common interface"""
|
||||
template_kwargs = {
|
||||
"from_dttm": from_dttm,
|
||||
|
@ -643,8 +654,8 @@ class SqlaTable(Model, BaseDatasource):
|
|||
# Database spec supports join-free timeslot grouping
|
||||
time_groupby_inline = db_engine_spec.time_groupby_inline
|
||||
|
||||
cols = {col.column_name: col for col in self.columns}
|
||||
metrics_dict = {m.metric_name: m for m in self.metrics}
|
||||
cols: Dict[str, Column] = {col.column_name: col for col in self.columns}
|
||||
metrics_dict: Dict[str, SqlMetric] = {m.metric_name: m for m in self.metrics}
|
||||
|
||||
if not granularity and is_timeseries:
|
||||
raise Exception(
|
||||
|
@ -660,7 +671,7 @@ class SqlaTable(Model, BaseDatasource):
|
|||
if utils.is_adhoc_metric(m):
|
||||
metrics_exprs.append(self.adhoc_metric_to_sqla(m, cols))
|
||||
elif m in metrics_dict:
|
||||
metrics_exprs.append(metrics_dict.get(m).get_sqla_col())
|
||||
metrics_exprs.append(metrics_dict[m].get_sqla_col())
|
||||
else:
|
||||
raise Exception(_("Metric '%(metric)s' does not exist", metric=m))
|
||||
if metrics_exprs:
|
||||
|
@ -669,8 +680,8 @@ class SqlaTable(Model, BaseDatasource):
|
|||
main_metric_expr, label = literal_column("COUNT(*)"), "ccount"
|
||||
main_metric_expr = self.make_sqla_column_compatible(main_metric_expr, label)
|
||||
|
||||
select_exprs = []
|
||||
groupby_exprs_sans_timestamp = OrderedDict()
|
||||
select_exprs: List[Column] = []
|
||||
groupby_exprs_sans_timestamp: OrderedDict = OrderedDict()
|
||||
|
||||
if groupby:
|
||||
select_exprs = []
|
||||
|
@ -729,7 +740,7 @@ class SqlaTable(Model, BaseDatasource):
|
|||
qry = qry.group_by(*groupby_exprs_with_timestamp.values())
|
||||
|
||||
where_clause_and = []
|
||||
having_clause_and = []
|
||||
having_clause_and: List = []
|
||||
for flt in filter:
|
||||
if not all([flt.get(s) for s in ["col", "op"]]):
|
||||
continue
|
||||
|
@ -899,7 +910,9 @@ class SqlaTable(Model, BaseDatasource):
|
|||
|
||||
return ob
|
||||
|
||||
def _get_top_groups(self, df, dimensions, groupby_exprs):
|
||||
def _get_top_groups(
|
||||
self, df: pd.DataFrame, dimensions: List, groupby_exprs: OrderedDict
|
||||
) -> ColumnElement:
|
||||
groups = []
|
||||
for unused, row in df.iterrows():
|
||||
group = []
|
||||
|
@ -909,7 +922,7 @@ class SqlaTable(Model, BaseDatasource):
|
|||
|
||||
return or_(*groups)
|
||||
|
||||
def query(self, query_obj):
|
||||
def query(self, query_obj: Dict) -> QueryResult:
|
||||
qry_start_dttm = datetime.now()
|
||||
query_str_ext = self.get_query_str_extended(query_obj)
|
||||
sql = query_str_ext.sql
|
||||
|
@ -945,10 +958,10 @@ class SqlaTable(Model, BaseDatasource):
|
|||
error_message=error_message,
|
||||
)
|
||||
|
||||
def get_sqla_table_object(self):
|
||||
def get_sqla_table_object(self) -> Table:
|
||||
return self.database.get_table(self.table_name, schema=self.schema)
|
||||
|
||||
def fetch_metadata(self):
|
||||
def fetch_metadata(self) -> None:
|
||||
"""Fetches the metadata for the table and merges it in"""
|
||||
try:
|
||||
table = self.get_sqla_table_object()
|
||||
|
@ -1012,7 +1025,7 @@ class SqlaTable(Model, BaseDatasource):
|
|||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def import_obj(cls, i_datasource, import_time=None):
|
||||
def import_obj(cls, i_datasource, import_time=None) -> int:
|
||||
"""Imports the datasource from the object to the database.
|
||||
|
||||
Metrics and columns and datasource will be overrided if exists.
|
||||
|
@ -1052,7 +1065,9 @@ class SqlaTable(Model, BaseDatasource):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def query_datasources_by_name(cls, session, database, datasource_name, schema=None):
|
||||
def query_datasources_by_name(
|
||||
cls, session: Session, database: Database, datasource_name: str, schema=None
|
||||
) -> List["SqlaTable"]:
|
||||
query = (
|
||||
session.query(cls)
|
||||
.filter_by(database_id=database.id)
|
||||
|
@ -1063,7 +1078,7 @@ class SqlaTable(Model, BaseDatasource):
|
|||
return query.all()
|
||||
|
||||
@staticmethod
|
||||
def default_query(qry):
|
||||
def default_query(qry) -> Query:
|
||||
return qry.filter_by(is_sqllab_view=False)
|
||||
|
||||
def has_extra_cache_keys(self, query_obj: Dict) -> bool:
|
||||
|
|
Loading…
Reference in New Issue