diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index 5e706c50a4..cd5ab87226 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -104,10 +104,33 @@ class BaseEngineSpec(object): ) return database.compile_sqla_query(qry) elif LimitMethod.FORCE_LIMIT: - sql_without_limit = utils.get_query_without_limit(sql) + sql_without_limit = cls.get_query_without_limit(sql) return '{sql_without_limit} LIMIT {limit}'.format(**locals()) return sql + @classmethod + def get_limit_from_sql(cls, sql): + limit_pattern = re.compile(r""" + (?ix) # case insensitive, verbose + \s+ # whitespace + LIMIT\s+(\d+) # LIMIT $ROWS + ;? # optional semi-colon + (\s|;)*$ # remove trailing spaces tabs or semicolons + """) + matches = limit_pattern.findall(sql) + if matches: + return int(matches[0][0]) + + @classmethod + def get_query_without_limit(cls, sql): + return re.sub(r""" + (?ix) # case insensitive, verbose + \s+ # whitespace + LIMIT\s+\d+ # LIMIT $ROWS + ;? # optional semi-colon + (\s|;)*$ # remove trailing spaces tabs or semicolons + """, '', sql) + @staticmethod def csv_to_df(**kwargs): kwargs['filepath_or_buffer'] = \ diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 7aa5d03efb..c9f07ae906 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -186,11 +186,10 @@ def execute_sql( query.user_id, start_dttm.strftime('%Y_%m_%d_%H_%M_%S')) executed_sql = superset_query.as_create_table(query.tmp_table_name) query.select_as_cta_used = True - elif (superset_query.is_select() and SQL_MAX_ROWS and + if (superset_query.is_select() and SQL_MAX_ROWS and (not query.limit or query.limit > SQL_MAX_ROWS)): query.limit = SQL_MAX_ROWS executed_sql = database.apply_limit_to_sql(executed_sql, query.limit) - query.limit_used = True # Hook to allow environment-specific mutation (usually comments) to the SQL SQL_QUERY_MUTATOR = config.get('SQL_QUERY_MUTATOR') diff --git a/superset/utils.py b/superset/utils.py index 09131f6d96..08ce0d2f38 100644 --- a/superset/utils.py +++ b/superset/utils.py @@ -18,7 +18,6 @@ import functools import json import logging import os -import re import signal import smtplib import sys @@ -883,29 +882,3 @@ def split_adhoc_filters_into_base_filters(fd): fd['having_filters'] = simple_having_filters fd['filters'] = simple_where_filters del fd['adhoc_filters'] - - -def get_query_without_limit(sql): - return re.sub(r""" - (?ix) # case insensitive, verbose - \s+ # whitespace - LIMIT\s+\d+ # LIMIT $ROWS - ;? # optional semi-colon - (\s|;)*$ # remove trailing spaces tabs or semicolons - """, '', sql) - - -def get_limit_from_sql(sql): - # returns the limit of the quest or None if it has no limit. - - limit_pattern = re.compile(r""" - (?ix) # case insensitive, verbose - \s+ # whitespace - LIMIT\s+(\d+) # LIMIT $ROWS - ;? # optional semi-colon - (\s|;)*$ # remove trailing spaces tabs or semicolons - """) - matches = limit_pattern.findall(sql) - - if matches: - return int(matches[0]) diff --git a/superset/views/core.py b/superset/views/core.py index 84e305889b..b4a1689a91 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -2394,7 +2394,7 @@ class Superset(BaseSupersetView): query = Query( database_id=int(database_id), - limit=utils.get_limit_from_sql(sql), + limit=mydb.db_engine_spec.get_limit_from_sql(sql), sql=sql, schema=schema, select_as_cta=request.form.get('select_as_cta') == 'true', diff --git a/tests/celery_tests.py b/tests/celery_tests.py index f6d1a2958f..39b7749ae8 100644 --- a/tests/celery_tests.py +++ b/tests/celery_tests.py @@ -196,10 +196,9 @@ class CeleryTestCase(SupersetTestCase): self.assertEqual([{'name': 'Admin'}], df.to_dict(orient='records')) self.assertEqual(QueryStatus.SUCCESS, query.status) self.assertTrue('FROM tmp_async_1' in query.select_sql) - self.assertTrue('LIMIT 666' in query.select_sql) self.assertEqual( 'CREATE TABLE tmp_async_1 AS \nSELECT name FROM ab_role ' - "WHERE name='Admin'", query.executed_sql) + "WHERE name='Admin' LIMIT 666", query.executed_sql) self.assertEqual(sql_where, query.sql) self.assertEqual(0, query.rows) self.assertEqual(666, query.limit) @@ -207,6 +206,33 @@ class CeleryTestCase(SupersetTestCase): self.assertEqual(True, query.select_as_cta) self.assertEqual(True, query.select_as_cta_used) + def test_run_async_query_with_lower_limit(self): + main_db = self.get_main_database(db.session) + eng = main_db.get_sqla_engine() + sql_where = "SELECT name FROM ab_role WHERE name='Alpha' LIMIT 1" + result = self.run_sql( + main_db.id, sql_where, '5', async='true', tmp_table='tmp_async_2', + cta='true') + assert result['query']['state'] in ( + QueryStatus.PENDING, QueryStatus.RUNNING, QueryStatus.SUCCESS) + + time.sleep(1) + + query = self.get_query_by_id(result['query']['serverId']) + df = pd.read_sql_query(query.select_sql, con=eng) + self.assertEqual(QueryStatus.SUCCESS, query.status) + self.assertEqual([{'name': 'Alpha'}], df.to_dict(orient='records')) + self.assertEqual(QueryStatus.SUCCESS, query.status) + self.assertTrue('FROM tmp_async_2' in query.select_sql) + self.assertEqual( + 'CREATE TABLE tmp_async_2 AS \nSELECT name FROM ab_role ' + "WHERE name='Alpha' LIMIT 1", query.executed_sql) + self.assertEqual(sql_where, query.sql) + self.assertEqual(0, query.rows) + self.assertEqual(1, query.limit) + self.assertEqual(True, query.select_as_cta) + self.assertEqual(True, query.select_as_cta_used) + @staticmethod def de_unicode_dict(d): def str_if_basestring(o): diff --git a/tests/db_engine_specs_test.py b/tests/db_engine_specs_test.py index c38e4f5690..bdce0b060d 100644 --- a/tests/db_engine_specs_test.py +++ b/tests/db_engine_specs_test.py @@ -95,6 +95,19 @@ class DbEngineSpecsTestCase(SupersetTestCase): limited = engine_spec_class.apply_limit_to_sql(sql, limit, main) self.assertEquals(expected_sql, limited) + def test_extract_limit_from_query(self, engine_spec_class=MySQLEngineSpec): + q0 = 'select * from table' + q1 = 'select * from mytable limit 10' + q2 = 'select * from (select * from my_subquery limit 10) where col=1 limit 20' + q3 = 'select * from (select * from my_subquery limit 10);' + q4 = 'select * from (select * from my_subquery limit 10) where col=1 limit 20;' + + self.assertEqual(engine_spec_class.get_limit_from_sql(q0), None) + self.assertEqual(engine_spec_class.get_limit_from_sql(q1), 10) + self.assertEqual(engine_spec_class.get_limit_from_sql(q2), 20) + self.assertEqual(engine_spec_class.get_limit_from_sql(q3), None) + self.assertEqual(engine_spec_class.get_limit_from_sql(q4), 20) + def test_wrapped_query(self): self.sql_limit_regex( 'SELECT * FROM a',