superset/tests/celery_tests.py

597 lines
21 KiB
Python
Raw Normal View History

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
"""Unit tests for Superset Celery worker"""
import datetime
2016-08-30 01:11:36 -04:00
import json
from typing import Optional
from parameterized import parameterized
2016-08-30 01:11:36 -04:00
import time
import unittest
import unittest.mock as mock
2016-08-30 01:11:36 -04:00
import flask
from flask import current_app
from tests.test_app import app
from superset import db, sql_lab
from superset.result_set import SupersetResultSet
from superset.db_engine_specs.base import BaseEngineSpec
from superset.extensions import celery_app
from superset.models.core import Database
from superset.models.helpers import QueryStatus
from superset.models.sql_lab import Query
from superset.sql_parse import ParsedQuery, CtasMethod
from superset.utils.core import get_example_database
2016-08-30 01:11:36 -04:00
from .base_tests import SupersetTestCase
from .sqllab_test_util import (
setup_presto_if_needed,
CTAS_SCHEMA_NAME,
) # noqa autoused fixture
2017-11-07 23:23:40 -05:00
CELERY_SHORT_SLEEP_TIME = 2
CELERY_SLEEP_TIME = 10
DROP_TABLE_SLEEP_TIME = 10
2016-08-30 01:11:36 -04:00
class TestUtilityFunction(SupersetTestCase):
2016-08-30 01:11:36 -04:00
# TODO(bkyryliuk): support more cases in CTA function.
def test_create_table_as(self):
2019-06-25 16:34:48 -04:00
q = ParsedQuery("SELECT * FROM outer_space;")
2016-08-30 01:11:36 -04:00
self.assertEqual(
2019-06-25 16:34:48 -04:00
"CREATE TABLE tmp AS \nSELECT * FROM outer_space", q.as_create_table("tmp")
)
2016-08-30 01:11:36 -04:00
self.assertEqual(
2019-06-25 16:34:48 -04:00
"DROP TABLE IF EXISTS tmp;\n"
"CREATE TABLE tmp AS \nSELECT * FROM outer_space",
q.as_create_table("tmp", overwrite=True),
)
2016-08-30 01:11:36 -04:00
# now without a semicolon
2019-06-25 16:34:48 -04:00
q = ParsedQuery("SELECT * FROM outer_space")
2016-08-30 01:11:36 -04:00
self.assertEqual(
2019-06-25 16:34:48 -04:00
"CREATE TABLE tmp AS \nSELECT * FROM outer_space", q.as_create_table("tmp")
)
2016-08-30 01:11:36 -04:00
# now a multi-line query
2019-06-25 16:34:48 -04:00
multi_line_query = "SELECT * FROM planets WHERE\n" "Luke_Father = 'Darth Vader'"
q = ParsedQuery(multi_line_query)
2016-08-30 01:11:36 -04:00
self.assertEqual(
2019-06-25 16:34:48 -04:00
"CREATE TABLE tmp AS \nSELECT * FROM planets WHERE\n"
"Luke_Father = 'Darth Vader'",
2019-06-25 16:34:48 -04:00
q.as_create_table("tmp"),
)
2016-08-30 01:11:36 -04:00
class TestAppContext(SupersetTestCase):
def test_in_app_context(self):
@celery_app.task()
def my_task():
self.assertTrue(current_app)
# Make sure we can call tasks with an app already setup
my_task()
# Make sure the app gets pushed onto the stack properly
try:
popped_app = flask._app_ctx_stack.pop()
my_task()
finally:
flask._app_ctx_stack.push(popped_app)
class TestCelery(SupersetTestCase):
2016-08-30 01:11:36 -04:00
def get_query_by_name(self, sql):
session = db.session
query = session.query(Query).filter_by(sql=sql).first()
session.close()
2016-08-30 01:11:36 -04:00
return query
def get_query_by_id(self, id):
session = db.session
query = session.query(Query).filter_by(id=id).first()
session.close()
2016-08-30 01:11:36 -04:00
return query
@classmethod
def setUpClass(cls):
with app.app_context():
db.session.query(Query).delete()
db.session.commit()
2019-06-25 16:34:48 -04:00
def run_sql(
self,
db_id,
sql,
client_id=None,
cta=False,
tmp_table="tmp",
async_=False,
ctas_method=CtasMethod.TABLE,
2019-06-25 16:34:48 -04:00
):
2016-08-30 01:11:36 -04:00
self.login()
resp = self.client.post(
2019-06-25 16:34:48 -04:00
"/superset/sql_json/",
json=dict(
database_id=db_id,
2016-08-30 01:11:36 -04:00
sql=sql,
runAsync=async_,
2016-08-30 01:11:36 -04:00
select_as_cta=cta,
tmp_table_name=tmp_table,
client_id=client_id,
ctas_method=ctas_method,
2016-08-30 01:11:36 -04:00
),
)
self.logout()
return json.loads(resp.data)
2016-08-30 01:11:36 -04:00
@parameterized.expand(
[CtasMethod.TABLE,]
)
def test_run_sync_query_dont_exist(self, ctas_method):
main_db = get_example_database()
db_id = main_db.id
2019-06-25 16:34:48 -04:00
sql_dont_exist = "SELECT name FROM table_dont_exist"
result = self.run_sql(
db_id, sql_dont_exist, f"1_{ctas_method}", cta=True, ctas_method=ctas_method
)
if (
get_example_database().backend != "sqlite"
and ctas_method == CtasMethod.VIEW
):
self.assertEqual(QueryStatus.SUCCESS, result["status"], msg=result)
else:
self.assertEqual(QueryStatus.FAILED, result["status"], msg=result)
@parameterized.expand([CtasMethod.TABLE, CtasMethod.VIEW])
def test_run_sync_query_cta(self, ctas_method):
main_db = get_example_database()
backend = main_db.backend
db_id = main_db.id
tmp_table_name = f"tmp_sync_23_{ctas_method.lower()}"
self.drop_table_if_exists(tmp_table_name, ctas_method, main_db)
name = "James"
sql_where = f"SELECT name FROM birth_names WHERE name='{name}' LIMIT 1"
result = self.run_sql(
db_id,
sql_where,
f"2_{ctas_method}",
tmp_table=tmp_table_name,
cta=True,
ctas_method=ctas_method,
)
# provide better error message
self.assertEqual(QueryStatus.SUCCESS, result["query"]["state"], msg=result)
expected_result = []
if backend == "presto":
expected_result = (
[{"rows": 1}] if ctas_method == CtasMethod.TABLE else [{"result": True}]
)
self.assertEqual(expected_result, result["data"])
# TODO(bkyryliuk): refactor database specific logic into a separate class
expected_columns = []
if backend == "presto":
expected_columns = [
{
"name": "rows" if ctas_method == CtasMethod.TABLE else "result",
"type": "BIGINT" if ctas_method == CtasMethod.TABLE else "BOOLEAN",
"is_date": False,
}
]
self.assertEqual(expected_columns, result["columns"])
2019-06-25 16:34:48 -04:00
query2 = self.get_query_by_id(result["query"]["serverId"])
2016-08-30 01:11:36 -04:00
# Check the data in the tmp table.
results = self.run_sql(db_id, query2.select_sql, f"7_{ctas_method}")
self.assertEqual(QueryStatus.SUCCESS, results["status"], msg=results)
self.assertGreater(len(results["data"]), 0)
# cleanup tmp table
self.drop_table_if_exists(tmp_table_name, ctas_method, get_example_database())
2016-08-30 01:11:36 -04:00
def test_run_sync_query_cta_no_data(self):
main_db = get_example_database()
db_id = main_db.id
sql_empty_result = "SELECT * FROM birth_names WHERE name='random'"
2019-06-25 16:34:48 -04:00
result3 = self.run_sql(db_id, sql_empty_result, "3")
self.assertEqual(QueryStatus.SUCCESS, result3["query"]["state"])
self.assertEqual([], result3["data"])
self.assertEqual([], result3["columns"])
2016-08-30 01:11:36 -04:00
2019-06-25 16:34:48 -04:00
query3 = self.get_query_by_id(result3["query"]["serverId"])
2016-08-30 01:11:36 -04:00
self.assertEqual(QueryStatus.SUCCESS, query3.status)
def drop_table_if_exists(
self, table_name: str, table_type: CtasMethod, database: Database,
) -> None:
"""Drop table if it exists, works on any DB"""
sql = f"DROP {table_type} IF EXISTS {table_name}"
database.get_sqla_engine().execute(sql)
@parameterized.expand([CtasMethod.TABLE, CtasMethod.VIEW])
def test_run_sync_query_cta_config(self, ctas_method):
with mock.patch(
"superset.views.core.get_cta_schema_name",
lambda d, u, s, sql: CTAS_SCHEMA_NAME,
):
examples_db = get_example_database()
db_id = examples_db.id
backend = examples_db.backend
if backend == "sqlite":
# sqlite doesn't support schemas
return
tmp_table_name = f"tmp_async_22_{ctas_method.lower()}"
quote = (
examples_db.inspector.engine.dialect.identifier_preparer.quote_identifier
)
expected_full_table_name = f"{CTAS_SCHEMA_NAME}.{quote(tmp_table_name)}"
self.drop_table_if_exists(
expected_full_table_name, ctas_method, examples_db
)
name = "James"
sql_where = f"SELECT name FROM birth_names WHERE name='{name}'"
result = self.run_sql(
db_id,
sql_where,
f"3_{ctas_method}",
tmp_table=tmp_table_name,
cta=True,
ctas_method=ctas_method,
)
self.assertEqual(QueryStatus.SUCCESS, result["query"]["state"], msg=result)
expected_result = []
# TODO(bkyryliuk): refactor database specific logic into a separate class
if backend == "presto":
expected_result = (
[{"rows": 1}]
if ctas_method == CtasMethod.TABLE
else [{"result": True}]
)
self.assertEqual(expected_result, result["data"])
expected_columns = []
# TODO(bkyryliuk): refactor database specific logic into a separate class
if backend == "presto":
expected_columns = [
{
"name": "rows" if ctas_method == CtasMethod.TABLE else "result",
"type": "BIGINT"
if ctas_method == CtasMethod.TABLE
else "BOOLEAN",
"is_date": False,
}
]
self.assertEqual(expected_columns, result["columns"])
query = self.get_query_by_id(result["query"]["serverId"])
self.assertEqual(
f"CREATE {ctas_method} {CTAS_SCHEMA_NAME}.{tmp_table_name} AS \n"
"SELECT name FROM birth_names "
"WHERE name='James'",
query.executed_sql,
)
self.assertEqual(
"SELECT *\n" f"FROM {CTAS_SCHEMA_NAME}.{tmp_table_name}"
if backend != "presto"
else "SELECT *\n"
f"FROM {quote(CTAS_SCHEMA_NAME)}.{quote(tmp_table_name)}",
query.select_sql,
)
time.sleep(CELERY_SHORT_SLEEP_TIME)
results = self.run_sql(db_id, query.select_sql)
self.assertEqual(QueryStatus.SUCCESS, results["status"], msg=result)
self.drop_table_if_exists(
expected_full_table_name, ctas_method, get_example_database()
)
@parameterized.expand([CtasMethod.TABLE, CtasMethod.VIEW])
def test_run_async_query_cta_config(self, ctas_method):
with mock.patch(
"superset.views.core.get_cta_schema_name",
lambda d, u, s, sql: CTAS_SCHEMA_NAME,
):
example_db = get_example_database()
db_id = example_db.id
if example_db.backend == "sqlite":
# sqlite doesn't support schemas
return
tmp_table_name = f"sqllab_test_table_async_1_{ctas_method}"
quote = (
example_db.inspector.engine.dialect.identifier_preparer.quote_identifier
)
schema_name = (
quote(CTAS_SCHEMA_NAME)
if example_db.backend == "presto"
else CTAS_SCHEMA_NAME
)
expected_full_table_name = f"{schema_name}.{quote(tmp_table_name)}"
self.drop_table_if_exists(expected_full_table_name, ctas_method, example_db)
sql_where = "SELECT name FROM birth_names WHERE name='James' LIMIT 10"
result = self.run_sql(
db_id,
sql_where,
f"4_{ctas_method}",
async_=True,
tmp_table=tmp_table_name,
cta=True,
ctas_method=ctas_method,
)
db.session.close()
time.sleep(CELERY_SLEEP_TIME)
query = self.get_query_by_id(result["query"]["serverId"])
self.assertEqual(QueryStatus.SUCCESS, query.status)
self.assertIn(expected_full_table_name, query.select_sql)
self.assertEqual(
f"CREATE {ctas_method} {CTAS_SCHEMA_NAME}.{tmp_table_name} AS \n"
"SELECT name FROM birth_names "
"WHERE name='James' "
"LIMIT 10",
query.executed_sql,
)
self.drop_table_if_exists(
f"{schema_name}.{tmp_table_name}", ctas_method, get_example_database()
)
@parameterized.expand([CtasMethod.TABLE, CtasMethod.VIEW])
def test_run_async_cta_query(self, ctas_method):
main_db = get_example_database()
db_backend = main_db.backend
db_id = main_db.id
table_name = f"tmp_async_4_{ctas_method}"
self.drop_table_if_exists(table_name, ctas_method, main_db)
time.sleep(DROP_TABLE_SLEEP_TIME)
sql_where = "SELECT name FROM birth_names WHERE name='James' LIMIT 10"
result = self.run_sql(
db_id,
sql_where,
f"5_{ctas_method}",
async_=True,
tmp_table=table_name,
cta=True,
ctas_method=ctas_method,
2019-06-25 16:34:48 -04:00
)
db.session.close()
2019-06-25 16:34:48 -04:00
assert result["query"]["state"] in (
QueryStatus.PENDING,
QueryStatus.RUNNING,
QueryStatus.SUCCESS,
)
2016-08-30 01:11:36 -04:00
time.sleep(CELERY_SLEEP_TIME)
2016-08-30 01:11:36 -04:00
2019-06-25 16:34:48 -04:00
query = self.get_query_by_id(result["query"]["serverId"])
self.assertEqual(QueryStatus.SUCCESS, query.status)
self.assertIn(table_name, query.select_sql)
2016-08-30 01:11:36 -04:00
self.assertEqual(
f"CREATE {ctas_method} {table_name} AS \n"
"SELECT name FROM birth_names "
"WHERE name='James' "
"LIMIT 10",
2019-06-25 16:34:48 -04:00
query.executed_sql,
)
self.assertEqual(sql_where, query.sql)
if db_backend == "presto":
self.assertEqual(1, query.rows)
else:
self.assertEqual(0, query.rows)
self.assertEqual(True, query.select_as_cta)
self.assertEqual(True, query.select_as_cta_used)
self.drop_table_if_exists(table_name, ctas_method, get_example_database())
2016-08-30 01:11:36 -04:00
@parameterized.expand([CtasMethod.TABLE, CtasMethod.VIEW])
def test_run_async_cta_query_with_lower_limit(self, ctas_method):
example_db = get_example_database()
db_backend = example_db.backend
db_id = example_db.id
tmp_table = f"tmp_async_2_{ctas_method}"
self.drop_table_if_exists(tmp_table, ctas_method, example_db)
sql_where = "SELECT name FROM birth_names LIMIT 1"
2018-05-25 18:45:11 -04:00
result = self.run_sql(
db_id,
sql_where,
f"6_{ctas_method}",
async_=True,
tmp_table=tmp_table,
cta=True,
ctas_method=ctas_method,
2019-06-25 16:34:48 -04:00
)
db.session.close()
2019-06-25 16:34:48 -04:00
assert result["query"]["state"] in (
QueryStatus.PENDING,
QueryStatus.RUNNING,
QueryStatus.SUCCESS,
)
2018-05-25 18:45:11 -04:00
time.sleep(CELERY_SLEEP_TIME)
2018-05-25 18:45:11 -04:00
2019-06-25 16:34:48 -04:00
query = self.get_query_by_id(result["query"]["serverId"])
2018-05-25 18:45:11 -04:00
self.assertEqual(QueryStatus.SUCCESS, query.status)
self.assertIn(tmp_table, query.select_sql)
2018-05-25 18:45:11 -04:00
self.assertEqual(
f"CREATE {ctas_method} {tmp_table} AS \n"
"SELECT name FROM birth_names LIMIT 1",
2019-06-25 16:34:48 -04:00
query.executed_sql,
)
2018-05-25 18:45:11 -04:00
self.assertEqual(sql_where, query.sql)
if db_backend == "presto":
self.assertEqual(1, query.rows)
else:
self.assertEqual(0, query.rows)
self.assertEqual(None, query.limit)
2018-05-25 18:45:11 -04:00
self.assertEqual(True, query.select_as_cta)
self.assertEqual(True, query.select_as_cta_used)
self.drop_table_if_exists(tmp_table, ctas_method, get_example_database())
2018-05-25 18:45:11 -04:00
def test_default_data_serialization(self):
data = [("a", 4, 4.0, datetime.datetime(2019, 8, 18, 16, 39, 16, 660000))]
cursor_descr = (
("a", "string"),
("b", "int"),
("c", "float"),
("d", "datetime"),
)
db_engine_spec = BaseEngineSpec()
results = SupersetResultSet(data, cursor_descr, db_engine_spec)
with mock.patch.object(
db_engine_spec, "expand_data", wraps=db_engine_spec.expand_data
) as expand_data:
(
data,
selected_columns,
all_columns,
expanded_columns,
) = sql_lab._serialize_and_expand_data(results, db_engine_spec, False, True)
expand_data.assert_called_once()
self.assertIsInstance(data, list)
def test_new_data_serialization(self):
data = [("a", 4, 4.0, datetime.datetime(2019, 8, 18, 16, 39, 16, 660000))]
cursor_descr = (
("a", "string"),
("b", "int"),
("c", "float"),
("d", "datetime"),
)
db_engine_spec = BaseEngineSpec()
results = SupersetResultSet(data, cursor_descr, db_engine_spec)
with mock.patch.object(
db_engine_spec, "expand_data", wraps=db_engine_spec.expand_data
) as expand_data:
(
data,
selected_columns,
all_columns,
expanded_columns,
) = sql_lab._serialize_and_expand_data(results, db_engine_spec, True)
expand_data.assert_not_called()
self.assertIsInstance(data, bytes)
def test_default_payload_serialization(self):
use_new_deserialization = False
data = [("a", 4, 4.0, datetime.datetime(2019, 8, 18, 16, 39, 16, 660000))]
cursor_descr = (
("a", "string"),
("b", "int"),
("c", "float"),
("d", "datetime"),
)
db_engine_spec = BaseEngineSpec()
results = SupersetResultSet(data, cursor_descr, db_engine_spec)
query = {
"database_id": 1,
"sql": "SELECT * FROM birth_names LIMIT 100",
"status": QueryStatus.PENDING,
}
(
serialized_data,
selected_columns,
all_columns,
expanded_columns,
) = sql_lab._serialize_and_expand_data(
results, db_engine_spec, use_new_deserialization
)
payload = {
"query_id": 1,
"status": QueryStatus.SUCCESS,
"state": QueryStatus.SUCCESS,
"data": serialized_data,
"columns": all_columns,
"selected_columns": selected_columns,
"expanded_columns": expanded_columns,
"query": query,
}
serialized = sql_lab._serialize_payload(payload, use_new_deserialization)
self.assertIsInstance(serialized, str)
def test_msgpack_payload_serialization(self):
use_new_deserialization = True
data = [("a", 4, 4.0, datetime.datetime(2019, 8, 18, 16, 39, 16, 660000))]
cursor_descr = (
("a", "string"),
("b", "int"),
("c", "float"),
("d", "datetime"),
)
db_engine_spec = BaseEngineSpec()
results = SupersetResultSet(data, cursor_descr, db_engine_spec)
query = {
"database_id": 1,
"sql": "SELECT * FROM birth_names LIMIT 100",
"status": QueryStatus.PENDING,
}
(
serialized_data,
selected_columns,
all_columns,
expanded_columns,
) = sql_lab._serialize_and_expand_data(
results, db_engine_spec, use_new_deserialization
)
payload = {
"query_id": 1,
"status": QueryStatus.SUCCESS,
"state": QueryStatus.SUCCESS,
"data": serialized_data,
"columns": all_columns,
"selected_columns": selected_columns,
"expanded_columns": expanded_columns,
"query": query,
}
serialized = sql_lab._serialize_payload(payload, use_new_deserialization)
self.assertIsInstance(serialized, bytes)
@staticmethod
def de_unicode_dict(d):
def str_if_basestring(o):
if isinstance(o, str):
return str(o)
return o
2019-06-25 16:34:48 -04:00
return {str_if_basestring(k): str_if_basestring(d[k]) for k in d}
@classmethod
def dictify_list_of_dicts(cls, l, k):
return {str(o[k]): cls.de_unicode_dict(o) for o in l}
2016-08-30 01:11:36 -04:00
2019-06-25 16:34:48 -04:00
if __name__ == "__main__":
2016-08-30 01:11:36 -04:00
unittest.main()