feat: Implement Celery SoftTimeLimit handling (#13740)

* log soft time limit error

* lint

* update test
This commit is contained in:
Lily Kuang 2021-04-12 13:18:17 -07:00 committed by GitHub
parent 911462a148
commit 7980b767c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 142 additions and 32 deletions

View File

@ -155,7 +155,8 @@ class AlertCommand(BaseCommand):
(stop - start) * 1000.0, (stop - start) * 1000.0,
) )
return df return df
except SoftTimeLimitExceeded: except SoftTimeLimitExceeded as ex:
logger.warning("A timeout occurred while executing the alert query: %s", ex)
raise AlertQueryTimeout() raise AlertQueryTimeout()
except Exception as ex: except Exception as ex:
raise AlertQueryError(message=str(ex)) raise AlertQueryError(message=str(ex))

View File

@ -184,6 +184,7 @@ class BaseReportState:
try: try:
image_data = screenshot.get_screenshot(user=user) image_data = screenshot.get_screenshot(user=user)
except SoftTimeLimitExceeded: except SoftTimeLimitExceeded:
logger.warning("A timeout occurred while taking a screenshot.")
raise ReportScheduleScreenshotTimeout() raise ReportScheduleScreenshotTimeout()
except Exception as ex: except Exception as ex:
raise ReportScheduleScreenshotFailedError( raise ReportScheduleScreenshotFailedError(

View File

@ -159,6 +159,15 @@ def get_sql_results( # pylint: disable=too-many-arguments
expand_data=expand_data, expand_data=expand_data,
log_params=log_params, log_params=log_params,
) )
except SoftTimeLimitExceeded as ex:
logger.warning("Query %d: Time limit exceeded", query_id)
logger.debug("Query %d: %s", query_id, ex)
raise SqlLabTimeoutException(
_(
"SQL Lab timeout. This environment's policy is to kill queries "
"after {} seconds.".format(SQLLAB_TIMEOUT)
)
)
except Exception as ex: # pylint: disable=broad-except except Exception as ex: # pylint: disable=broad-except
logger.debug("Query %d: %s", query_id, ex) logger.debug("Query %d: %s", query_id, ex)
stats_logger.incr("error_sqllab_unhandled") stats_logger.incr("error_sqllab_unhandled")
@ -237,14 +246,6 @@ def execute_sql_statement(
str(query.to_dict()), str(query.to_dict()),
) )
data = db_engine_spec.fetch_data(cursor, query.limit) data = db_engine_spec.fetch_data(cursor, query.limit)
except SoftTimeLimitExceeded as ex:
logger.error("Query %d: Time limit exceeded", query.id)
logger.debug("Query %d: %s", query.id, ex)
raise SqlLabTimeoutException(
"SQL Lab timeout. This environment's policy is to kill queries "
"after {} seconds.".format(SQLLAB_TIMEOUT)
)
except Exception as ex: except Exception as ex:
logger.error("Query %d: %s", query.id, type(ex)) logger.error("Query %d: %s", query.id, type(ex))
logger.debug("Query %d: %s", query.id, ex) logger.debug("Query %d: %s", query.id, ex)

View File

@ -18,6 +18,7 @@
import logging import logging
from typing import Any, cast, Dict, Optional from typing import Any, cast, Dict, Optional
from celery.exceptions import SoftTimeLimitExceeded
from flask import current_app, g from flask import current_app, g
from superset import app from superset import app
@ -47,9 +48,7 @@ def ensure_user_is_set(user_id: Optional[int]) -> None:
def load_chart_data_into_cache( def load_chart_data_into_cache(
job_metadata: Dict[str, Any], form_data: Dict[str, Any], job_metadata: Dict[str, Any], form_data: Dict[str, Any],
) -> None: ) -> None:
from superset.charts.commands.data import ( from superset.charts.commands.data import ChartDataCommand
ChartDataCommand,
) # load here due to circular imports
with app.app_context(): # type: ignore with app.app_context(): # type: ignore
try: try:
@ -62,6 +61,11 @@ def load_chart_data_into_cache(
async_query_manager.update_job( async_query_manager.update_job(
job_metadata, async_query_manager.STATUS_DONE, result_url=result_url, job_metadata, async_query_manager.STATUS_DONE, result_url=result_url,
) )
except SoftTimeLimitExceeded as exc:
logger.warning(
"A timeout occurred while loading chart data, error: %s", exc
)
raise exc
except Exception as exc: except Exception as exc:
# TODO: QueryContext should support SIP-40 style errors # TODO: QueryContext should support SIP-40 style errors
error = exc.message if hasattr(exc, "message") else str(exc) # type: ignore # pylint: disable=no-member error = exc.message if hasattr(exc, "message") else str(exc) # type: ignore # pylint: disable=no-member
@ -75,7 +79,7 @@ def load_chart_data_into_cache(
@celery_app.task(name="load_explore_json_into_cache", soft_time_limit=query_timeout) @celery_app.task(name="load_explore_json_into_cache", soft_time_limit=query_timeout)
def load_explore_json_into_cache( def load_explore_json_into_cache( # pylint: disable=too-many-locals
job_metadata: Dict[str, Any], job_metadata: Dict[str, Any],
form_data: Dict[str, Any], form_data: Dict[str, Any],
response_type: Optional[str] = None, response_type: Optional[str] = None,
@ -106,6 +110,11 @@ def load_explore_json_into_cache(
async_query_manager.update_job( async_query_manager.update_job(
job_metadata, async_query_manager.STATUS_DONE, result_url=result_url, job_metadata, async_query_manager.STATUS_DONE, result_url=result_url,
) )
except SoftTimeLimitExceeded as ex:
logger.warning(
"A timeout occurred while loading explore json, error: %s", ex
)
raise ex
except Exception as exc: except Exception as exc:
if isinstance(exc, SupersetVizException): if isinstance(exc, SupersetVizException):
errors = exc.errors # pylint: disable=no-member errors = exc.errors # pylint: disable=no-member

View File

@ -19,6 +19,7 @@ from datetime import datetime, timedelta
from typing import Iterator from typing import Iterator
import croniter import croniter
from celery.exceptions import SoftTimeLimitExceeded
from dateutil import parser from dateutil import parser
from superset import app from superset import app
@ -91,5 +92,7 @@ def execute(report_schedule_id: int, scheduled_dttm: str) -> None:
def prune_log() -> None: def prune_log() -> None:
try: try:
AsyncPruneReportScheduleLogCommand().run() AsyncPruneReportScheduleLogCommand().run()
except SoftTimeLimitExceeded as ex:
logger.warning("A timeout occurred while pruning report schedule logs: %s", ex)
except CommandException as ex: except CommandException as ex:
logger.error("An exception occurred while pruning report schedule logs: %s", ex) logger.error("An exception occurred while pruning report schedule logs: %s", ex)

View File

@ -47,11 +47,13 @@ from superset.reports.commands.exceptions import (
ReportScheduleNotFoundError, ReportScheduleNotFoundError,
ReportScheduleNotificationError, ReportScheduleNotificationError,
ReportSchedulePreviousWorkingError, ReportSchedulePreviousWorkingError,
ReportSchedulePruneLogError,
ReportScheduleScreenshotFailedError, ReportScheduleScreenshotFailedError,
ReportScheduleScreenshotTimeout, ReportScheduleScreenshotTimeout,
ReportScheduleWorkingTimeoutError, ReportScheduleWorkingTimeoutError,
) )
from superset.reports.commands.execute import AsyncExecuteReportScheduleCommand from superset.reports.commands.execute import AsyncExecuteReportScheduleCommand
from superset.reports.commands.log_prune import AsyncPruneReportScheduleLogCommand
from superset.utils.core import get_example_database from superset.utils.core import get_example_database
from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices
from tests.fixtures.world_bank_dashboard import ( from tests.fixtures.world_bank_dashboard import (
@ -193,7 +195,7 @@ def create_test_table_context(database: Database):
database.get_sqla_engine().execute("DROP TABLE test_table") database.get_sqla_engine().execute("DROP TABLE test_table")
@pytest.yield_fixture() @pytest.fixture()
def create_report_email_chart(): def create_report_email_chart():
with app.app_context(): with app.app_context():
chart = db.session.query(Slice).first() chart = db.session.query(Slice).first()
@ -205,7 +207,7 @@ def create_report_email_chart():
cleanup_report_schedule(report_schedule) cleanup_report_schedule(report_schedule)
@pytest.yield_fixture() @pytest.fixture()
def create_report_email_dashboard(): def create_report_email_dashboard():
with app.app_context(): with app.app_context():
dashboard = db.session.query(Dashboard).first() dashboard = db.session.query(Dashboard).first()
@ -217,7 +219,7 @@ def create_report_email_dashboard():
cleanup_report_schedule(report_schedule) cleanup_report_schedule(report_schedule)
@pytest.yield_fixture() @pytest.fixture()
def create_report_slack_chart(): def create_report_slack_chart():
with app.app_context(): with app.app_context():
chart = db.session.query(Slice).first() chart = db.session.query(Slice).first()
@ -229,7 +231,7 @@ def create_report_slack_chart():
cleanup_report_schedule(report_schedule) cleanup_report_schedule(report_schedule)
@pytest.yield_fixture() @pytest.fixture()
def create_report_slack_chart_working(): def create_report_slack_chart_working():
with app.app_context(): with app.app_context():
chart = db.session.query(Slice).first() chart = db.session.query(Slice).first()
@ -255,7 +257,7 @@ def create_report_slack_chart_working():
cleanup_report_schedule(report_schedule) cleanup_report_schedule(report_schedule)
@pytest.yield_fixture() @pytest.fixture()
def create_alert_slack_chart_success(): def create_alert_slack_chart_success():
with app.app_context(): with app.app_context():
chart = db.session.query(Slice).first() chart = db.session.query(Slice).first()
@ -281,7 +283,7 @@ def create_alert_slack_chart_success():
cleanup_report_schedule(report_schedule) cleanup_report_schedule(report_schedule)
@pytest.yield_fixture() @pytest.fixture()
def create_alert_slack_chart_grace(): def create_alert_slack_chart_grace():
with app.app_context(): with app.app_context():
chart = db.session.query(Slice).first() chart = db.session.query(Slice).first()
@ -1115,3 +1117,17 @@ def test_grace_period_error_flap(
assert ( assert (
get_notification_error_sent_count(create_invalid_sql_alert_email_chart) == 2 get_notification_error_sent_count(create_invalid_sql_alert_email_chart) == 2
) )
@pytest.mark.usefixtures(
"load_birth_names_dashboard_with_slices", "create_report_email_dashboard"
)
@patch("superset.reports.dao.ReportScheduleDAO.bulk_delete_logs")
def test_prune_log_soft_time_out(bulk_delete_logs, create_report_email_dashboard):
from celery.exceptions import SoftTimeLimitExceeded
from datetime import datetime, timedelta
bulk_delete_logs.side_effect = SoftTimeLimitExceeded()
with pytest.raises(SoftTimeLimitExceeded) as excinfo:
AsyncPruneReportScheduleLogCommand().run()
assert str(excinfo.value) == "SoftTimeLimitExceeded()"

View File

@ -33,7 +33,12 @@ from superset.errors import ErrorLevel, SupersetErrorType
from superset.models.core import Database from superset.models.core import Database
from superset.models.sql_lab import Query, SavedQuery from superset.models.sql_lab import Query, SavedQuery
from superset.result_set import SupersetResultSet from superset.result_set import SupersetResultSet
from superset.sql_lab import execute_sql_statements, SqlLabException from superset.sql_lab import (
execute_sql_statements,
get_sql_results,
SqlLabException,
SqlLabTimeoutException,
)
from superset.sql_parse import CtasMethod from superset.sql_parse import CtasMethod
from superset.utils.core import ( from superset.utils.core import (
datetime_to_epoch, datetime_to_epoch,
@ -793,3 +798,26 @@ class TestSqlLab(SupersetTestCase):
"sure your query has only a SELECT statement. Then, " "sure your query has only a SELECT statement. Then, "
"try running your query again." "try running your query again."
) )
@mock.patch("superset.sql_lab.get_query")
@mock.patch("superset.sql_lab.execute_sql_statement")
def test_get_sql_results_soft_time_limit(
self, mock_execute_sql_statement, mock_get_query
):
from celery.exceptions import SoftTimeLimitExceeded
sql = """
-- comment
SET @value = 42;
SELECT @value AS foo;
-- comment
"""
mock_get_query.side_effect = SoftTimeLimitExceeded()
with pytest.raises(SqlLabTimeoutException) as excinfo:
get_sql_results(
1, sql, return_results=True, store_results=False,
)
assert (
str(excinfo.value)
== "SQL Lab timeout. This environment's policy is to kill queries after 21600 seconds."
)

View File

@ -20,6 +20,7 @@ from unittest import mock
from uuid import uuid4 from uuid import uuid4
import pytest import pytest
from celery.exceptions import SoftTimeLimitExceeded
from superset import db from superset import db
from superset.charts.commands.data import ChartDataCommand from superset.charts.commands.data import ChartDataCommand
@ -94,6 +95,31 @@ class TestAsyncQueries(SupersetTestCase):
errors = [{"message": "Error: foo"}] errors = [{"message": "Error: foo"}]
mock_update_job.assert_called_once_with(job_metadata, "error", errors=errors) mock_update_job.assert_called_once_with(job_metadata, "error", errors=errors)
@mock.patch.object(ChartDataCommand, "run")
@mock.patch.object(async_query_manager, "update_job")
def test_soft_timeout_load_chart_data_into_cache(
self, mock_update_job, mock_run_command
):
async_query_manager.init_app(app)
user = security_manager.find_user("gamma")
form_data = {}
job_metadata = {
"channel_id": str(uuid4()),
"job_id": str(uuid4()),
"user_id": user.id,
"status": "pending",
"errors": [],
}
errors = ["A timeout occurred while loading chart data"]
with pytest.raises(SoftTimeLimitExceeded):
with mock.patch.object(
async_queries, "ensure_user_is_set",
) as ensure_user_is_set:
ensure_user_is_set.side_effect = SoftTimeLimitExceeded()
load_chart_data_into_cache(job_metadata, form_data)
ensure_user_is_set.assert_called_once_with(user.id, "error", errors=errors)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@mock.patch.object(async_query_manager, "update_job") @mock.patch.object(async_query_manager, "update_job")
def test_load_explore_json_into_cache(self, mock_update_job): def test_load_explore_json_into_cache(self, mock_update_job):
@ -151,3 +177,28 @@ class TestAsyncQueries(SupersetTestCase):
errors = ["The dataset associated with this chart no longer exists"] errors = ["The dataset associated with this chart no longer exists"]
mock_update_job.assert_called_once_with(job_metadata, "error", errors=errors) mock_update_job.assert_called_once_with(job_metadata, "error", errors=errors)
@mock.patch.object(ChartDataCommand, "run")
@mock.patch.object(async_query_manager, "update_job")
def test_soft_timeout_load_explore_json_into_cache(
self, mock_update_job, mock_run_command
):
async_query_manager.init_app(app)
user = security_manager.find_user("gamma")
form_data = {}
job_metadata = {
"channel_id": str(uuid4()),
"job_id": str(uuid4()),
"user_id": user.id,
"status": "pending",
"errors": [],
}
errors = ["A timeout occurred while loading explore json, error"]
with pytest.raises(SoftTimeLimitExceeded):
with mock.patch.object(
async_queries, "ensure_user_is_set",
) as ensure_user_is_set:
ensure_user_is_set.side_effect = SoftTimeLimitExceeded()
load_explore_json_into_cache(job_metadata, form_data)
ensure_user_is_set.assert_called_once_with(user.id, "error", errors=errors)

View File

@ -48,7 +48,7 @@ class TestThumbnailsSeleniumLive(LiveServerTestCase):
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature") @skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
def test_get_async_dashboard_screenshot(self): def test_get_async_dashboard_screenshot(self):
""" """
Thumbnails: Simple get async dashboard screenshot Thumbnails: Simple get async dashboard screenshot
""" """
dashboard = db.session.query(Dashboard).all()[0] dashboard = db.session.query(Dashboard).all()[0]
with patch("superset.dashboards.api.DashboardRestApi.get") as mock_get: with patch("superset.dashboards.api.DashboardRestApi.get") as mock_get:
@ -65,7 +65,7 @@ class TestThumbnails(SupersetTestCase):
def test_dashboard_thumbnail_disabled(self): def test_dashboard_thumbnail_disabled(self):
""" """
Thumbnails: Dashboard thumbnail disabled Thumbnails: Dashboard thumbnail disabled
""" """
if is_feature_enabled("THUMBNAILS"): if is_feature_enabled("THUMBNAILS"):
return return
@ -77,7 +77,7 @@ class TestThumbnails(SupersetTestCase):
def test_chart_thumbnail_disabled(self): def test_chart_thumbnail_disabled(self):
""" """
Thumbnails: Chart thumbnail disabled Thumbnails: Chart thumbnail disabled
""" """
if is_feature_enabled("THUMBNAILS"): if is_feature_enabled("THUMBNAILS"):
return return
@ -90,7 +90,7 @@ class TestThumbnails(SupersetTestCase):
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature") @skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
def test_get_async_dashboard_screenshot(self): def test_get_async_dashboard_screenshot(self):
""" """
Thumbnails: Simple get async dashboard screenshot Thumbnails: Simple get async dashboard screenshot
""" """
dashboard = db.session.query(Dashboard).all()[0] dashboard = db.session.query(Dashboard).all()[0]
self.login(username="admin") self.login(username="admin")
@ -105,7 +105,7 @@ class TestThumbnails(SupersetTestCase):
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature") @skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
def test_get_async_dashboard_notfound(self): def test_get_async_dashboard_notfound(self):
""" """
Thumbnails: Simple get async dashboard not found Thumbnails: Simple get async dashboard not found
""" """
max_id = db.session.query(func.max(Dashboard.id)).scalar() max_id = db.session.query(func.max(Dashboard.id)).scalar()
self.login(username="admin") self.login(username="admin")
@ -116,7 +116,7 @@ class TestThumbnails(SupersetTestCase):
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature") @skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
def test_get_async_dashboard_not_allowed(self): def test_get_async_dashboard_not_allowed(self):
""" """
Thumbnails: Simple get async dashboard not allowed Thumbnails: Simple get async dashboard not allowed
""" """
dashboard = db.session.query(Dashboard).all()[0] dashboard = db.session.query(Dashboard).all()[0]
self.login(username="gamma") self.login(username="gamma")
@ -127,7 +127,7 @@ class TestThumbnails(SupersetTestCase):
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature") @skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
def test_get_async_chart_screenshot(self): def test_get_async_chart_screenshot(self):
""" """
Thumbnails: Simple get async chart screenshot Thumbnails: Simple get async chart screenshot
""" """
chart = db.session.query(Slice).all()[0] chart = db.session.query(Slice).all()[0]
self.login(username="admin") self.login(username="admin")
@ -142,7 +142,7 @@ class TestThumbnails(SupersetTestCase):
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature") @skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
def test_get_async_chart_notfound(self): def test_get_async_chart_notfound(self):
""" """
Thumbnails: Simple get async chart not found Thumbnails: Simple get async chart not found
""" """
max_id = db.session.query(func.max(Slice.id)).scalar() max_id = db.session.query(func.max(Slice.id)).scalar()
self.login(username="admin") self.login(username="admin")
@ -153,7 +153,7 @@ class TestThumbnails(SupersetTestCase):
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature") @skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
def test_get_cached_chart_wrong_digest(self): def test_get_cached_chart_wrong_digest(self):
""" """
Thumbnails: Simple get chart with wrong digest Thumbnails: Simple get chart with wrong digest
""" """
chart = db.session.query(Slice).all()[0] chart = db.session.query(Slice).all()[0]
chart_url = get_url_path("Superset.slice", slice_id=chart.id, standalone="true") chart_url = get_url_path("Superset.slice", slice_id=chart.id, standalone="true")
@ -169,7 +169,7 @@ class TestThumbnails(SupersetTestCase):
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature") @skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
def test_get_cached_dashboard_screenshot(self): def test_get_cached_dashboard_screenshot(self):
""" """
Thumbnails: Simple get cached dashboard screenshot Thumbnails: Simple get cached dashboard screenshot
""" """
dashboard = db.session.query(Dashboard).all()[0] dashboard = db.session.query(Dashboard).all()[0]
dashboard_url = get_url_path("Superset.dashboard", dashboard_id=dashboard.id) dashboard_url = get_url_path("Superset.dashboard", dashboard_id=dashboard.id)
@ -185,7 +185,7 @@ class TestThumbnails(SupersetTestCase):
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature") @skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
def test_get_cached_chart_screenshot(self): def test_get_cached_chart_screenshot(self):
""" """
Thumbnails: Simple get cached chart screenshot Thumbnails: Simple get cached chart screenshot
""" """
chart = db.session.query(Slice).all()[0] chart = db.session.query(Slice).all()[0]
chart_url = get_url_path("Superset.slice", slice_id=chart.id, standalone="true") chart_url = get_url_path("Superset.slice", slice_id=chart.id, standalone="true")
@ -201,7 +201,7 @@ class TestThumbnails(SupersetTestCase):
@skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature") @skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature")
def test_get_cached_dashboard_wrong_digest(self): def test_get_cached_dashboard_wrong_digest(self):
""" """
Thumbnails: Simple get dashboard with wrong digest Thumbnails: Simple get dashboard with wrong digest
""" """
dashboard = db.session.query(Dashboard).all()[0] dashboard = db.session.query(Dashboard).all()[0]
dashboard_url = get_url_path("Superset.dashboard", dashboard_id=dashboard.id) dashboard_url = get_url_path("Superset.dashboard", dashboard_id=dashboard.id)