From 25c599d0400cf320ce5accf50ed88e76ccff5980 Mon Sep 17 00:00:00 2001 From: Maxime Beauchemin Date: Thu, 27 Jul 2017 09:47:31 -0700 Subject: [PATCH] Escaping the user's SQL in the explore view (#3186) * Escaping the user's SQL in the explore view When executing SQL from SQL Lab, we use a lower level API to the database which doesn't require escaping the SQL. When going through the explore view, the stack chain leading to the same method may need escaping depending on how the DBAPI driver is written, and that is the case for Presto (and perhaps other drivers). * Using regex to avoid doubling doubles --- superset/connectors/sqla/models.py | 16 ++++++++++------ superset/db_engine_specs.py | 17 +++++++++-------- superset/sql_lab.py | 1 - 3 files changed, 19 insertions(+), 15 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 147c667df0..0d06bfbde0 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -285,10 +285,12 @@ class SqlaTable(Model, BaseDatasource): """ cols = {col.column_name: col for col in self.columns} target_col = cols[column_name] + tp = self.get_template_processor() + db_engine_spec = self.database.db_engine_spec qry = ( select([target_col.sqla_col]) - .select_from(self.get_from_clause()) + .select_from(self.get_from_clause(tp, db_engine_spec)) .distinct(column_name) ) if limit: @@ -322,7 +324,6 @@ class SqlaTable(Model, BaseDatasource): ) logging.info(sql) sql = sqlparse.format(sql, reindent=True) - sql = self.database.db_engine_spec.sql_preprocessor(sql) return sql def get_sqla_table(self): @@ -331,12 +332,14 @@ class SqlaTable(Model, BaseDatasource): tbl.schema = self.schema return tbl - def get_from_clause(self, template_processor=None): + def get_from_clause(self, template_processor=None, db_engine_spec=None): # Supporting arbitrary SQL statements in place of tables if self.sql: from_sql = self.sql if template_processor: from_sql = template_processor.process_template(from_sql) + if db_engine_spec: + from_sql = db_engine_spec.escape_sql(from_sql) return TextAsFrom(sa.text(from_sql), []).alias('expr_qry') return self.get_sqla_table() @@ -367,13 +370,14 @@ class SqlaTable(Model, BaseDatasource): 'form_data': form_data, } template_processor = self.get_template_processor(**template_kwargs) + db_engine_spec = self.database.db_engine_spec # For backward compatibility if granularity not in self.dttm_cols: granularity = self.main_dttm_col # Database spec supports join-free timeslot grouping - time_groupby_inline = self.database.db_engine_spec.time_groupby_inline + time_groupby_inline = db_engine_spec.time_groupby_inline cols = {col.column_name: col for col in self.columns} metrics_dict = {m.metric_name: m for m in self.metrics} @@ -428,7 +432,7 @@ class SqlaTable(Model, BaseDatasource): groupby_exprs += [timestamp] # Use main dttm column to support index with secondary dttm columns - if self.database.db_engine_spec.time_secondary_columns and \ + if db_engine_spec.time_secondary_columns and \ self.main_dttm_col in self.dttm_cols and \ self.main_dttm_col != dttm_col.column_name: time_filters.append(cols[self.main_dttm_col]. @@ -438,7 +442,7 @@ class SqlaTable(Model, BaseDatasource): select_exprs += metrics_exprs qry = sa.select(select_exprs) - tbl = self.get_from_clause(template_processor) + tbl = self.get_from_clause(template_processor, db_engine_spec) if not columns: qry = qry.group_by(*groupby_exprs) diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index b460226c1a..d08f2a8feb 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -73,6 +73,11 @@ class BaseEngineSpec(object): """Returns engine-specific table metadata""" return {} + @classmethod + def escape_sql(cls, sql): + """Escapes the raw SQL""" + return sql + @classmethod def convert_dttm(cls, target_type, dttm): return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S')) @@ -139,14 +144,6 @@ class BaseEngineSpec(object): """ return uri - @classmethod - def sql_preprocessor(cls, sql): - """If the SQL needs to be altered prior to running it - - For example Presto needs to double `%` characters - """ - return sql - @classmethod def patch(cls): pass @@ -399,6 +396,10 @@ class PrestoEngineSpec(BaseEngineSpec): uri.database = database return uri + @classmethod + def escape_sql(cls, sql): + return re.sub(r'%%|%', "%%", sql) + @classmethod def convert_dttm(cls, target_type, dttm): tt = target_type.upper() diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 4b0bd863bc..638b29abbe 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -154,7 +154,6 @@ def execute_sql(ctask, query_id, return_results=True, store_results=False): template_processor = get_template_processor( database=database, query=query) executed_sql = template_processor.process_template(executed_sql) - executed_sql = db_engine_spec.sql_preprocessor(executed_sql) except Exception as e: logging.exception(e) msg = "Template rendering failed: " + utils.error_msg_from_exception(e)