From bacf567656608a7542bd51f5346cd335a1cfc37a Mon Sep 17 00:00:00 2001 From: John Bodley <4567245+john-bodley@users.noreply.github.com> Date: Mon, 6 Jul 2020 20:59:17 -0700 Subject: [PATCH] chore: Cleaning up types and names for SQLA models (#10248) --- superset/connectors/sqla/models.py | 54 ++++++++++++++++-------------- 1 file changed, 29 insertions(+), 25 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 8cfae24222..cc2a7b2035 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -685,13 +685,13 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at return self.get_sqla_table() def adhoc_metric_to_sqla( - self, metric: Dict[str, Any], cols: Dict[str, Any] + self, metric: Dict[str, Any], columns_by_name: Dict[str, Any] ) -> Optional[Column]: """ Turn an adhoc metric into a sqlalchemy column. :param dict metric: Adhoc metric definition - :param dict cols: Columns for the current table + :param dict columns_by_name: Columns for the current table :returns: The metric defined as a sqlalchemy column :rtype: sqlalchemy.sql.column """ @@ -700,7 +700,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES["SIMPLE"]: column_name = metric["column"].get("column_name") - table_column = cols.get(column_name) + table_column = columns_by_name.get(column_name) if table_column: sqla_column = table_column.get_sqla_col() else: @@ -780,8 +780,10 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at # Database spec supports join-free timeslot grouping time_groupby_inline = db_engine_spec.time_groupby_inline - cols: Dict[str, Column] = {col.column_name: col for col in self.columns} - metrics_dict: Dict[str, SqlMetric] = {m.metric_name: m for m in self.metrics} + columns_by_name: Dict[str, TableColumn] = { + col.column_name: col for col in self.columns + } + metrics_by_name: Dict[str, SqlMetric] = {m.metric_name: m for m in self.metrics} if not granularity and is_timeseries: raise Exception( @@ -800,9 +802,9 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at for metric in metrics: if utils.is_adhoc_metric(metric): assert isinstance(metric, dict) - metrics_exprs.append(self.adhoc_metric_to_sqla(metric, cols)) - elif isinstance(metric, str) and metric in metrics_dict: - metrics_exprs.append(metrics_dict[metric].get_sqla_col()) + metrics_exprs.append(self.adhoc_metric_to_sqla(metric, columns_by_name)) + elif isinstance(metric, str) and metric in metrics_by_name: + metrics_exprs.append(metrics_by_name[metric].get_sqla_col()) else: raise Exception(_("Metric '%(metric)s' does not exist", metric=metric)) if metrics_exprs: @@ -822,8 +824,8 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at select_exprs = [] for selected in groupby: - if selected in cols: - outer = cols[selected].get_sqla_col() + if selected in columns_by_name: + outer = columns_by_name[selected].get_sqla_col() else: outer = literal_column(f"({selected})") outer = self.make_sqla_column_compatible(outer, selected) @@ -833,8 +835,8 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at elif columns: for selected in columns: select_exprs.append( - cols[selected].get_sqla_col() - if selected in cols + columns_by_name[selected].get_sqla_col() + if selected in columns_by_name else self.make_sqla_column_compatible(literal_column(selected)) ) metrics_exprs = [] @@ -843,7 +845,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at time_range_endpoints = extras.get("time_range_endpoints") groupby_exprs_with_timestamp = OrderedDict(groupby_exprs_sans_timestamp.items()) if granularity: - dttm_col = cols[granularity] + dttm_col = columns_by_name[granularity] time_grain = extras.get("time_grain_sqla") time_filters = [] @@ -859,7 +861,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at and self.main_dttm_col != dttm_col.column_name ): time_filters.append( - cols[self.main_dttm_col].get_time_filter( + columns_by_name[self.main_dttm_col].get_time_filter( from_dttm, to_dttm, time_range_endpoints ) ) @@ -892,7 +894,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at continue col = flt["col"] op = flt["op"].upper() - col_obj = cols.get(col) + col_obj = columns_by_name.get(col) if col_obj: is_list_target = op in ( utils.FilterOperator.IN.value, @@ -977,9 +979,9 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at for col, ascending in orderby: direction = asc if ascending else desc if utils.is_adhoc_metric(col): - col = self.adhoc_metric_to_sqla(col, cols) - elif col in cols: - col = cols[col].get_sqla_col() + col = self.adhoc_metric_to_sqla(col, columns_by_name) + elif col in columns_by_name: + col = columns_by_name[col].get_sqla_col() if isinstance(col, Label): label = col._label # pylint: disable=protected-access @@ -1026,7 +1028,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at ob = inner_main_metric_expr if timeseries_limit_metric: ob = self._get_timeseries_orderby( - timeseries_limit_metric, metrics_dict, cols + timeseries_limit_metric, metrics_by_name, columns_by_name ) direction = desc if order_desc else asc subq = subq.order_by(direction(ob)) @@ -1046,7 +1048,9 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at orderby = [ ( self._get_timeseries_orderby( - timeseries_limit_metric, metrics_dict, cols + timeseries_limit_metric, + metrics_by_name, + columns_by_name, ), False, ) @@ -1090,17 +1094,17 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at def _get_timeseries_orderby( self, timeseries_limit_metric: Metric, - metrics_dict: Dict[str, SqlMetric], - cols: Dict[str, Column], + metrics_by_name: Dict[str, SqlMetric], + columns_by_name: Dict[str, TableColumn], ) -> Optional[Column]: if utils.is_adhoc_metric(timeseries_limit_metric): assert isinstance(timeseries_limit_metric, dict) - ob = self.adhoc_metric_to_sqla(timeseries_limit_metric, cols) + ob = self.adhoc_metric_to_sqla(timeseries_limit_metric, columns_by_name) elif ( isinstance(timeseries_limit_metric, str) - and timeseries_limit_metric in metrics_dict + and timeseries_limit_metric in metrics_by_name ): - ob = metrics_dict[timeseries_limit_metric].get_sqla_col() + ob = metrics_by_name[timeseries_limit_metric].get_sqla_col() else: raise Exception( _("Metric '%(metric)s' does not exist", metric=timeseries_limit_metric)