mirror of https://github.com/apache/superset.git
add tests
This commit is contained in:
parent
d38315a307
commit
a9d7fafd9f
|
@ -104,10 +104,33 @@ class BaseEngineSpec(object):
|
|||
)
|
||||
return database.compile_sqla_query(qry)
|
||||
elif LimitMethod.FORCE_LIMIT:
|
||||
sql_without_limit = utils.get_query_without_limit(sql)
|
||||
sql_without_limit = cls.get_query_without_limit(sql)
|
||||
return '{sql_without_limit} LIMIT {limit}'.format(**locals())
|
||||
return sql
|
||||
|
||||
@classmethod
|
||||
def get_limit_from_sql(cls, sql):
|
||||
limit_pattern = re.compile(r"""
|
||||
(?ix) # case insensitive, verbose
|
||||
\s+ # whitespace
|
||||
LIMIT\s+(\d+) # LIMIT $ROWS
|
||||
;? # optional semi-colon
|
||||
(\s|;)*$ # remove trailing spaces tabs or semicolons
|
||||
""")
|
||||
matches = limit_pattern.findall(sql)
|
||||
if matches:
|
||||
return int(matches[0][0])
|
||||
|
||||
@classmethod
|
||||
def get_query_without_limit(cls, sql):
|
||||
return re.sub(r"""
|
||||
(?ix) # case insensitive, verbose
|
||||
\s+ # whitespace
|
||||
LIMIT\s+\d+ # LIMIT $ROWS
|
||||
;? # optional semi-colon
|
||||
(\s|;)*$ # remove trailing spaces tabs or semicolons
|
||||
""", '', sql)
|
||||
|
||||
@staticmethod
|
||||
def csv_to_df(**kwargs):
|
||||
kwargs['filepath_or_buffer'] = \
|
||||
|
|
|
@ -186,11 +186,10 @@ def execute_sql(
|
|||
query.user_id, start_dttm.strftime('%Y_%m_%d_%H_%M_%S'))
|
||||
executed_sql = superset_query.as_create_table(query.tmp_table_name)
|
||||
query.select_as_cta_used = True
|
||||
elif (superset_query.is_select() and SQL_MAX_ROWS and
|
||||
if (superset_query.is_select() and SQL_MAX_ROWS and
|
||||
(not query.limit or query.limit > SQL_MAX_ROWS)):
|
||||
query.limit = SQL_MAX_ROWS
|
||||
executed_sql = database.apply_limit_to_sql(executed_sql, query.limit)
|
||||
query.limit_used = True
|
||||
|
||||
# Hook to allow environment-specific mutation (usually comments) to the SQL
|
||||
SQL_QUERY_MUTATOR = config.get('SQL_QUERY_MUTATOR')
|
||||
|
|
|
@ -18,7 +18,6 @@ import functools
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import signal
|
||||
import smtplib
|
||||
import sys
|
||||
|
@ -883,29 +882,3 @@ def split_adhoc_filters_into_base_filters(fd):
|
|||
fd['having_filters'] = simple_having_filters
|
||||
fd['filters'] = simple_where_filters
|
||||
del fd['adhoc_filters']
|
||||
|
||||
|
||||
def get_query_without_limit(sql):
|
||||
return re.sub(r"""
|
||||
(?ix) # case insensitive, verbose
|
||||
\s+ # whitespace
|
||||
LIMIT\s+\d+ # LIMIT $ROWS
|
||||
;? # optional semi-colon
|
||||
(\s|;)*$ # remove trailing spaces tabs or semicolons
|
||||
""", '', sql)
|
||||
|
||||
|
||||
def get_limit_from_sql(sql):
|
||||
# returns the limit of the quest or None if it has no limit.
|
||||
|
||||
limit_pattern = re.compile(r"""
|
||||
(?ix) # case insensitive, verbose
|
||||
\s+ # whitespace
|
||||
LIMIT\s+(\d+) # LIMIT $ROWS
|
||||
;? # optional semi-colon
|
||||
(\s|;)*$ # remove trailing spaces tabs or semicolons
|
||||
""")
|
||||
matches = limit_pattern.findall(sql)
|
||||
|
||||
if matches:
|
||||
return int(matches[0])
|
||||
|
|
|
@ -2394,7 +2394,7 @@ class Superset(BaseSupersetView):
|
|||
|
||||
query = Query(
|
||||
database_id=int(database_id),
|
||||
limit=utils.get_limit_from_sql(sql),
|
||||
limit=mydb.db_engine_spec.get_limit_from_sql(sql),
|
||||
sql=sql,
|
||||
schema=schema,
|
||||
select_as_cta=request.form.get('select_as_cta') == 'true',
|
||||
|
|
|
@ -196,10 +196,9 @@ class CeleryTestCase(SupersetTestCase):
|
|||
self.assertEqual([{'name': 'Admin'}], df.to_dict(orient='records'))
|
||||
self.assertEqual(QueryStatus.SUCCESS, query.status)
|
||||
self.assertTrue('FROM tmp_async_1' in query.select_sql)
|
||||
self.assertTrue('LIMIT 666' in query.select_sql)
|
||||
self.assertEqual(
|
||||
'CREATE TABLE tmp_async_1 AS \nSELECT name FROM ab_role '
|
||||
"WHERE name='Admin'", query.executed_sql)
|
||||
"WHERE name='Admin' LIMIT 666", query.executed_sql)
|
||||
self.assertEqual(sql_where, query.sql)
|
||||
self.assertEqual(0, query.rows)
|
||||
self.assertEqual(666, query.limit)
|
||||
|
@ -207,6 +206,33 @@ class CeleryTestCase(SupersetTestCase):
|
|||
self.assertEqual(True, query.select_as_cta)
|
||||
self.assertEqual(True, query.select_as_cta_used)
|
||||
|
||||
def test_run_async_query_with_lower_limit(self):
|
||||
main_db = self.get_main_database(db.session)
|
||||
eng = main_db.get_sqla_engine()
|
||||
sql_where = "SELECT name FROM ab_role WHERE name='Alpha' LIMIT 1"
|
||||
result = self.run_sql(
|
||||
main_db.id, sql_where, '5', async='true', tmp_table='tmp_async_2',
|
||||
cta='true')
|
||||
assert result['query']['state'] in (
|
||||
QueryStatus.PENDING, QueryStatus.RUNNING, QueryStatus.SUCCESS)
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
query = self.get_query_by_id(result['query']['serverId'])
|
||||
df = pd.read_sql_query(query.select_sql, con=eng)
|
||||
self.assertEqual(QueryStatus.SUCCESS, query.status)
|
||||
self.assertEqual([{'name': 'Alpha'}], df.to_dict(orient='records'))
|
||||
self.assertEqual(QueryStatus.SUCCESS, query.status)
|
||||
self.assertTrue('FROM tmp_async_2' in query.select_sql)
|
||||
self.assertEqual(
|
||||
'CREATE TABLE tmp_async_2 AS \nSELECT name FROM ab_role '
|
||||
"WHERE name='Alpha' LIMIT 1", query.executed_sql)
|
||||
self.assertEqual(sql_where, query.sql)
|
||||
self.assertEqual(0, query.rows)
|
||||
self.assertEqual(1, query.limit)
|
||||
self.assertEqual(True, query.select_as_cta)
|
||||
self.assertEqual(True, query.select_as_cta_used)
|
||||
|
||||
@staticmethod
|
||||
def de_unicode_dict(d):
|
||||
def str_if_basestring(o):
|
||||
|
|
|
@ -95,6 +95,19 @@ class DbEngineSpecsTestCase(SupersetTestCase):
|
|||
limited = engine_spec_class.apply_limit_to_sql(sql, limit, main)
|
||||
self.assertEquals(expected_sql, limited)
|
||||
|
||||
def test_extract_limit_from_query(self, engine_spec_class=MySQLEngineSpec):
|
||||
q0 = 'select * from table'
|
||||
q1 = 'select * from mytable limit 10'
|
||||
q2 = 'select * from (select * from my_subquery limit 10) where col=1 limit 20'
|
||||
q3 = 'select * from (select * from my_subquery limit 10);'
|
||||
q4 = 'select * from (select * from my_subquery limit 10) where col=1 limit 20;'
|
||||
|
||||
self.assertEqual(engine_spec_class.get_limit_from_sql(q0), None)
|
||||
self.assertEqual(engine_spec_class.get_limit_from_sql(q1), 10)
|
||||
self.assertEqual(engine_spec_class.get_limit_from_sql(q2), 20)
|
||||
self.assertEqual(engine_spec_class.get_limit_from_sql(q3), None)
|
||||
self.assertEqual(engine_spec_class.get_limit_from_sql(q4), 20)
|
||||
|
||||
def test_wrapped_query(self):
|
||||
self.sql_limit_regex(
|
||||
'SELECT * FROM a',
|
||||
|
|
Loading…
Reference in New Issue