Escaping the user's SQL in the explore view (#3186)

* Escaping the user's SQL in the explore view

When executing SQL from SQL Lab, we use a lower level API to the
database which doesn't require escaping the SQL. When going through
the explore view, the stack chain leading to the same method may need
escaping depending on how the DBAPI driver is written, and that is the
case for Presto (and perhaps other drivers).

* Using regex to avoid doubling doubles
This commit is contained in:
Maxime Beauchemin 2017-07-27 09:47:31 -07:00 committed by GitHub
parent fb866a937b
commit 25c599d040
3 changed files with 19 additions and 15 deletions

View File

@ -285,10 +285,12 @@ class SqlaTable(Model, BaseDatasource):
""" """
cols = {col.column_name: col for col in self.columns} cols = {col.column_name: col for col in self.columns}
target_col = cols[column_name] target_col = cols[column_name]
tp = self.get_template_processor()
db_engine_spec = self.database.db_engine_spec
qry = ( qry = (
select([target_col.sqla_col]) select([target_col.sqla_col])
.select_from(self.get_from_clause()) .select_from(self.get_from_clause(tp, db_engine_spec))
.distinct(column_name) .distinct(column_name)
) )
if limit: if limit:
@ -322,7 +324,6 @@ class SqlaTable(Model, BaseDatasource):
) )
logging.info(sql) logging.info(sql)
sql = sqlparse.format(sql, reindent=True) sql = sqlparse.format(sql, reindent=True)
sql = self.database.db_engine_spec.sql_preprocessor(sql)
return sql return sql
def get_sqla_table(self): def get_sqla_table(self):
@ -331,12 +332,14 @@ class SqlaTable(Model, BaseDatasource):
tbl.schema = self.schema tbl.schema = self.schema
return tbl return tbl
def get_from_clause(self, template_processor=None): def get_from_clause(self, template_processor=None, db_engine_spec=None):
# Supporting arbitrary SQL statements in place of tables # Supporting arbitrary SQL statements in place of tables
if self.sql: if self.sql:
from_sql = self.sql from_sql = self.sql
if template_processor: if template_processor:
from_sql = template_processor.process_template(from_sql) from_sql = template_processor.process_template(from_sql)
if db_engine_spec:
from_sql = db_engine_spec.escape_sql(from_sql)
return TextAsFrom(sa.text(from_sql), []).alias('expr_qry') return TextAsFrom(sa.text(from_sql), []).alias('expr_qry')
return self.get_sqla_table() return self.get_sqla_table()
@ -367,13 +370,14 @@ class SqlaTable(Model, BaseDatasource):
'form_data': form_data, 'form_data': form_data,
} }
template_processor = self.get_template_processor(**template_kwargs) template_processor = self.get_template_processor(**template_kwargs)
db_engine_spec = self.database.db_engine_spec
# For backward compatibility # For backward compatibility
if granularity not in self.dttm_cols: if granularity not in self.dttm_cols:
granularity = self.main_dttm_col granularity = self.main_dttm_col
# Database spec supports join-free timeslot grouping # Database spec supports join-free timeslot grouping
time_groupby_inline = self.database.db_engine_spec.time_groupby_inline time_groupby_inline = db_engine_spec.time_groupby_inline
cols = {col.column_name: col for col in self.columns} cols = {col.column_name: col for col in self.columns}
metrics_dict = {m.metric_name: m for m in self.metrics} metrics_dict = {m.metric_name: m for m in self.metrics}
@ -428,7 +432,7 @@ class SqlaTable(Model, BaseDatasource):
groupby_exprs += [timestamp] groupby_exprs += [timestamp]
# Use main dttm column to support index with secondary dttm columns # Use main dttm column to support index with secondary dttm columns
if self.database.db_engine_spec.time_secondary_columns and \ if db_engine_spec.time_secondary_columns and \
self.main_dttm_col in self.dttm_cols and \ self.main_dttm_col in self.dttm_cols and \
self.main_dttm_col != dttm_col.column_name: self.main_dttm_col != dttm_col.column_name:
time_filters.append(cols[self.main_dttm_col]. time_filters.append(cols[self.main_dttm_col].
@ -438,7 +442,7 @@ class SqlaTable(Model, BaseDatasource):
select_exprs += metrics_exprs select_exprs += metrics_exprs
qry = sa.select(select_exprs) qry = sa.select(select_exprs)
tbl = self.get_from_clause(template_processor) tbl = self.get_from_clause(template_processor, db_engine_spec)
if not columns: if not columns:
qry = qry.group_by(*groupby_exprs) qry = qry.group_by(*groupby_exprs)

View File

@ -73,6 +73,11 @@ class BaseEngineSpec(object):
"""Returns engine-specific table metadata""" """Returns engine-specific table metadata"""
return {} return {}
@classmethod
def escape_sql(cls, sql):
"""Escapes the raw SQL"""
return sql
@classmethod @classmethod
def convert_dttm(cls, target_type, dttm): def convert_dttm(cls, target_type, dttm):
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S')) return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
@ -139,14 +144,6 @@ class BaseEngineSpec(object):
""" """
return uri return uri
@classmethod
def sql_preprocessor(cls, sql):
"""If the SQL needs to be altered prior to running it
For example Presto needs to double `%` characters
"""
return sql
@classmethod @classmethod
def patch(cls): def patch(cls):
pass pass
@ -399,6 +396,10 @@ class PrestoEngineSpec(BaseEngineSpec):
uri.database = database uri.database = database
return uri return uri
@classmethod
def escape_sql(cls, sql):
return re.sub(r'%%|%', "%%", sql)
@classmethod @classmethod
def convert_dttm(cls, target_type, dttm): def convert_dttm(cls, target_type, dttm):
tt = target_type.upper() tt = target_type.upper()

View File

@ -154,7 +154,6 @@ def execute_sql(ctask, query_id, return_results=True, store_results=False):
template_processor = get_template_processor( template_processor = get_template_processor(
database=database, query=query) database=database, query=query)
executed_sql = template_processor.process_template(executed_sql) executed_sql = template_processor.process_template(executed_sql)
executed_sql = db_engine_spec.sql_preprocessor(executed_sql)
except Exception as e: except Exception as e:
logging.exception(e) logging.exception(e)
msg = "Template rendering failed: " + utils.error_msg_from_exception(e) msg = "Template rendering failed: " + utils.error_msg_from_exception(e)