# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """Unit tests for Superset Celery worker""" import json import subprocess import time import unittest from superset import app, db from superset.models.helpers import QueryStatus from superset.models.sql_lab import Query from superset.sql_parse import ParsedQuery from superset.utils.core import get_main_database from .base_tests import SupersetTestCase BASE_DIR = app.config.get('BASE_DIR') CELERY_SLEEP_TIME = 5 class CeleryConfig(object): BROKER_URL = app.config.get('CELERY_RESULT_BACKEND') CELERY_IMPORTS = ('superset.sql_lab', ) CELERY_ANNOTATIONS = {'sql_lab.add': {'rate_limit': '10/s'}} CONCURRENCY = 1 app.config['CELERY_CONFIG'] = CeleryConfig class UtilityFunctionTests(SupersetTestCase): # TODO(bkyryliuk): support more cases in CTA function. def test_create_table_as(self): q = ParsedQuery('SELECT * FROM outer_space;') self.assertEqual( 'CREATE TABLE tmp AS \nSELECT * FROM outer_space', q.as_create_table('tmp')) self.assertEqual( 'DROP TABLE IF EXISTS tmp;\n' 'CREATE TABLE tmp AS \nSELECT * FROM outer_space', q.as_create_table('tmp', overwrite=True)) # now without a semicolon q = ParsedQuery('SELECT * FROM outer_space') self.assertEqual( 'CREATE TABLE tmp AS \nSELECT * FROM outer_space', q.as_create_table('tmp')) # now a multi-line query multi_line_query = ( 'SELECT * FROM planets WHERE\n' "Luke_Father = 'Darth Vader'") q = ParsedQuery(multi_line_query) self.assertEqual( 'CREATE TABLE tmp AS \nSELECT * FROM planets WHERE\n' "Luke_Father = 'Darth Vader'", q.as_create_table('tmp'), ) class CeleryTestCase(SupersetTestCase): def __init__(self, *args, **kwargs): super(CeleryTestCase, self).__init__(*args, **kwargs) self.client = app.test_client() def get_query_by_name(self, sql): session = db.session query = session.query(Query).filter_by(sql=sql).first() session.close() return query def get_query_by_id(self, id): session = db.session query = session.query(Query).filter_by(id=id).first() session.close() return query @classmethod def setUpClass(cls): db.session.query(Query).delete() db.session.commit() worker_command = BASE_DIR + '/bin/superset worker -w 2' subprocess.Popen( worker_command, shell=True, stdout=subprocess.PIPE) @classmethod def tearDownClass(cls): subprocess.call( "ps auxww | grep 'celeryd' | awk '{print $2}' | xargs kill -9", shell=True, ) subprocess.call( "ps auxww | grep 'superset worker' | awk '{print $2}' | xargs kill -9", shell=True, ) def run_sql(self, db_id, sql, client_id=None, cta='false', tmp_table='tmp', async_='false'): self.login() resp = self.client.post( '/superset/sql_json/', data=dict( database_id=db_id, sql=sql, runAsync=async_, select_as_cta=cta, tmp_table_name=tmp_table, client_id=client_id, ), ) self.logout() return json.loads(resp.data.decode('utf-8')) def test_run_sync_query_dont_exist(self): main_db = get_main_database(db.session) db_id = main_db.id sql_dont_exist = 'SELECT name FROM table_dont_exist' result1 = self.run_sql(db_id, sql_dont_exist, '1', cta='true') self.assertTrue('error' in result1) def test_run_sync_query_cta(self): main_db = get_main_database(db.session) backend = main_db.backend db_id = main_db.id tmp_table_name = 'tmp_async_22' self.drop_table_if_exists(tmp_table_name, main_db) perm_name = 'can_sql_json' sql_where = ( "SELECT name FROM ab_permission WHERE name='{}'".format(perm_name)) result2 = self.run_sql( db_id, sql_where, '2', tmp_table=tmp_table_name, cta='true') self.assertEqual(QueryStatus.SUCCESS, result2['query']['state']) self.assertEqual([], result2['data']) self.assertEqual([], result2['columns']) query2 = self.get_query_by_id(result2['query']['serverId']) # Check the data in the tmp table. if backend != 'postgresql': # TODO This test won't work in Postgres results = self.run_sql(db_id, query2.select_sql, 'sdf2134') self.assertEquals(results['status'], 'success') self.assertGreater(len(results['data']), 0) def test_run_sync_query_cta_no_data(self): main_db = get_main_database(db.session) db_id = main_db.id sql_empty_result = 'SELECT * FROM ab_user WHERE id=666' result3 = self.run_sql(db_id, sql_empty_result, '3') self.assertEqual(QueryStatus.SUCCESS, result3['query']['state']) self.assertEqual([], result3['data']) self.assertEqual([], result3['columns']) query3 = self.get_query_by_id(result3['query']['serverId']) self.assertEqual(QueryStatus.SUCCESS, query3.status) def drop_table_if_exists(self, table_name, database=None): """Drop table if it exists, works on any DB""" sql = 'DROP TABLE {}'.format(table_name) db_id = database.id if database: database.allow_dml = True db.session.flush() return self.run_sql(db_id, sql) def test_run_async_query(self): main_db = get_main_database(db.session) db_id = main_db.id self.drop_table_if_exists('tmp_async_1', main_db) sql_where = "SELECT name FROM ab_role WHERE name='Admin'" result = self.run_sql( db_id, sql_where, '4', async_='true', tmp_table='tmp_async_1', cta='true') assert result['query']['state'] in ( QueryStatus.PENDING, QueryStatus.RUNNING, QueryStatus.SUCCESS) time.sleep(CELERY_SLEEP_TIME) query = self.get_query_by_id(result['query']['serverId']) self.assertEqual(QueryStatus.SUCCESS, query.status) self.assertTrue('FROM tmp_async_1' in query.select_sql) self.assertEqual( 'CREATE TABLE tmp_async_1 AS \n' 'SELECT name FROM ab_role ' "WHERE name='Admin'\n" 'LIMIT 666', query.executed_sql) self.assertEqual(sql_where, query.sql) self.assertEqual(0, query.rows) self.assertEqual(False, query.limit_used) self.assertEqual(True, query.select_as_cta) self.assertEqual(True, query.select_as_cta_used) def test_run_async_query_with_lower_limit(self): main_db = get_main_database(db.session) db_id = main_db.id self.drop_table_if_exists('tmp_async_2', main_db) sql_where = "SELECT name FROM ab_role WHERE name='Alpha' LIMIT 1" result = self.run_sql( db_id, sql_where, '5', async_='true', tmp_table='tmp_async_2', cta='true') assert result['query']['state'] in ( QueryStatus.PENDING, QueryStatus.RUNNING, QueryStatus.SUCCESS) time.sleep(CELERY_SLEEP_TIME) query = self.get_query_by_id(result['query']['serverId']) self.assertEqual(QueryStatus.SUCCESS, query.status) self.assertTrue('FROM tmp_async_2' in query.select_sql) self.assertEqual( 'CREATE TABLE tmp_async_2 AS \nSELECT name FROM ab_role ' "WHERE name='Alpha' LIMIT 1", query.executed_sql) self.assertEqual(sql_where, query.sql) self.assertEqual(0, query.rows) self.assertEqual(1, query.limit) self.assertEqual(True, query.select_as_cta) self.assertEqual(True, query.select_as_cta_used) @staticmethod def de_unicode_dict(d): def str_if_basestring(o): if isinstance(o, str): return str(o) return o return {str_if_basestring(k): str_if_basestring(d[k]) for k in d} @classmethod def dictify_list_of_dicts(cls, l, k): return {str(o[k]): cls.de_unicode_dict(o) for o in l} if __name__ == '__main__': unittest.main()