[SQL Lab] Allow running multiple statements (#6112)

* Allow running multiple statements from SQL Lab

* fix tests

* More tests

* merge heads

* fix heads
This commit is contained in:
Maxime Beauchemin 2018-12-22 10:28:22 -08:00 committed by Beto Dealmeida
parent 6e942c9fb3
commit d427db0a8b
19 changed files with 357 additions and 205 deletions

View File

@ -35,7 +35,7 @@ const defaultProps = {
const SEARCH_HEIGHT = 46;
const LOADING_STYLES = { position: 'relative', height: 50 };
const LOADING_STYLES = { position: 'relative', minHeight: 100 };
export default class ResultSet extends React.PureComponent {
constructor(props) {
@ -231,11 +231,19 @@ export default class ResultSet extends React.PureComponent {
</Button>
);
}
const progressMsg = query && query.extra && query.extra.progress ? query.extra.progress : null;
return (
<div style={LOADING_STYLES}>
<div>
{!progressBar && <Loading position="normal" />}
</div>
<QueryStateLabel query={query} />
{!progressBar && <Loading />}
{progressBar}
<div>
{progressMsg && <Alert bsStyle="success">{progressMsg}</Alert>}
</div>
<div>
{progressBar}
</div>
<div>
{trackingUrl}
</div>

View File

@ -1,3 +1,4 @@
@import "../../stylesheets/less/cosmo/variables.less";
body {
overflow: hidden;
}
@ -168,8 +169,8 @@ div.Workspace {
}
.Resizer {
background: #000;
opacity: .2;
background: @brand-primary;
opacity: 0.5;
z-index: 1;
-moz-box-sizing: border-box;
-webkit-box-sizing: border-box;
@ -180,23 +181,24 @@ div.Workspace {
}
.Resizer:hover {
-webkit-transition: all 2s ease;
transition: all 2s ease;
-webkit-transition: all 0.3s ease;
transition: all 0.3s ease;
opacity: 0.3;
}
.Resizer.horizontal {
height: 10px;
margin: -5px 0;
border-top: 5px solid rgba(255, 255, 255, 0);
border-bottom: 5px solid rgba(255, 255, 255, 0);
border-top: 5px solid transparent;
border-bottom: 5px solid transparent;
cursor: row-resize;
width: 100%;
padding: 1px;
}
.Resizer.horizontal:hover {
border-top: 5px solid rgba(0, 0, 0, 0.5);
border-bottom: 5px solid rgba(0, 0, 0, 0.5);
border-top: 5px solid @brand-primary;
border-bottom: 5px solid @brand-primary;
}
.Resizer.vertical {

View File

@ -3,27 +3,34 @@ import PropTypes from 'prop-types';
const propTypes = {
size: PropTypes.number,
position: PropTypes.oneOf(['floating', 'normal']),
};
const defaultProps = {
size: 50,
position: 'floating',
};
export default function Loading({ size }) {
const FLOATING_STYLE = {
padding: 0,
margin: 0,
position: 'absolute',
left: '50%',
top: '50%',
transform: 'translate(-50%, -50%)',
};
export default function Loading({ size, position }) {
const style = position === 'floating' ? FLOATING_STYLE : {};
const styleWithWidth = {
...style,
size,
};
return (
<img
className="loading"
alt="Loading..."
src="/static/assets/images/loading.gif"
style={{
width: Math.min(size, 50),
// height is auto
padding: 0,
margin: 0,
position: 'absolute',
left: '50%',
top: '50%',
transform: 'translate(-50%, -50%)',
}}
style={styleWithWidth}
/>
);
}

View File

@ -148,18 +148,18 @@ class BaseEngineSpec(object):
)
return database.compile_sqla_query(qry)
elif LimitMethod.FORCE_LIMIT:
parsed_query = sql_parse.SupersetQuery(sql)
parsed_query = sql_parse.ParsedQuery(sql)
sql = parsed_query.get_query_with_new_limit(limit)
return sql
@classmethod
def get_limit_from_sql(cls, sql):
parsed_query = sql_parse.SupersetQuery(sql)
parsed_query = sql_parse.ParsedQuery(sql)
return parsed_query.limit
@classmethod
def get_query_with_new_limit(cls, sql, limit):
parsed_query = sql_parse.SupersetQuery(sql)
parsed_query = sql_parse.ParsedQuery(sql)
return parsed_query.get_query_with_new_limit(limit)
@staticmethod

View File

@ -0,0 +1,21 @@
"""Add extra column to Query
Revision ID: 0b1f1ab473c0
Revises: 55e910a74826
Create Date: 2018-11-05 08:42:56.181012
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '0b1f1ab473c0'
down_revision = '55e910a74826'
def upgrade():
op.add_column('query', sa.Column('extra_json', sa.Text(), nullable=True))
def downgrade():
op.drop_column('query', 'extra_json')

View File

@ -1,22 +0,0 @@
"""empty message
Revision ID: a7ca4a272e0a
Revises: ('3e1b21cd94a4', 'cefabc8f7d38')
Create Date: 2018-12-20 21:53:05.719149
"""
# revision identifiers, used by Alembic.
revision = 'a7ca4a272e0a'
down_revision = ('3e1b21cd94a4', 'cefabc8f7d38')
from alembic import op
import sqlalchemy as sa
def upgrade():
pass
def downgrade():
pass

View File

@ -0,0 +1,18 @@
"""empty message
Revision ID: de021a1ca60d
Revises: ('0b1f1ab473c0', 'cefabc8f7d38')
Create Date: 2018-12-18 22:45:55.783083
"""
# revision identifiers, used by Alembic.
revision = 'de021a1ca60d'
down_revision = ('0b1f1ab473c0', 'cefabc8f7d38', '3e1b21cd94a4')
def upgrade():
pass
def downgrade():
pass

View File

@ -294,3 +294,23 @@ class QueryResult(object):
self.duration = duration
self.status = status
self.error_message = error_message
class ExtraJSONMixin:
"""Mixin to add an `extra` column (JSON) and utility methods"""
extra_json = sa.Column(sa.Text, default='{}')
@property
def extra(self):
try:
return json.loads(self.extra_json)
except Exception:
return {}
def set_extra_json(self, d):
self.extra_json = json.dumps(d)
def set_extra_json_key(self, key, value):
extra = self.extra
extra[key] = value
self.extra_json = json.dumps(extra)

View File

@ -12,12 +12,15 @@ from sqlalchemy import (
from sqlalchemy.orm import backref, relationship
from superset import security_manager
from superset.models.helpers import AuditMixinNullable
from superset.models.helpers import AuditMixinNullable, ExtraJSONMixin
from superset.utils.core import QueryStatus, user_label
class Query(Model):
"""ORM model for SQL query"""
class Query(Model, ExtraJSONMixin):
"""ORM model for SQL query
Now that SQL Lab support multi-statement execution, an entry in this
table may represent multiple SQL statements executed sequentially"""
__tablename__ = 'query'
id = Column(Integer, primary_key=True)
@ -105,6 +108,7 @@ class Query(Model):
'limit_reached': self.limit_reached,
'resultsKey': self.results_key,
'trackingUrl': self.tracking_url,
'extra': self.extra,
}
@property

View File

@ -165,7 +165,7 @@ class SupersetSecurityManager(SecurityManager):
database, table_name, schema=table_schema)
def rejected_datasources(self, sql, database, schema):
superset_query = sql_parse.SupersetQuery(sql)
superset_query = sql_parse.ParsedQuery(sql)
return [
t for t in superset_query.tables if not
self.datasource_access_by_fullname(database, t, schema)]

View File

@ -1,4 +1,5 @@
# pylint: disable=C,R,W
from contextlib import closing
from datetime import datetime
import logging
from time import sleep
@ -6,6 +7,7 @@ import uuid
from celery.exceptions import SoftTimeLimitExceeded
from contextlib2 import contextmanager
from flask_babel import lazy_gettext as _
import simplejson as json
import sqlalchemy
from sqlalchemy.orm import sessionmaker
@ -13,14 +15,15 @@ from sqlalchemy.pool import NullPool
from superset import app, dataframe, db, results_backend, security_manager
from superset.models.sql_lab import Query
from superset.sql_parse import SupersetQuery
from superset.sql_parse import ParsedQuery
from superset.tasks.celery_app import app as celery_app
from superset.utils.core import (
json_iso_dttm_ser,
now_as_float,
QueryStatus,
zlib_compress,
)
from superset.utils.dates import now_as_float
from superset.utils.decorators import stats_timing
config = app.config
stats_logger = config.get('STATS_LOGGER')
@ -32,6 +35,31 @@ class SqlLabException(Exception):
pass
class SqlLabSecurityException(SqlLabException):
pass
class SqlLabTimeoutException(SqlLabException):
pass
def handle_query_error(msg, query, session, payload=None):
"""Local method handling error while processing the SQL"""
payload = payload or {}
troubleshooting_link = config['TROUBLESHOOTING_LINK']
query.error_message = msg
query.status = QueryStatus.FAILED
query.tmp_table_name = None
session.commit()
payload.update({
'status': query.status,
'error': msg,
})
if troubleshooting_link:
payload['link'] = troubleshooting_link
return payload
def get_query(query_id, session, retry_count=5):
"""attemps to get the query and retry if it cannot"""
query = None
@ -86,102 +114,52 @@ def get_sql_results(
with session_scope(not ctask.request.called_directly) as session:
try:
return execute_sql(
return execute_sql_statements(
ctask, query_id, rendered_query, return_results, store_results, user_name,
session=session, start_time=start_time)
except Exception as e:
logging.exception(e)
stats_logger.incr('error_sqllab_unhandled')
query = get_query(query_id, session)
query.error_message = str(e)
query.status = QueryStatus.FAILED
query.tmp_table_name = None
session.commit()
raise
return handle_query_error(str(e), query, session)
def execute_sql(
ctask, query_id, rendered_query, return_results=True, store_results=False,
user_name=None, session=None, start_time=None,
):
"""Executes the sql query returns the results."""
if store_results and start_time:
# only asynchronous queries
stats_logger.timing(
'sqllab.query.time_pending', now_as_float() - start_time)
query = get_query(query_id, session)
payload = dict(query_id=query_id)
def execute_sql_statement(
sql_statement, query, user_name, session,
cursor, return_results=False):
"""Executes a single SQL statement"""
database = query.database
db_engine_spec = database.db_engine_spec
db_engine_spec.patch()
def handle_error(msg):
"""Local method handling error while processing the SQL"""
troubleshooting_link = config['TROUBLESHOOTING_LINK']
query.error_message = msg
query.status = QueryStatus.FAILED
query.tmp_table_name = None
session.commit()
payload.update({
'status': query.status,
'error': msg,
})
if troubleshooting_link:
payload['link'] = troubleshooting_link
return payload
if store_results and not results_backend:
return handle_error("Results backend isn't configured.")
# Limit enforced only for retrieving the data, not for the CTA queries.
superset_query = SupersetQuery(rendered_query)
executed_sql = superset_query.stripped()
parsed_query = ParsedQuery(sql_statement)
sql = parsed_query.stripped()
SQL_MAX_ROWS = app.config.get('SQL_MAX_ROW')
if not superset_query.is_readonly() and not database.allow_dml:
return handle_error(
'Only `SELECT` statements are allowed against this database')
if not parsed_query.is_readonly() and not database.allow_dml:
raise SqlLabSecurityException(
_('Only `SELECT` statements are allowed against this database'))
if query.select_as_cta:
if not superset_query.is_select():
return handle_error(
if not parsed_query.is_select():
raise SqlLabException(_(
'Only `SELECT` statements can be used with the CREATE TABLE '
'feature.')
'feature.'))
if not query.tmp_table_name:
start_dttm = datetime.fromtimestamp(query.start_time)
query.tmp_table_name = 'tmp_{}_table_{}'.format(
query.user_id, start_dttm.strftime('%Y_%m_%d_%H_%M_%S'))
executed_sql = superset_query.as_create_table(query.tmp_table_name)
sql = parsed_query.as_create_table(query.tmp_table_name)
query.select_as_cta_used = True
if superset_query.is_select():
if parsed_query.is_select():
if SQL_MAX_ROWS and (not query.limit or query.limit > SQL_MAX_ROWS):
query.limit = SQL_MAX_ROWS
if query.limit:
executed_sql = database.apply_limit_to_sql(executed_sql, query.limit)
sql = database.apply_limit_to_sql(sql, query.limit)
# Hook to allow environment-specific mutation (usually comments) to the SQL
SQL_QUERY_MUTATOR = config.get('SQL_QUERY_MUTATOR')
if SQL_QUERY_MUTATOR:
executed_sql = SQL_QUERY_MUTATOR(
executed_sql, user_name, security_manager, database)
sql = SQL_QUERY_MUTATOR(sql, user_name, security_manager, database)
query.executed_sql = executed_sql
query.status = QueryStatus.RUNNING
query.start_running_time = now_as_float()
session.merge(query)
session.commit()
logging.info("Set query to 'running'")
conn = None
try:
engine = database.get_sqla_engine(
schema=query.schema,
nullpool=True,
user_name=user_name,
)
conn = engine.raw_connection()
cursor = conn.cursor()
logging.info('Running query: \n{}'.format(executed_sql))
logging.info(query.executed_sql)
query_start_time = now_as_float()
if log_query:
log_query(
query.database.sqlalchemy_uri,
@ -191,56 +169,102 @@ def execute_sql(
__name__,
security_manager,
)
db_engine_spec.execute(cursor, query.executed_sql, async_=True)
logging.info('Handling cursor')
db_engine_spec.handle_cursor(cursor, query, session)
logging.info('Fetching data: {}'.format(query.to_dict()))
stats_logger.timing(
'sqllab.query.time_executing_query',
now_as_float() - query_start_time)
fetching_start_time = now_as_float()
data = db_engine_spec.fetch_data(cursor, query.limit)
stats_logger.timing(
'sqllab.query.time_fetching_results',
now_as_float() - fetching_start_time)
query.executed_sql = sql
with stats_timing('sqllab.query.time_executing_query', stats_logger):
logging.info('Running query: \n{}'.format(sql))
db_engine_spec.execute(cursor, sql, async_=True)
logging.info('Handling cursor')
db_engine_spec.handle_cursor(cursor, query, session)
with stats_timing('sqllab.query.time_fetching_results', stats_logger):
logging.debug('Fetching data for query object: {}'.format(query.to_dict()))
data = db_engine_spec.fetch_data(cursor, query.limit)
except SoftTimeLimitExceeded as e:
logging.exception(e)
if conn is not None:
conn.close()
return handle_error(
raise SqlLabTimeoutException(
"SQL Lab timeout. This environment's policy is to kill queries "
'after {} seconds.'.format(SQLLAB_TIMEOUT))
except Exception as e:
logging.exception(e)
if conn is not None:
conn.close()
return handle_error(db_engine_spec.extract_error_message(e))
raise SqlLabException(db_engine_spec.extract_error_message(e))
logging.info('Fetching cursor description')
logging.debug('Fetching cursor description')
cursor_description = cursor.description
if conn is not None:
conn.commit()
conn.close()
return dataframe.SupersetDataFrame(data, cursor_description, db_engine_spec)
if query.status == QueryStatus.STOPPED:
return handle_error('The query has been stopped')
cdf = dataframe.SupersetDataFrame(data, cursor_description, db_engine_spec)
def execute_sql_statements(
ctask, query_id, rendered_query, return_results=True, store_results=False,
user_name=None, session=None, start_time=None,
):
"""Executes the sql query returns the results."""
if store_results and start_time:
# only asynchronous queries
stats_logger.timing(
'sqllab.query.time_pending', now_as_float() - start_time)
query = get_query(query_id, session)
payload = dict(query_id=query_id)
database = query.database
db_engine_spec = database.db_engine_spec
db_engine_spec.patch()
if store_results and not results_backend:
raise SqlLabException("Results backend isn't configured.")
# Breaking down into multiple statements
parsed_query = ParsedQuery(rendered_query)
statements = parsed_query.get_statements()
logging.info(f'Executing {len(statements)} statement(s)')
logging.info("Set query to 'running'")
query.status = QueryStatus.RUNNING
query.start_running_time = now_as_float()
engine = database.get_sqla_engine(
schema=query.schema,
nullpool=True,
user_name=user_name,
)
# Sharing a single connection and cursor across the
# execution of all statements (if many)
with closing(engine.raw_connection()) as conn:
with closing(conn.cursor()) as cursor:
statement_count = len(statements)
for i, statement in enumerate(statements):
# TODO CHECK IF STOPPED
msg = f'Running statement {i+1} out of {statement_count}'
logging.info(msg)
query.set_extra_json_key('progress', msg)
session.commit()
is_last_statement = i == len(statements) - 1
try:
cdf = execute_sql_statement(
statement, query, user_name, session, cursor,
return_results=is_last_statement and return_results)
msg = f'Running statement {i+1} out of {statement_count}'
except Exception as e:
msg = str(e)
if statement_count > 1:
msg = f'[Statement {i+1} out of {statement_count}] ' + msg
payload = handle_query_error(msg, query, session, payload)
return payload
# Success, updating the query entry in database
query.rows = cdf.size
query.progress = 100
query.set_extra_json_key('progress', None)
query.status = QueryStatus.SUCCESS
if query.select_as_cta:
query.select_sql = '{}'.format(
database.select_star(
query.tmp_table_name,
limit=query.limit,
schema=database.force_ctas_schema,
show_cols=False,
latest_partition=False))
query.select_sql = database.select_star(
query.tmp_table_name,
limit=query.limit,
schema=database.force_ctas_schema,
show_cols=False,
latest_partition=False)
query.end_time = now_as_float()
session.merge(query)
session.flush()
session.commit()
payload.update({
'status': query.status,
@ -248,21 +272,18 @@ def execute_sql(
'columns': cdf.columns if cdf.columns else [],
'query': query.to_dict(),
})
if store_results:
key = '{}'.format(uuid.uuid4())
logging.info('Storing results in results backend, key: {}'.format(key))
write_to_results_backend_start = now_as_float()
json_payload = json.dumps(
payload, default=json_iso_dttm_ser, ignore_nan=True)
cache_timeout = database.cache_timeout
if cache_timeout is None:
cache_timeout = config.get('CACHE_DEFAULT_TIMEOUT', 0)
results_backend.set(key, zlib_compress(json_payload), cache_timeout)
key = str(uuid.uuid4())
logging.info(f'Storing results in results backend, key: {key}')
with stats_timing('sqllab.query.results_backend_write', stats_logger):
json_payload = json.dumps(
payload, default=json_iso_dttm_ser, ignore_nan=True)
cache_timeout = database.cache_timeout
if cache_timeout is None:
cache_timeout = config.get('CACHE_DEFAULT_TIMEOUT', 0)
results_backend.set(key, zlib_compress(json_payload), cache_timeout)
query.results_key = key
stats_logger.timing(
'sqllab.query.results_backend_write',
now_as_float() - write_to_results_backend_start)
session.merge(query)
session.commit()
if return_results:

View File

@ -10,13 +10,12 @@ ON_KEYWORD = 'ON'
PRECEDES_TABLE_NAME = {'FROM', 'JOIN', 'DESC', 'DESCRIBE', 'WITH'}
class SupersetQuery(object):
class ParsedQuery(object):
def __init__(self, sql_statement):
self.sql = sql_statement
self._table_names = set()
self._alias_names = set()
self._limit = None
# TODO: multistatement support
logging.info('Parsing with sqlparse statement {}'.format(self.sql))
self._parsed = sqlparse.parse(self.sql)
@ -37,7 +36,7 @@ class SupersetQuery(object):
return self._parsed[0].get_type() == 'SELECT'
def is_explain(self):
return self.sql.strip().upper().startswith('EXPLAIN')
return self.stripped().upper().startswith('EXPLAIN')
def is_readonly(self):
"""Pessimistic readonly, 100% sure statement won't mutate anything"""
@ -46,6 +45,16 @@ class SupersetQuery(object):
def stripped(self):
return self.sql.strip(' \t\n;')
def get_statements(self):
"""Returns a list of SQL statements as strings, stripped"""
statements = []
for statement in self._parsed:
if statement:
sql = str(statement).strip(' \n;\t')
if sql:
statements.append(sql)
return statements
@staticmethod
def __precedes_table_name(token_value):
for keyword in PRECEDES_TABLE_NAME:

View File

@ -34,19 +34,18 @@ import pandas as pd
import parsedatetime
from past.builtins import basestring
from pydruid.utils.having import Having
import pytz
import sqlalchemy as sa
from sqlalchemy import event, exc, select, Text
from sqlalchemy.dialects.mysql import MEDIUMTEXT
from sqlalchemy.types import TEXT, TypeDecorator
from superset.exceptions import SupersetException, SupersetTimeoutException
from superset.utils.dates import datetime_to_epoch, EPOCH
logging.getLogger('MARKDOWN').setLevel(logging.INFO)
PY3K = sys.version_info >= (3, 0)
EPOCH = datetime(1970, 1, 1)
DTTM_ALIAS = '__timestamp'
ADHOC_METRIC_EXPRESSION_TYPES = {
'SIMPLE': 'SIMPLE',
@ -357,18 +356,6 @@ def pessimistic_json_iso_dttm_ser(obj):
return json_iso_dttm_ser(obj, pessimistic=True)
def datetime_to_epoch(dttm):
if dttm.tzinfo:
dttm = dttm.replace(tzinfo=pytz.utc)
epoch_with_tz = pytz.utc.localize(EPOCH)
return (dttm - epoch_with_tz).total_seconds() * 1000
return (dttm - EPOCH).total_seconds() * 1000
def now_as_float():
return datetime_to_epoch(datetime.utcnow())
def json_int_dttm_ser(obj):
"""json serializer that deals with dates"""
val = base_json_conv(obj)

17
superset/utils/dates.py Normal file
View File

@ -0,0 +1,17 @@
from datetime import datetime
import pytz
EPOCH = datetime(1970, 1, 1)
def datetime_to_epoch(dttm):
if dttm.tzinfo:
dttm = dttm.replace(tzinfo=pytz.utc)
epoch_with_tz = pytz.utc.localize(EPOCH)
return (dttm - epoch_with_tz).total_seconds() * 1000
return (dttm - EPOCH).total_seconds() * 1000
def now_as_float():
return datetime_to_epoch(datetime.utcnow())

View File

@ -0,0 +1,15 @@
from contextlib2 import contextmanager
from superset.utils.dates import now_as_float
@contextmanager
def stats_timing(stats_key, stats_logger):
"""Provide a transactional scope around a series of operations."""
start_ts = now_as_float()
try:
yield start_ts
except Exception as e:
raise e
finally:
stats_logger.timing(stats_key, now_as_float() - start_ts)

View File

@ -39,9 +39,10 @@ from superset.legacy import cast_form_data, update_time_range
import superset.models.core as models
from superset.models.sql_lab import Query
from superset.models.user_attributes import UserAttribute
from superset.sql_parse import SupersetQuery
from superset.sql_parse import ParsedQuery
from superset.utils import core as utils
from superset.utils import dashboard_import_export
from superset.utils.dates import now_as_float
from .base import (
api, BaseSupersetView,
check_ownership,
@ -2244,7 +2245,7 @@ class Superset(BaseSupersetView):
table.schema = data.get('schema')
table.template_params = data.get('templateParams')
table.is_sqllab_view = True
q = SupersetQuery(data.get('sql'))
q = ParsedQuery(data.get('sql'))
table.sql = q.stripped()
db.session.add(table)
cols = []
@ -2390,11 +2391,11 @@ class Superset(BaseSupersetView):
if not results_backend:
return json_error_response("Results backend isn't configured")
read_from_results_backend_start = utils.now_as_float()
read_from_results_backend_start = now_as_float()
blob = results_backend.get(key)
stats_logger.timing(
'sqllab.query.results_backend_read',
utils.now_as_float() - read_from_results_backend_start,
now_as_float() - read_from_results_backend_start,
)
if not blob:
return json_error_response(
@ -2488,7 +2489,7 @@ class Superset(BaseSupersetView):
sql=sql,
schema=schema,
select_as_cta=request.form.get('select_as_cta') == 'true',
start_time=utils.now_as_float(),
start_time=now_as_float(),
tab_name=request.form.get('tab'),
status=QueryStatus.PENDING if async_ else QueryStatus.RUNNING,
sql_editor_id=request.form.get('sql_editor_id'),
@ -2525,7 +2526,7 @@ class Superset(BaseSupersetView):
return_results=False,
store_results=not query.select_as_cta,
user_name=g.user.username if g.user else None,
start_time=utils.now_as_float())
start_time=now_as_float())
except Exception as e:
logging.exception(e)
msg = _(

View File

@ -4,13 +4,12 @@ import subprocess
import time
import unittest
import pandas as pd
from past.builtins import basestring
from superset import app, db
from superset.models.helpers import QueryStatus
from superset.models.sql_lab import Query
from superset.sql_parse import SupersetQuery
from superset.sql_parse import ParsedQuery
from superset.utils.core import get_main_database
from .base_tests import SupersetTestCase
@ -33,7 +32,7 @@ class UtilityFunctionTests(SupersetTestCase):
# TODO(bkyryliuk): support more cases in CTA function.
def test_create_table_as(self):
q = SupersetQuery('SELECT * FROM outer_space;')
q = ParsedQuery('SELECT * FROM outer_space;')
self.assertEqual(
'CREATE TABLE tmp AS \nSELECT * FROM outer_space',
@ -45,7 +44,7 @@ class UtilityFunctionTests(SupersetTestCase):
q.as_create_table('tmp', overwrite=True))
# now without a semicolon
q = SupersetQuery('SELECT * FROM outer_space')
q = ParsedQuery('SELECT * FROM outer_space')
self.assertEqual(
'CREATE TABLE tmp AS \nSELECT * FROM outer_space',
q.as_create_table('tmp'))
@ -54,7 +53,7 @@ class UtilityFunctionTests(SupersetTestCase):
multi_line_query = (
'SELECT * FROM planets WHERE\n'
"Luke_Father = 'Darth Vader'")
q = SupersetQuery(multi_line_query)
q = ParsedQuery(multi_line_query)
self.assertEqual(
'CREATE TABLE tmp AS \nSELECT * FROM planets WHERE\n'
"Luke_Father = 'Darth Vader'",
@ -125,8 +124,8 @@ class CeleryTestCase(SupersetTestCase):
def test_run_sync_query_cta(self):
main_db = get_main_database(db.session)
backend = main_db.backend
db_id = main_db.id
eng = main_db.get_sqla_engine()
tmp_table_name = 'tmp_async_22'
self.drop_table_if_exists(tmp_table_name, main_db)
perm_name = 'can_sql_json'
@ -140,9 +139,11 @@ class CeleryTestCase(SupersetTestCase):
query2 = self.get_query_by_id(result2['query']['serverId'])
# Check the data in the tmp table.
df2 = pd.read_sql_query(sql=query2.select_sql, con=eng)
data2 = df2.to_dict(orient='records')
self.assertEqual([{'name': perm_name}], data2)
if backend != 'postgresql':
# TODO This test won't work in Postgres
results = self.run_sql(db_id, query2.select_sql, 'sdf2134')
self.assertEquals(results['status'], 'success')
self.assertGreater(len(results['data']), 0)
def test_run_sync_query_cta_no_data(self):
main_db = get_main_database(db.session)
@ -184,7 +185,8 @@ class CeleryTestCase(SupersetTestCase):
self.assertEqual(QueryStatus.SUCCESS, query.status)
self.assertTrue('FROM tmp_async_1' in query.select_sql)
self.assertEqual(
'CREATE TABLE tmp_async_1 AS \nSELECT name FROM ab_role '
'CREATE TABLE tmp_async_1 AS \n'
'SELECT name FROM ab_role '
"WHERE name='Admin' LIMIT 666", query.executed_sql)
self.assertEqual(sql_where, query.sql)
self.assertEqual(0, query.rows)

View File

@ -6,7 +6,7 @@ from superset import sql_parse
class SupersetTestCase(unittest.TestCase):
def extract_tables(self, query):
sq = sql_parse.SupersetQuery(query)
sq = sql_parse.ParsedQuery(query)
return sq.tables
def test_simple_select(self):
@ -294,12 +294,12 @@ class SupersetTestCase(unittest.TestCase):
self.assertEquals({'t1', 't2'}, self.extract_tables(query))
def test_update_not_select(self):
sql = sql_parse.SupersetQuery('UPDATE t1 SET col1 = NULL')
sql = sql_parse.ParsedQuery('UPDATE t1 SET col1 = NULL')
self.assertEquals(False, sql.is_select())
self.assertEquals(False, sql.is_readonly())
def test_explain(self):
sql = sql_parse.SupersetQuery('EXPLAIN SELECT 1')
sql = sql_parse.ParsedQuery('EXPLAIN SELECT 1')
self.assertEquals(True, sql.is_explain())
self.assertEquals(False, sql.is_select())
@ -369,3 +369,35 @@ class SupersetTestCase(unittest.TestCase):
self.assertEquals(
{'a', 'b', 'c', 'd', 'e', 'f'},
self.extract_tables(query))
def test_basic_breakdown_statements(self):
multi_sql = """
SELECT * FROM ab_user;
SELECT * FROM ab_user LIMIT 1;
"""
parsed = sql_parse.ParsedQuery(multi_sql)
statements = parsed.get_statements()
self.assertEquals(len(statements), 2)
expected = [
'SELECT * FROM ab_user',
'SELECT * FROM ab_user LIMIT 1',
]
self.assertEquals(statements, expected)
def test_messy_breakdown_statements(self):
multi_sql = """
SELECT 1;\t\n\n\n \t
\t\nSELECT 2;
SELECT * FROM ab_user;;;
SELECT * FROM ab_user LIMIT 1
"""
parsed = sql_parse.ParsedQuery(multi_sql)
statements = parsed.get_statements()
self.assertEquals(len(statements), 4)
expected = [
'SELECT 1',
'SELECT 2',
'SELECT * FROM ab_user',
'SELECT * FROM ab_user LIMIT 1',
]
self.assertEquals(statements, expected)

View File

@ -50,6 +50,16 @@ class SqlLabTests(SupersetTestCase):
data = self.run_sql('SELECT * FROM unexistant_table', '2')
self.assertLess(0, len(data['error']))
def test_multi_sql(self):
self.login('admin')
multi_sql = """
SELECT first_name FROM ab_user;
SELECT first_name FROM ab_user;
"""
data = self.run_sql(multi_sql, '2234')
self.assertLess(0, len(data['data']))
def test_explain(self):
self.login('admin')