From 77fe9ef130a383d6902b63f89743446b5578d3b5 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt <33317356+villebro@users.noreply.github.com> Date: Tue, 4 Sep 2018 08:49:58 +0300 Subject: [PATCH] Force quoted column aliases for Oracle-like databases (#5686) * Replace dataframe label override logic with table column override * Add mutation to any_date_col * Linting * Add mutation to oracle and redshift * Fine tune how and which labels are mutated * Implement alias quoting logic for oracle-like databases * Fix and align column and metric sqla_col methods * Clean up typos and redundant logic * Move new attribute to old location * Linting * Replace old sqla_col property references with function calls * Remove redundant calls to mutate_column_label * Move duplicated logic to common function * Add db_engine_specs to all sqla_col calls * Add missing mydb * Add note about snowflake-sqlalchemy regression * Make db_engine_spec mandatory in sqla_col * Small refactoring and cleanup * Remove db_engine_spec from get_from_clause call * Make db_engine_spec mandatory in adhoc_metric_to_sa * Remove redundant mutate_expression_label call * Add missing db_engine_specs to adhoc_metric_to_sa * Rename arg label_name to label in get_column_label() * Rename label function and add docstring * Remove redundant db_engine_spec args * Rename col_label to label * Remove get_column_name wrapper and make direct calls to db_engine_spec * Remove unneeded db_engine_specs * Rename sa_ vars to sqla_ --- docs/installation.rst | 4 ++ superset/connectors/sqla/models.py | 104 ++++++++++++++++------------- superset/dataframe.py | 4 +- superset/db_engine_specs.py | 66 ++++-------------- superset/viz.py | 4 -- 5 files changed, 75 insertions(+), 107 deletions(-) diff --git a/docs/installation.rst b/docs/installation.rst index 6b08b82ab3..75293233fb 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -393,6 +393,10 @@ Make sure the user has privileges to access and use all required databases/schemas/tables/views/warehouses, as the Snowflake SQLAlchemy engine does not test for user rights during engine creation. +*Note*: At the time of writing, there is a regression in the current stable version (1.1.2) of +snowflake-sqlalchemy package that causes problems when used with Superset. It is recommended to +use version 1.1.0 or try a newer version. + See `Snowflake SQLAlchemy `_. Caching diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index f410279d0f..037eb779a0 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -99,13 +99,13 @@ class TableColumn(Model, BaseColumn): s for s in export_fields if s not in ('table_id',)] export_parent = 'table' - @property - def sqla_col(self): - name = self.column_name + def get_sqla_col(self, label=None): + db_engine_spec = self.table.database.db_engine_spec + label = db_engine_spec.make_label_compatible(label if label else self.column_name) if not self.expression: - col = column(self.column_name).label(name) + col = column(self.column_name).label(label) else: - col = literal_column(self.expression).label(name) + col = literal_column(self.expression).label(label) return col @property @@ -113,7 +113,7 @@ class TableColumn(Model, BaseColumn): return self.table def get_time_filter(self, start_dttm, end_dttm): - col = self.sqla_col.label('__time') + col = self.get_sqla_col(label='__time') l = [] # noqa: E741 if start_dttm: l.append(col >= text(self.dttm_sql_literal(start_dttm))) @@ -231,10 +231,10 @@ class SqlMetric(Model, BaseMetric): s for s in export_fields if s not in ('table_id', )]) export_parent = 'table' - @property - def sqla_col(self): - name = self.metric_name - return literal_column(self.expression).label(name) + def get_sqla_col(self, label=None): + db_engine_spec = self.table.database.db_engine_spec + label = db_engine_spec.make_label_compatible(label if label else self.metric_name) + return literal_column(self.expression).label(label) @property def perm(self): @@ -421,11 +421,10 @@ class SqlaTable(Model, BaseDatasource): cols = {col.column_name: col for col in self.columns} target_col = cols[column_name] tp = self.get_template_processor() - db_engine_spec = self.database.db_engine_spec qry = ( - select([target_col.sqla_col]) - .select_from(self.get_from_clause(tp, db_engine_spec)) + select([target_col.get_sqla_col()]) + .select_from(self.get_from_clause(tp)) .distinct() ) if limit: @@ -474,7 +473,7 @@ class SqlaTable(Model, BaseDatasource): tbl.schema = self.schema return tbl - def get_from_clause(self, template_processor=None, db_engine_spec=None): + def get_from_clause(self, template_processor=None): # Supporting arbitrary SQL statements in place of tables if self.sql: from_sql = self.sql @@ -484,7 +483,7 @@ class SqlaTable(Model, BaseDatasource): return TextAsFrom(sa.text(from_sql), []).alias('expr_qry') return self.get_sqla_table() - def adhoc_metric_to_sa(self, metric, cols): + def adhoc_metric_to_sqla(self, metric, cols): """ Turn an adhoc metric into a sqlalchemy column. @@ -493,22 +492,25 @@ class SqlaTable(Model, BaseDatasource): :returns: The metric defined as a sqlalchemy column :rtype: sqlalchemy.sql.column """ - expressionType = metric.get('expressionType') - if expressionType == utils.ADHOC_METRIC_EXPRESSION_TYPES['SIMPLE']: + expression_type = metric.get('expressionType') + db_engine_spec = self.database.db_engine_spec + label = db_engine_spec.make_label_compatible(metric.get('label')) + + if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SIMPLE']: column_name = metric.get('column').get('column_name') - sa_column = column(column_name) + sqla_column = column(column_name) table_column = cols.get(column_name) if table_column: - sa_column = table_column.sqla_col + sqla_column = table_column.get_sqla_col() - sa_metric = self.sqla_aggregations[metric.get('aggregate')](sa_column) - sa_metric = sa_metric.label(metric.get('label')) - return sa_metric - elif expressionType == utils.ADHOC_METRIC_EXPRESSION_TYPES['SQL']: - sa_metric = literal_column(metric.get('sqlExpression')) - sa_metric = sa_metric.label(metric.get('label')) - return sa_metric + sqla_metric = self.sqla_aggregations[metric.get('aggregate')](sqla_column) + sqla_metric = sqla_metric.label(label) + return sqla_metric + elif expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SQL']: + sqla_metric = literal_column(metric.get('sqlExpression')) + sqla_metric = sqla_metric.label(label) + return sqla_metric else: return None @@ -566,15 +568,16 @@ class SqlaTable(Model, BaseDatasource): metrics_exprs = [] for m in metrics: if utils.is_adhoc_metric(m): - metrics_exprs.append(self.adhoc_metric_to_sa(m, cols)) + metrics_exprs.append(self.adhoc_metric_to_sqla(m, cols)) elif m in metrics_dict: - metrics_exprs.append(metrics_dict.get(m).sqla_col) + metrics_exprs.append(metrics_dict.get(m).get_sqla_col()) else: raise Exception(_("Metric '{}' is not valid".format(m))) if metrics_exprs: main_metric_expr = metrics_exprs[0] else: - main_metric_expr = literal_column('COUNT(*)').label('ccount') + main_metric_expr = literal_column('COUNT(*)').label( + db_engine_spec.make_label_compatible('count')) select_exprs = [] groupby_exprs = [] @@ -585,8 +588,8 @@ class SqlaTable(Model, BaseDatasource): inner_groupby_exprs = [] for s in groupby: col = cols[s] - outer = col.sqla_col - inner = col.sqla_col.label(col.column_name + '__') + outer = col.get_sqla_col() + inner = col.get_sqla_col(col.column_name + '__') groupby_exprs.append(outer) select_exprs.append(outer) @@ -594,7 +597,7 @@ class SqlaTable(Model, BaseDatasource): inner_select_exprs.append(inner) elif columns: for s in columns: - select_exprs.append(cols[s].sqla_col) + select_exprs.append(cols[s].get_sqla_col()) metrics_exprs = [] if granularity: @@ -618,7 +621,7 @@ class SqlaTable(Model, BaseDatasource): select_exprs += metrics_exprs qry = sa.select(select_exprs) - tbl = self.get_from_clause(template_processor, db_engine_spec) + tbl = self.get_from_clause(template_processor) if not columns: qry = qry.group_by(*groupby_exprs) @@ -638,9 +641,9 @@ class SqlaTable(Model, BaseDatasource): target_column_is_numeric=col_obj.is_num, is_list_target=is_list_target) if op in ('in', 'not in'): - cond = col_obj.sqla_col.in_(eq) + cond = col_obj.get_sqla_col().in_(eq) if '' in eq: - cond = or_(cond, col_obj.sqla_col == None) # noqa + cond = or_(cond, col_obj.get_sqla_col() == None) # noqa if op == 'not in': cond = ~cond where_clause_and.append(cond) @@ -648,23 +651,24 @@ class SqlaTable(Model, BaseDatasource): if col_obj.is_num: eq = utils.string_to_num(flt['val']) if op == '==': - where_clause_and.append(col_obj.sqla_col == eq) + where_clause_and.append(col_obj.get_sqla_col() == eq) elif op == '!=': - where_clause_and.append(col_obj.sqla_col != eq) + where_clause_and.append(col_obj.get_sqla_col() != eq) elif op == '>': - where_clause_and.append(col_obj.sqla_col > eq) + where_clause_and.append(col_obj.get_sqla_col() > eq) elif op == '<': - where_clause_and.append(col_obj.sqla_col < eq) + where_clause_and.append(col_obj.get_sqla_col() < eq) elif op == '>=': - where_clause_and.append(col_obj.sqla_col >= eq) + where_clause_and.append(col_obj.get_sqla_col() >= eq) elif op == '<=': - where_clause_and.append(col_obj.sqla_col <= eq) + where_clause_and.append(col_obj.get_sqla_col() <= eq) elif op == 'LIKE': - where_clause_and.append(col_obj.sqla_col.like(eq)) + where_clause_and.append(col_obj.get_sqla_col().like(eq)) elif op == 'IS NULL': - where_clause_and.append(col_obj.sqla_col == None) # noqa + where_clause_and.append(col_obj.get_sqla_col() == None) # noqa elif op == 'IS NOT NULL': - where_clause_and.append(col_obj.sqla_col != None) # noqa + where_clause_and.append( + col_obj.get_sqla_col() != None) # noqa if extras: where = extras.get('where') if where: @@ -686,7 +690,7 @@ class SqlaTable(Model, BaseDatasource): for col, ascending in orderby: direction = asc if ascending else desc if utils.is_adhoc_metric(col): - col = self.adhoc_metric_to_sa(col, cols) + col = self.adhoc_metric_to_sqla(col, cols) qry = qry.order_by(direction(col)) if row_limit: @@ -712,12 +716,12 @@ class SqlaTable(Model, BaseDatasource): ob = inner_main_metric_expr if timeseries_limit_metric: if utils.is_adhoc_metric(timeseries_limit_metric): - ob = self.adhoc_metric_to_sa(timeseries_limit_metric, cols) + ob = self.adhoc_metric_to_sqla(timeseries_limit_metric, cols) elif timeseries_limit_metric in metrics_dict: timeseries_limit_metric = metrics_dict.get( timeseries_limit_metric, ) - ob = timeseries_limit_metric.sqla_col + ob = timeseries_limit_metric.get_sqla_col() else: raise Exception(_("Metric '{}' is not valid".format(m))) direction = desc if order_desc else asc @@ -762,7 +766,7 @@ class SqlaTable(Model, BaseDatasource): group = [] for dimension in dimensions: col_obj = cols.get(dimension) - group.append(col_obj.sqla_col == row[dimension]) + group.append(col_obj.get_sqla_col() == row[dimension]) groups.append(and_(*group)) return or_(*groups) @@ -816,6 +820,7 @@ class SqlaTable(Model, BaseDatasource): .filter(or_(TableColumn.column_name == col.name for col in table.columns))) dbcols = {dbcol.column_name: dbcol for dbcol in dbcols} + db_engine_spec = self.database.db_engine_spec for col in table.columns: try: @@ -848,6 +853,9 @@ class SqlaTable(Model, BaseDatasource): )) if not self.main_dttm_col: self.main_dttm_col = any_date_col + for metric in metrics: + metric.metric_name = db_engine_spec.mutate_expression_label( + metric.metric_name) self.add_missing_metrics(metrics) db.session.merge(self) db.session.commit() diff --git a/superset/dataframe.py b/superset/dataframe.py index 834f118047..1678dd97f7 100644 --- a/superset/dataframe.py +++ b/superset/dataframe.py @@ -73,9 +73,7 @@ class SupersetDataFrame(object): if cursor_description: column_names = [col[0] for col in cursor_description] - case_sensitive = db_engine_spec.consistent_case_sensitivity - self.column_names = dedup(column_names, - case_sensitive=case_sensitive) + self.column_names = dedup(column_names) data = data or [] self.df = ( diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index 2ce7db06bf..a8a9faabb1 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -35,7 +35,7 @@ import sqlalchemy as sqla from sqlalchemy import select from sqlalchemy.engine import create_engine from sqlalchemy.engine.url import make_url -from sqlalchemy.sql import text +from sqlalchemy.sql import quoted_name, text from sqlalchemy.sql.expression import TextAsFrom import sqlparse from tableschema import Table @@ -101,7 +101,7 @@ class BaseEngineSpec(object): time_secondary_columns = False inner_joins = True allows_subquery = True - consistent_case_sensitivity = True # do results have same case as qry for col names? + force_column_alias_quotes = False arraysize = None @classmethod @@ -376,55 +376,15 @@ class BaseEngineSpec(object): cursor.execute(query) @classmethod - def adjust_df_column_names(cls, df, fd): - """Based of fields in form_data, return dataframe with new column names - - Usually sqla engines return column names whose case matches that of the - original query. For example: - SELECT 1 as col1, 2 as COL2, 3 as Col_3 - will usually result in the following df.columns: - ['col1', 'COL2', 'Col_3']. - For these engines there is no need to adjust the dataframe column names - (default behavior). However, some engines (at least Snowflake, Oracle and - Redshift) return column names with different case than in the original query, - usually all uppercase. For these the column names need to be adjusted to - correspond to the case of the fields specified in the form data for Viz - to work properly. This adjustment can be done here. + def make_label_compatible(cls, label): """ - if cls.consistent_case_sensitivity: - return df - else: - return cls.align_df_col_names_with_form_data(df, fd) - - @staticmethod - def align_df_col_names_with_form_data(df, fd): - """Helper function to rename columns that have changed case during query. - - Returns a dataframe where column names have been adjusted to correspond with - column names in form data (case insensitive). Examples: - dataframe: 'col1', form_data: 'col1' -> no change - dataframe: 'COL1', form_data: 'col1' -> dataframe column renamed: 'col1' - dataframe: 'col1', form_data: 'Col1' -> dataframe column renamed: 'Col1' + Return a sqlalchemy.sql.elements.quoted_name if the engine requires + quoting of aliases to ensure that select query and query results + have same case. """ - - columns = set() - lowercase_mapping = {} - - metrics = utils.get_metric_names(fd.get('metrics', [])) - groupby = fd.get('groupby', []) - other_cols = [utils.DTTM_ALIAS] - for col in metrics + groupby + other_cols: - columns.add(col) - lowercase_mapping[col.lower()] = col - - rename_cols = {} - for col in df.columns: - if col not in columns: - orig_col = lowercase_mapping.get(col.lower()) - if orig_col: - rename_cols[col] = orig_col - - return df.rename(index=str, columns=rename_cols) + if cls.force_column_alias_quotes is True: + return quoted_name(label, True) + return label @staticmethod def mutate_expression_label(label): @@ -478,7 +438,8 @@ class PostgresEngineSpec(PostgresBaseEngineSpec): class SnowflakeEngineSpec(PostgresBaseEngineSpec): engine = 'snowflake' - consistent_case_sensitivity = False + force_column_alias_quotes = True + time_grain_functions = { None: '{col}', 'PT1S': "DATE_TRUNC('SECOND', {col})", @@ -515,13 +476,13 @@ class VerticaEngineSpec(PostgresBaseEngineSpec): class RedshiftEngineSpec(PostgresBaseEngineSpec): engine = 'redshift' - consistent_case_sensitivity = False + force_column_alias_quotes = True class OracleEngineSpec(PostgresBaseEngineSpec): engine = 'oracle' limit_method = LimitMethod.WRAP_SQL - consistent_case_sensitivity = False + force_column_alias_quotes = True time_grain_functions = { None: '{col}', @@ -545,6 +506,7 @@ class OracleEngineSpec(PostgresBaseEngineSpec): class Db2EngineSpec(BaseEngineSpec): engine = 'ibm_db_sa' limit_method = LimitMethod.WRAP_SQL + force_column_alias_quotes = True time_grain_functions = { None: '{col}', 'PT1S': 'CAST({col} as TIMESTAMP)' diff --git a/superset/viz.py b/superset/viz.py index 90c209efe8..6a18dfab17 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -391,10 +391,6 @@ class BaseViz(object): if query_obj and not is_loaded: try: df = self.get_df(query_obj) - if hasattr(self.datasource, 'database') and \ - hasattr(self.datasource.database, 'db_engine_spec'): - db_engine_spec = self.datasource.database.db_engine_spec - df = db_engine_spec.adjust_df_column_names(df, self.form_data) if self.status != utils.QueryStatus.FAILED: stats_logger.incr('loaded_from_source') is_loaded = True