Force quoted column aliases for Oracle-like databases (#5686)

* Replace dataframe label override logic with table column override

* Add mutation to any_date_col

* Linting

* Add mutation to oracle and redshift

* Fine tune how and which labels are mutated

* Implement alias quoting logic for oracle-like databases

* Fix and align column and metric sqla_col methods

* Clean up typos and redundant logic

* Move new attribute to old location

* Linting

* Replace old sqla_col property references with function calls

* Remove redundant calls to mutate_column_label

* Move duplicated logic to common function

* Add db_engine_specs to all sqla_col calls

* Add missing mydb

* Add note about snowflake-sqlalchemy regression

* Make db_engine_spec mandatory in sqla_col

* Small refactoring and cleanup

* Remove db_engine_spec from get_from_clause call

* Make db_engine_spec mandatory in adhoc_metric_to_sa

* Remove redundant mutate_expression_label call

* Add missing db_engine_specs to adhoc_metric_to_sa

* Rename arg label_name to label in get_column_label()

* Rename label function and add docstring

* Remove redundant db_engine_spec args

* Rename col_label to label

* Remove get_column_name wrapper and make direct calls to db_engine_spec

* Remove unneeded db_engine_specs

* Rename sa_ vars to sqla_
This commit is contained in:
Ville Brofeldt 2018-09-04 08:49:58 +03:00 committed by Maxime Beauchemin
parent 8a4b1b7c25
commit 77fe9ef130
5 changed files with 75 additions and 107 deletions

View File

@ -393,6 +393,10 @@ Make sure the user has privileges to access and use all required
databases/schemas/tables/views/warehouses, as the Snowflake SQLAlchemy engine does databases/schemas/tables/views/warehouses, as the Snowflake SQLAlchemy engine does
not test for user rights during engine creation. not test for user rights during engine creation.
*Note*: At the time of writing, there is a regression in the current stable version (1.1.2) of
snowflake-sqlalchemy package that causes problems when used with Superset. It is recommended to
use version 1.1.0 or try a newer version.
See `Snowflake SQLAlchemy <https://github.com/snowflakedb/snowflake-sqlalchemy>`_. See `Snowflake SQLAlchemy <https://github.com/snowflakedb/snowflake-sqlalchemy>`_.
Caching Caching

View File

@ -99,13 +99,13 @@ class TableColumn(Model, BaseColumn):
s for s in export_fields if s not in ('table_id',)] s for s in export_fields if s not in ('table_id',)]
export_parent = 'table' export_parent = 'table'
@property def get_sqla_col(self, label=None):
def sqla_col(self): db_engine_spec = self.table.database.db_engine_spec
name = self.column_name label = db_engine_spec.make_label_compatible(label if label else self.column_name)
if not self.expression: if not self.expression:
col = column(self.column_name).label(name) col = column(self.column_name).label(label)
else: else:
col = literal_column(self.expression).label(name) col = literal_column(self.expression).label(label)
return col return col
@property @property
@ -113,7 +113,7 @@ class TableColumn(Model, BaseColumn):
return self.table return self.table
def get_time_filter(self, start_dttm, end_dttm): def get_time_filter(self, start_dttm, end_dttm):
col = self.sqla_col.label('__time') col = self.get_sqla_col(label='__time')
l = [] # noqa: E741 l = [] # noqa: E741
if start_dttm: if start_dttm:
l.append(col >= text(self.dttm_sql_literal(start_dttm))) l.append(col >= text(self.dttm_sql_literal(start_dttm)))
@ -231,10 +231,10 @@ class SqlMetric(Model, BaseMetric):
s for s in export_fields if s not in ('table_id', )]) s for s in export_fields if s not in ('table_id', )])
export_parent = 'table' export_parent = 'table'
@property def get_sqla_col(self, label=None):
def sqla_col(self): db_engine_spec = self.table.database.db_engine_spec
name = self.metric_name label = db_engine_spec.make_label_compatible(label if label else self.metric_name)
return literal_column(self.expression).label(name) return literal_column(self.expression).label(label)
@property @property
def perm(self): def perm(self):
@ -421,11 +421,10 @@ class SqlaTable(Model, BaseDatasource):
cols = {col.column_name: col for col in self.columns} cols = {col.column_name: col for col in self.columns}
target_col = cols[column_name] target_col = cols[column_name]
tp = self.get_template_processor() tp = self.get_template_processor()
db_engine_spec = self.database.db_engine_spec
qry = ( qry = (
select([target_col.sqla_col]) select([target_col.get_sqla_col()])
.select_from(self.get_from_clause(tp, db_engine_spec)) .select_from(self.get_from_clause(tp))
.distinct() .distinct()
) )
if limit: if limit:
@ -474,7 +473,7 @@ class SqlaTable(Model, BaseDatasource):
tbl.schema = self.schema tbl.schema = self.schema
return tbl return tbl
def get_from_clause(self, template_processor=None, db_engine_spec=None): def get_from_clause(self, template_processor=None):
# Supporting arbitrary SQL statements in place of tables # Supporting arbitrary SQL statements in place of tables
if self.sql: if self.sql:
from_sql = self.sql from_sql = self.sql
@ -484,7 +483,7 @@ class SqlaTable(Model, BaseDatasource):
return TextAsFrom(sa.text(from_sql), []).alias('expr_qry') return TextAsFrom(sa.text(from_sql), []).alias('expr_qry')
return self.get_sqla_table() return self.get_sqla_table()
def adhoc_metric_to_sa(self, metric, cols): def adhoc_metric_to_sqla(self, metric, cols):
""" """
Turn an adhoc metric into a sqlalchemy column. Turn an adhoc metric into a sqlalchemy column.
@ -493,22 +492,25 @@ class SqlaTable(Model, BaseDatasource):
:returns: The metric defined as a sqlalchemy column :returns: The metric defined as a sqlalchemy column
:rtype: sqlalchemy.sql.column :rtype: sqlalchemy.sql.column
""" """
expressionType = metric.get('expressionType') expression_type = metric.get('expressionType')
if expressionType == utils.ADHOC_METRIC_EXPRESSION_TYPES['SIMPLE']: db_engine_spec = self.database.db_engine_spec
label = db_engine_spec.make_label_compatible(metric.get('label'))
if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SIMPLE']:
column_name = metric.get('column').get('column_name') column_name = metric.get('column').get('column_name')
sa_column = column(column_name) sqla_column = column(column_name)
table_column = cols.get(column_name) table_column = cols.get(column_name)
if table_column: if table_column:
sa_column = table_column.sqla_col sqla_column = table_column.get_sqla_col()
sa_metric = self.sqla_aggregations[metric.get('aggregate')](sa_column) sqla_metric = self.sqla_aggregations[metric.get('aggregate')](sqla_column)
sa_metric = sa_metric.label(metric.get('label')) sqla_metric = sqla_metric.label(label)
return sa_metric return sqla_metric
elif expressionType == utils.ADHOC_METRIC_EXPRESSION_TYPES['SQL']: elif expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SQL']:
sa_metric = literal_column(metric.get('sqlExpression')) sqla_metric = literal_column(metric.get('sqlExpression'))
sa_metric = sa_metric.label(metric.get('label')) sqla_metric = sqla_metric.label(label)
return sa_metric return sqla_metric
else: else:
return None return None
@ -566,15 +568,16 @@ class SqlaTable(Model, BaseDatasource):
metrics_exprs = [] metrics_exprs = []
for m in metrics: for m in metrics:
if utils.is_adhoc_metric(m): if utils.is_adhoc_metric(m):
metrics_exprs.append(self.adhoc_metric_to_sa(m, cols)) metrics_exprs.append(self.adhoc_metric_to_sqla(m, cols))
elif m in metrics_dict: elif m in metrics_dict:
metrics_exprs.append(metrics_dict.get(m).sqla_col) metrics_exprs.append(metrics_dict.get(m).get_sqla_col())
else: else:
raise Exception(_("Metric '{}' is not valid".format(m))) raise Exception(_("Metric '{}' is not valid".format(m)))
if metrics_exprs: if metrics_exprs:
main_metric_expr = metrics_exprs[0] main_metric_expr = metrics_exprs[0]
else: else:
main_metric_expr = literal_column('COUNT(*)').label('ccount') main_metric_expr = literal_column('COUNT(*)').label(
db_engine_spec.make_label_compatible('count'))
select_exprs = [] select_exprs = []
groupby_exprs = [] groupby_exprs = []
@ -585,8 +588,8 @@ class SqlaTable(Model, BaseDatasource):
inner_groupby_exprs = [] inner_groupby_exprs = []
for s in groupby: for s in groupby:
col = cols[s] col = cols[s]
outer = col.sqla_col outer = col.get_sqla_col()
inner = col.sqla_col.label(col.column_name + '__') inner = col.get_sqla_col(col.column_name + '__')
groupby_exprs.append(outer) groupby_exprs.append(outer)
select_exprs.append(outer) select_exprs.append(outer)
@ -594,7 +597,7 @@ class SqlaTable(Model, BaseDatasource):
inner_select_exprs.append(inner) inner_select_exprs.append(inner)
elif columns: elif columns:
for s in columns: for s in columns:
select_exprs.append(cols[s].sqla_col) select_exprs.append(cols[s].get_sqla_col())
metrics_exprs = [] metrics_exprs = []
if granularity: if granularity:
@ -618,7 +621,7 @@ class SqlaTable(Model, BaseDatasource):
select_exprs += metrics_exprs select_exprs += metrics_exprs
qry = sa.select(select_exprs) qry = sa.select(select_exprs)
tbl = self.get_from_clause(template_processor, db_engine_spec) tbl = self.get_from_clause(template_processor)
if not columns: if not columns:
qry = qry.group_by(*groupby_exprs) qry = qry.group_by(*groupby_exprs)
@ -638,9 +641,9 @@ class SqlaTable(Model, BaseDatasource):
target_column_is_numeric=col_obj.is_num, target_column_is_numeric=col_obj.is_num,
is_list_target=is_list_target) is_list_target=is_list_target)
if op in ('in', 'not in'): if op in ('in', 'not in'):
cond = col_obj.sqla_col.in_(eq) cond = col_obj.get_sqla_col().in_(eq)
if '<NULL>' in eq: if '<NULL>' in eq:
cond = or_(cond, col_obj.sqla_col == None) # noqa cond = or_(cond, col_obj.get_sqla_col() == None) # noqa
if op == 'not in': if op == 'not in':
cond = ~cond cond = ~cond
where_clause_and.append(cond) where_clause_and.append(cond)
@ -648,23 +651,24 @@ class SqlaTable(Model, BaseDatasource):
if col_obj.is_num: if col_obj.is_num:
eq = utils.string_to_num(flt['val']) eq = utils.string_to_num(flt['val'])
if op == '==': if op == '==':
where_clause_and.append(col_obj.sqla_col == eq) where_clause_and.append(col_obj.get_sqla_col() == eq)
elif op == '!=': elif op == '!=':
where_clause_and.append(col_obj.sqla_col != eq) where_clause_and.append(col_obj.get_sqla_col() != eq)
elif op == '>': elif op == '>':
where_clause_and.append(col_obj.sqla_col > eq) where_clause_and.append(col_obj.get_sqla_col() > eq)
elif op == '<': elif op == '<':
where_clause_and.append(col_obj.sqla_col < eq) where_clause_and.append(col_obj.get_sqla_col() < eq)
elif op == '>=': elif op == '>=':
where_clause_and.append(col_obj.sqla_col >= eq) where_clause_and.append(col_obj.get_sqla_col() >= eq)
elif op == '<=': elif op == '<=':
where_clause_and.append(col_obj.sqla_col <= eq) where_clause_and.append(col_obj.get_sqla_col() <= eq)
elif op == 'LIKE': elif op == 'LIKE':
where_clause_and.append(col_obj.sqla_col.like(eq)) where_clause_and.append(col_obj.get_sqla_col().like(eq))
elif op == 'IS NULL': elif op == 'IS NULL':
where_clause_and.append(col_obj.sqla_col == None) # noqa where_clause_and.append(col_obj.get_sqla_col() == None) # noqa
elif op == 'IS NOT NULL': elif op == 'IS NOT NULL':
where_clause_and.append(col_obj.sqla_col != None) # noqa where_clause_and.append(
col_obj.get_sqla_col() != None) # noqa
if extras: if extras:
where = extras.get('where') where = extras.get('where')
if where: if where:
@ -686,7 +690,7 @@ class SqlaTable(Model, BaseDatasource):
for col, ascending in orderby: for col, ascending in orderby:
direction = asc if ascending else desc direction = asc if ascending else desc
if utils.is_adhoc_metric(col): if utils.is_adhoc_metric(col):
col = self.adhoc_metric_to_sa(col, cols) col = self.adhoc_metric_to_sqla(col, cols)
qry = qry.order_by(direction(col)) qry = qry.order_by(direction(col))
if row_limit: if row_limit:
@ -712,12 +716,12 @@ class SqlaTable(Model, BaseDatasource):
ob = inner_main_metric_expr ob = inner_main_metric_expr
if timeseries_limit_metric: if timeseries_limit_metric:
if utils.is_adhoc_metric(timeseries_limit_metric): if utils.is_adhoc_metric(timeseries_limit_metric):
ob = self.adhoc_metric_to_sa(timeseries_limit_metric, cols) ob = self.adhoc_metric_to_sqla(timeseries_limit_metric, cols)
elif timeseries_limit_metric in metrics_dict: elif timeseries_limit_metric in metrics_dict:
timeseries_limit_metric = metrics_dict.get( timeseries_limit_metric = metrics_dict.get(
timeseries_limit_metric, timeseries_limit_metric,
) )
ob = timeseries_limit_metric.sqla_col ob = timeseries_limit_metric.get_sqla_col()
else: else:
raise Exception(_("Metric '{}' is not valid".format(m))) raise Exception(_("Metric '{}' is not valid".format(m)))
direction = desc if order_desc else asc direction = desc if order_desc else asc
@ -762,7 +766,7 @@ class SqlaTable(Model, BaseDatasource):
group = [] group = []
for dimension in dimensions: for dimension in dimensions:
col_obj = cols.get(dimension) col_obj = cols.get(dimension)
group.append(col_obj.sqla_col == row[dimension]) group.append(col_obj.get_sqla_col() == row[dimension])
groups.append(and_(*group)) groups.append(and_(*group))
return or_(*groups) return or_(*groups)
@ -816,6 +820,7 @@ class SqlaTable(Model, BaseDatasource):
.filter(or_(TableColumn.column_name == col.name .filter(or_(TableColumn.column_name == col.name
for col in table.columns))) for col in table.columns)))
dbcols = {dbcol.column_name: dbcol for dbcol in dbcols} dbcols = {dbcol.column_name: dbcol for dbcol in dbcols}
db_engine_spec = self.database.db_engine_spec
for col in table.columns: for col in table.columns:
try: try:
@ -848,6 +853,9 @@ class SqlaTable(Model, BaseDatasource):
)) ))
if not self.main_dttm_col: if not self.main_dttm_col:
self.main_dttm_col = any_date_col self.main_dttm_col = any_date_col
for metric in metrics:
metric.metric_name = db_engine_spec.mutate_expression_label(
metric.metric_name)
self.add_missing_metrics(metrics) self.add_missing_metrics(metrics)
db.session.merge(self) db.session.merge(self)
db.session.commit() db.session.commit()

View File

@ -73,9 +73,7 @@ class SupersetDataFrame(object):
if cursor_description: if cursor_description:
column_names = [col[0] for col in cursor_description] column_names = [col[0] for col in cursor_description]
case_sensitive = db_engine_spec.consistent_case_sensitivity self.column_names = dedup(column_names)
self.column_names = dedup(column_names,
case_sensitive=case_sensitive)
data = data or [] data = data or []
self.df = ( self.df = (

View File

@ -35,7 +35,7 @@ import sqlalchemy as sqla
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.engine import create_engine from sqlalchemy.engine import create_engine
from sqlalchemy.engine.url import make_url from sqlalchemy.engine.url import make_url
from sqlalchemy.sql import text from sqlalchemy.sql import quoted_name, text
from sqlalchemy.sql.expression import TextAsFrom from sqlalchemy.sql.expression import TextAsFrom
import sqlparse import sqlparse
from tableschema import Table from tableschema import Table
@ -101,7 +101,7 @@ class BaseEngineSpec(object):
time_secondary_columns = False time_secondary_columns = False
inner_joins = True inner_joins = True
allows_subquery = True allows_subquery = True
consistent_case_sensitivity = True # do results have same case as qry for col names? force_column_alias_quotes = False
arraysize = None arraysize = None
@classmethod @classmethod
@ -376,55 +376,15 @@ class BaseEngineSpec(object):
cursor.execute(query) cursor.execute(query)
@classmethod @classmethod
def adjust_df_column_names(cls, df, fd): def make_label_compatible(cls, label):
"""Based of fields in form_data, return dataframe with new column names
Usually sqla engines return column names whose case matches that of the
original query. For example:
SELECT 1 as col1, 2 as COL2, 3 as Col_3
will usually result in the following df.columns:
['col1', 'COL2', 'Col_3'].
For these engines there is no need to adjust the dataframe column names
(default behavior). However, some engines (at least Snowflake, Oracle and
Redshift) return column names with different case than in the original query,
usually all uppercase. For these the column names need to be adjusted to
correspond to the case of the fields specified in the form data for Viz
to work properly. This adjustment can be done here.
""" """
if cls.consistent_case_sensitivity: Return a sqlalchemy.sql.elements.quoted_name if the engine requires
return df quoting of aliases to ensure that select query and query results
else: have same case.
return cls.align_df_col_names_with_form_data(df, fd)
@staticmethod
def align_df_col_names_with_form_data(df, fd):
"""Helper function to rename columns that have changed case during query.
Returns a dataframe where column names have been adjusted to correspond with
column names in form data (case insensitive). Examples:
dataframe: 'col1', form_data: 'col1' -> no change
dataframe: 'COL1', form_data: 'col1' -> dataframe column renamed: 'col1'
dataframe: 'col1', form_data: 'Col1' -> dataframe column renamed: 'Col1'
""" """
if cls.force_column_alias_quotes is True:
columns = set() return quoted_name(label, True)
lowercase_mapping = {} return label
metrics = utils.get_metric_names(fd.get('metrics', []))
groupby = fd.get('groupby', [])
other_cols = [utils.DTTM_ALIAS]
for col in metrics + groupby + other_cols:
columns.add(col)
lowercase_mapping[col.lower()] = col
rename_cols = {}
for col in df.columns:
if col not in columns:
orig_col = lowercase_mapping.get(col.lower())
if orig_col:
rename_cols[col] = orig_col
return df.rename(index=str, columns=rename_cols)
@staticmethod @staticmethod
def mutate_expression_label(label): def mutate_expression_label(label):
@ -478,7 +438,8 @@ class PostgresEngineSpec(PostgresBaseEngineSpec):
class SnowflakeEngineSpec(PostgresBaseEngineSpec): class SnowflakeEngineSpec(PostgresBaseEngineSpec):
engine = 'snowflake' engine = 'snowflake'
consistent_case_sensitivity = False force_column_alias_quotes = True
time_grain_functions = { time_grain_functions = {
None: '{col}', None: '{col}',
'PT1S': "DATE_TRUNC('SECOND', {col})", 'PT1S': "DATE_TRUNC('SECOND', {col})",
@ -515,13 +476,13 @@ class VerticaEngineSpec(PostgresBaseEngineSpec):
class RedshiftEngineSpec(PostgresBaseEngineSpec): class RedshiftEngineSpec(PostgresBaseEngineSpec):
engine = 'redshift' engine = 'redshift'
consistent_case_sensitivity = False force_column_alias_quotes = True
class OracleEngineSpec(PostgresBaseEngineSpec): class OracleEngineSpec(PostgresBaseEngineSpec):
engine = 'oracle' engine = 'oracle'
limit_method = LimitMethod.WRAP_SQL limit_method = LimitMethod.WRAP_SQL
consistent_case_sensitivity = False force_column_alias_quotes = True
time_grain_functions = { time_grain_functions = {
None: '{col}', None: '{col}',
@ -545,6 +506,7 @@ class OracleEngineSpec(PostgresBaseEngineSpec):
class Db2EngineSpec(BaseEngineSpec): class Db2EngineSpec(BaseEngineSpec):
engine = 'ibm_db_sa' engine = 'ibm_db_sa'
limit_method = LimitMethod.WRAP_SQL limit_method = LimitMethod.WRAP_SQL
force_column_alias_quotes = True
time_grain_functions = { time_grain_functions = {
None: '{col}', None: '{col}',
'PT1S': 'CAST({col} as TIMESTAMP)' 'PT1S': 'CAST({col} as TIMESTAMP)'

View File

@ -391,10 +391,6 @@ class BaseViz(object):
if query_obj and not is_loaded: if query_obj and not is_loaded:
try: try:
df = self.get_df(query_obj) df = self.get_df(query_obj)
if hasattr(self.datasource, 'database') and \
hasattr(self.datasource.database, 'db_engine_spec'):
db_engine_spec = self.datasource.database.db_engine_spec
df = db_engine_spec.adjust_df_column_names(df, self.form_data)
if self.status != utils.QueryStatus.FAILED: if self.status != utils.QueryStatus.FAILED:
stats_logger.incr('loaded_from_source') stats_logger.incr('loaded_from_source')
is_loaded = True is_loaded = True