From 33fa9ebff1d29fc27672473ed78355d960d2edca Mon Sep 17 00:00:00 2001 From: Bogdan Date: Fri, 28 Aug 2020 17:04:30 -0700 Subject: [PATCH] Covert celery tests to the pytest (#10704) Co-authored-by: bogdan kyryliuk --- tests/celery_tests.py | 891 ++++++++++++++++-------------------------- 1 file changed, 345 insertions(+), 546 deletions(-) diff --git a/tests/celery_tests.py b/tests/celery_tests.py index fad32d6430..3a81d29260 100644 --- a/tests/celery_tests.py +++ b/tests/celery_tests.py @@ -18,579 +18,378 @@ """Unit tests for Superset Celery worker""" import datetime import json +import string +import random +from typing import Optional -from parameterized import parameterized +import pytest import time -import unittest import unittest.mock as mock import flask from flask import current_app +from tests.base_tests import login from tests.conftest import CTAS_SCHEMA_NAME from tests.test_app import app from superset import db, sql_lab from superset.result_set import SupersetResultSet from superset.db_engine_specs.base import BaseEngineSpec from superset.extensions import celery_app -from superset.models.core import Database from superset.models.helpers import QueryStatus from superset.models.sql_lab import Query from superset.sql_parse import ParsedQuery, CtasMethod -from superset.utils.core import get_example_database +from superset.utils.core import get_example_database, backend -from .base_tests import SupersetTestCase - -CELERY_SHORT_SLEEP_TIME = 2 CELERY_SLEEP_TIME = 6 -DROP_TABLE_SLEEP_TIME = 2 +QUERY = "SELECT name FROM birth_names LIMIT 1" +TEST_SYNC = "test_sync" +TEST_ASYNC_LOWER_LIMIT = "test_async_lower_limit" +TEST_SYNC_CTA = "test_sync_cta" +TEST_ASYNC_CTA = "test_async_cta" +TEST_ASYNC_CTA_CONFIG = "test_async_cta_config" +TMP_TABLES = [ + TEST_SYNC, + TEST_SYNC_CTA, + TEST_ASYNC_CTA, + TEST_ASYNC_CTA_CONFIG, + TEST_ASYNC_LOWER_LIMIT, +] -class TestUtilityFunction(SupersetTestCase): - # TODO(bkyryliuk): support more cases in CTA function. - def test_create_table_as(self): - q = ParsedQuery("SELECT * FROM outer_space;") - - self.assertEqual( - "CREATE TABLE tmp AS \nSELECT * FROM outer_space", q.as_create_table("tmp") - ) - - self.assertEqual( - "DROP TABLE IF EXISTS tmp;\n" - "CREATE TABLE tmp AS \nSELECT * FROM outer_space", - q.as_create_table("tmp", overwrite=True), - ) - - # now without a semicolon - q = ParsedQuery("SELECT * FROM outer_space") - self.assertEqual( - "CREATE TABLE tmp AS \nSELECT * FROM outer_space", q.as_create_table("tmp") - ) - - # now a multi-line query - multi_line_query = "SELECT * FROM planets WHERE\n" "Luke_Father = 'Darth Vader'" - q = ParsedQuery(multi_line_query) - self.assertEqual( - "CREATE TABLE tmp AS \nSELECT * FROM planets WHERE\n" - "Luke_Father = 'Darth Vader'", - q.as_create_table("tmp"), - ) +test_client = app.test_client() -class TestAppContext(SupersetTestCase): - def test_in_app_context(self): - @celery_app.task() - def my_task(): - self.assertTrue(current_app) - - # Make sure we can call tasks with an app already setup - my_task() - - # Make sure the app gets pushed onto the stack properly - try: - popped_app = flask._app_ctx_stack.pop() - my_task() - finally: - flask._app_ctx_stack.push(popped_app) +def get_query_by_id(id: int): + db.session.commit() + query = db.session.query(Query).filter_by(id=id).first() + return query -class TestCelery(SupersetTestCase): - def get_query_by_name(self, sql): - session = db.session - query = session.query(Query).filter_by(sql=sql).first() - session.close() - return query +@pytest.fixture(autouse=True, scope="module") +def setup_sqllab(): + with app.app_context(): + yield - def get_query_by_id(self, id): - session = db.session - query = session.query(Query).filter_by(id=id).first() - session.close() - return query + db.session.query(Query).delete() + db.session.commit() + for tbl in TMP_TABLES: + drop_table_if_exists(f"{tbl}_{CtasMethod.TABLE.lower()}", CtasMethod.TABLE) + drop_table_if_exists(f"{tbl}_{CtasMethod.VIEW.lower()}", CtasMethod.VIEW) + drop_table_if_exists( + f"{CTAS_SCHEMA_NAME}.{tbl}_{CtasMethod.TABLE.lower()}", CtasMethod.TABLE + ) + drop_table_if_exists( + f"{CTAS_SCHEMA_NAME}.{tbl}_{CtasMethod.VIEW.lower()}", CtasMethod.VIEW + ) - @classmethod - def setUpClass(cls): - with app.app_context(): - db.session.query(Query).delete() - db.session.commit() - def run_sql( - self, - db_id, - sql, - client_id=None, - cta=False, - tmp_table="tmp", - async_=False, - ctas_method=CtasMethod.TABLE, - ): - self.login() - resp = self.client.post( - "/superset/sql_json/", - json=dict( - database_id=db_id, - sql=sql, - runAsync=async_, - select_as_cta=cta, - tmp_table_name=tmp_table, - client_id=client_id, - ctas_method=ctas_method, - ), - ) - self.logout() - return json.loads(resp.data) - - @parameterized.expand( - [CtasMethod.TABLE,] +def run_sql( + sql, cta=False, ctas_method=CtasMethod.TABLE, tmp_table="tmp", async_=False +): + login(test_client, username="admin") + db_id = get_example_database().id + resp = test_client.post( + "/superset/sql_json/", + json=dict( + database_id=db_id, + sql=sql, + runAsync=async_, + select_as_cta=cta, + tmp_table_name=tmp_table, + client_id="".join(random.choice(string.ascii_lowercase) for i in range(5)), + ctas_method=ctas_method, + ), ) - def test_run_sync_query_dont_exist(self, ctas_method): - main_db = get_example_database() - db_id = main_db.id - sql_dont_exist = "SELECT name FROM table_dont_exist" - result = self.run_sql( - db_id, sql_dont_exist, f"1_{ctas_method}", cta=True, ctas_method=ctas_method - ) - if ( - get_example_database().backend != "sqlite" - and ctas_method == CtasMethod.VIEW - ): - self.assertEqual(QueryStatus.SUCCESS, result["status"], msg=result) - else: - self.assertEqual(QueryStatus.FAILED, result["status"], msg=result) - - @parameterized.expand([CtasMethod.TABLE, CtasMethod.VIEW]) - def test_run_sync_query_cta(self, ctas_method): - main_db = get_example_database() - backend = main_db.backend - db_id = main_db.id - tmp_table_name = f"tmp_sync_23_{ctas_method.lower()}" - self.drop_table_if_exists(tmp_table_name, ctas_method, main_db) - name = "James" - sql_where = f"SELECT name FROM birth_names WHERE name='{name}' LIMIT 1" - result = self.run_sql( - db_id, - sql_where, - f"2_{ctas_method}", - tmp_table=tmp_table_name, - cta=True, - ctas_method=ctas_method, - ) - # provide better error message - self.assertEqual(QueryStatus.SUCCESS, result["query"]["state"], msg=result) - - expected_result = [] - if backend == "presto": - expected_result = ( - [{"rows": 1}] if ctas_method == CtasMethod.TABLE else [{"result": True}] - ) - self.assertEqual(expected_result, result["data"]) - # TODO(bkyryliuk): refactor database specific logic into a separate class - expected_columns = [] - if backend == "presto": - expected_columns = [ - { - "name": "rows" if ctas_method == CtasMethod.TABLE else "result", - "type": "BIGINT" if ctas_method == CtasMethod.TABLE else "BOOLEAN", - "is_date": False, - } - ] - self.assertEqual(expected_columns, result["columns"]) - query2 = self.get_query_by_id(result["query"]["serverId"]) - - # Check the data in the tmp table. - results = self.run_sql(db_id, query2.select_sql, f"7_{ctas_method}") - self.assertEqual(QueryStatus.SUCCESS, results["status"], msg=results) - self.assertGreater(len(results["data"]), 0) - - # cleanup tmp table - self.drop_table_if_exists(tmp_table_name, ctas_method, get_example_database()) - - def test_run_sync_query_cta_no_data(self): - main_db = get_example_database() - db_id = main_db.id - sql_empty_result = "SELECT * FROM birth_names WHERE name='random'" - result3 = self.run_sql(db_id, sql_empty_result, "3") - self.assertEqual(QueryStatus.SUCCESS, result3["query"]["state"]) - self.assertEqual([], result3["data"]) - self.assertEqual([], result3["columns"]) - - query3 = self.get_query_by_id(result3["query"]["serverId"]) - self.assertEqual(QueryStatus.SUCCESS, query3.status) - - def drop_table_if_exists( - self, table_name: str, table_type: CtasMethod, database: Database, - ) -> None: - """Drop table if it exists, works on any DB""" - sql = f"DROP {table_type} IF EXISTS {table_name}" - database.get_sqla_engine().execute(sql) - - @parameterized.expand([CtasMethod.TABLE, CtasMethod.VIEW]) - def test_run_sync_query_cta_config(self, ctas_method): - with mock.patch( - "superset.views.core.get_cta_schema_name", - lambda d, u, s, sql: CTAS_SCHEMA_NAME, - ): - examples_db = get_example_database() - db_id = examples_db.id - backend = examples_db.backend - if backend == "sqlite": - # sqlite doesn't support schemas - return - tmp_table_name = f"tmp_async_22_{ctas_method.lower()}" - quote = ( - examples_db.inspector.engine.dialect.identifier_preparer.quote_identifier - ) - expected_full_table_name = f"{CTAS_SCHEMA_NAME}.{quote(tmp_table_name)}" - self.drop_table_if_exists( - expected_full_table_name, ctas_method, examples_db - ) - name = "James" - sql_where = f"SELECT name FROM birth_names WHERE name='{name}'" - result = self.run_sql( - db_id, - sql_where, - f"3_{ctas_method}", - tmp_table=tmp_table_name, - cta=True, - ctas_method=ctas_method, - ) - self.assertEqual(QueryStatus.SUCCESS, result["query"]["state"], msg=result) - - expected_result = [] - # TODO(bkyryliuk): refactor database specific logic into a separate class - if backend == "presto": - expected_result = ( - [{"rows": 1}] - if ctas_method == CtasMethod.TABLE - else [{"result": True}] - ) - self.assertEqual(expected_result, result["data"]) - - expected_columns = [] - # TODO(bkyryliuk): refactor database specific logic into a separate class - if backend == "presto": - expected_columns = [ - { - "name": "rows" if ctas_method == CtasMethod.TABLE else "result", - "type": "BIGINT" - if ctas_method == CtasMethod.TABLE - else "BOOLEAN", - "is_date": False, - } - ] - self.assertEqual(expected_columns, result["columns"]) - - query = self.get_query_by_id(result["query"]["serverId"]) - self.assertEqual( - f"CREATE {ctas_method} {CTAS_SCHEMA_NAME}.{tmp_table_name} AS \n" - "SELECT name FROM birth_names " - "WHERE name='James'", - query.executed_sql, - ) - - # TODO(bkyryliuk): quote table and schema names for all databases - if backend in {"presto", "hive"}: - assert query.select_sql == ( - f"SELECT *\nFROM {quote(CTAS_SCHEMA_NAME)}.{quote(tmp_table_name)}" - ) - else: - assert ( - query.select_sql == "SELECT *\n" - f"FROM {CTAS_SCHEMA_NAME}.{tmp_table_name}" - ) - time.sleep(CELERY_SHORT_SLEEP_TIME) - results = self.run_sql(db_id, query.select_sql) - self.assertEqual(QueryStatus.SUCCESS, results["status"], msg=result) - self.drop_table_if_exists( - expected_full_table_name, ctas_method, get_example_database() - ) - - @parameterized.expand([CtasMethod.TABLE, CtasMethod.VIEW]) - def test_run_async_query_cta_config(self, ctas_method): - with mock.patch( - "superset.views.core.get_cta_schema_name", - lambda d, u, s, sql: CTAS_SCHEMA_NAME, - ): - example_db = get_example_database() - db_id = example_db.id - if example_db.backend == "sqlite": - # sqlite doesn't support schemas - return - - tmp_table_name = f"sqllab_test_table_async_1_{ctas_method}" - quote = ( - example_db.inspector.engine.dialect.identifier_preparer.quote_identifier - ) - - schema_name = ( - quote(CTAS_SCHEMA_NAME) - if example_db.backend in {"presto", "hive"} - else CTAS_SCHEMA_NAME - ) - expected_full_table_name = f"{schema_name}.{quote(tmp_table_name)}" - self.drop_table_if_exists(expected_full_table_name, ctas_method, example_db) - sql_where = "SELECT name FROM birth_names WHERE name='James' LIMIT 10" - result = self.run_sql( - db_id, - sql_where, - f"4_{ctas_method}", - async_=True, - tmp_table=tmp_table_name, - cta=True, - ctas_method=ctas_method, - ) - db.session.close() - time.sleep(CELERY_SLEEP_TIME) - - query = self.get_query_by_id(result["query"]["serverId"]) - self.assertEqual(QueryStatus.SUCCESS, query.status) - self.assertIn(expected_full_table_name, query.select_sql) - self.assertEqual( - f"CREATE {ctas_method} {CTAS_SCHEMA_NAME}.{tmp_table_name} AS \n" - "SELECT name FROM birth_names " - "WHERE name='James' " - "LIMIT 10", - query.executed_sql, - ) - self.drop_table_if_exists( - f"{schema_name}.{tmp_table_name}", ctas_method, get_example_database() - ) - - @parameterized.expand([CtasMethod.TABLE, CtasMethod.VIEW]) - def test_run_async_cta_query(self, ctas_method): - main_db = get_example_database() - db_backend = main_db.backend - db_id = main_db.id - - table_name = f"tmp_async_4_{ctas_method}" - self.drop_table_if_exists(table_name, ctas_method, main_db) - time.sleep(DROP_TABLE_SLEEP_TIME) - - sql_where = "SELECT name FROM birth_names WHERE name='James' LIMIT 10" - - result = self.run_sql( - db_id, - sql_where, - f"5_{ctas_method}", - async_=True, - tmp_table=table_name, - cta=True, - ctas_method=ctas_method, - ) - db.session.close() - - assert result["query"]["state"] in ( - QueryStatus.PENDING, - QueryStatus.RUNNING, - QueryStatus.SUCCESS, - ) - - time.sleep(CELERY_SLEEP_TIME) - - query = self.get_query_by_id(result["query"]["serverId"]) - self.assertEqual(QueryStatus.SUCCESS, query.status) - - self.assertIn(table_name, query.select_sql) - - self.assertEqual( - f"CREATE {ctas_method} {table_name} AS \n" - "SELECT name FROM birth_names " - "WHERE name='James' " - "LIMIT 10", - query.executed_sql, - ) - self.assertEqual(sql_where, query.sql) - if db_backend == "presto": - self.assertEqual(1, query.rows) - else: - self.assertEqual(0, query.rows) - self.assertEqual(True, query.select_as_cta) - self.assertEqual(True, query.select_as_cta_used) - self.drop_table_if_exists(table_name, ctas_method, get_example_database()) - - @parameterized.expand([CtasMethod.TABLE, CtasMethod.VIEW]) - def test_run_async_cta_query_with_lower_limit(self, ctas_method): - example_db = get_example_database() - db_backend = example_db.backend - db_id = example_db.id - tmp_table = f"tmp_async_2_{ctas_method}" - self.drop_table_if_exists(tmp_table, ctas_method, example_db) - - sql_where = "SELECT name FROM birth_names LIMIT 1" - result = self.run_sql( - db_id, - sql_where, - f"6_{ctas_method}", - async_=True, - tmp_table=tmp_table, - cta=True, - ctas_method=ctas_method, - ) - db.session.close() - - assert result["query"]["state"] in ( - QueryStatus.PENDING, - QueryStatus.RUNNING, - QueryStatus.SUCCESS, - ) - - time.sleep(CELERY_SLEEP_TIME) - - query = self.get_query_by_id(result["query"]["serverId"]) - self.assertEqual(QueryStatus.SUCCESS, query.status) - - self.assertIn(tmp_table, query.select_sql) - self.assertEqual( - f"CREATE {ctas_method} {tmp_table} AS \n" - "SELECT name FROM birth_names LIMIT 1", - query.executed_sql, - ) - self.assertEqual(sql_where, query.sql) - if db_backend == "presto": - self.assertEqual(1, query.rows) - else: - self.assertEqual(0, query.rows) - self.assertEqual(None, query.limit) - self.assertEqual(True, query.select_as_cta) - self.assertEqual(True, query.select_as_cta_used) - self.drop_table_if_exists(tmp_table, ctas_method, get_example_database()) - - def test_default_data_serialization(self): - data = [("a", 4, 4.0, datetime.datetime(2019, 8, 18, 16, 39, 16, 660000))] - cursor_descr = ( - ("a", "string"), - ("b", "int"), - ("c", "float"), - ("d", "datetime"), - ) - db_engine_spec = BaseEngineSpec() - results = SupersetResultSet(data, cursor_descr, db_engine_spec) - - with mock.patch.object( - db_engine_spec, "expand_data", wraps=db_engine_spec.expand_data - ) as expand_data: - ( - data, - selected_columns, - all_columns, - expanded_columns, - ) = sql_lab._serialize_and_expand_data(results, db_engine_spec, False, True) - expand_data.assert_called_once() - - self.assertIsInstance(data, list) - - def test_new_data_serialization(self): - data = [("a", 4, 4.0, datetime.datetime(2019, 8, 18, 16, 39, 16, 660000))] - cursor_descr = ( - ("a", "string"), - ("b", "int"), - ("c", "float"), - ("d", "datetime"), - ) - db_engine_spec = BaseEngineSpec() - results = SupersetResultSet(data, cursor_descr, db_engine_spec) - - with mock.patch.object( - db_engine_spec, "expand_data", wraps=db_engine_spec.expand_data - ) as expand_data: - ( - data, - selected_columns, - all_columns, - expanded_columns, - ) = sql_lab._serialize_and_expand_data(results, db_engine_spec, True) - expand_data.assert_not_called() - - self.assertIsInstance(data, bytes) - - def test_default_payload_serialization(self): - use_new_deserialization = False - data = [("a", 4, 4.0, datetime.datetime(2019, 8, 18, 16, 39, 16, 660000))] - cursor_descr = ( - ("a", "string"), - ("b", "int"), - ("c", "float"), - ("d", "datetime"), - ) - db_engine_spec = BaseEngineSpec() - results = SupersetResultSet(data, cursor_descr, db_engine_spec) - query = { - "database_id": 1, - "sql": "SELECT * FROM birth_names LIMIT 100", - "status": QueryStatus.PENDING, - } - ( - serialized_data, - selected_columns, - all_columns, - expanded_columns, - ) = sql_lab._serialize_and_expand_data( - results, db_engine_spec, use_new_deserialization - ) - payload = { - "query_id": 1, - "status": QueryStatus.SUCCESS, - "state": QueryStatus.SUCCESS, - "data": serialized_data, - "columns": all_columns, - "selected_columns": selected_columns, - "expanded_columns": expanded_columns, - "query": query, - } - - serialized = sql_lab._serialize_payload(payload, use_new_deserialization) - self.assertIsInstance(serialized, str) - - def test_msgpack_payload_serialization(self): - use_new_deserialization = True - data = [("a", 4, 4.0, datetime.datetime(2019, 8, 18, 16, 39, 16, 660000))] - cursor_descr = ( - ("a", "string"), - ("b", "int"), - ("c", "float"), - ("d", "datetime"), - ) - db_engine_spec = BaseEngineSpec() - results = SupersetResultSet(data, cursor_descr, db_engine_spec) - query = { - "database_id": 1, - "sql": "SELECT * FROM birth_names LIMIT 100", - "status": QueryStatus.PENDING, - } - ( - serialized_data, - selected_columns, - all_columns, - expanded_columns, - ) = sql_lab._serialize_and_expand_data( - results, db_engine_spec, use_new_deserialization - ) - payload = { - "query_id": 1, - "status": QueryStatus.SUCCESS, - "state": QueryStatus.SUCCESS, - "data": serialized_data, - "columns": all_columns, - "selected_columns": selected_columns, - "expanded_columns": expanded_columns, - "query": query, - } - - serialized = sql_lab._serialize_payload(payload, use_new_deserialization) - self.assertIsInstance(serialized, bytes) - - @staticmethod - def de_unicode_dict(d): - def str_if_basestring(o): - if isinstance(o, str): - return str(o) - return o - - return {str_if_basestring(k): str_if_basestring(d[k]) for k in d} - - @classmethod - def dictify_list_of_dicts(cls, l, k): - return {str(o[k]): cls.de_unicode_dict(o) for o in l} + test_client.get("/logout/", follow_redirects=True) + return json.loads(resp.data) -if __name__ == "__main__": - unittest.main() +def drop_table_if_exists(table_name: str, table_type: CtasMethod) -> None: + """Drop table if it exists, works on any DB""" + sql = f"DROP {table_type} IF EXISTS {table_name}" + get_example_database().get_sqla_engine().execute(sql) + + +def quote_f(value: Optional[str]): + if not value: + return value + return get_example_database().inspector.engine.dialect.identifier_preparer.quote_identifier( + value + ) + + +def cta_result(ctas_method: CtasMethod): + if backend() != "presto": + return [], [] + if ctas_method == CtasMethod.TABLE: + return [{"rows": 1}], [{"name": "rows", "type": "BIGINT", "is_date": False}] + return [{"result": True}], [{"name": "result", "type": "BOOLEAN", "is_date": False}] + + +# TODO(bkyryliuk): quote table and schema names for all databases +def get_select_star(table: str, schema: Optional[str] = None): + if backend() in {"presto", "hive"}: + schema = quote_f(schema) + table = quote_f(table) + if schema: + return f"SELECT *\nFROM {schema}.{table}" + return f"SELECT *\nFROM {table}" + + +@pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW]) +def test_run_sync_query_dont_exist(setup_sqllab, ctas_method): + sql_dont_exist = "SELECT name FROM table_dont_exist" + result = run_sql(sql_dont_exist, cta=True, ctas_method=ctas_method) + if backend() == "sqlite" and ctas_method == CtasMethod.VIEW: + assert QueryStatus.SUCCESS == result["status"], result + else: + assert QueryStatus.FAILED == result["status"], result + + +@pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW]) +def test_run_sync_query_cta(setup_sqllab, ctas_method): + tmp_table_name = f"{TEST_SYNC}_{ctas_method.lower()}" + result = run_sql(QUERY, tmp_table=tmp_table_name, cta=True, ctas_method=ctas_method) + assert QueryStatus.SUCCESS == result["query"]["state"], result + assert cta_result(ctas_method) == (result["data"], result["columns"]) + + # Check the data in the tmp table. + select_query = get_query_by_id(result["query"]["serverId"]) + results = run_sql(select_query.select_sql) + assert QueryStatus.SUCCESS == results["status"], results + assert len(results["data"]) > 0 + + +def test_run_sync_query_cta_no_data(setup_sqllab): + sql_empty_result = "SELECT * FROM birth_names WHERE name='random'" + result = run_sql(sql_empty_result) + assert QueryStatus.SUCCESS == result["query"]["state"] + assert ([], []) == (result["data"], result["columns"]) + + query = get_query_by_id(result["query"]["serverId"]) + assert QueryStatus.SUCCESS == query.status + + +@pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW]) +@mock.patch( + "superset.views.core.get_cta_schema_name", lambda d, u, s, sql: CTAS_SCHEMA_NAME +) +def test_run_sync_query_cta_config(setup_sqllab, ctas_method): + if backend() == "sqlite": + # sqlite doesn't support schemas + return + tmp_table_name = f"{TEST_SYNC_CTA}_{ctas_method.lower()}" + result = run_sql(QUERY, cta=True, ctas_method=ctas_method, tmp_table=tmp_table_name) + assert QueryStatus.SUCCESS == result["query"]["state"], result + assert cta_result(ctas_method) == (result["data"], result["columns"]) + + query = get_query_by_id(result["query"]["serverId"]) + assert ( + f"CREATE {ctas_method} {CTAS_SCHEMA_NAME}.{tmp_table_name} AS \n{QUERY}" + == query.executed_sql + ) + + assert query.select_sql == get_select_star(tmp_table_name, schema=CTAS_SCHEMA_NAME) + time.sleep(CELERY_SLEEP_TIME) + results = run_sql(query.select_sql) + assert QueryStatus.SUCCESS == results["status"], result + + +@pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW]) +@mock.patch( + "superset.views.core.get_cta_schema_name", lambda d, u, s, sql: CTAS_SCHEMA_NAME +) +def test_run_async_query_cta_config(setup_sqllab, ctas_method): + if backend() == "sqlite": + # sqlite doesn't support schemas + return + tmp_table_name = f"{TEST_ASYNC_CTA_CONFIG}_{ctas_method.lower()}" + result = run_sql( + QUERY, cta=True, ctas_method=ctas_method, async_=True, tmp_table=tmp_table_name, + ) + + time.sleep(CELERY_SLEEP_TIME) + + query = get_query_by_id(result["query"]["serverId"]) + assert QueryStatus.SUCCESS == query.status + assert get_select_star(tmp_table_name, schema=CTAS_SCHEMA_NAME) == query.select_sql + assert ( + f"CREATE {ctas_method} {CTAS_SCHEMA_NAME}.{tmp_table_name} AS \n{QUERY}" + == query.executed_sql + ) + + +@pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW]) +def test_run_async_cta_query(setup_sqllab, ctas_method): + table_name = f"{TEST_ASYNC_CTA}_{ctas_method.lower()}" + result = run_sql( + QUERY, cta=True, ctas_method=ctas_method, async_=True, tmp_table=table_name + ) + + time.sleep(CELERY_SLEEP_TIME) + + query = get_query_by_id(result["query"]["serverId"]) + assert QueryStatus.SUCCESS == query.status + assert get_select_star(table_name) in query.select_sql + + assert f"CREATE {ctas_method} {table_name} AS \n{QUERY}" == query.executed_sql + assert QUERY == query.sql + assert query.rows == (1 if backend() == "presto" else 0) + assert query.select_as_cta + assert query.select_as_cta_used + + +@pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW]) +def test_run_async_cta_query_with_lower_limit(setup_sqllab, ctas_method): + tmp_table = f"{TEST_ASYNC_LOWER_LIMIT}_{ctas_method.lower()}" + result = run_sql( + QUERY, cta=True, ctas_method=ctas_method, async_=True, tmp_table=tmp_table + ) + time.sleep(CELERY_SLEEP_TIME) + + query = get_query_by_id(result["query"]["serverId"]) + assert QueryStatus.SUCCESS == query.status + + assert get_select_star(tmp_table) == query.select_sql + assert f"CREATE {ctas_method} {tmp_table} AS \n{QUERY}" == query.executed_sql + assert QUERY == query.sql + assert query.rows == (1 if backend() == "presto" else 0) + assert query.limit is None + assert query.select_as_cta + assert query.select_as_cta_used + + +SERIALIZATION_DATA = [("a", 4, 4.0, datetime.datetime(2019, 8, 18, 16, 39, 16, 660000))] +CURSOR_DESCR = ( + ("a", "string"), + ("b", "int"), + ("c", "float"), + ("d", "datetime"), +) + + +def test_default_data_serialization(): + db_engine_spec = BaseEngineSpec() + results = SupersetResultSet(SERIALIZATION_DATA, CURSOR_DESCR, db_engine_spec) + + with mock.patch.object( + db_engine_spec, "expand_data", wraps=db_engine_spec.expand_data + ) as expand_data: + data = sql_lab._serialize_and_expand_data(results, db_engine_spec, False, True) + expand_data.assert_called_once() + assert isinstance(data[0], list) + + +def test_new_data_serialization(): + db_engine_spec = BaseEngineSpec() + results = SupersetResultSet(SERIALIZATION_DATA, CURSOR_DESCR, db_engine_spec) + + with mock.patch.object( + db_engine_spec, "expand_data", wraps=db_engine_spec.expand_data + ) as expand_data: + data = sql_lab._serialize_and_expand_data(results, db_engine_spec, True) + expand_data.assert_not_called() + assert isinstance(data[0], bytes) + + +def test_default_payload_serialization(): + use_new_deserialization = False + db_engine_spec = BaseEngineSpec() + results = SupersetResultSet(SERIALIZATION_DATA, CURSOR_DESCR, db_engine_spec) + query = { + "database_id": 1, + "sql": "SELECT * FROM birth_names LIMIT 100", + "status": QueryStatus.PENDING, + } + ( + serialized_data, + selected_columns, + all_columns, + expanded_columns, + ) = sql_lab._serialize_and_expand_data( + results, db_engine_spec, use_new_deserialization + ) + payload = { + "query_id": 1, + "status": QueryStatus.SUCCESS, + "state": QueryStatus.SUCCESS, + "data": serialized_data, + "columns": all_columns, + "selected_columns": selected_columns, + "expanded_columns": expanded_columns, + "query": query, + } + + serialized = sql_lab._serialize_payload(payload, use_new_deserialization) + assert isinstance(serialized, str) + + +def test_msgpack_payload_serialization(): + use_new_deserialization = True + db_engine_spec = BaseEngineSpec() + results = SupersetResultSet(SERIALIZATION_DATA, CURSOR_DESCR, db_engine_spec) + query = { + "database_id": 1, + "sql": "SELECT * FROM birth_names LIMIT 100", + "status": QueryStatus.PENDING, + } + ( + serialized_data, + selected_columns, + all_columns, + expanded_columns, + ) = sql_lab._serialize_and_expand_data( + results, db_engine_spec, use_new_deserialization + ) + payload = { + "query_id": 1, + "status": QueryStatus.SUCCESS, + "state": QueryStatus.SUCCESS, + "data": serialized_data, + "columns": all_columns, + "selected_columns": selected_columns, + "expanded_columns": expanded_columns, + "query": query, + } + + serialized = sql_lab._serialize_payload(payload, use_new_deserialization) + assert isinstance(serialized, bytes) + + +def test_create_table_as(): + q = ParsedQuery("SELECT * FROM outer_space;") + + assert "CREATE TABLE tmp AS \nSELECT * FROM outer_space" == q.as_create_table("tmp") + assert ( + "DROP TABLE IF EXISTS tmp;\nCREATE TABLE tmp AS \nSELECT * FROM outer_space" + == q.as_create_table("tmp", overwrite=True) + ) + + # now without a semicolon + q = ParsedQuery("SELECT * FROM outer_space") + assert "CREATE TABLE tmp AS \nSELECT * FROM outer_space" == q.as_create_table("tmp") + + # now a multi-line query + multi_line_query = "SELECT * FROM planets WHERE\n" "Luke_Father = 'Darth Vader'" + q = ParsedQuery(multi_line_query) + assert ( + "CREATE TABLE tmp AS \nSELECT * FROM planets WHERE\nLuke_Father = 'Darth Vader'" + == q.as_create_table("tmp") + ) + + +def test_in_app_context(): + @celery_app.task() + def my_task(): + assert current_app + + # Make sure we can call tasks with an app already setup + my_task() + + # Make sure the app gets pushed onto the stack properly + try: + popped_app = flask._app_ctx_stack.pop() + my_task() + finally: + flask._app_ctx_stack.push(popped_app)