mirror of https://github.com/apache/superset.git
Adding app context wrapper to Celery tasks (#8653)
* Adding app context wrapper to Celery tasks
This commit is contained in:
parent
96fb108894
commit
df2ee5cbcb
|
@ -96,6 +96,22 @@ class SupersetAppInitializer:
|
||||||
def configure_celery(self) -> None:
|
def configure_celery(self) -> None:
|
||||||
celery_app.config_from_object(self.config["CELERY_CONFIG"])
|
celery_app.config_from_object(self.config["CELERY_CONFIG"])
|
||||||
celery_app.set_default()
|
celery_app.set_default()
|
||||||
|
flask_app = self.flask_app
|
||||||
|
|
||||||
|
# Here, we want to ensure that every call into Celery task has an app context
|
||||||
|
# setup properly
|
||||||
|
task_base = celery_app.Task
|
||||||
|
|
||||||
|
class AppContextTask(task_base): # type: ignore
|
||||||
|
# pylint: disable=too-few-public-methods
|
||||||
|
abstract = True
|
||||||
|
|
||||||
|
# Grab each call into the task and set up an app context
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
with flask_app.app_context():
|
||||||
|
return task_base.__call__(self, *args, **kwargs)
|
||||||
|
|
||||||
|
celery_app.Task = AppContextTask
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def init_views() -> None:
|
def init_views() -> None:
|
||||||
|
|
|
@ -54,6 +54,7 @@ stats_logger = config["STATS_LOGGER"]
|
||||||
SQLLAB_TIMEOUT = config["SQLLAB_ASYNC_TIME_LIMIT_SEC"]
|
SQLLAB_TIMEOUT = config["SQLLAB_ASYNC_TIME_LIMIT_SEC"]
|
||||||
SQLLAB_HARD_TIMEOUT = SQLLAB_TIMEOUT + 60
|
SQLLAB_HARD_TIMEOUT = SQLLAB_TIMEOUT + 60
|
||||||
log_query = config["QUERY_LOGGER"]
|
log_query = config["QUERY_LOGGER"]
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SqlLabException(Exception):
|
class SqlLabException(Exception):
|
||||||
|
@ -84,9 +85,9 @@ def handle_query_error(msg, query, session, payload=None):
|
||||||
|
|
||||||
def get_query_backoff_handler(details):
|
def get_query_backoff_handler(details):
|
||||||
query_id = details["kwargs"]["query_id"]
|
query_id = details["kwargs"]["query_id"]
|
||||||
logging.error(f"Query with id `{query_id}` could not be retrieved")
|
logger.error(f"Query with id `{query_id}` could not be retrieved")
|
||||||
stats_logger.incr("error_attempting_orm_query_{}".format(details["tries"] - 1))
|
stats_logger.incr("error_attempting_orm_query_{}".format(details["tries"] - 1))
|
||||||
logging.error(f"Query {query_id}: Sleeping for a sec before retrying...")
|
logger.error(f"Query {query_id}: Sleeping for a sec before retrying...")
|
||||||
|
|
||||||
|
|
||||||
def get_query_giveup_handler(details):
|
def get_query_giveup_handler(details):
|
||||||
|
@ -128,7 +129,7 @@ def session_scope(nullpool):
|
||||||
session.commit()
|
session.commit()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
session.rollback()
|
session.rollback()
|
||||||
logging.exception(e)
|
logger.exception(e)
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
session.close()
|
session.close()
|
||||||
|
@ -166,7 +167,7 @@ def get_sql_results(
|
||||||
expand_data=expand_data,
|
expand_data=expand_data,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(f"Query {query_id}: {e}")
|
logger.exception(f"Query {query_id}: {e}")
|
||||||
stats_logger.incr("error_sqllab_unhandled")
|
stats_logger.incr("error_sqllab_unhandled")
|
||||||
query = get_query(query_id, session)
|
query = get_query(query_id, session)
|
||||||
return handle_query_error(str(e), query, session)
|
return handle_query_error(str(e), query, session)
|
||||||
|
@ -224,13 +225,13 @@ def execute_sql_statement(sql_statement, query, user_name, session, cursor):
|
||||||
query.executed_sql = sql
|
query.executed_sql = sql
|
||||||
session.commit()
|
session.commit()
|
||||||
with stats_timing("sqllab.query.time_executing_query", stats_logger):
|
with stats_timing("sqllab.query.time_executing_query", stats_logger):
|
||||||
logging.info(f"Query {query_id}: Running query: \n{sql}")
|
logger.info(f"Query {query_id}: Running query: \n{sql}")
|
||||||
db_engine_spec.execute(cursor, sql, async_=True)
|
db_engine_spec.execute(cursor, sql, async_=True)
|
||||||
logging.info(f"Query {query_id}: Handling cursor")
|
logger.info(f"Query {query_id}: Handling cursor")
|
||||||
db_engine_spec.handle_cursor(cursor, query, session)
|
db_engine_spec.handle_cursor(cursor, query, session)
|
||||||
|
|
||||||
with stats_timing("sqllab.query.time_fetching_results", stats_logger):
|
with stats_timing("sqllab.query.time_fetching_results", stats_logger):
|
||||||
logging.debug(
|
logger.debug(
|
||||||
"Query {}: Fetching data for query object: {}".format(
|
"Query {}: Fetching data for query object: {}".format(
|
||||||
query_id, query.to_dict()
|
query_id, query.to_dict()
|
||||||
)
|
)
|
||||||
|
@ -238,16 +239,16 @@ def execute_sql_statement(sql_statement, query, user_name, session, cursor):
|
||||||
data = db_engine_spec.fetch_data(cursor, query.limit)
|
data = db_engine_spec.fetch_data(cursor, query.limit)
|
||||||
|
|
||||||
except SoftTimeLimitExceeded as e:
|
except SoftTimeLimitExceeded as e:
|
||||||
logging.exception(f"Query {query_id}: {e}")
|
logger.exception(f"Query {query_id}: {e}")
|
||||||
raise SqlLabTimeoutException(
|
raise SqlLabTimeoutException(
|
||||||
"SQL Lab timeout. This environment's policy is to kill queries "
|
"SQL Lab timeout. This environment's policy is to kill queries "
|
||||||
"after {} seconds.".format(SQLLAB_TIMEOUT)
|
"after {} seconds.".format(SQLLAB_TIMEOUT)
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(f"Query {query_id}: {e}")
|
logger.exception(f"Query {query_id}: {e}")
|
||||||
raise SqlLabException(db_engine_spec.extract_error_message(e))
|
raise SqlLabException(db_engine_spec.extract_error_message(e))
|
||||||
|
|
||||||
logging.debug(f"Query {query_id}: Fetching cursor description")
|
logger.debug(f"Query {query_id}: Fetching cursor description")
|
||||||
cursor_description = cursor.description
|
cursor_description = cursor.description
|
||||||
return SupersetDataFrame(data, cursor_description, db_engine_spec)
|
return SupersetDataFrame(data, cursor_description, db_engine_spec)
|
||||||
|
|
||||||
|
@ -255,7 +256,7 @@ def execute_sql_statement(sql_statement, query, user_name, session, cursor):
|
||||||
def _serialize_payload(
|
def _serialize_payload(
|
||||||
payload: dict, use_msgpack: Optional[bool] = False
|
payload: dict, use_msgpack: Optional[bool] = False
|
||||||
) -> Union[bytes, str]:
|
) -> Union[bytes, str]:
|
||||||
logging.debug(f"Serializing to msgpack: {use_msgpack}")
|
logger.debug(f"Serializing to msgpack: {use_msgpack}")
|
||||||
if use_msgpack:
|
if use_msgpack:
|
||||||
return msgpack.dumps(payload, default=json_iso_dttm_ser, use_bin_type=True)
|
return msgpack.dumps(payload, default=json_iso_dttm_ser, use_bin_type=True)
|
||||||
else:
|
else:
|
||||||
|
@ -324,9 +325,9 @@ def execute_sql_statements(
|
||||||
# Breaking down into multiple statements
|
# Breaking down into multiple statements
|
||||||
parsed_query = ParsedQuery(rendered_query)
|
parsed_query = ParsedQuery(rendered_query)
|
||||||
statements = parsed_query.get_statements()
|
statements = parsed_query.get_statements()
|
||||||
logging.info(f"Query {query_id}: Executing {len(statements)} statement(s)")
|
logger.info(f"Query {query_id}: Executing {len(statements)} statement(s)")
|
||||||
|
|
||||||
logging.info(f"Query {query_id}: Set query to 'running'")
|
logger.info(f"Query {query_id}: Set query to 'running'")
|
||||||
query.status = QueryStatus.RUNNING
|
query.status = QueryStatus.RUNNING
|
||||||
query.start_running_time = now_as_float()
|
query.start_running_time = now_as_float()
|
||||||
session.commit()
|
session.commit()
|
||||||
|
@ -350,7 +351,7 @@ def execute_sql_statements(
|
||||||
|
|
||||||
# Run statement
|
# Run statement
|
||||||
msg = f"Running statement {i+1} out of {statement_count}"
|
msg = f"Running statement {i+1} out of {statement_count}"
|
||||||
logging.info(f"Query {query_id}: {msg}")
|
logger.info(f"Query {query_id}: {msg}")
|
||||||
query.set_extra_json_key("progress", msg)
|
query.set_extra_json_key("progress", msg)
|
||||||
session.commit()
|
session.commit()
|
||||||
try:
|
try:
|
||||||
|
@ -396,9 +397,7 @@ def execute_sql_statements(
|
||||||
|
|
||||||
if store_results and results_backend:
|
if store_results and results_backend:
|
||||||
key = str(uuid.uuid4())
|
key = str(uuid.uuid4())
|
||||||
logging.info(
|
logger.info(f"Query {query_id}: Storing results in results backend, key: {key}")
|
||||||
f"Query {query_id}: Storing results in results backend, key: {key}"
|
|
||||||
)
|
|
||||||
with stats_timing("sqllab.query.results_backend_write", stats_logger):
|
with stats_timing("sqllab.query.results_backend_write", stats_logger):
|
||||||
with stats_timing(
|
with stats_timing(
|
||||||
"sqllab.query.results_backend_write_serialization", stats_logger
|
"sqllab.query.results_backend_write_serialization", stats_logger
|
||||||
|
@ -411,10 +410,10 @@ def execute_sql_statements(
|
||||||
cache_timeout = config["CACHE_DEFAULT_TIMEOUT"]
|
cache_timeout = config["CACHE_DEFAULT_TIMEOUT"]
|
||||||
|
|
||||||
compressed = zlib_compress(serialized_payload)
|
compressed = zlib_compress(serialized_payload)
|
||||||
logging.debug(
|
logger.debug(
|
||||||
f"*** serialized payload size: {getsizeof(serialized_payload)}"
|
f"*** serialized payload size: {getsizeof(serialized_payload)}"
|
||||||
)
|
)
|
||||||
logging.debug(f"*** compressed payload size: {getsizeof(compressed)}")
|
logger.debug(f"*** compressed payload size: {getsizeof(compressed)}")
|
||||||
results_backend.set(key, compressed, cache_timeout)
|
results_backend.set(key, compressed, cache_timeout)
|
||||||
query.results_key = key
|
query.results_key = key
|
||||||
|
|
||||||
|
|
|
@ -23,10 +23,14 @@ import time
|
||||||
import unittest
|
import unittest
|
||||||
import unittest.mock as mock
|
import unittest.mock as mock
|
||||||
|
|
||||||
from tests.test_app import app # isort:skip
|
import flask
|
||||||
|
from flask import current_app
|
||||||
|
|
||||||
|
from tests.test_app import app
|
||||||
from superset import db, sql_lab
|
from superset import db, sql_lab
|
||||||
from superset.dataframe import SupersetDataFrame
|
from superset.dataframe import SupersetDataFrame
|
||||||
from superset.db_engine_specs.base import BaseEngineSpec
|
from superset.db_engine_specs.base import BaseEngineSpec
|
||||||
|
from superset.extensions import celery_app
|
||||||
from superset.models.helpers import QueryStatus
|
from superset.models.helpers import QueryStatus
|
||||||
from superset.models.sql_lab import Query
|
from superset.models.sql_lab import Query
|
||||||
from superset.sql_parse import ParsedQuery
|
from superset.sql_parse import ParsedQuery
|
||||||
|
@ -69,6 +73,23 @@ class UtilityFunctionTests(SupersetTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AppContextTests(SupersetTestCase):
|
||||||
|
def test_in_app_context(self):
|
||||||
|
@celery_app.task()
|
||||||
|
def my_task():
|
||||||
|
self.assertTrue(current_app)
|
||||||
|
|
||||||
|
# Make sure we can call tasks with an app already setup
|
||||||
|
my_task()
|
||||||
|
|
||||||
|
# Make sure the app gets pushed onto the stack properly
|
||||||
|
try:
|
||||||
|
popped_app = flask._app_ctx_stack.pop()
|
||||||
|
my_task()
|
||||||
|
finally:
|
||||||
|
flask._app_ctx_stack.push(popped_app)
|
||||||
|
|
||||||
|
|
||||||
class CeleryTestCase(SupersetTestCase):
|
class CeleryTestCase(SupersetTestCase):
|
||||||
def get_query_by_name(self, sql):
|
def get_query_by_name(self, sql):
|
||||||
session = db.session
|
session = db.session
|
||||||
|
|
Loading…
Reference in New Issue