[sql] Correct SQL parameter formatting (#5178)

This commit is contained in:
John Bodley 2018-07-21 12:01:26 -07:00 committed by GitHub
parent 6e7b5879be
commit 7fcc2af68f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 138 additions and 30 deletions

View File

@ -282,7 +282,7 @@ ignored-modules=numpy,pandas,alembic.op,sqlalchemy,alembic.context,flask_appbuil
# List of class names for which member attributes should not be checked (useful
# for classes with dynamically set attributes). This supports the use of
# qualified names.
ignored-classes=optparse.Values,thread._local,_thread._local,sqlalchemy.orm.scoping.scoped_session
ignored-classes=contextlib.closing,optparse.Values,thread._local,_thread._local,sqlalchemy.orm.scoping.scoped_session
# List of members which are set dynamically and missed by pylint inference
# system, and so shouldn't trigger E1101 when accessed. Python regular

View File

@ -12,7 +12,6 @@ from flask import escape, Markup
from flask_appbuilder import Model
from flask_babel import lazy_gettext as _
import pandas as pd
import six
import sqlalchemy as sa
from sqlalchemy import (
and_, asc, Boolean, Column, DateTime, desc, ForeignKey, Integer, or_,
@ -427,14 +426,8 @@ class SqlaTable(Model, BaseDatasource):
table=self, database=self.database, **kwargs)
def get_query_str(self, query_obj):
engine = self.database.get_sqla_engine()
qry = self.get_sqla_query(**query_obj)
sql = six.text_type(
qry.compile(
engine,
compile_kwargs={'literal_binds': True},
),
)
sql = self.database.compile_sqla_query(qry)
logging.info(sql)
sql = sqlparse.format(sql, reindent=True)
if query_obj['is_prequery']:

View File

@ -65,7 +65,6 @@ class BaseEngineSpec(object):
"""Abstract class for database engine specific configurations"""
engine = 'base' # str as defined in sqlalchemy.engine.engine
cursor_execute_kwargs = {}
time_grains = tuple()
time_groupby_inline = False
limit_method = LimitMethod.FORCE_LIMIT
@ -331,6 +330,10 @@ class BaseEngineSpec(object):
def normalize_column_name(column_name):
return column_name
@staticmethod
def execute(cursor, query, async=False):
cursor.execute(query)
class PostgresBaseEngineSpec(BaseEngineSpec):
""" Abstract class for Postgres 'like' databases """
@ -558,7 +561,6 @@ class SqliteEngineSpec(BaseEngineSpec):
class MySQLEngineSpec(BaseEngineSpec):
engine = 'mysql'
cursor_execute_kwargs = {'args': {}}
time_grains = (
Grain('Time Column', _('Time Column'), '{col}', None),
Grain('second', _('second'), 'DATE_ADD(DATE({col}), '
@ -639,7 +641,6 @@ class MySQLEngineSpec(BaseEngineSpec):
class PrestoEngineSpec(BaseEngineSpec):
engine = 'presto'
cursor_execute_kwargs = {'parameters': None}
time_grains = (
Grain('Time Column', _('Time Column'), '{col}', None),
@ -938,7 +939,6 @@ class HiveEngineSpec(PrestoEngineSpec):
"""Reuses PrestoEngineSpec functionality."""
engine = 'hive'
cursor_execute_kwargs = {'async': True}
# Scoping regex at class level to avoid recompiling
# 17/02/07 19:36:38 INFO ql.Driver: Total jobs = 5
@ -1230,6 +1230,10 @@ class HiveEngineSpec(PrestoEngineSpec):
configuration['hive.server2.proxy.user'] = username
return configuration
@staticmethod
def execute(cursor, query, async=False):
cursor.execute(query, async=async)
class MssqlEngineSpec(BaseEngineSpec):
engine = 'mssql'

View File

@ -0,0 +1,86 @@
"""remove double percents
Revision ID: 4451805bbaa1
Revises: afb7730f6a9c
Create Date: 2018-06-13 10:20:35.846744
"""
# revision identifiers, used by Alembic.
revision = '4451805bbaa1'
down_revision = 'bddc498dd179'
from alembic import op
import json
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Column, create_engine, ForeignKey, Integer, String, Text
from superset import db
Base = declarative_base()
class Slice(Base):
__tablename__ = 'slices'
id = Column(Integer, primary_key=True)
datasource_id = Column(Integer, ForeignKey('tables.id'))
datasource_type = Column(String(200))
params = Column(Text)
class Table(Base):
__tablename__ = 'tables'
id = Column(Integer, primary_key=True)
database_id = Column(Integer, ForeignKey('dbs.id'))
class Database(Base):
__tablename__ = 'dbs'
id = Column(Integer, primary_key=True)
sqlalchemy_uri = Column(String(1024))
def replace(source, target):
bind = op.get_bind()
session = db.Session(bind=bind)
query = (
session.query(Slice, Database)
.join(Table)
.join(Database)
.filter(Slice.datasource_type == 'table')
.all()
)
for slc, database in query:
try:
engine = create_engine(database.sqlalchemy_uri)
if engine.dialect.identifier_preparer._double_percents:
params = json.loads(slc.params)
if 'adhoc_filters' in params:
for filt in params['adhoc_filters']:
if 'sqlExpression' in filt:
filt['sqlExpression'] = (
filt['sqlExpression'].replace(source, target)
)
slc.params = json.dumps(params, sort_keys=True)
except Exception:
pass
session.commit()
session.close()
def upgrade():
replace('%%', '%')
def downgrade():
replace('%', '%%')

View File

@ -6,6 +6,7 @@ from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from contextlib import closing
from copy import copy, deepcopy
from datetime import datetime
import functools
@ -19,6 +20,7 @@ from flask_appbuilder.models.decorators import renders
from future.standard_library import install_aliases
import numpy
import pandas as pd
import six
import sqlalchemy as sqla
from sqlalchemy import (
Boolean, Column, create_engine, DateTime, ForeignKey, Integer,
@ -749,12 +751,7 @@ class Database(Model, AuditMixinNullable, ImportMixin):
def get_df(self, sql, schema):
sqls = [str(s).strip().strip(';') for s in sqlparse.parse(sql)]
eng = self.get_sqla_engine(schema=schema)
for i in range(len(sqls) - 1):
eng.execute(sqls[i])
df = pd.read_sql_query(sqls[-1], eng)
engine = self.get_sqla_engine(schema=schema)
def needs_conversion(df_series):
if df_series.empty:
@ -763,15 +760,35 @@ class Database(Model, AuditMixinNullable, ImportMixin):
return True
return False
for k, v in df.dtypes.items():
if v.type == numpy.object_ and needs_conversion(df[k]):
df[k] = df[k].apply(utils.json_dumps_w_dates)
return df
with closing(engine.raw_connection()) as conn:
with closing(conn.cursor()) as cursor:
for sql in sqls:
self.db_engine_spec.execute(cursor, sql)
df = pd.DataFrame.from_records(
data=list(cursor.fetchall()),
columns=[col_desc[0] for col_desc in cursor.description],
coerce_float=True,
)
for k, v in df.dtypes.items():
if v.type == numpy.object_ and needs_conversion(df[k]):
df[k] = df[k].apply(utils.json_dumps_w_dates)
return df
def compile_sqla_query(self, qry, schema=None):
eng = self.get_sqla_engine(schema=schema)
compiled = qry.compile(eng, compile_kwargs={'literal_binds': True})
return '{}'.format(compiled)
engine = self.get_sqla_engine(schema=schema)
sql = six.text_type(
qry.compile(
engine,
compile_kwargs={'literal_binds': True},
),
)
if engine.dialect.identifier_preparer._double_percents:
sql = sql.replace('%%', '%')
return sql
def select_star(
self, table_name, schema=None, limit=100, show_cols=False,

View File

@ -172,8 +172,7 @@ def execute_sql(
cursor = conn.cursor()
logging.info('Running query: \n{}'.format(executed_sql))
logging.info(query.executed_sql)
cursor.execute(query.executed_sql,
**db_engine_spec.cursor_execute_kwargs)
db_engine_spec.execute(cursor, query.executed_sql, async=True)
logging.info('Handling cursor')
db_engine_spec.handle_cursor(cursor, query, session)
logging.info('Fetching data: {}'.format(query.to_dict()))

View File

@ -436,6 +436,15 @@ class CoreTests(SupersetTestCase):
expected_data = csv.reader(
io.StringIO('first_name,last_name\nadmin, user\n'))
sql = "SELECT first_name FROM ab_user WHERE first_name LIKE '%admin%'"
client_id = '{}'.format(random.getrandbits(64))[:10]
self.run_sql(sql, client_id, raise_on_error=True)
resp = self.get_resp('/superset/csv/{}'.format(client_id))
data = csv.reader(io.StringIO(resp))
expected_data = csv.reader(
io.StringIO('first_name\nadmin\n'))
self.assertEqual(list(expected_data), list(data))
self.logout()

View File

@ -254,7 +254,7 @@ class SqlLabTests(SupersetTestCase):
'sql': """\
SELECT viz_type, count(1) as ccount
FROM slices
WHERE viz_type LIKE '%%a%%'
WHERE viz_type LIKE '%a%'
GROUP BY viz_type""",
'dbId': 1,
}

View File

@ -37,7 +37,7 @@ setenv =
SUPERSET_CONFIG = tests.superset_test_config
SUPERSET_HOME = {envtmpdir}
py27-mysql: SUPERSET__SQLALCHEMY_DATABASE_URI = mysql://mysqluser:mysqluserpassword@localhost/superset?charset=utf8
py34-mysql: SUPERSET__SQLALCHEMY_DATABASE_URI = mysql://mysqluser:mysqluserpassword@localhost/superset
py{34,36}-mysql: SUPERSET__SQLALCHEMY_DATABASE_URI = mysql://mysqluser:mysqluserpassword@localhost/superset
py{27,34,36}-postgres: SUPERSET__SQLALCHEMY_DATABASE_URI = postgresql+psycopg2://postgresuser:pguserpassword@localhost/superset
py{27,34,36}-sqlite: SUPERSET__SQLALCHEMY_DATABASE_URI = sqlite:////{envtmpdir}/superset.db
whitelist_externals =