diff --git a/superset/app.py b/superset/app.py index efa0622808..5eb246eb52 100644 --- a/superset/app.py +++ b/superset/app.py @@ -96,6 +96,22 @@ class SupersetAppInitializer: def configure_celery(self) -> None: celery_app.config_from_object(self.config["CELERY_CONFIG"]) celery_app.set_default() + flask_app = self.flask_app + + # Here, we want to ensure that every call into Celery task has an app context + # setup properly + task_base = celery_app.Task + + class AppContextTask(task_base): # type: ignore + # pylint: disable=too-few-public-methods + abstract = True + + # Grab each call into the task and set up an app context + def __call__(self, *args, **kwargs): + with flask_app.app_context(): + return task_base.__call__(self, *args, **kwargs) + + celery_app.Task = AppContextTask @staticmethod def init_views() -> None: diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 842f292e27..d7fe551cd9 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -54,6 +54,7 @@ stats_logger = config["STATS_LOGGER"] SQLLAB_TIMEOUT = config["SQLLAB_ASYNC_TIME_LIMIT_SEC"] SQLLAB_HARD_TIMEOUT = SQLLAB_TIMEOUT + 60 log_query = config["QUERY_LOGGER"] +logger = logging.getLogger(__name__) class SqlLabException(Exception): @@ -84,9 +85,9 @@ def handle_query_error(msg, query, session, payload=None): def get_query_backoff_handler(details): query_id = details["kwargs"]["query_id"] - logging.error(f"Query with id `{query_id}` could not be retrieved") + logger.error(f"Query with id `{query_id}` could not be retrieved") stats_logger.incr("error_attempting_orm_query_{}".format(details["tries"] - 1)) - logging.error(f"Query {query_id}: Sleeping for a sec before retrying...") + logger.error(f"Query {query_id}: Sleeping for a sec before retrying...") def get_query_giveup_handler(details): @@ -128,7 +129,7 @@ def session_scope(nullpool): session.commit() except Exception as e: session.rollback() - logging.exception(e) + logger.exception(e) raise finally: session.close() @@ -166,7 +167,7 @@ def get_sql_results( expand_data=expand_data, ) except Exception as e: - logging.exception(f"Query {query_id}: {e}") + logger.exception(f"Query {query_id}: {e}") stats_logger.incr("error_sqllab_unhandled") query = get_query(query_id, session) return handle_query_error(str(e), query, session) @@ -224,13 +225,13 @@ def execute_sql_statement(sql_statement, query, user_name, session, cursor): query.executed_sql = sql session.commit() with stats_timing("sqllab.query.time_executing_query", stats_logger): - logging.info(f"Query {query_id}: Running query: \n{sql}") + logger.info(f"Query {query_id}: Running query: \n{sql}") db_engine_spec.execute(cursor, sql, async_=True) - logging.info(f"Query {query_id}: Handling cursor") + logger.info(f"Query {query_id}: Handling cursor") db_engine_spec.handle_cursor(cursor, query, session) with stats_timing("sqllab.query.time_fetching_results", stats_logger): - logging.debug( + logger.debug( "Query {}: Fetching data for query object: {}".format( query_id, query.to_dict() ) @@ -238,16 +239,16 @@ def execute_sql_statement(sql_statement, query, user_name, session, cursor): data = db_engine_spec.fetch_data(cursor, query.limit) except SoftTimeLimitExceeded as e: - logging.exception(f"Query {query_id}: {e}") + logger.exception(f"Query {query_id}: {e}") raise SqlLabTimeoutException( "SQL Lab timeout. This environment's policy is to kill queries " "after {} seconds.".format(SQLLAB_TIMEOUT) ) except Exception as e: - logging.exception(f"Query {query_id}: {e}") + logger.exception(f"Query {query_id}: {e}") raise SqlLabException(db_engine_spec.extract_error_message(e)) - logging.debug(f"Query {query_id}: Fetching cursor description") + logger.debug(f"Query {query_id}: Fetching cursor description") cursor_description = cursor.description return SupersetDataFrame(data, cursor_description, db_engine_spec) @@ -255,7 +256,7 @@ def execute_sql_statement(sql_statement, query, user_name, session, cursor): def _serialize_payload( payload: dict, use_msgpack: Optional[bool] = False ) -> Union[bytes, str]: - logging.debug(f"Serializing to msgpack: {use_msgpack}") + logger.debug(f"Serializing to msgpack: {use_msgpack}") if use_msgpack: return msgpack.dumps(payload, default=json_iso_dttm_ser, use_bin_type=True) else: @@ -324,9 +325,9 @@ def execute_sql_statements( # Breaking down into multiple statements parsed_query = ParsedQuery(rendered_query) statements = parsed_query.get_statements() - logging.info(f"Query {query_id}: Executing {len(statements)} statement(s)") + logger.info(f"Query {query_id}: Executing {len(statements)} statement(s)") - logging.info(f"Query {query_id}: Set query to 'running'") + logger.info(f"Query {query_id}: Set query to 'running'") query.status = QueryStatus.RUNNING query.start_running_time = now_as_float() session.commit() @@ -350,7 +351,7 @@ def execute_sql_statements( # Run statement msg = f"Running statement {i+1} out of {statement_count}" - logging.info(f"Query {query_id}: {msg}") + logger.info(f"Query {query_id}: {msg}") query.set_extra_json_key("progress", msg) session.commit() try: @@ -396,9 +397,7 @@ def execute_sql_statements( if store_results and results_backend: key = str(uuid.uuid4()) - logging.info( - f"Query {query_id}: Storing results in results backend, key: {key}" - ) + logger.info(f"Query {query_id}: Storing results in results backend, key: {key}") with stats_timing("sqllab.query.results_backend_write", stats_logger): with stats_timing( "sqllab.query.results_backend_write_serialization", stats_logger @@ -411,10 +410,10 @@ def execute_sql_statements( cache_timeout = config["CACHE_DEFAULT_TIMEOUT"] compressed = zlib_compress(serialized_payload) - logging.debug( + logger.debug( f"*** serialized payload size: {getsizeof(serialized_payload)}" ) - logging.debug(f"*** compressed payload size: {getsizeof(compressed)}") + logger.debug(f"*** compressed payload size: {getsizeof(compressed)}") results_backend.set(key, compressed, cache_timeout) query.results_key = key diff --git a/tests/celery_tests.py b/tests/celery_tests.py index 954c84d5fa..521be0bf8e 100644 --- a/tests/celery_tests.py +++ b/tests/celery_tests.py @@ -23,10 +23,14 @@ import time import unittest import unittest.mock as mock -from tests.test_app import app # isort:skip +import flask +from flask import current_app + +from tests.test_app import app from superset import db, sql_lab from superset.dataframe import SupersetDataFrame from superset.db_engine_specs.base import BaseEngineSpec +from superset.extensions import celery_app from superset.models.helpers import QueryStatus from superset.models.sql_lab import Query from superset.sql_parse import ParsedQuery @@ -69,6 +73,23 @@ class UtilityFunctionTests(SupersetTestCase): ) +class AppContextTests(SupersetTestCase): + def test_in_app_context(self): + @celery_app.task() + def my_task(): + self.assertTrue(current_app) + + # Make sure we can call tasks with an app already setup + my_task() + + # Make sure the app gets pushed onto the stack properly + try: + popped_app = flask._app_ctx_stack.pop() + my_task() + finally: + flask._app_ctx_stack.push(popped_app) + + class CeleryTestCase(SupersetTestCase): def get_query_by_name(self, sql): session = db.session