Adding app context wrapper to Celery tasks (#8653)

* Adding app context wrapper to Celery tasks
This commit is contained in:
Craig Rueda 2019-11-27 07:06:06 -08:00 committed by Daniel Vaz Gaspar
parent 96fb108894
commit df2ee5cbcb
3 changed files with 56 additions and 20 deletions

View File

@ -96,6 +96,22 @@ class SupersetAppInitializer:
def configure_celery(self) -> None: def configure_celery(self) -> None:
celery_app.config_from_object(self.config["CELERY_CONFIG"]) celery_app.config_from_object(self.config["CELERY_CONFIG"])
celery_app.set_default() celery_app.set_default()
flask_app = self.flask_app
# Here, we want to ensure that every call into Celery task has an app context
# setup properly
task_base = celery_app.Task
class AppContextTask(task_base): # type: ignore
# pylint: disable=too-few-public-methods
abstract = True
# Grab each call into the task and set up an app context
def __call__(self, *args, **kwargs):
with flask_app.app_context():
return task_base.__call__(self, *args, **kwargs)
celery_app.Task = AppContextTask
@staticmethod @staticmethod
def init_views() -> None: def init_views() -> None:

View File

@ -54,6 +54,7 @@ stats_logger = config["STATS_LOGGER"]
SQLLAB_TIMEOUT = config["SQLLAB_ASYNC_TIME_LIMIT_SEC"] SQLLAB_TIMEOUT = config["SQLLAB_ASYNC_TIME_LIMIT_SEC"]
SQLLAB_HARD_TIMEOUT = SQLLAB_TIMEOUT + 60 SQLLAB_HARD_TIMEOUT = SQLLAB_TIMEOUT + 60
log_query = config["QUERY_LOGGER"] log_query = config["QUERY_LOGGER"]
logger = logging.getLogger(__name__)
class SqlLabException(Exception): class SqlLabException(Exception):
@ -84,9 +85,9 @@ def handle_query_error(msg, query, session, payload=None):
def get_query_backoff_handler(details): def get_query_backoff_handler(details):
query_id = details["kwargs"]["query_id"] query_id = details["kwargs"]["query_id"]
logging.error(f"Query with id `{query_id}` could not be retrieved") logger.error(f"Query with id `{query_id}` could not be retrieved")
stats_logger.incr("error_attempting_orm_query_{}".format(details["tries"] - 1)) stats_logger.incr("error_attempting_orm_query_{}".format(details["tries"] - 1))
logging.error(f"Query {query_id}: Sleeping for a sec before retrying...") logger.error(f"Query {query_id}: Sleeping for a sec before retrying...")
def get_query_giveup_handler(details): def get_query_giveup_handler(details):
@ -128,7 +129,7 @@ def session_scope(nullpool):
session.commit() session.commit()
except Exception as e: except Exception as e:
session.rollback() session.rollback()
logging.exception(e) logger.exception(e)
raise raise
finally: finally:
session.close() session.close()
@ -166,7 +167,7 @@ def get_sql_results(
expand_data=expand_data, expand_data=expand_data,
) )
except Exception as e: except Exception as e:
logging.exception(f"Query {query_id}: {e}") logger.exception(f"Query {query_id}: {e}")
stats_logger.incr("error_sqllab_unhandled") stats_logger.incr("error_sqllab_unhandled")
query = get_query(query_id, session) query = get_query(query_id, session)
return handle_query_error(str(e), query, session) return handle_query_error(str(e), query, session)
@ -224,13 +225,13 @@ def execute_sql_statement(sql_statement, query, user_name, session, cursor):
query.executed_sql = sql query.executed_sql = sql
session.commit() session.commit()
with stats_timing("sqllab.query.time_executing_query", stats_logger): with stats_timing("sqllab.query.time_executing_query", stats_logger):
logging.info(f"Query {query_id}: Running query: \n{sql}") logger.info(f"Query {query_id}: Running query: \n{sql}")
db_engine_spec.execute(cursor, sql, async_=True) db_engine_spec.execute(cursor, sql, async_=True)
logging.info(f"Query {query_id}: Handling cursor") logger.info(f"Query {query_id}: Handling cursor")
db_engine_spec.handle_cursor(cursor, query, session) db_engine_spec.handle_cursor(cursor, query, session)
with stats_timing("sqllab.query.time_fetching_results", stats_logger): with stats_timing("sqllab.query.time_fetching_results", stats_logger):
logging.debug( logger.debug(
"Query {}: Fetching data for query object: {}".format( "Query {}: Fetching data for query object: {}".format(
query_id, query.to_dict() query_id, query.to_dict()
) )
@ -238,16 +239,16 @@ def execute_sql_statement(sql_statement, query, user_name, session, cursor):
data = db_engine_spec.fetch_data(cursor, query.limit) data = db_engine_spec.fetch_data(cursor, query.limit)
except SoftTimeLimitExceeded as e: except SoftTimeLimitExceeded as e:
logging.exception(f"Query {query_id}: {e}") logger.exception(f"Query {query_id}: {e}")
raise SqlLabTimeoutException( raise SqlLabTimeoutException(
"SQL Lab timeout. This environment's policy is to kill queries " "SQL Lab timeout. This environment's policy is to kill queries "
"after {} seconds.".format(SQLLAB_TIMEOUT) "after {} seconds.".format(SQLLAB_TIMEOUT)
) )
except Exception as e: except Exception as e:
logging.exception(f"Query {query_id}: {e}") logger.exception(f"Query {query_id}: {e}")
raise SqlLabException(db_engine_spec.extract_error_message(e)) raise SqlLabException(db_engine_spec.extract_error_message(e))
logging.debug(f"Query {query_id}: Fetching cursor description") logger.debug(f"Query {query_id}: Fetching cursor description")
cursor_description = cursor.description cursor_description = cursor.description
return SupersetDataFrame(data, cursor_description, db_engine_spec) return SupersetDataFrame(data, cursor_description, db_engine_spec)
@ -255,7 +256,7 @@ def execute_sql_statement(sql_statement, query, user_name, session, cursor):
def _serialize_payload( def _serialize_payload(
payload: dict, use_msgpack: Optional[bool] = False payload: dict, use_msgpack: Optional[bool] = False
) -> Union[bytes, str]: ) -> Union[bytes, str]:
logging.debug(f"Serializing to msgpack: {use_msgpack}") logger.debug(f"Serializing to msgpack: {use_msgpack}")
if use_msgpack: if use_msgpack:
return msgpack.dumps(payload, default=json_iso_dttm_ser, use_bin_type=True) return msgpack.dumps(payload, default=json_iso_dttm_ser, use_bin_type=True)
else: else:
@ -324,9 +325,9 @@ def execute_sql_statements(
# Breaking down into multiple statements # Breaking down into multiple statements
parsed_query = ParsedQuery(rendered_query) parsed_query = ParsedQuery(rendered_query)
statements = parsed_query.get_statements() statements = parsed_query.get_statements()
logging.info(f"Query {query_id}: Executing {len(statements)} statement(s)") logger.info(f"Query {query_id}: Executing {len(statements)} statement(s)")
logging.info(f"Query {query_id}: Set query to 'running'") logger.info(f"Query {query_id}: Set query to 'running'")
query.status = QueryStatus.RUNNING query.status = QueryStatus.RUNNING
query.start_running_time = now_as_float() query.start_running_time = now_as_float()
session.commit() session.commit()
@ -350,7 +351,7 @@ def execute_sql_statements(
# Run statement # Run statement
msg = f"Running statement {i+1} out of {statement_count}" msg = f"Running statement {i+1} out of {statement_count}"
logging.info(f"Query {query_id}: {msg}") logger.info(f"Query {query_id}: {msg}")
query.set_extra_json_key("progress", msg) query.set_extra_json_key("progress", msg)
session.commit() session.commit()
try: try:
@ -396,9 +397,7 @@ def execute_sql_statements(
if store_results and results_backend: if store_results and results_backend:
key = str(uuid.uuid4()) key = str(uuid.uuid4())
logging.info( logger.info(f"Query {query_id}: Storing results in results backend, key: {key}")
f"Query {query_id}: Storing results in results backend, key: {key}"
)
with stats_timing("sqllab.query.results_backend_write", stats_logger): with stats_timing("sqllab.query.results_backend_write", stats_logger):
with stats_timing( with stats_timing(
"sqllab.query.results_backend_write_serialization", stats_logger "sqllab.query.results_backend_write_serialization", stats_logger
@ -411,10 +410,10 @@ def execute_sql_statements(
cache_timeout = config["CACHE_DEFAULT_TIMEOUT"] cache_timeout = config["CACHE_DEFAULT_TIMEOUT"]
compressed = zlib_compress(serialized_payload) compressed = zlib_compress(serialized_payload)
logging.debug( logger.debug(
f"*** serialized payload size: {getsizeof(serialized_payload)}" f"*** serialized payload size: {getsizeof(serialized_payload)}"
) )
logging.debug(f"*** compressed payload size: {getsizeof(compressed)}") logger.debug(f"*** compressed payload size: {getsizeof(compressed)}")
results_backend.set(key, compressed, cache_timeout) results_backend.set(key, compressed, cache_timeout)
query.results_key = key query.results_key = key

View File

@ -23,10 +23,14 @@ import time
import unittest import unittest
import unittest.mock as mock import unittest.mock as mock
from tests.test_app import app # isort:skip import flask
from flask import current_app
from tests.test_app import app
from superset import db, sql_lab from superset import db, sql_lab
from superset.dataframe import SupersetDataFrame from superset.dataframe import SupersetDataFrame
from superset.db_engine_specs.base import BaseEngineSpec from superset.db_engine_specs.base import BaseEngineSpec
from superset.extensions import celery_app
from superset.models.helpers import QueryStatus from superset.models.helpers import QueryStatus
from superset.models.sql_lab import Query from superset.models.sql_lab import Query
from superset.sql_parse import ParsedQuery from superset.sql_parse import ParsedQuery
@ -69,6 +73,23 @@ class UtilityFunctionTests(SupersetTestCase):
) )
class AppContextTests(SupersetTestCase):
def test_in_app_context(self):
@celery_app.task()
def my_task():
self.assertTrue(current_app)
# Make sure we can call tasks with an app already setup
my_task()
# Make sure the app gets pushed onto the stack properly
try:
popped_app = flask._app_ctx_stack.pop()
my_task()
finally:
flask._app_ctx_stack.push(popped_app)
class CeleryTestCase(SupersetTestCase): class CeleryTestCase(SupersetTestCase):
def get_query_by_name(self, sql): def get_query_by_name(self, sql):
session = db.session session = db.session