Force quoted column aliases for Oracle-like databases (#5686)

* Replace dataframe label override logic with table column override

* Add mutation to any_date_col

* Linting

* Add mutation to oracle and redshift

* Fine tune how and which labels are mutated

* Implement alias quoting logic for oracle-like databases

* Fix and align column and metric sqla_col methods

* Clean up typos and redundant logic

* Move new attribute to old location

* Linting

* Replace old sqla_col property references with function calls

* Remove redundant calls to mutate_column_label

* Move duplicated logic to common function

* Add db_engine_specs to all sqla_col calls

* Add missing mydb

* Add note about snowflake-sqlalchemy regression

* Make db_engine_spec mandatory in sqla_col

* Small refactoring and cleanup

* Remove db_engine_spec from get_from_clause call

* Make db_engine_spec mandatory in adhoc_metric_to_sa

* Remove redundant mutate_expression_label call

* Add missing db_engine_specs to adhoc_metric_to_sa

* Rename arg label_name to label in get_column_label()

* Rename label function and add docstring

* Remove redundant db_engine_spec args

* Rename col_label to label

* Remove get_column_name wrapper and make direct calls to db_engine_spec

* Remove unneeded db_engine_specs

* Rename sa_ vars to sqla_
This commit is contained in:
Ville Brofeldt 2018-09-04 08:49:58 +03:00 committed by Maxime Beauchemin
parent 8a4b1b7c25
commit 77fe9ef130
5 changed files with 75 additions and 107 deletions

View File

@ -393,6 +393,10 @@ Make sure the user has privileges to access and use all required
databases/schemas/tables/views/warehouses, as the Snowflake SQLAlchemy engine does
not test for user rights during engine creation.
*Note*: At the time of writing, there is a regression in the current stable version (1.1.2) of
snowflake-sqlalchemy package that causes problems when used with Superset. It is recommended to
use version 1.1.0 or try a newer version.
See `Snowflake SQLAlchemy <https://github.com/snowflakedb/snowflake-sqlalchemy>`_.
Caching

View File

@ -99,13 +99,13 @@ class TableColumn(Model, BaseColumn):
s for s in export_fields if s not in ('table_id',)]
export_parent = 'table'
@property
def sqla_col(self):
name = self.column_name
def get_sqla_col(self, label=None):
db_engine_spec = self.table.database.db_engine_spec
label = db_engine_spec.make_label_compatible(label if label else self.column_name)
if not self.expression:
col = column(self.column_name).label(name)
col = column(self.column_name).label(label)
else:
col = literal_column(self.expression).label(name)
col = literal_column(self.expression).label(label)
return col
@property
@ -113,7 +113,7 @@ class TableColumn(Model, BaseColumn):
return self.table
def get_time_filter(self, start_dttm, end_dttm):
col = self.sqla_col.label('__time')
col = self.get_sqla_col(label='__time')
l = [] # noqa: E741
if start_dttm:
l.append(col >= text(self.dttm_sql_literal(start_dttm)))
@ -231,10 +231,10 @@ class SqlMetric(Model, BaseMetric):
s for s in export_fields if s not in ('table_id', )])
export_parent = 'table'
@property
def sqla_col(self):
name = self.metric_name
return literal_column(self.expression).label(name)
def get_sqla_col(self, label=None):
db_engine_spec = self.table.database.db_engine_spec
label = db_engine_spec.make_label_compatible(label if label else self.metric_name)
return literal_column(self.expression).label(label)
@property
def perm(self):
@ -421,11 +421,10 @@ class SqlaTable(Model, BaseDatasource):
cols = {col.column_name: col for col in self.columns}
target_col = cols[column_name]
tp = self.get_template_processor()
db_engine_spec = self.database.db_engine_spec
qry = (
select([target_col.sqla_col])
.select_from(self.get_from_clause(tp, db_engine_spec))
select([target_col.get_sqla_col()])
.select_from(self.get_from_clause(tp))
.distinct()
)
if limit:
@ -474,7 +473,7 @@ class SqlaTable(Model, BaseDatasource):
tbl.schema = self.schema
return tbl
def get_from_clause(self, template_processor=None, db_engine_spec=None):
def get_from_clause(self, template_processor=None):
# Supporting arbitrary SQL statements in place of tables
if self.sql:
from_sql = self.sql
@ -484,7 +483,7 @@ class SqlaTable(Model, BaseDatasource):
return TextAsFrom(sa.text(from_sql), []).alias('expr_qry')
return self.get_sqla_table()
def adhoc_metric_to_sa(self, metric, cols):
def adhoc_metric_to_sqla(self, metric, cols):
"""
Turn an adhoc metric into a sqlalchemy column.
@ -493,22 +492,25 @@ class SqlaTable(Model, BaseDatasource):
:returns: The metric defined as a sqlalchemy column
:rtype: sqlalchemy.sql.column
"""
expressionType = metric.get('expressionType')
if expressionType == utils.ADHOC_METRIC_EXPRESSION_TYPES['SIMPLE']:
expression_type = metric.get('expressionType')
db_engine_spec = self.database.db_engine_spec
label = db_engine_spec.make_label_compatible(metric.get('label'))
if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SIMPLE']:
column_name = metric.get('column').get('column_name')
sa_column = column(column_name)
sqla_column = column(column_name)
table_column = cols.get(column_name)
if table_column:
sa_column = table_column.sqla_col
sqla_column = table_column.get_sqla_col()
sa_metric = self.sqla_aggregations[metric.get('aggregate')](sa_column)
sa_metric = sa_metric.label(metric.get('label'))
return sa_metric
elif expressionType == utils.ADHOC_METRIC_EXPRESSION_TYPES['SQL']:
sa_metric = literal_column(metric.get('sqlExpression'))
sa_metric = sa_metric.label(metric.get('label'))
return sa_metric
sqla_metric = self.sqla_aggregations[metric.get('aggregate')](sqla_column)
sqla_metric = sqla_metric.label(label)
return sqla_metric
elif expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SQL']:
sqla_metric = literal_column(metric.get('sqlExpression'))
sqla_metric = sqla_metric.label(label)
return sqla_metric
else:
return None
@ -566,15 +568,16 @@ class SqlaTable(Model, BaseDatasource):
metrics_exprs = []
for m in metrics:
if utils.is_adhoc_metric(m):
metrics_exprs.append(self.adhoc_metric_to_sa(m, cols))
metrics_exprs.append(self.adhoc_metric_to_sqla(m, cols))
elif m in metrics_dict:
metrics_exprs.append(metrics_dict.get(m).sqla_col)
metrics_exprs.append(metrics_dict.get(m).get_sqla_col())
else:
raise Exception(_("Metric '{}' is not valid".format(m)))
if metrics_exprs:
main_metric_expr = metrics_exprs[0]
else:
main_metric_expr = literal_column('COUNT(*)').label('ccount')
main_metric_expr = literal_column('COUNT(*)').label(
db_engine_spec.make_label_compatible('count'))
select_exprs = []
groupby_exprs = []
@ -585,8 +588,8 @@ class SqlaTable(Model, BaseDatasource):
inner_groupby_exprs = []
for s in groupby:
col = cols[s]
outer = col.sqla_col
inner = col.sqla_col.label(col.column_name + '__')
outer = col.get_sqla_col()
inner = col.get_sqla_col(col.column_name + '__')
groupby_exprs.append(outer)
select_exprs.append(outer)
@ -594,7 +597,7 @@ class SqlaTable(Model, BaseDatasource):
inner_select_exprs.append(inner)
elif columns:
for s in columns:
select_exprs.append(cols[s].sqla_col)
select_exprs.append(cols[s].get_sqla_col())
metrics_exprs = []
if granularity:
@ -618,7 +621,7 @@ class SqlaTable(Model, BaseDatasource):
select_exprs += metrics_exprs
qry = sa.select(select_exprs)
tbl = self.get_from_clause(template_processor, db_engine_spec)
tbl = self.get_from_clause(template_processor)
if not columns:
qry = qry.group_by(*groupby_exprs)
@ -638,9 +641,9 @@ class SqlaTable(Model, BaseDatasource):
target_column_is_numeric=col_obj.is_num,
is_list_target=is_list_target)
if op in ('in', 'not in'):
cond = col_obj.sqla_col.in_(eq)
cond = col_obj.get_sqla_col().in_(eq)
if '<NULL>' in eq:
cond = or_(cond, col_obj.sqla_col == None) # noqa
cond = or_(cond, col_obj.get_sqla_col() == None) # noqa
if op == 'not in':
cond = ~cond
where_clause_and.append(cond)
@ -648,23 +651,24 @@ class SqlaTable(Model, BaseDatasource):
if col_obj.is_num:
eq = utils.string_to_num(flt['val'])
if op == '==':
where_clause_and.append(col_obj.sqla_col == eq)
where_clause_and.append(col_obj.get_sqla_col() == eq)
elif op == '!=':
where_clause_and.append(col_obj.sqla_col != eq)
where_clause_and.append(col_obj.get_sqla_col() != eq)
elif op == '>':
where_clause_and.append(col_obj.sqla_col > eq)
where_clause_and.append(col_obj.get_sqla_col() > eq)
elif op == '<':
where_clause_and.append(col_obj.sqla_col < eq)
where_clause_and.append(col_obj.get_sqla_col() < eq)
elif op == '>=':
where_clause_and.append(col_obj.sqla_col >= eq)
where_clause_and.append(col_obj.get_sqla_col() >= eq)
elif op == '<=':
where_clause_and.append(col_obj.sqla_col <= eq)
where_clause_and.append(col_obj.get_sqla_col() <= eq)
elif op == 'LIKE':
where_clause_and.append(col_obj.sqla_col.like(eq))
where_clause_and.append(col_obj.get_sqla_col().like(eq))
elif op == 'IS NULL':
where_clause_and.append(col_obj.sqla_col == None) # noqa
where_clause_and.append(col_obj.get_sqla_col() == None) # noqa
elif op == 'IS NOT NULL':
where_clause_and.append(col_obj.sqla_col != None) # noqa
where_clause_and.append(
col_obj.get_sqla_col() != None) # noqa
if extras:
where = extras.get('where')
if where:
@ -686,7 +690,7 @@ class SqlaTable(Model, BaseDatasource):
for col, ascending in orderby:
direction = asc if ascending else desc
if utils.is_adhoc_metric(col):
col = self.adhoc_metric_to_sa(col, cols)
col = self.adhoc_metric_to_sqla(col, cols)
qry = qry.order_by(direction(col))
if row_limit:
@ -712,12 +716,12 @@ class SqlaTable(Model, BaseDatasource):
ob = inner_main_metric_expr
if timeseries_limit_metric:
if utils.is_adhoc_metric(timeseries_limit_metric):
ob = self.adhoc_metric_to_sa(timeseries_limit_metric, cols)
ob = self.adhoc_metric_to_sqla(timeseries_limit_metric, cols)
elif timeseries_limit_metric in metrics_dict:
timeseries_limit_metric = metrics_dict.get(
timeseries_limit_metric,
)
ob = timeseries_limit_metric.sqla_col
ob = timeseries_limit_metric.get_sqla_col()
else:
raise Exception(_("Metric '{}' is not valid".format(m)))
direction = desc if order_desc else asc
@ -762,7 +766,7 @@ class SqlaTable(Model, BaseDatasource):
group = []
for dimension in dimensions:
col_obj = cols.get(dimension)
group.append(col_obj.sqla_col == row[dimension])
group.append(col_obj.get_sqla_col() == row[dimension])
groups.append(and_(*group))
return or_(*groups)
@ -816,6 +820,7 @@ class SqlaTable(Model, BaseDatasource):
.filter(or_(TableColumn.column_name == col.name
for col in table.columns)))
dbcols = {dbcol.column_name: dbcol for dbcol in dbcols}
db_engine_spec = self.database.db_engine_spec
for col in table.columns:
try:
@ -848,6 +853,9 @@ class SqlaTable(Model, BaseDatasource):
))
if not self.main_dttm_col:
self.main_dttm_col = any_date_col
for metric in metrics:
metric.metric_name = db_engine_spec.mutate_expression_label(
metric.metric_name)
self.add_missing_metrics(metrics)
db.session.merge(self)
db.session.commit()

View File

@ -73,9 +73,7 @@ class SupersetDataFrame(object):
if cursor_description:
column_names = [col[0] for col in cursor_description]
case_sensitive = db_engine_spec.consistent_case_sensitivity
self.column_names = dedup(column_names,
case_sensitive=case_sensitive)
self.column_names = dedup(column_names)
data = data or []
self.df = (

View File

@ -35,7 +35,7 @@ import sqlalchemy as sqla
from sqlalchemy import select
from sqlalchemy.engine import create_engine
from sqlalchemy.engine.url import make_url
from sqlalchemy.sql import text
from sqlalchemy.sql import quoted_name, text
from sqlalchemy.sql.expression import TextAsFrom
import sqlparse
from tableschema import Table
@ -101,7 +101,7 @@ class BaseEngineSpec(object):
time_secondary_columns = False
inner_joins = True
allows_subquery = True
consistent_case_sensitivity = True # do results have same case as qry for col names?
force_column_alias_quotes = False
arraysize = None
@classmethod
@ -376,55 +376,15 @@ class BaseEngineSpec(object):
cursor.execute(query)
@classmethod
def adjust_df_column_names(cls, df, fd):
"""Based of fields in form_data, return dataframe with new column names
Usually sqla engines return column names whose case matches that of the
original query. For example:
SELECT 1 as col1, 2 as COL2, 3 as Col_3
will usually result in the following df.columns:
['col1', 'COL2', 'Col_3'].
For these engines there is no need to adjust the dataframe column names
(default behavior). However, some engines (at least Snowflake, Oracle and
Redshift) return column names with different case than in the original query,
usually all uppercase. For these the column names need to be adjusted to
correspond to the case of the fields specified in the form data for Viz
to work properly. This adjustment can be done here.
def make_label_compatible(cls, label):
"""
if cls.consistent_case_sensitivity:
return df
else:
return cls.align_df_col_names_with_form_data(df, fd)
@staticmethod
def align_df_col_names_with_form_data(df, fd):
"""Helper function to rename columns that have changed case during query.
Returns a dataframe where column names have been adjusted to correspond with
column names in form data (case insensitive). Examples:
dataframe: 'col1', form_data: 'col1' -> no change
dataframe: 'COL1', form_data: 'col1' -> dataframe column renamed: 'col1'
dataframe: 'col1', form_data: 'Col1' -> dataframe column renamed: 'Col1'
Return a sqlalchemy.sql.elements.quoted_name if the engine requires
quoting of aliases to ensure that select query and query results
have same case.
"""
columns = set()
lowercase_mapping = {}
metrics = utils.get_metric_names(fd.get('metrics', []))
groupby = fd.get('groupby', [])
other_cols = [utils.DTTM_ALIAS]
for col in metrics + groupby + other_cols:
columns.add(col)
lowercase_mapping[col.lower()] = col
rename_cols = {}
for col in df.columns:
if col not in columns:
orig_col = lowercase_mapping.get(col.lower())
if orig_col:
rename_cols[col] = orig_col
return df.rename(index=str, columns=rename_cols)
if cls.force_column_alias_quotes is True:
return quoted_name(label, True)
return label
@staticmethod
def mutate_expression_label(label):
@ -478,7 +438,8 @@ class PostgresEngineSpec(PostgresBaseEngineSpec):
class SnowflakeEngineSpec(PostgresBaseEngineSpec):
engine = 'snowflake'
consistent_case_sensitivity = False
force_column_alias_quotes = True
time_grain_functions = {
None: '{col}',
'PT1S': "DATE_TRUNC('SECOND', {col})",
@ -515,13 +476,13 @@ class VerticaEngineSpec(PostgresBaseEngineSpec):
class RedshiftEngineSpec(PostgresBaseEngineSpec):
engine = 'redshift'
consistent_case_sensitivity = False
force_column_alias_quotes = True
class OracleEngineSpec(PostgresBaseEngineSpec):
engine = 'oracle'
limit_method = LimitMethod.WRAP_SQL
consistent_case_sensitivity = False
force_column_alias_quotes = True
time_grain_functions = {
None: '{col}',
@ -545,6 +506,7 @@ class OracleEngineSpec(PostgresBaseEngineSpec):
class Db2EngineSpec(BaseEngineSpec):
engine = 'ibm_db_sa'
limit_method = LimitMethod.WRAP_SQL
force_column_alias_quotes = True
time_grain_functions = {
None: '{col}',
'PT1S': 'CAST({col} as TIMESTAMP)'

View File

@ -391,10 +391,6 @@ class BaseViz(object):
if query_obj and not is_loaded:
try:
df = self.get_df(query_obj)
if hasattr(self.datasource, 'database') and \
hasattr(self.datasource.database, 'db_engine_spec'):
db_engine_spec = self.datasource.database.db_engine_spec
df = db_engine_spec.adjust_df_column_names(df, self.form_data)
if self.status != utils.QueryStatus.FAILED:
stats_logger.incr('loaded_from_source')
is_loaded = True