From 7980b767c0ee5bbbd23b13a9434bd182e20f9613 Mon Sep 17 00:00:00 2001 From: Lily Kuang Date: Mon, 12 Apr 2021 13:18:17 -0700 Subject: [PATCH] feat: Implement Celery SoftTimeLimit handling (#13740) * log soft time limit error * lint * update test --- superset/reports/commands/alert.py | 3 +- superset/reports/commands/execute.py | 1 + superset/sql_lab.py | 17 +++++----- superset/tasks/async_queries.py | 17 +++++++--- superset/tasks/scheduler.py | 3 ++ tests/reports/commands_tests.py | 28 +++++++++++---- tests/sqllab_tests.py | 30 +++++++++++++++- tests/tasks/async_queries_tests.py | 51 ++++++++++++++++++++++++++++ tests/thumbnails_tests.py | 24 ++++++------- 9 files changed, 142 insertions(+), 32 deletions(-) diff --git a/superset/reports/commands/alert.py b/superset/reports/commands/alert.py index 469e85c189..5dd9797055 100644 --- a/superset/reports/commands/alert.py +++ b/superset/reports/commands/alert.py @@ -155,7 +155,8 @@ class AlertCommand(BaseCommand): (stop - start) * 1000.0, ) return df - except SoftTimeLimitExceeded: + except SoftTimeLimitExceeded as ex: + logger.warning("A timeout occurred while executing the alert query: %s", ex) raise AlertQueryTimeout() except Exception as ex: raise AlertQueryError(message=str(ex)) diff --git a/superset/reports/commands/execute.py b/superset/reports/commands/execute.py index 9f9696259d..ed9691cea2 100644 --- a/superset/reports/commands/execute.py +++ b/superset/reports/commands/execute.py @@ -184,6 +184,7 @@ class BaseReportState: try: image_data = screenshot.get_screenshot(user=user) except SoftTimeLimitExceeded: + logger.warning("A timeout occurred while taking a screenshot.") raise ReportScheduleScreenshotTimeout() except Exception as ex: raise ReportScheduleScreenshotFailedError( diff --git a/superset/sql_lab.py b/superset/sql_lab.py index dd8282652c..234b1dd110 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -159,6 +159,15 @@ def get_sql_results( # pylint: disable=too-many-arguments expand_data=expand_data, log_params=log_params, ) + except SoftTimeLimitExceeded as ex: + logger.warning("Query %d: Time limit exceeded", query_id) + logger.debug("Query %d: %s", query_id, ex) + raise SqlLabTimeoutException( + _( + "SQL Lab timeout. This environment's policy is to kill queries " + "after {} seconds.".format(SQLLAB_TIMEOUT) + ) + ) except Exception as ex: # pylint: disable=broad-except logger.debug("Query %d: %s", query_id, ex) stats_logger.incr("error_sqllab_unhandled") @@ -237,14 +246,6 @@ def execute_sql_statement( str(query.to_dict()), ) data = db_engine_spec.fetch_data(cursor, query.limit) - - except SoftTimeLimitExceeded as ex: - logger.error("Query %d: Time limit exceeded", query.id) - logger.debug("Query %d: %s", query.id, ex) - raise SqlLabTimeoutException( - "SQL Lab timeout. This environment's policy is to kill queries " - "after {} seconds.".format(SQLLAB_TIMEOUT) - ) except Exception as ex: logger.error("Query %d: %s", query.id, type(ex)) logger.debug("Query %d: %s", query.id, ex) diff --git a/superset/tasks/async_queries.py b/superset/tasks/async_queries.py index f5f3c14f6b..f008bc1f17 100644 --- a/superset/tasks/async_queries.py +++ b/superset/tasks/async_queries.py @@ -18,6 +18,7 @@ import logging from typing import Any, cast, Dict, Optional +from celery.exceptions import SoftTimeLimitExceeded from flask import current_app, g from superset import app @@ -47,9 +48,7 @@ def ensure_user_is_set(user_id: Optional[int]) -> None: def load_chart_data_into_cache( job_metadata: Dict[str, Any], form_data: Dict[str, Any], ) -> None: - from superset.charts.commands.data import ( - ChartDataCommand, - ) # load here due to circular imports + from superset.charts.commands.data import ChartDataCommand with app.app_context(): # type: ignore try: @@ -62,6 +61,11 @@ def load_chart_data_into_cache( async_query_manager.update_job( job_metadata, async_query_manager.STATUS_DONE, result_url=result_url, ) + except SoftTimeLimitExceeded as exc: + logger.warning( + "A timeout occurred while loading chart data, error: %s", exc + ) + raise exc except Exception as exc: # TODO: QueryContext should support SIP-40 style errors error = exc.message if hasattr(exc, "message") else str(exc) # type: ignore # pylint: disable=no-member @@ -75,7 +79,7 @@ def load_chart_data_into_cache( @celery_app.task(name="load_explore_json_into_cache", soft_time_limit=query_timeout) -def load_explore_json_into_cache( +def load_explore_json_into_cache( # pylint: disable=too-many-locals job_metadata: Dict[str, Any], form_data: Dict[str, Any], response_type: Optional[str] = None, @@ -106,6 +110,11 @@ def load_explore_json_into_cache( async_query_manager.update_job( job_metadata, async_query_manager.STATUS_DONE, result_url=result_url, ) + except SoftTimeLimitExceeded as ex: + logger.warning( + "A timeout occurred while loading explore json, error: %s", ex + ) + raise ex except Exception as exc: if isinstance(exc, SupersetVizException): errors = exc.errors # pylint: disable=no-member diff --git a/superset/tasks/scheduler.py b/superset/tasks/scheduler.py index 84026aefb4..86bc9ca180 100644 --- a/superset/tasks/scheduler.py +++ b/superset/tasks/scheduler.py @@ -19,6 +19,7 @@ from datetime import datetime, timedelta from typing import Iterator import croniter +from celery.exceptions import SoftTimeLimitExceeded from dateutil import parser from superset import app @@ -91,5 +92,7 @@ def execute(report_schedule_id: int, scheduled_dttm: str) -> None: def prune_log() -> None: try: AsyncPruneReportScheduleLogCommand().run() + except SoftTimeLimitExceeded as ex: + logger.warning("A timeout occurred while pruning report schedule logs: %s", ex) except CommandException as ex: logger.error("An exception occurred while pruning report schedule logs: %s", ex) diff --git a/tests/reports/commands_tests.py b/tests/reports/commands_tests.py index 8d97269ef4..7a20ab3c8f 100644 --- a/tests/reports/commands_tests.py +++ b/tests/reports/commands_tests.py @@ -47,11 +47,13 @@ from superset.reports.commands.exceptions import ( ReportScheduleNotFoundError, ReportScheduleNotificationError, ReportSchedulePreviousWorkingError, + ReportSchedulePruneLogError, ReportScheduleScreenshotFailedError, ReportScheduleScreenshotTimeout, ReportScheduleWorkingTimeoutError, ) from superset.reports.commands.execute import AsyncExecuteReportScheduleCommand +from superset.reports.commands.log_prune import AsyncPruneReportScheduleLogCommand from superset.utils.core import get_example_database from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices from tests.fixtures.world_bank_dashboard import ( @@ -193,7 +195,7 @@ def create_test_table_context(database: Database): database.get_sqla_engine().execute("DROP TABLE test_table") -@pytest.yield_fixture() +@pytest.fixture() def create_report_email_chart(): with app.app_context(): chart = db.session.query(Slice).first() @@ -205,7 +207,7 @@ def create_report_email_chart(): cleanup_report_schedule(report_schedule) -@pytest.yield_fixture() +@pytest.fixture() def create_report_email_dashboard(): with app.app_context(): dashboard = db.session.query(Dashboard).first() @@ -217,7 +219,7 @@ def create_report_email_dashboard(): cleanup_report_schedule(report_schedule) -@pytest.yield_fixture() +@pytest.fixture() def create_report_slack_chart(): with app.app_context(): chart = db.session.query(Slice).first() @@ -229,7 +231,7 @@ def create_report_slack_chart(): cleanup_report_schedule(report_schedule) -@pytest.yield_fixture() +@pytest.fixture() def create_report_slack_chart_working(): with app.app_context(): chart = db.session.query(Slice).first() @@ -255,7 +257,7 @@ def create_report_slack_chart_working(): cleanup_report_schedule(report_schedule) -@pytest.yield_fixture() +@pytest.fixture() def create_alert_slack_chart_success(): with app.app_context(): chart = db.session.query(Slice).first() @@ -281,7 +283,7 @@ def create_alert_slack_chart_success(): cleanup_report_schedule(report_schedule) -@pytest.yield_fixture() +@pytest.fixture() def create_alert_slack_chart_grace(): with app.app_context(): chart = db.session.query(Slice).first() @@ -1115,3 +1117,17 @@ def test_grace_period_error_flap( assert ( get_notification_error_sent_count(create_invalid_sql_alert_email_chart) == 2 ) + + +@pytest.mark.usefixtures( + "load_birth_names_dashboard_with_slices", "create_report_email_dashboard" +) +@patch("superset.reports.dao.ReportScheduleDAO.bulk_delete_logs") +def test_prune_log_soft_time_out(bulk_delete_logs, create_report_email_dashboard): + from celery.exceptions import SoftTimeLimitExceeded + from datetime import datetime, timedelta + + bulk_delete_logs.side_effect = SoftTimeLimitExceeded() + with pytest.raises(SoftTimeLimitExceeded) as excinfo: + AsyncPruneReportScheduleLogCommand().run() + assert str(excinfo.value) == "SoftTimeLimitExceeded()" diff --git a/tests/sqllab_tests.py b/tests/sqllab_tests.py index ee0fac541a..afed7a678a 100644 --- a/tests/sqllab_tests.py +++ b/tests/sqllab_tests.py @@ -33,7 +33,12 @@ from superset.errors import ErrorLevel, SupersetErrorType from superset.models.core import Database from superset.models.sql_lab import Query, SavedQuery from superset.result_set import SupersetResultSet -from superset.sql_lab import execute_sql_statements, SqlLabException +from superset.sql_lab import ( + execute_sql_statements, + get_sql_results, + SqlLabException, + SqlLabTimeoutException, +) from superset.sql_parse import CtasMethod from superset.utils.core import ( datetime_to_epoch, @@ -793,3 +798,26 @@ class TestSqlLab(SupersetTestCase): "sure your query has only a SELECT statement. Then, " "try running your query again." ) + + @mock.patch("superset.sql_lab.get_query") + @mock.patch("superset.sql_lab.execute_sql_statement") + def test_get_sql_results_soft_time_limit( + self, mock_execute_sql_statement, mock_get_query + ): + from celery.exceptions import SoftTimeLimitExceeded + + sql = """ + -- comment + SET @value = 42; + SELECT @value AS foo; + -- comment + """ + mock_get_query.side_effect = SoftTimeLimitExceeded() + with pytest.raises(SqlLabTimeoutException) as excinfo: + get_sql_results( + 1, sql, return_results=True, store_results=False, + ) + assert ( + str(excinfo.value) + == "SQL Lab timeout. This environment's policy is to kill queries after 21600 seconds." + ) diff --git a/tests/tasks/async_queries_tests.py b/tests/tasks/async_queries_tests.py index cd4f0c0ce7..cca58d8e2a 100644 --- a/tests/tasks/async_queries_tests.py +++ b/tests/tasks/async_queries_tests.py @@ -20,6 +20,7 @@ from unittest import mock from uuid import uuid4 import pytest +from celery.exceptions import SoftTimeLimitExceeded from superset import db from superset.charts.commands.data import ChartDataCommand @@ -94,6 +95,31 @@ class TestAsyncQueries(SupersetTestCase): errors = [{"message": "Error: foo"}] mock_update_job.assert_called_once_with(job_metadata, "error", errors=errors) + @mock.patch.object(ChartDataCommand, "run") + @mock.patch.object(async_query_manager, "update_job") + def test_soft_timeout_load_chart_data_into_cache( + self, mock_update_job, mock_run_command + ): + async_query_manager.init_app(app) + user = security_manager.find_user("gamma") + form_data = {} + job_metadata = { + "channel_id": str(uuid4()), + "job_id": str(uuid4()), + "user_id": user.id, + "status": "pending", + "errors": [], + } + errors = ["A timeout occurred while loading chart data"] + + with pytest.raises(SoftTimeLimitExceeded): + with mock.patch.object( + async_queries, "ensure_user_is_set", + ) as ensure_user_is_set: + ensure_user_is_set.side_effect = SoftTimeLimitExceeded() + load_chart_data_into_cache(job_metadata, form_data) + ensure_user_is_set.assert_called_once_with(user.id, "error", errors=errors) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @mock.patch.object(async_query_manager, "update_job") def test_load_explore_json_into_cache(self, mock_update_job): @@ -151,3 +177,28 @@ class TestAsyncQueries(SupersetTestCase): errors = ["The dataset associated with this chart no longer exists"] mock_update_job.assert_called_once_with(job_metadata, "error", errors=errors) + + @mock.patch.object(ChartDataCommand, "run") + @mock.patch.object(async_query_manager, "update_job") + def test_soft_timeout_load_explore_json_into_cache( + self, mock_update_job, mock_run_command + ): + async_query_manager.init_app(app) + user = security_manager.find_user("gamma") + form_data = {} + job_metadata = { + "channel_id": str(uuid4()), + "job_id": str(uuid4()), + "user_id": user.id, + "status": "pending", + "errors": [], + } + errors = ["A timeout occurred while loading explore json, error"] + + with pytest.raises(SoftTimeLimitExceeded): + with mock.patch.object( + async_queries, "ensure_user_is_set", + ) as ensure_user_is_set: + ensure_user_is_set.side_effect = SoftTimeLimitExceeded() + load_explore_json_into_cache(job_metadata, form_data) + ensure_user_is_set.assert_called_once_with(user.id, "error", errors=errors) diff --git a/tests/thumbnails_tests.py b/tests/thumbnails_tests.py index fb1fd689aa..5879f4a8b6 100644 --- a/tests/thumbnails_tests.py +++ b/tests/thumbnails_tests.py @@ -48,7 +48,7 @@ class TestThumbnailsSeleniumLive(LiveServerTestCase): @skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature") def test_get_async_dashboard_screenshot(self): """ - Thumbnails: Simple get async dashboard screenshot + Thumbnails: Simple get async dashboard screenshot """ dashboard = db.session.query(Dashboard).all()[0] with patch("superset.dashboards.api.DashboardRestApi.get") as mock_get: @@ -65,7 +65,7 @@ class TestThumbnails(SupersetTestCase): def test_dashboard_thumbnail_disabled(self): """ - Thumbnails: Dashboard thumbnail disabled + Thumbnails: Dashboard thumbnail disabled """ if is_feature_enabled("THUMBNAILS"): return @@ -77,7 +77,7 @@ class TestThumbnails(SupersetTestCase): def test_chart_thumbnail_disabled(self): """ - Thumbnails: Chart thumbnail disabled + Thumbnails: Chart thumbnail disabled """ if is_feature_enabled("THUMBNAILS"): return @@ -90,7 +90,7 @@ class TestThumbnails(SupersetTestCase): @skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature") def test_get_async_dashboard_screenshot(self): """ - Thumbnails: Simple get async dashboard screenshot + Thumbnails: Simple get async dashboard screenshot """ dashboard = db.session.query(Dashboard).all()[0] self.login(username="admin") @@ -105,7 +105,7 @@ class TestThumbnails(SupersetTestCase): @skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature") def test_get_async_dashboard_notfound(self): """ - Thumbnails: Simple get async dashboard not found + Thumbnails: Simple get async dashboard not found """ max_id = db.session.query(func.max(Dashboard.id)).scalar() self.login(username="admin") @@ -116,7 +116,7 @@ class TestThumbnails(SupersetTestCase): @skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature") def test_get_async_dashboard_not_allowed(self): """ - Thumbnails: Simple get async dashboard not allowed + Thumbnails: Simple get async dashboard not allowed """ dashboard = db.session.query(Dashboard).all()[0] self.login(username="gamma") @@ -127,7 +127,7 @@ class TestThumbnails(SupersetTestCase): @skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature") def test_get_async_chart_screenshot(self): """ - Thumbnails: Simple get async chart screenshot + Thumbnails: Simple get async chart screenshot """ chart = db.session.query(Slice).all()[0] self.login(username="admin") @@ -142,7 +142,7 @@ class TestThumbnails(SupersetTestCase): @skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature") def test_get_async_chart_notfound(self): """ - Thumbnails: Simple get async chart not found + Thumbnails: Simple get async chart not found """ max_id = db.session.query(func.max(Slice.id)).scalar() self.login(username="admin") @@ -153,7 +153,7 @@ class TestThumbnails(SupersetTestCase): @skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature") def test_get_cached_chart_wrong_digest(self): """ - Thumbnails: Simple get chart with wrong digest + Thumbnails: Simple get chart with wrong digest """ chart = db.session.query(Slice).all()[0] chart_url = get_url_path("Superset.slice", slice_id=chart.id, standalone="true") @@ -169,7 +169,7 @@ class TestThumbnails(SupersetTestCase): @skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature") def test_get_cached_dashboard_screenshot(self): """ - Thumbnails: Simple get cached dashboard screenshot + Thumbnails: Simple get cached dashboard screenshot """ dashboard = db.session.query(Dashboard).all()[0] dashboard_url = get_url_path("Superset.dashboard", dashboard_id=dashboard.id) @@ -185,7 +185,7 @@ class TestThumbnails(SupersetTestCase): @skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature") def test_get_cached_chart_screenshot(self): """ - Thumbnails: Simple get cached chart screenshot + Thumbnails: Simple get cached chart screenshot """ chart = db.session.query(Slice).all()[0] chart_url = get_url_path("Superset.slice", slice_id=chart.id, standalone="true") @@ -201,7 +201,7 @@ class TestThumbnails(SupersetTestCase): @skipUnless((is_feature_enabled("THUMBNAILS")), "Thumbnails feature") def test_get_cached_dashboard_wrong_digest(self): """ - Thumbnails: Simple get dashboard with wrong digest + Thumbnails: Simple get dashboard with wrong digest """ dashboard = db.session.query(Dashboard).all()[0] dashboard_url = get_url_path("Superset.dashboard", dashboard_id=dashboard.id)