chore: Cleaning up types and names for SQLA models (#10248)

This commit is contained in:
John Bodley 2020-07-06 20:59:17 -07:00 committed by GitHub
parent 569e4a7c50
commit bacf567656
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 29 additions and 25 deletions

View File

@ -685,13 +685,13 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
return self.get_sqla_table()
def adhoc_metric_to_sqla(
self, metric: Dict[str, Any], cols: Dict[str, Any]
self, metric: Dict[str, Any], columns_by_name: Dict[str, Any]
) -> Optional[Column]:
"""
Turn an adhoc metric into a sqlalchemy column.
:param dict metric: Adhoc metric definition
:param dict cols: Columns for the current table
:param dict columns_by_name: Columns for the current table
:returns: The metric defined as a sqlalchemy column
:rtype: sqlalchemy.sql.column
"""
@ -700,7 +700,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES["SIMPLE"]:
column_name = metric["column"].get("column_name")
table_column = cols.get(column_name)
table_column = columns_by_name.get(column_name)
if table_column:
sqla_column = table_column.get_sqla_col()
else:
@ -780,8 +780,10 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
# Database spec supports join-free timeslot grouping
time_groupby_inline = db_engine_spec.time_groupby_inline
cols: Dict[str, Column] = {col.column_name: col for col in self.columns}
metrics_dict: Dict[str, SqlMetric] = {m.metric_name: m for m in self.metrics}
columns_by_name: Dict[str, TableColumn] = {
col.column_name: col for col in self.columns
}
metrics_by_name: Dict[str, SqlMetric] = {m.metric_name: m for m in self.metrics}
if not granularity and is_timeseries:
raise Exception(
@ -800,9 +802,9 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
for metric in metrics:
if utils.is_adhoc_metric(metric):
assert isinstance(metric, dict)
metrics_exprs.append(self.adhoc_metric_to_sqla(metric, cols))
elif isinstance(metric, str) and metric in metrics_dict:
metrics_exprs.append(metrics_dict[metric].get_sqla_col())
metrics_exprs.append(self.adhoc_metric_to_sqla(metric, columns_by_name))
elif isinstance(metric, str) and metric in metrics_by_name:
metrics_exprs.append(metrics_by_name[metric].get_sqla_col())
else:
raise Exception(_("Metric '%(metric)s' does not exist", metric=metric))
if metrics_exprs:
@ -822,8 +824,8 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
select_exprs = []
for selected in groupby:
if selected in cols:
outer = cols[selected].get_sqla_col()
if selected in columns_by_name:
outer = columns_by_name[selected].get_sqla_col()
else:
outer = literal_column(f"({selected})")
outer = self.make_sqla_column_compatible(outer, selected)
@ -833,8 +835,8 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
elif columns:
for selected in columns:
select_exprs.append(
cols[selected].get_sqla_col()
if selected in cols
columns_by_name[selected].get_sqla_col()
if selected in columns_by_name
else self.make_sqla_column_compatible(literal_column(selected))
)
metrics_exprs = []
@ -843,7 +845,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
time_range_endpoints = extras.get("time_range_endpoints")
groupby_exprs_with_timestamp = OrderedDict(groupby_exprs_sans_timestamp.items())
if granularity:
dttm_col = cols[granularity]
dttm_col = columns_by_name[granularity]
time_grain = extras.get("time_grain_sqla")
time_filters = []
@ -859,7 +861,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
and self.main_dttm_col != dttm_col.column_name
):
time_filters.append(
cols[self.main_dttm_col].get_time_filter(
columns_by_name[self.main_dttm_col].get_time_filter(
from_dttm, to_dttm, time_range_endpoints
)
)
@ -892,7 +894,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
continue
col = flt["col"]
op = flt["op"].upper()
col_obj = cols.get(col)
col_obj = columns_by_name.get(col)
if col_obj:
is_list_target = op in (
utils.FilterOperator.IN.value,
@ -977,9 +979,9 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
for col, ascending in orderby:
direction = asc if ascending else desc
if utils.is_adhoc_metric(col):
col = self.adhoc_metric_to_sqla(col, cols)
elif col in cols:
col = cols[col].get_sqla_col()
col = self.adhoc_metric_to_sqla(col, columns_by_name)
elif col in columns_by_name:
col = columns_by_name[col].get_sqla_col()
if isinstance(col, Label):
label = col._label # pylint: disable=protected-access
@ -1026,7 +1028,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
ob = inner_main_metric_expr
if timeseries_limit_metric:
ob = self._get_timeseries_orderby(
timeseries_limit_metric, metrics_dict, cols
timeseries_limit_metric, metrics_by_name, columns_by_name
)
direction = desc if order_desc else asc
subq = subq.order_by(direction(ob))
@ -1046,7 +1048,9 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
orderby = [
(
self._get_timeseries_orderby(
timeseries_limit_metric, metrics_dict, cols
timeseries_limit_metric,
metrics_by_name,
columns_by_name,
),
False,
)
@ -1090,17 +1094,17 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
def _get_timeseries_orderby(
self,
timeseries_limit_metric: Metric,
metrics_dict: Dict[str, SqlMetric],
cols: Dict[str, Column],
metrics_by_name: Dict[str, SqlMetric],
columns_by_name: Dict[str, TableColumn],
) -> Optional[Column]:
if utils.is_adhoc_metric(timeseries_limit_metric):
assert isinstance(timeseries_limit_metric, dict)
ob = self.adhoc_metric_to_sqla(timeseries_limit_metric, cols)
ob = self.adhoc_metric_to_sqla(timeseries_limit_metric, columns_by_name)
elif (
isinstance(timeseries_limit_metric, str)
and timeseries_limit_metric in metrics_dict
and timeseries_limit_metric in metrics_by_name
):
ob = metrics_dict[timeseries_limit_metric].get_sqla_col()
ob = metrics_by_name[timeseries_limit_metric].get_sqla_col()
else:
raise Exception(
_("Metric '%(metric)s' does not exist", metric=timeseries_limit_metric)