mirror of https://github.com/apache/superset.git
[SQL Lab] Allow running multiple statements (#6112)
* Allow running multiple statements from SQL Lab * fix tests * More tests * merge heads * fix heads
This commit is contained in:
parent
6e942c9fb3
commit
d427db0a8b
|
@ -35,7 +35,7 @@ const defaultProps = {
|
|||
|
||||
const SEARCH_HEIGHT = 46;
|
||||
|
||||
const LOADING_STYLES = { position: 'relative', height: 50 };
|
||||
const LOADING_STYLES = { position: 'relative', minHeight: 100 };
|
||||
|
||||
export default class ResultSet extends React.PureComponent {
|
||||
constructor(props) {
|
||||
|
@ -231,11 +231,19 @@ export default class ResultSet extends React.PureComponent {
|
|||
</Button>
|
||||
);
|
||||
}
|
||||
const progressMsg = query && query.extra && query.extra.progress ? query.extra.progress : null;
|
||||
return (
|
||||
<div style={LOADING_STYLES}>
|
||||
<div>
|
||||
{!progressBar && <Loading position="normal" />}
|
||||
</div>
|
||||
<QueryStateLabel query={query} />
|
||||
{!progressBar && <Loading />}
|
||||
{progressBar}
|
||||
<div>
|
||||
{progressMsg && <Alert bsStyle="success">{progressMsg}</Alert>}
|
||||
</div>
|
||||
<div>
|
||||
{progressBar}
|
||||
</div>
|
||||
<div>
|
||||
{trackingUrl}
|
||||
</div>
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
@import "../../stylesheets/less/cosmo/variables.less";
|
||||
body {
|
||||
overflow: hidden;
|
||||
}
|
||||
|
@ -168,8 +169,8 @@ div.Workspace {
|
|||
}
|
||||
|
||||
.Resizer {
|
||||
background: #000;
|
||||
opacity: .2;
|
||||
background: @brand-primary;
|
||||
opacity: 0.5;
|
||||
z-index: 1;
|
||||
-moz-box-sizing: border-box;
|
||||
-webkit-box-sizing: border-box;
|
||||
|
@ -180,23 +181,24 @@ div.Workspace {
|
|||
}
|
||||
|
||||
.Resizer:hover {
|
||||
-webkit-transition: all 2s ease;
|
||||
transition: all 2s ease;
|
||||
-webkit-transition: all 0.3s ease;
|
||||
transition: all 0.3s ease;
|
||||
opacity: 0.3;
|
||||
}
|
||||
|
||||
.Resizer.horizontal {
|
||||
height: 10px;
|
||||
margin: -5px 0;
|
||||
border-top: 5px solid rgba(255, 255, 255, 0);
|
||||
border-bottom: 5px solid rgba(255, 255, 255, 0);
|
||||
border-top: 5px solid transparent;
|
||||
border-bottom: 5px solid transparent;
|
||||
cursor: row-resize;
|
||||
width: 100%;
|
||||
padding: 1px;
|
||||
}
|
||||
|
||||
.Resizer.horizontal:hover {
|
||||
border-top: 5px solid rgba(0, 0, 0, 0.5);
|
||||
border-bottom: 5px solid rgba(0, 0, 0, 0.5);
|
||||
border-top: 5px solid @brand-primary;
|
||||
border-bottom: 5px solid @brand-primary;
|
||||
}
|
||||
|
||||
.Resizer.vertical {
|
||||
|
|
|
@ -3,27 +3,34 @@ import PropTypes from 'prop-types';
|
|||
|
||||
const propTypes = {
|
||||
size: PropTypes.number,
|
||||
position: PropTypes.oneOf(['floating', 'normal']),
|
||||
};
|
||||
const defaultProps = {
|
||||
size: 50,
|
||||
position: 'floating',
|
||||
};
|
||||
|
||||
export default function Loading({ size }) {
|
||||
const FLOATING_STYLE = {
|
||||
padding: 0,
|
||||
margin: 0,
|
||||
position: 'absolute',
|
||||
left: '50%',
|
||||
top: '50%',
|
||||
transform: 'translate(-50%, -50%)',
|
||||
};
|
||||
|
||||
export default function Loading({ size, position }) {
|
||||
const style = position === 'floating' ? FLOATING_STYLE : {};
|
||||
const styleWithWidth = {
|
||||
...style,
|
||||
size,
|
||||
};
|
||||
return (
|
||||
<img
|
||||
className="loading"
|
||||
alt="Loading..."
|
||||
src="/static/assets/images/loading.gif"
|
||||
style={{
|
||||
width: Math.min(size, 50),
|
||||
// height is auto
|
||||
padding: 0,
|
||||
margin: 0,
|
||||
position: 'absolute',
|
||||
left: '50%',
|
||||
top: '50%',
|
||||
transform: 'translate(-50%, -50%)',
|
||||
}}
|
||||
style={styleWithWidth}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
|
|
@ -148,18 +148,18 @@ class BaseEngineSpec(object):
|
|||
)
|
||||
return database.compile_sqla_query(qry)
|
||||
elif LimitMethod.FORCE_LIMIT:
|
||||
parsed_query = sql_parse.SupersetQuery(sql)
|
||||
parsed_query = sql_parse.ParsedQuery(sql)
|
||||
sql = parsed_query.get_query_with_new_limit(limit)
|
||||
return sql
|
||||
|
||||
@classmethod
|
||||
def get_limit_from_sql(cls, sql):
|
||||
parsed_query = sql_parse.SupersetQuery(sql)
|
||||
parsed_query = sql_parse.ParsedQuery(sql)
|
||||
return parsed_query.limit
|
||||
|
||||
@classmethod
|
||||
def get_query_with_new_limit(cls, sql, limit):
|
||||
parsed_query = sql_parse.SupersetQuery(sql)
|
||||
parsed_query = sql_parse.ParsedQuery(sql)
|
||||
return parsed_query.get_query_with_new_limit(limit)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
"""Add extra column to Query
|
||||
|
||||
Revision ID: 0b1f1ab473c0
|
||||
Revises: 55e910a74826
|
||||
Create Date: 2018-11-05 08:42:56.181012
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '0b1f1ab473c0'
|
||||
down_revision = '55e910a74826'
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.add_column('query', sa.Column('extra_json', sa.Text(), nullable=True))
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_column('query', 'extra_json')
|
|
@ -1,22 +0,0 @@
|
|||
"""empty message
|
||||
|
||||
Revision ID: a7ca4a272e0a
|
||||
Revises: ('3e1b21cd94a4', 'cefabc8f7d38')
|
||||
Create Date: 2018-12-20 21:53:05.719149
|
||||
|
||||
"""
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'a7ca4a272e0a'
|
||||
down_revision = ('3e1b21cd94a4', 'cefabc8f7d38')
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
def upgrade():
|
||||
pass
|
||||
|
||||
|
||||
def downgrade():
|
||||
pass
|
|
@ -0,0 +1,18 @@
|
|||
"""empty message
|
||||
|
||||
Revision ID: de021a1ca60d
|
||||
Revises: ('0b1f1ab473c0', 'cefabc8f7d38')
|
||||
Create Date: 2018-12-18 22:45:55.783083
|
||||
|
||||
"""
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'de021a1ca60d'
|
||||
down_revision = ('0b1f1ab473c0', 'cefabc8f7d38', '3e1b21cd94a4')
|
||||
|
||||
|
||||
def upgrade():
|
||||
pass
|
||||
|
||||
|
||||
def downgrade():
|
||||
pass
|
|
@ -294,3 +294,23 @@ class QueryResult(object):
|
|||
self.duration = duration
|
||||
self.status = status
|
||||
self.error_message = error_message
|
||||
|
||||
|
||||
class ExtraJSONMixin:
|
||||
"""Mixin to add an `extra` column (JSON) and utility methods"""
|
||||
extra_json = sa.Column(sa.Text, default='{}')
|
||||
|
||||
@property
|
||||
def extra(self):
|
||||
try:
|
||||
return json.loads(self.extra_json)
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
def set_extra_json(self, d):
|
||||
self.extra_json = json.dumps(d)
|
||||
|
||||
def set_extra_json_key(self, key, value):
|
||||
extra = self.extra
|
||||
extra[key] = value
|
||||
self.extra_json = json.dumps(extra)
|
||||
|
|
|
@ -12,12 +12,15 @@ from sqlalchemy import (
|
|||
from sqlalchemy.orm import backref, relationship
|
||||
|
||||
from superset import security_manager
|
||||
from superset.models.helpers import AuditMixinNullable
|
||||
from superset.models.helpers import AuditMixinNullable, ExtraJSONMixin
|
||||
from superset.utils.core import QueryStatus, user_label
|
||||
|
||||
|
||||
class Query(Model):
|
||||
"""ORM model for SQL query"""
|
||||
class Query(Model, ExtraJSONMixin):
|
||||
"""ORM model for SQL query
|
||||
|
||||
Now that SQL Lab support multi-statement execution, an entry in this
|
||||
table may represent multiple SQL statements executed sequentially"""
|
||||
|
||||
__tablename__ = 'query'
|
||||
id = Column(Integer, primary_key=True)
|
||||
|
@ -105,6 +108,7 @@ class Query(Model):
|
|||
'limit_reached': self.limit_reached,
|
||||
'resultsKey': self.results_key,
|
||||
'trackingUrl': self.tracking_url,
|
||||
'extra': self.extra,
|
||||
}
|
||||
|
||||
@property
|
||||
|
|
|
@ -165,7 +165,7 @@ class SupersetSecurityManager(SecurityManager):
|
|||
database, table_name, schema=table_schema)
|
||||
|
||||
def rejected_datasources(self, sql, database, schema):
|
||||
superset_query = sql_parse.SupersetQuery(sql)
|
||||
superset_query = sql_parse.ParsedQuery(sql)
|
||||
return [
|
||||
t for t in superset_query.tables if not
|
||||
self.datasource_access_by_fullname(database, t, schema)]
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# pylint: disable=C,R,W
|
||||
from contextlib import closing
|
||||
from datetime import datetime
|
||||
import logging
|
||||
from time import sleep
|
||||
|
@ -6,6 +7,7 @@ import uuid
|
|||
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from contextlib2 import contextmanager
|
||||
from flask_babel import lazy_gettext as _
|
||||
import simplejson as json
|
||||
import sqlalchemy
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
@ -13,14 +15,15 @@ from sqlalchemy.pool import NullPool
|
|||
|
||||
from superset import app, dataframe, db, results_backend, security_manager
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.sql_parse import SupersetQuery
|
||||
from superset.sql_parse import ParsedQuery
|
||||
from superset.tasks.celery_app import app as celery_app
|
||||
from superset.utils.core import (
|
||||
json_iso_dttm_ser,
|
||||
now_as_float,
|
||||
QueryStatus,
|
||||
zlib_compress,
|
||||
)
|
||||
from superset.utils.dates import now_as_float
|
||||
from superset.utils.decorators import stats_timing
|
||||
|
||||
config = app.config
|
||||
stats_logger = config.get('STATS_LOGGER')
|
||||
|
@ -32,6 +35,31 @@ class SqlLabException(Exception):
|
|||
pass
|
||||
|
||||
|
||||
class SqlLabSecurityException(SqlLabException):
|
||||
pass
|
||||
|
||||
|
||||
class SqlLabTimeoutException(SqlLabException):
|
||||
pass
|
||||
|
||||
|
||||
def handle_query_error(msg, query, session, payload=None):
|
||||
"""Local method handling error while processing the SQL"""
|
||||
payload = payload or {}
|
||||
troubleshooting_link = config['TROUBLESHOOTING_LINK']
|
||||
query.error_message = msg
|
||||
query.status = QueryStatus.FAILED
|
||||
query.tmp_table_name = None
|
||||
session.commit()
|
||||
payload.update({
|
||||
'status': query.status,
|
||||
'error': msg,
|
||||
})
|
||||
if troubleshooting_link:
|
||||
payload['link'] = troubleshooting_link
|
||||
return payload
|
||||
|
||||
|
||||
def get_query(query_id, session, retry_count=5):
|
||||
"""attemps to get the query and retry if it cannot"""
|
||||
query = None
|
||||
|
@ -86,102 +114,52 @@ def get_sql_results(
|
|||
with session_scope(not ctask.request.called_directly) as session:
|
||||
|
||||
try:
|
||||
return execute_sql(
|
||||
return execute_sql_statements(
|
||||
ctask, query_id, rendered_query, return_results, store_results, user_name,
|
||||
session=session, start_time=start_time)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
stats_logger.incr('error_sqllab_unhandled')
|
||||
query = get_query(query_id, session)
|
||||
query.error_message = str(e)
|
||||
query.status = QueryStatus.FAILED
|
||||
query.tmp_table_name = None
|
||||
session.commit()
|
||||
raise
|
||||
return handle_query_error(str(e), query, session)
|
||||
|
||||
|
||||
def execute_sql(
|
||||
ctask, query_id, rendered_query, return_results=True, store_results=False,
|
||||
user_name=None, session=None, start_time=None,
|
||||
):
|
||||
"""Executes the sql query returns the results."""
|
||||
if store_results and start_time:
|
||||
# only asynchronous queries
|
||||
stats_logger.timing(
|
||||
'sqllab.query.time_pending', now_as_float() - start_time)
|
||||
query = get_query(query_id, session)
|
||||
payload = dict(query_id=query_id)
|
||||
|
||||
def execute_sql_statement(
|
||||
sql_statement, query, user_name, session,
|
||||
cursor, return_results=False):
|
||||
"""Executes a single SQL statement"""
|
||||
database = query.database
|
||||
db_engine_spec = database.db_engine_spec
|
||||
db_engine_spec.patch()
|
||||
|
||||
def handle_error(msg):
|
||||
"""Local method handling error while processing the SQL"""
|
||||
troubleshooting_link = config['TROUBLESHOOTING_LINK']
|
||||
query.error_message = msg
|
||||
query.status = QueryStatus.FAILED
|
||||
query.tmp_table_name = None
|
||||
session.commit()
|
||||
payload.update({
|
||||
'status': query.status,
|
||||
'error': msg,
|
||||
})
|
||||
if troubleshooting_link:
|
||||
payload['link'] = troubleshooting_link
|
||||
return payload
|
||||
|
||||
if store_results and not results_backend:
|
||||
return handle_error("Results backend isn't configured.")
|
||||
|
||||
# Limit enforced only for retrieving the data, not for the CTA queries.
|
||||
superset_query = SupersetQuery(rendered_query)
|
||||
executed_sql = superset_query.stripped()
|
||||
parsed_query = ParsedQuery(sql_statement)
|
||||
sql = parsed_query.stripped()
|
||||
SQL_MAX_ROWS = app.config.get('SQL_MAX_ROW')
|
||||
if not superset_query.is_readonly() and not database.allow_dml:
|
||||
return handle_error(
|
||||
'Only `SELECT` statements are allowed against this database')
|
||||
|
||||
if not parsed_query.is_readonly() and not database.allow_dml:
|
||||
raise SqlLabSecurityException(
|
||||
_('Only `SELECT` statements are allowed against this database'))
|
||||
if query.select_as_cta:
|
||||
if not superset_query.is_select():
|
||||
return handle_error(
|
||||
if not parsed_query.is_select():
|
||||
raise SqlLabException(_(
|
||||
'Only `SELECT` statements can be used with the CREATE TABLE '
|
||||
'feature.')
|
||||
'feature.'))
|
||||
if not query.tmp_table_name:
|
||||
start_dttm = datetime.fromtimestamp(query.start_time)
|
||||
query.tmp_table_name = 'tmp_{}_table_{}'.format(
|
||||
query.user_id, start_dttm.strftime('%Y_%m_%d_%H_%M_%S'))
|
||||
executed_sql = superset_query.as_create_table(query.tmp_table_name)
|
||||
sql = parsed_query.as_create_table(query.tmp_table_name)
|
||||
query.select_as_cta_used = True
|
||||
if superset_query.is_select():
|
||||
if parsed_query.is_select():
|
||||
if SQL_MAX_ROWS and (not query.limit or query.limit > SQL_MAX_ROWS):
|
||||
query.limit = SQL_MAX_ROWS
|
||||
if query.limit:
|
||||
executed_sql = database.apply_limit_to_sql(executed_sql, query.limit)
|
||||
sql = database.apply_limit_to_sql(sql, query.limit)
|
||||
|
||||
# Hook to allow environment-specific mutation (usually comments) to the SQL
|
||||
SQL_QUERY_MUTATOR = config.get('SQL_QUERY_MUTATOR')
|
||||
if SQL_QUERY_MUTATOR:
|
||||
executed_sql = SQL_QUERY_MUTATOR(
|
||||
executed_sql, user_name, security_manager, database)
|
||||
sql = SQL_QUERY_MUTATOR(sql, user_name, security_manager, database)
|
||||
|
||||
query.executed_sql = executed_sql
|
||||
query.status = QueryStatus.RUNNING
|
||||
query.start_running_time = now_as_float()
|
||||
session.merge(query)
|
||||
session.commit()
|
||||
logging.info("Set query to 'running'")
|
||||
conn = None
|
||||
try:
|
||||
engine = database.get_sqla_engine(
|
||||
schema=query.schema,
|
||||
nullpool=True,
|
||||
user_name=user_name,
|
||||
)
|
||||
conn = engine.raw_connection()
|
||||
cursor = conn.cursor()
|
||||
logging.info('Running query: \n{}'.format(executed_sql))
|
||||
logging.info(query.executed_sql)
|
||||
query_start_time = now_as_float()
|
||||
if log_query:
|
||||
log_query(
|
||||
query.database.sqlalchemy_uri,
|
||||
|
@ -191,56 +169,102 @@ def execute_sql(
|
|||
__name__,
|
||||
security_manager,
|
||||
)
|
||||
db_engine_spec.execute(cursor, query.executed_sql, async_=True)
|
||||
logging.info('Handling cursor')
|
||||
db_engine_spec.handle_cursor(cursor, query, session)
|
||||
logging.info('Fetching data: {}'.format(query.to_dict()))
|
||||
stats_logger.timing(
|
||||
'sqllab.query.time_executing_query',
|
||||
now_as_float() - query_start_time)
|
||||
fetching_start_time = now_as_float()
|
||||
data = db_engine_spec.fetch_data(cursor, query.limit)
|
||||
stats_logger.timing(
|
||||
'sqllab.query.time_fetching_results',
|
||||
now_as_float() - fetching_start_time)
|
||||
query.executed_sql = sql
|
||||
with stats_timing('sqllab.query.time_executing_query', stats_logger):
|
||||
logging.info('Running query: \n{}'.format(sql))
|
||||
db_engine_spec.execute(cursor, sql, async_=True)
|
||||
logging.info('Handling cursor')
|
||||
db_engine_spec.handle_cursor(cursor, query, session)
|
||||
|
||||
with stats_timing('sqllab.query.time_fetching_results', stats_logger):
|
||||
logging.debug('Fetching data for query object: {}'.format(query.to_dict()))
|
||||
data = db_engine_spec.fetch_data(cursor, query.limit)
|
||||
|
||||
except SoftTimeLimitExceeded as e:
|
||||
logging.exception(e)
|
||||
if conn is not None:
|
||||
conn.close()
|
||||
return handle_error(
|
||||
raise SqlLabTimeoutException(
|
||||
"SQL Lab timeout. This environment's policy is to kill queries "
|
||||
'after {} seconds.'.format(SQLLAB_TIMEOUT))
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
if conn is not None:
|
||||
conn.close()
|
||||
return handle_error(db_engine_spec.extract_error_message(e))
|
||||
raise SqlLabException(db_engine_spec.extract_error_message(e))
|
||||
|
||||
logging.info('Fetching cursor description')
|
||||
logging.debug('Fetching cursor description')
|
||||
cursor_description = cursor.description
|
||||
if conn is not None:
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return dataframe.SupersetDataFrame(data, cursor_description, db_engine_spec)
|
||||
|
||||
if query.status == QueryStatus.STOPPED:
|
||||
return handle_error('The query has been stopped')
|
||||
|
||||
cdf = dataframe.SupersetDataFrame(data, cursor_description, db_engine_spec)
|
||||
def execute_sql_statements(
|
||||
ctask, query_id, rendered_query, return_results=True, store_results=False,
|
||||
user_name=None, session=None, start_time=None,
|
||||
):
|
||||
"""Executes the sql query returns the results."""
|
||||
if store_results and start_time:
|
||||
# only asynchronous queries
|
||||
stats_logger.timing(
|
||||
'sqllab.query.time_pending', now_as_float() - start_time)
|
||||
|
||||
query = get_query(query_id, session)
|
||||
payload = dict(query_id=query_id)
|
||||
database = query.database
|
||||
db_engine_spec = database.db_engine_spec
|
||||
db_engine_spec.patch()
|
||||
|
||||
if store_results and not results_backend:
|
||||
raise SqlLabException("Results backend isn't configured.")
|
||||
|
||||
# Breaking down into multiple statements
|
||||
parsed_query = ParsedQuery(rendered_query)
|
||||
statements = parsed_query.get_statements()
|
||||
logging.info(f'Executing {len(statements)} statement(s)')
|
||||
|
||||
logging.info("Set query to 'running'")
|
||||
query.status = QueryStatus.RUNNING
|
||||
query.start_running_time = now_as_float()
|
||||
|
||||
engine = database.get_sqla_engine(
|
||||
schema=query.schema,
|
||||
nullpool=True,
|
||||
user_name=user_name,
|
||||
)
|
||||
# Sharing a single connection and cursor across the
|
||||
# execution of all statements (if many)
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
with closing(conn.cursor()) as cursor:
|
||||
statement_count = len(statements)
|
||||
for i, statement in enumerate(statements):
|
||||
# TODO CHECK IF STOPPED
|
||||
msg = f'Running statement {i+1} out of {statement_count}'
|
||||
logging.info(msg)
|
||||
query.set_extra_json_key('progress', msg)
|
||||
session.commit()
|
||||
is_last_statement = i == len(statements) - 1
|
||||
try:
|
||||
cdf = execute_sql_statement(
|
||||
statement, query, user_name, session, cursor,
|
||||
return_results=is_last_statement and return_results)
|
||||
msg = f'Running statement {i+1} out of {statement_count}'
|
||||
except Exception as e:
|
||||
msg = str(e)
|
||||
if statement_count > 1:
|
||||
msg = f'[Statement {i+1} out of {statement_count}] ' + msg
|
||||
payload = handle_query_error(msg, query, session, payload)
|
||||
return payload
|
||||
|
||||
# Success, updating the query entry in database
|
||||
query.rows = cdf.size
|
||||
query.progress = 100
|
||||
query.set_extra_json_key('progress', None)
|
||||
query.status = QueryStatus.SUCCESS
|
||||
if query.select_as_cta:
|
||||
query.select_sql = '{}'.format(
|
||||
database.select_star(
|
||||
query.tmp_table_name,
|
||||
limit=query.limit,
|
||||
schema=database.force_ctas_schema,
|
||||
show_cols=False,
|
||||
latest_partition=False))
|
||||
query.select_sql = database.select_star(
|
||||
query.tmp_table_name,
|
||||
limit=query.limit,
|
||||
schema=database.force_ctas_schema,
|
||||
show_cols=False,
|
||||
latest_partition=False)
|
||||
query.end_time = now_as_float()
|
||||
session.merge(query)
|
||||
session.flush()
|
||||
session.commit()
|
||||
|
||||
payload.update({
|
||||
'status': query.status,
|
||||
|
@ -248,21 +272,18 @@ def execute_sql(
|
|||
'columns': cdf.columns if cdf.columns else [],
|
||||
'query': query.to_dict(),
|
||||
})
|
||||
|
||||
if store_results:
|
||||
key = '{}'.format(uuid.uuid4())
|
||||
logging.info('Storing results in results backend, key: {}'.format(key))
|
||||
write_to_results_backend_start = now_as_float()
|
||||
json_payload = json.dumps(
|
||||
payload, default=json_iso_dttm_ser, ignore_nan=True)
|
||||
cache_timeout = database.cache_timeout
|
||||
if cache_timeout is None:
|
||||
cache_timeout = config.get('CACHE_DEFAULT_TIMEOUT', 0)
|
||||
results_backend.set(key, zlib_compress(json_payload), cache_timeout)
|
||||
key = str(uuid.uuid4())
|
||||
logging.info(f'Storing results in results backend, key: {key}')
|
||||
with stats_timing('sqllab.query.results_backend_write', stats_logger):
|
||||
json_payload = json.dumps(
|
||||
payload, default=json_iso_dttm_ser, ignore_nan=True)
|
||||
cache_timeout = database.cache_timeout
|
||||
if cache_timeout is None:
|
||||
cache_timeout = config.get('CACHE_DEFAULT_TIMEOUT', 0)
|
||||
results_backend.set(key, zlib_compress(json_payload), cache_timeout)
|
||||
query.results_key = key
|
||||
stats_logger.timing(
|
||||
'sqllab.query.results_backend_write',
|
||||
now_as_float() - write_to_results_backend_start)
|
||||
session.merge(query)
|
||||
session.commit()
|
||||
|
||||
if return_results:
|
||||
|
|
|
@ -10,13 +10,12 @@ ON_KEYWORD = 'ON'
|
|||
PRECEDES_TABLE_NAME = {'FROM', 'JOIN', 'DESC', 'DESCRIBE', 'WITH'}
|
||||
|
||||
|
||||
class SupersetQuery(object):
|
||||
class ParsedQuery(object):
|
||||
def __init__(self, sql_statement):
|
||||
self.sql = sql_statement
|
||||
self._table_names = set()
|
||||
self._alias_names = set()
|
||||
self._limit = None
|
||||
# TODO: multistatement support
|
||||
|
||||
logging.info('Parsing with sqlparse statement {}'.format(self.sql))
|
||||
self._parsed = sqlparse.parse(self.sql)
|
||||
|
@ -37,7 +36,7 @@ class SupersetQuery(object):
|
|||
return self._parsed[0].get_type() == 'SELECT'
|
||||
|
||||
def is_explain(self):
|
||||
return self.sql.strip().upper().startswith('EXPLAIN')
|
||||
return self.stripped().upper().startswith('EXPLAIN')
|
||||
|
||||
def is_readonly(self):
|
||||
"""Pessimistic readonly, 100% sure statement won't mutate anything"""
|
||||
|
@ -46,6 +45,16 @@ class SupersetQuery(object):
|
|||
def stripped(self):
|
||||
return self.sql.strip(' \t\n;')
|
||||
|
||||
def get_statements(self):
|
||||
"""Returns a list of SQL statements as strings, stripped"""
|
||||
statements = []
|
||||
for statement in self._parsed:
|
||||
if statement:
|
||||
sql = str(statement).strip(' \n;\t')
|
||||
if sql:
|
||||
statements.append(sql)
|
||||
return statements
|
||||
|
||||
@staticmethod
|
||||
def __precedes_table_name(token_value):
|
||||
for keyword in PRECEDES_TABLE_NAME:
|
||||
|
|
|
@ -34,19 +34,18 @@ import pandas as pd
|
|||
import parsedatetime
|
||||
from past.builtins import basestring
|
||||
from pydruid.utils.having import Having
|
||||
import pytz
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import event, exc, select, Text
|
||||
from sqlalchemy.dialects.mysql import MEDIUMTEXT
|
||||
from sqlalchemy.types import TEXT, TypeDecorator
|
||||
|
||||
from superset.exceptions import SupersetException, SupersetTimeoutException
|
||||
from superset.utils.dates import datetime_to_epoch, EPOCH
|
||||
|
||||
|
||||
logging.getLogger('MARKDOWN').setLevel(logging.INFO)
|
||||
|
||||
PY3K = sys.version_info >= (3, 0)
|
||||
EPOCH = datetime(1970, 1, 1)
|
||||
DTTM_ALIAS = '__timestamp'
|
||||
ADHOC_METRIC_EXPRESSION_TYPES = {
|
||||
'SIMPLE': 'SIMPLE',
|
||||
|
@ -357,18 +356,6 @@ def pessimistic_json_iso_dttm_ser(obj):
|
|||
return json_iso_dttm_ser(obj, pessimistic=True)
|
||||
|
||||
|
||||
def datetime_to_epoch(dttm):
|
||||
if dttm.tzinfo:
|
||||
dttm = dttm.replace(tzinfo=pytz.utc)
|
||||
epoch_with_tz = pytz.utc.localize(EPOCH)
|
||||
return (dttm - epoch_with_tz).total_seconds() * 1000
|
||||
return (dttm - EPOCH).total_seconds() * 1000
|
||||
|
||||
|
||||
def now_as_float():
|
||||
return datetime_to_epoch(datetime.utcnow())
|
||||
|
||||
|
||||
def json_int_dttm_ser(obj):
|
||||
"""json serializer that deals with dates"""
|
||||
val = base_json_conv(obj)
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
from datetime import datetime
|
||||
|
||||
import pytz
|
||||
|
||||
EPOCH = datetime(1970, 1, 1)
|
||||
|
||||
|
||||
def datetime_to_epoch(dttm):
|
||||
if dttm.tzinfo:
|
||||
dttm = dttm.replace(tzinfo=pytz.utc)
|
||||
epoch_with_tz = pytz.utc.localize(EPOCH)
|
||||
return (dttm - epoch_with_tz).total_seconds() * 1000
|
||||
return (dttm - EPOCH).total_seconds() * 1000
|
||||
|
||||
|
||||
def now_as_float():
|
||||
return datetime_to_epoch(datetime.utcnow())
|
|
@ -0,0 +1,15 @@
|
|||
from contextlib2 import contextmanager
|
||||
|
||||
from superset.utils.dates import now_as_float
|
||||
|
||||
|
||||
@contextmanager
|
||||
def stats_timing(stats_key, stats_logger):
|
||||
"""Provide a transactional scope around a series of operations."""
|
||||
start_ts = now_as_float()
|
||||
try:
|
||||
yield start_ts
|
||||
except Exception as e:
|
||||
raise e
|
||||
finally:
|
||||
stats_logger.timing(stats_key, now_as_float() - start_ts)
|
|
@ -39,9 +39,10 @@ from superset.legacy import cast_form_data, update_time_range
|
|||
import superset.models.core as models
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.models.user_attributes import UserAttribute
|
||||
from superset.sql_parse import SupersetQuery
|
||||
from superset.sql_parse import ParsedQuery
|
||||
from superset.utils import core as utils
|
||||
from superset.utils import dashboard_import_export
|
||||
from superset.utils.dates import now_as_float
|
||||
from .base import (
|
||||
api, BaseSupersetView,
|
||||
check_ownership,
|
||||
|
@ -2244,7 +2245,7 @@ class Superset(BaseSupersetView):
|
|||
table.schema = data.get('schema')
|
||||
table.template_params = data.get('templateParams')
|
||||
table.is_sqllab_view = True
|
||||
q = SupersetQuery(data.get('sql'))
|
||||
q = ParsedQuery(data.get('sql'))
|
||||
table.sql = q.stripped()
|
||||
db.session.add(table)
|
||||
cols = []
|
||||
|
@ -2390,11 +2391,11 @@ class Superset(BaseSupersetView):
|
|||
if not results_backend:
|
||||
return json_error_response("Results backend isn't configured")
|
||||
|
||||
read_from_results_backend_start = utils.now_as_float()
|
||||
read_from_results_backend_start = now_as_float()
|
||||
blob = results_backend.get(key)
|
||||
stats_logger.timing(
|
||||
'sqllab.query.results_backend_read',
|
||||
utils.now_as_float() - read_from_results_backend_start,
|
||||
now_as_float() - read_from_results_backend_start,
|
||||
)
|
||||
if not blob:
|
||||
return json_error_response(
|
||||
|
@ -2488,7 +2489,7 @@ class Superset(BaseSupersetView):
|
|||
sql=sql,
|
||||
schema=schema,
|
||||
select_as_cta=request.form.get('select_as_cta') == 'true',
|
||||
start_time=utils.now_as_float(),
|
||||
start_time=now_as_float(),
|
||||
tab_name=request.form.get('tab'),
|
||||
status=QueryStatus.PENDING if async_ else QueryStatus.RUNNING,
|
||||
sql_editor_id=request.form.get('sql_editor_id'),
|
||||
|
@ -2525,7 +2526,7 @@ class Superset(BaseSupersetView):
|
|||
return_results=False,
|
||||
store_results=not query.select_as_cta,
|
||||
user_name=g.user.username if g.user else None,
|
||||
start_time=utils.now_as_float())
|
||||
start_time=now_as_float())
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
msg = _(
|
||||
|
|
|
@ -4,13 +4,12 @@ import subprocess
|
|||
import time
|
||||
import unittest
|
||||
|
||||
import pandas as pd
|
||||
from past.builtins import basestring
|
||||
|
||||
from superset import app, db
|
||||
from superset.models.helpers import QueryStatus
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.sql_parse import SupersetQuery
|
||||
from superset.sql_parse import ParsedQuery
|
||||
from superset.utils.core import get_main_database
|
||||
from .base_tests import SupersetTestCase
|
||||
|
||||
|
@ -33,7 +32,7 @@ class UtilityFunctionTests(SupersetTestCase):
|
|||
|
||||
# TODO(bkyryliuk): support more cases in CTA function.
|
||||
def test_create_table_as(self):
|
||||
q = SupersetQuery('SELECT * FROM outer_space;')
|
||||
q = ParsedQuery('SELECT * FROM outer_space;')
|
||||
|
||||
self.assertEqual(
|
||||
'CREATE TABLE tmp AS \nSELECT * FROM outer_space',
|
||||
|
@ -45,7 +44,7 @@ class UtilityFunctionTests(SupersetTestCase):
|
|||
q.as_create_table('tmp', overwrite=True))
|
||||
|
||||
# now without a semicolon
|
||||
q = SupersetQuery('SELECT * FROM outer_space')
|
||||
q = ParsedQuery('SELECT * FROM outer_space')
|
||||
self.assertEqual(
|
||||
'CREATE TABLE tmp AS \nSELECT * FROM outer_space',
|
||||
q.as_create_table('tmp'))
|
||||
|
@ -54,7 +53,7 @@ class UtilityFunctionTests(SupersetTestCase):
|
|||
multi_line_query = (
|
||||
'SELECT * FROM planets WHERE\n'
|
||||
"Luke_Father = 'Darth Vader'")
|
||||
q = SupersetQuery(multi_line_query)
|
||||
q = ParsedQuery(multi_line_query)
|
||||
self.assertEqual(
|
||||
'CREATE TABLE tmp AS \nSELECT * FROM planets WHERE\n'
|
||||
"Luke_Father = 'Darth Vader'",
|
||||
|
@ -125,8 +124,8 @@ class CeleryTestCase(SupersetTestCase):
|
|||
|
||||
def test_run_sync_query_cta(self):
|
||||
main_db = get_main_database(db.session)
|
||||
backend = main_db.backend
|
||||
db_id = main_db.id
|
||||
eng = main_db.get_sqla_engine()
|
||||
tmp_table_name = 'tmp_async_22'
|
||||
self.drop_table_if_exists(tmp_table_name, main_db)
|
||||
perm_name = 'can_sql_json'
|
||||
|
@ -140,9 +139,11 @@ class CeleryTestCase(SupersetTestCase):
|
|||
query2 = self.get_query_by_id(result2['query']['serverId'])
|
||||
|
||||
# Check the data in the tmp table.
|
||||
df2 = pd.read_sql_query(sql=query2.select_sql, con=eng)
|
||||
data2 = df2.to_dict(orient='records')
|
||||
self.assertEqual([{'name': perm_name}], data2)
|
||||
if backend != 'postgresql':
|
||||
# TODO This test won't work in Postgres
|
||||
results = self.run_sql(db_id, query2.select_sql, 'sdf2134')
|
||||
self.assertEquals(results['status'], 'success')
|
||||
self.assertGreater(len(results['data']), 0)
|
||||
|
||||
def test_run_sync_query_cta_no_data(self):
|
||||
main_db = get_main_database(db.session)
|
||||
|
@ -184,7 +185,8 @@ class CeleryTestCase(SupersetTestCase):
|
|||
self.assertEqual(QueryStatus.SUCCESS, query.status)
|
||||
self.assertTrue('FROM tmp_async_1' in query.select_sql)
|
||||
self.assertEqual(
|
||||
'CREATE TABLE tmp_async_1 AS \nSELECT name FROM ab_role '
|
||||
'CREATE TABLE tmp_async_1 AS \n'
|
||||
'SELECT name FROM ab_role '
|
||||
"WHERE name='Admin' LIMIT 666", query.executed_sql)
|
||||
self.assertEqual(sql_where, query.sql)
|
||||
self.assertEqual(0, query.rows)
|
||||
|
|
|
@ -6,7 +6,7 @@ from superset import sql_parse
|
|||
class SupersetTestCase(unittest.TestCase):
|
||||
|
||||
def extract_tables(self, query):
|
||||
sq = sql_parse.SupersetQuery(query)
|
||||
sq = sql_parse.ParsedQuery(query)
|
||||
return sq.tables
|
||||
|
||||
def test_simple_select(self):
|
||||
|
@ -294,12 +294,12 @@ class SupersetTestCase(unittest.TestCase):
|
|||
self.assertEquals({'t1', 't2'}, self.extract_tables(query))
|
||||
|
||||
def test_update_not_select(self):
|
||||
sql = sql_parse.SupersetQuery('UPDATE t1 SET col1 = NULL')
|
||||
sql = sql_parse.ParsedQuery('UPDATE t1 SET col1 = NULL')
|
||||
self.assertEquals(False, sql.is_select())
|
||||
self.assertEquals(False, sql.is_readonly())
|
||||
|
||||
def test_explain(self):
|
||||
sql = sql_parse.SupersetQuery('EXPLAIN SELECT 1')
|
||||
sql = sql_parse.ParsedQuery('EXPLAIN SELECT 1')
|
||||
|
||||
self.assertEquals(True, sql.is_explain())
|
||||
self.assertEquals(False, sql.is_select())
|
||||
|
@ -369,3 +369,35 @@ class SupersetTestCase(unittest.TestCase):
|
|||
self.assertEquals(
|
||||
{'a', 'b', 'c', 'd', 'e', 'f'},
|
||||
self.extract_tables(query))
|
||||
|
||||
def test_basic_breakdown_statements(self):
|
||||
multi_sql = """
|
||||
SELECT * FROM ab_user;
|
||||
SELECT * FROM ab_user LIMIT 1;
|
||||
"""
|
||||
parsed = sql_parse.ParsedQuery(multi_sql)
|
||||
statements = parsed.get_statements()
|
||||
self.assertEquals(len(statements), 2)
|
||||
expected = [
|
||||
'SELECT * FROM ab_user',
|
||||
'SELECT * FROM ab_user LIMIT 1',
|
||||
]
|
||||
self.assertEquals(statements, expected)
|
||||
|
||||
def test_messy_breakdown_statements(self):
|
||||
multi_sql = """
|
||||
SELECT 1;\t\n\n\n \t
|
||||
\t\nSELECT 2;
|
||||
SELECT * FROM ab_user;;;
|
||||
SELECT * FROM ab_user LIMIT 1
|
||||
"""
|
||||
parsed = sql_parse.ParsedQuery(multi_sql)
|
||||
statements = parsed.get_statements()
|
||||
self.assertEquals(len(statements), 4)
|
||||
expected = [
|
||||
'SELECT 1',
|
||||
'SELECT 2',
|
||||
'SELECT * FROM ab_user',
|
||||
'SELECT * FROM ab_user LIMIT 1',
|
||||
]
|
||||
self.assertEquals(statements, expected)
|
||||
|
|
|
@ -50,6 +50,16 @@ class SqlLabTests(SupersetTestCase):
|
|||
data = self.run_sql('SELECT * FROM unexistant_table', '2')
|
||||
self.assertLess(0, len(data['error']))
|
||||
|
||||
def test_multi_sql(self):
|
||||
self.login('admin')
|
||||
|
||||
multi_sql = """
|
||||
SELECT first_name FROM ab_user;
|
||||
SELECT first_name FROM ab_user;
|
||||
"""
|
||||
data = self.run_sql(multi_sql, '2234')
|
||||
self.assertLess(0, len(data['data']))
|
||||
|
||||
def test_explain(self):
|
||||
self.login('admin')
|
||||
|
||||
|
|
Loading…
Reference in New Issue