From 7fcc2af68f79a8a78e1799feb80647ae90ac9370 Mon Sep 17 00:00:00 2001 From: John Bodley <4567245+john-bodley@users.noreply.github.com> Date: Sat, 21 Jul 2018 12:01:26 -0700 Subject: [PATCH] [sql] Correct SQL parameter formatting (#5178) --- .pylintrc | 2 +- superset/connectors/sqla/models.py | 9 +- superset/db_engine_specs.py | 12 ++- .../4451805bbaa1_remove_double_percents.py | 86 +++++++++++++++++++ superset/models/core.py | 43 +++++++--- superset/sql_lab.py | 3 +- tests/core_tests.py | 9 ++ tests/sqllab_tests.py | 2 +- tox.ini | 2 +- 9 files changed, 138 insertions(+), 30 deletions(-) create mode 100644 superset/migrations/versions/4451805bbaa1_remove_double_percents.py diff --git a/.pylintrc b/.pylintrc index 820637dbd0..016b04e367 100644 --- a/.pylintrc +++ b/.pylintrc @@ -282,7 +282,7 @@ ignored-modules=numpy,pandas,alembic.op,sqlalchemy,alembic.context,flask_appbuil # List of class names for which member attributes should not be checked (useful # for classes with dynamically set attributes). This supports the use of # qualified names. -ignored-classes=optparse.Values,thread._local,_thread._local,sqlalchemy.orm.scoping.scoped_session +ignored-classes=contextlib.closing,optparse.Values,thread._local,_thread._local,sqlalchemy.orm.scoping.scoped_session # List of members which are set dynamically and missed by pylint inference # system, and so shouldn't trigger E1101 when accessed. Python regular diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 3c5b18ea99..c86d4eadda 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -12,7 +12,6 @@ from flask import escape, Markup from flask_appbuilder import Model from flask_babel import lazy_gettext as _ import pandas as pd -import six import sqlalchemy as sa from sqlalchemy import ( and_, asc, Boolean, Column, DateTime, desc, ForeignKey, Integer, or_, @@ -427,14 +426,8 @@ class SqlaTable(Model, BaseDatasource): table=self, database=self.database, **kwargs) def get_query_str(self, query_obj): - engine = self.database.get_sqla_engine() qry = self.get_sqla_query(**query_obj) - sql = six.text_type( - qry.compile( - engine, - compile_kwargs={'literal_binds': True}, - ), - ) + sql = self.database.compile_sqla_query(qry) logging.info(sql) sql = sqlparse.format(sql, reindent=True) if query_obj['is_prequery']: diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index 97b8095d71..2b7454160d 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -65,7 +65,6 @@ class BaseEngineSpec(object): """Abstract class for database engine specific configurations""" engine = 'base' # str as defined in sqlalchemy.engine.engine - cursor_execute_kwargs = {} time_grains = tuple() time_groupby_inline = False limit_method = LimitMethod.FORCE_LIMIT @@ -331,6 +330,10 @@ class BaseEngineSpec(object): def normalize_column_name(column_name): return column_name + @staticmethod + def execute(cursor, query, async=False): + cursor.execute(query) + class PostgresBaseEngineSpec(BaseEngineSpec): """ Abstract class for Postgres 'like' databases """ @@ -558,7 +561,6 @@ class SqliteEngineSpec(BaseEngineSpec): class MySQLEngineSpec(BaseEngineSpec): engine = 'mysql' - cursor_execute_kwargs = {'args': {}} time_grains = ( Grain('Time Column', _('Time Column'), '{col}', None), Grain('second', _('second'), 'DATE_ADD(DATE({col}), ' @@ -639,7 +641,6 @@ class MySQLEngineSpec(BaseEngineSpec): class PrestoEngineSpec(BaseEngineSpec): engine = 'presto' - cursor_execute_kwargs = {'parameters': None} time_grains = ( Grain('Time Column', _('Time Column'), '{col}', None), @@ -938,7 +939,6 @@ class HiveEngineSpec(PrestoEngineSpec): """Reuses PrestoEngineSpec functionality.""" engine = 'hive' - cursor_execute_kwargs = {'async': True} # Scoping regex at class level to avoid recompiling # 17/02/07 19:36:38 INFO ql.Driver: Total jobs = 5 @@ -1230,6 +1230,10 @@ class HiveEngineSpec(PrestoEngineSpec): configuration['hive.server2.proxy.user'] = username return configuration + @staticmethod + def execute(cursor, query, async=False): + cursor.execute(query, async=async) + class MssqlEngineSpec(BaseEngineSpec): engine = 'mssql' diff --git a/superset/migrations/versions/4451805bbaa1_remove_double_percents.py b/superset/migrations/versions/4451805bbaa1_remove_double_percents.py new file mode 100644 index 0000000000..2e57b39d3f --- /dev/null +++ b/superset/migrations/versions/4451805bbaa1_remove_double_percents.py @@ -0,0 +1,86 @@ +"""remove double percents + +Revision ID: 4451805bbaa1 +Revises: afb7730f6a9c +Create Date: 2018-06-13 10:20:35.846744 + +""" + +# revision identifiers, used by Alembic. +revision = '4451805bbaa1' +down_revision = 'bddc498dd179' + + +from alembic import op +import json +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy import Column, create_engine, ForeignKey, Integer, String, Text + +from superset import db + +Base = declarative_base() + + +class Slice(Base): + __tablename__ = 'slices' + + id = Column(Integer, primary_key=True) + datasource_id = Column(Integer, ForeignKey('tables.id')) + datasource_type = Column(String(200)) + params = Column(Text) + + +class Table(Base): + __tablename__ = 'tables' + + id = Column(Integer, primary_key=True) + database_id = Column(Integer, ForeignKey('dbs.id')) + + +class Database(Base): + __tablename__ = 'dbs' + + id = Column(Integer, primary_key=True) + sqlalchemy_uri = Column(String(1024)) + + +def replace(source, target): + bind = op.get_bind() + session = db.Session(bind=bind) + + query = ( + session.query(Slice, Database) + .join(Table) + .join(Database) + .filter(Slice.datasource_type == 'table') + .all() + ) + + for slc, database in query: + try: + engine = create_engine(database.sqlalchemy_uri) + + if engine.dialect.identifier_preparer._double_percents: + params = json.loads(slc.params) + + if 'adhoc_filters' in params: + for filt in params['adhoc_filters']: + if 'sqlExpression' in filt: + filt['sqlExpression'] = ( + filt['sqlExpression'].replace(source, target) + ) + + slc.params = json.dumps(params, sort_keys=True) + except Exception: + pass + + session.commit() + session.close() + + +def upgrade(): + replace('%%', '%') + + +def downgrade(): + replace('%', '%%') diff --git a/superset/models/core.py b/superset/models/core.py index 13021e7dba..d50cf4f3d9 100644 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -6,6 +6,7 @@ from __future__ import division from __future__ import print_function from __future__ import unicode_literals +from contextlib import closing from copy import copy, deepcopy from datetime import datetime import functools @@ -19,6 +20,7 @@ from flask_appbuilder.models.decorators import renders from future.standard_library import install_aliases import numpy import pandas as pd +import six import sqlalchemy as sqla from sqlalchemy import ( Boolean, Column, create_engine, DateTime, ForeignKey, Integer, @@ -749,12 +751,7 @@ class Database(Model, AuditMixinNullable, ImportMixin): def get_df(self, sql, schema): sqls = [str(s).strip().strip(';') for s in sqlparse.parse(sql)] - eng = self.get_sqla_engine(schema=schema) - - for i in range(len(sqls) - 1): - eng.execute(sqls[i]) - - df = pd.read_sql_query(sqls[-1], eng) + engine = self.get_sqla_engine(schema=schema) def needs_conversion(df_series): if df_series.empty: @@ -763,15 +760,35 @@ class Database(Model, AuditMixinNullable, ImportMixin): return True return False - for k, v in df.dtypes.items(): - if v.type == numpy.object_ and needs_conversion(df[k]): - df[k] = df[k].apply(utils.json_dumps_w_dates) - return df + with closing(engine.raw_connection()) as conn: + with closing(conn.cursor()) as cursor: + for sql in sqls: + self.db_engine_spec.execute(cursor, sql) + df = pd.DataFrame.from_records( + data=list(cursor.fetchall()), + columns=[col_desc[0] for col_desc in cursor.description], + coerce_float=True, + ) + + for k, v in df.dtypes.items(): + if v.type == numpy.object_ and needs_conversion(df[k]): + df[k] = df[k].apply(utils.json_dumps_w_dates) + return df def compile_sqla_query(self, qry, schema=None): - eng = self.get_sqla_engine(schema=schema) - compiled = qry.compile(eng, compile_kwargs={'literal_binds': True}) - return '{}'.format(compiled) + engine = self.get_sqla_engine(schema=schema) + + sql = six.text_type( + qry.compile( + engine, + compile_kwargs={'literal_binds': True}, + ), + ) + + if engine.dialect.identifier_preparer._double_percents: + sql = sql.replace('%%', '%') + + return sql def select_star( self, table_name, schema=None, limit=100, show_cols=False, diff --git a/superset/sql_lab.py b/superset/sql_lab.py index b45cbbb64b..a626b68874 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -172,8 +172,7 @@ def execute_sql( cursor = conn.cursor() logging.info('Running query: \n{}'.format(executed_sql)) logging.info(query.executed_sql) - cursor.execute(query.executed_sql, - **db_engine_spec.cursor_execute_kwargs) + db_engine_spec.execute(cursor, query.executed_sql, async=True) logging.info('Handling cursor') db_engine_spec.handle_cursor(cursor, query, session) logging.info('Fetching data: {}'.format(query.to_dict())) diff --git a/tests/core_tests.py b/tests/core_tests.py index f1a01796b7..eb95f1f142 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -436,6 +436,15 @@ class CoreTests(SupersetTestCase): expected_data = csv.reader( io.StringIO('first_name,last_name\nadmin, user\n')) + sql = "SELECT first_name FROM ab_user WHERE first_name LIKE '%admin%'" + client_id = '{}'.format(random.getrandbits(64))[:10] + self.run_sql(sql, client_id, raise_on_error=True) + + resp = self.get_resp('/superset/csv/{}'.format(client_id)) + data = csv.reader(io.StringIO(resp)) + expected_data = csv.reader( + io.StringIO('first_name\nadmin\n')) + self.assertEqual(list(expected_data), list(data)) self.logout() diff --git a/tests/sqllab_tests.py b/tests/sqllab_tests.py index a3bb564dd8..51c336b21d 100644 --- a/tests/sqllab_tests.py +++ b/tests/sqllab_tests.py @@ -254,7 +254,7 @@ class SqlLabTests(SupersetTestCase): 'sql': """\ SELECT viz_type, count(1) as ccount FROM slices - WHERE viz_type LIKE '%%a%%' + WHERE viz_type LIKE '%a%' GROUP BY viz_type""", 'dbId': 1, } diff --git a/tox.ini b/tox.ini index 6f3c9fd6bd..464ab1b8ce 100644 --- a/tox.ini +++ b/tox.ini @@ -37,7 +37,7 @@ setenv = SUPERSET_CONFIG = tests.superset_test_config SUPERSET_HOME = {envtmpdir} py27-mysql: SUPERSET__SQLALCHEMY_DATABASE_URI = mysql://mysqluser:mysqluserpassword@localhost/superset?charset=utf8 - py34-mysql: SUPERSET__SQLALCHEMY_DATABASE_URI = mysql://mysqluser:mysqluserpassword@localhost/superset + py{34,36}-mysql: SUPERSET__SQLALCHEMY_DATABASE_URI = mysql://mysqluser:mysqluserpassword@localhost/superset py{27,34,36}-postgres: SUPERSET__SQLALCHEMY_DATABASE_URI = postgresql+psycopg2://postgresuser:pguserpassword@localhost/superset py{27,34,36}-sqlite: SUPERSET__SQLALCHEMY_DATABASE_URI = sqlite:////{envtmpdir}/superset.db whitelist_externals =