chore: Cleaning up types and names for SQLA models (#10248)

This commit is contained in:
John Bodley 2020-07-06 20:59:17 -07:00 committed by GitHub
parent 569e4a7c50
commit bacf567656
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 29 additions and 25 deletions

View File

@ -685,13 +685,13 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
return self.get_sqla_table() return self.get_sqla_table()
def adhoc_metric_to_sqla( def adhoc_metric_to_sqla(
self, metric: Dict[str, Any], cols: Dict[str, Any] self, metric: Dict[str, Any], columns_by_name: Dict[str, Any]
) -> Optional[Column]: ) -> Optional[Column]:
""" """
Turn an adhoc metric into a sqlalchemy column. Turn an adhoc metric into a sqlalchemy column.
:param dict metric: Adhoc metric definition :param dict metric: Adhoc metric definition
:param dict cols: Columns for the current table :param dict columns_by_name: Columns for the current table
:returns: The metric defined as a sqlalchemy column :returns: The metric defined as a sqlalchemy column
:rtype: sqlalchemy.sql.column :rtype: sqlalchemy.sql.column
""" """
@ -700,7 +700,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES["SIMPLE"]: if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES["SIMPLE"]:
column_name = metric["column"].get("column_name") column_name = metric["column"].get("column_name")
table_column = cols.get(column_name) table_column = columns_by_name.get(column_name)
if table_column: if table_column:
sqla_column = table_column.get_sqla_col() sqla_column = table_column.get_sqla_col()
else: else:
@ -780,8 +780,10 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
# Database spec supports join-free timeslot grouping # Database spec supports join-free timeslot grouping
time_groupby_inline = db_engine_spec.time_groupby_inline time_groupby_inline = db_engine_spec.time_groupby_inline
cols: Dict[str, Column] = {col.column_name: col for col in self.columns} columns_by_name: Dict[str, TableColumn] = {
metrics_dict: Dict[str, SqlMetric] = {m.metric_name: m for m in self.metrics} col.column_name: col for col in self.columns
}
metrics_by_name: Dict[str, SqlMetric] = {m.metric_name: m for m in self.metrics}
if not granularity and is_timeseries: if not granularity and is_timeseries:
raise Exception( raise Exception(
@ -800,9 +802,9 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
for metric in metrics: for metric in metrics:
if utils.is_adhoc_metric(metric): if utils.is_adhoc_metric(metric):
assert isinstance(metric, dict) assert isinstance(metric, dict)
metrics_exprs.append(self.adhoc_metric_to_sqla(metric, cols)) metrics_exprs.append(self.adhoc_metric_to_sqla(metric, columns_by_name))
elif isinstance(metric, str) and metric in metrics_dict: elif isinstance(metric, str) and metric in metrics_by_name:
metrics_exprs.append(metrics_dict[metric].get_sqla_col()) metrics_exprs.append(metrics_by_name[metric].get_sqla_col())
else: else:
raise Exception(_("Metric '%(metric)s' does not exist", metric=metric)) raise Exception(_("Metric '%(metric)s' does not exist", metric=metric))
if metrics_exprs: if metrics_exprs:
@ -822,8 +824,8 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
select_exprs = [] select_exprs = []
for selected in groupby: for selected in groupby:
if selected in cols: if selected in columns_by_name:
outer = cols[selected].get_sqla_col() outer = columns_by_name[selected].get_sqla_col()
else: else:
outer = literal_column(f"({selected})") outer = literal_column(f"({selected})")
outer = self.make_sqla_column_compatible(outer, selected) outer = self.make_sqla_column_compatible(outer, selected)
@ -833,8 +835,8 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
elif columns: elif columns:
for selected in columns: for selected in columns:
select_exprs.append( select_exprs.append(
cols[selected].get_sqla_col() columns_by_name[selected].get_sqla_col()
if selected in cols if selected in columns_by_name
else self.make_sqla_column_compatible(literal_column(selected)) else self.make_sqla_column_compatible(literal_column(selected))
) )
metrics_exprs = [] metrics_exprs = []
@ -843,7 +845,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
time_range_endpoints = extras.get("time_range_endpoints") time_range_endpoints = extras.get("time_range_endpoints")
groupby_exprs_with_timestamp = OrderedDict(groupby_exprs_sans_timestamp.items()) groupby_exprs_with_timestamp = OrderedDict(groupby_exprs_sans_timestamp.items())
if granularity: if granularity:
dttm_col = cols[granularity] dttm_col = columns_by_name[granularity]
time_grain = extras.get("time_grain_sqla") time_grain = extras.get("time_grain_sqla")
time_filters = [] time_filters = []
@ -859,7 +861,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
and self.main_dttm_col != dttm_col.column_name and self.main_dttm_col != dttm_col.column_name
): ):
time_filters.append( time_filters.append(
cols[self.main_dttm_col].get_time_filter( columns_by_name[self.main_dttm_col].get_time_filter(
from_dttm, to_dttm, time_range_endpoints from_dttm, to_dttm, time_range_endpoints
) )
) )
@ -892,7 +894,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
continue continue
col = flt["col"] col = flt["col"]
op = flt["op"].upper() op = flt["op"].upper()
col_obj = cols.get(col) col_obj = columns_by_name.get(col)
if col_obj: if col_obj:
is_list_target = op in ( is_list_target = op in (
utils.FilterOperator.IN.value, utils.FilterOperator.IN.value,
@ -977,9 +979,9 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
for col, ascending in orderby: for col, ascending in orderby:
direction = asc if ascending else desc direction = asc if ascending else desc
if utils.is_adhoc_metric(col): if utils.is_adhoc_metric(col):
col = self.adhoc_metric_to_sqla(col, cols) col = self.adhoc_metric_to_sqla(col, columns_by_name)
elif col in cols: elif col in columns_by_name:
col = cols[col].get_sqla_col() col = columns_by_name[col].get_sqla_col()
if isinstance(col, Label): if isinstance(col, Label):
label = col._label # pylint: disable=protected-access label = col._label # pylint: disable=protected-access
@ -1026,7 +1028,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
ob = inner_main_metric_expr ob = inner_main_metric_expr
if timeseries_limit_metric: if timeseries_limit_metric:
ob = self._get_timeseries_orderby( ob = self._get_timeseries_orderby(
timeseries_limit_metric, metrics_dict, cols timeseries_limit_metric, metrics_by_name, columns_by_name
) )
direction = desc if order_desc else asc direction = desc if order_desc else asc
subq = subq.order_by(direction(ob)) subq = subq.order_by(direction(ob))
@ -1046,7 +1048,9 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
orderby = [ orderby = [
( (
self._get_timeseries_orderby( self._get_timeseries_orderby(
timeseries_limit_metric, metrics_dict, cols timeseries_limit_metric,
metrics_by_name,
columns_by_name,
), ),
False, False,
) )
@ -1090,17 +1094,17 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
def _get_timeseries_orderby( def _get_timeseries_orderby(
self, self,
timeseries_limit_metric: Metric, timeseries_limit_metric: Metric,
metrics_dict: Dict[str, SqlMetric], metrics_by_name: Dict[str, SqlMetric],
cols: Dict[str, Column], columns_by_name: Dict[str, TableColumn],
) -> Optional[Column]: ) -> Optional[Column]:
if utils.is_adhoc_metric(timeseries_limit_metric): if utils.is_adhoc_metric(timeseries_limit_metric):
assert isinstance(timeseries_limit_metric, dict) assert isinstance(timeseries_limit_metric, dict)
ob = self.adhoc_metric_to_sqla(timeseries_limit_metric, cols) ob = self.adhoc_metric_to_sqla(timeseries_limit_metric, columns_by_name)
elif ( elif (
isinstance(timeseries_limit_metric, str) isinstance(timeseries_limit_metric, str)
and timeseries_limit_metric in metrics_dict and timeseries_limit_metric in metrics_by_name
): ):
ob = metrics_dict[timeseries_limit_metric].get_sqla_col() ob = metrics_by_name[timeseries_limit_metric].get_sqla_col()
else: else:
raise Exception( raise Exception(
_("Metric '%(metric)s' does not exist", metric=timeseries_limit_metric) _("Metric '%(metric)s' does not exist", metric=timeseries_limit_metric)