diff --git a/superset/tasks/async_queries.py b/superset/tasks/async_queries.py index 74adcd080c..1157c5fd37 100644 --- a/superset/tasks/async_queries.py +++ b/superset/tasks/async_queries.py @@ -33,6 +33,7 @@ from superset.extensions import ( security_manager, ) from superset.utils.cache import generate_cache_key, set_and_log_cache +from superset.utils.core import override_user from superset.views.utils import get_datasource_info, get_viz if TYPE_CHECKING: @@ -44,16 +45,6 @@ query_timeout = current_app.config[ ] # TODO: new config key -def ensure_user_is_set(user_id: Optional[int]) -> None: - user_is_not_set = not (hasattr(g, "user") and g.user is not None) - if user_is_not_set and user_id is not None: - # pylint: disable=assigning-non-slot - g.user = security_manager.get_user_by_id(user_id) - elif user_is_not_set: - # pylint: disable=assigning-non-slot - g.user = security_manager.get_anonymous_user() - - def set_form_data(form_data: Dict[str, Any]) -> None: # pylint: disable=assigning-non-slot g.form_data = form_data @@ -76,30 +67,35 @@ def load_chart_data_into_cache( # pylint: disable=import-outside-toplevel from superset.charts.data.commands.get_data_command import ChartDataCommand - try: - ensure_user_is_set(job_metadata.get("user_id")) - set_form_data(form_data) - query_context = _create_query_context_from_form(form_data) - command = ChartDataCommand(query_context) - result = command.run(cache=True) - cache_key = result["cache_key"] - result_url = f"/api/v1/chart/data/{cache_key}" - async_query_manager.update_job( - job_metadata, - async_query_manager.STATUS_DONE, - result_url=result_url, - ) - except SoftTimeLimitExceeded as ex: - logger.warning("A timeout occurred while loading chart data, error: %s", ex) - raise ex - except Exception as ex: - # TODO: QueryContext should support SIP-40 style errors - error = ex.message if hasattr(ex, "message") else str(ex) # type: ignore # pylint: disable=no-member - errors = [{"message": error}] - async_query_manager.update_job( - job_metadata, async_query_manager.STATUS_ERROR, errors=errors - ) - raise ex + user = ( + security_manager.get_user_by_id(job_metadata.get("user_id")) + or security_manager.get_anonymous_user() + ) + + with override_user(user, force=False): + try: + set_form_data(form_data) + query_context = _create_query_context_from_form(form_data) + command = ChartDataCommand(query_context) + result = command.run(cache=True) + cache_key = result["cache_key"] + result_url = f"/api/v1/chart/data/{cache_key}" + async_query_manager.update_job( + job_metadata, + async_query_manager.STATUS_DONE, + result_url=result_url, + ) + except SoftTimeLimitExceeded as ex: + logger.warning("A timeout occurred while loading chart data, error: %s", ex) + raise ex + except Exception as ex: + # TODO: QueryContext should support SIP-40 style errors + error = ex.message if hasattr(ex, "message") else str(ex) # type: ignore # pylint: disable=no-member + errors = [{"message": error}] + async_query_manager.update_job( + job_metadata, async_query_manager.STATUS_ERROR, errors=errors + ) + raise ex @celery_app.task(name="load_explore_json_into_cache", soft_time_limit=query_timeout) @@ -110,53 +106,61 @@ def load_explore_json_into_cache( # pylint: disable=too-many-locals force: bool = False, ) -> None: cache_key_prefix = "ejr-" # ejr: explore_json request - try: - ensure_user_is_set(job_metadata.get("user_id")) - set_form_data(form_data) - datasource_id, datasource_type = get_datasource_info(None, None, form_data) - # Perform a deep copy here so that below we can cache the original - # value of the form_data object. This is necessary since the viz - # objects modify the form_data object. If the modified version were - # to be cached here, it will lead to a cache miss when clients - # attempt to retrieve the value of the completed async query. - original_form_data = copy.deepcopy(form_data) + user = ( + security_manager.get_user_by_id(job_metadata.get("user_id")) + or security_manager.get_anonymous_user() + ) - viz_obj = get_viz( - datasource_type=cast(str, datasource_type), - datasource_id=datasource_id, - form_data=form_data, - force=force, - ) - # run query & cache results - payload = viz_obj.get_payload() - if viz_obj.has_error(payload): - raise SupersetVizException(errors=payload["errors"]) + with override_user(user, force=False): + try: + set_form_data(form_data) + datasource_id, datasource_type = get_datasource_info(None, None, form_data) - # Cache the original form_data value for async retrieval - cache_value = { - "form_data": original_form_data, - "response_type": response_type, - } - cache_key = generate_cache_key(cache_value, cache_key_prefix) - set_and_log_cache(cache_manager.cache, cache_key, cache_value) - result_url = f"/superset/explore_json/data/{cache_key}" - async_query_manager.update_job( - job_metadata, - async_query_manager.STATUS_DONE, - result_url=result_url, - ) - except SoftTimeLimitExceeded as ex: - logger.warning("A timeout occurred while loading explore json, error: %s", ex) - raise ex - except Exception as ex: - if isinstance(ex, SupersetVizException): - errors = ex.errors # pylint: disable=no-member - else: - error = ex.message if hasattr(ex, "message") else str(ex) # type: ignore # pylint: disable=no-member - errors = [error] + # Perform a deep copy here so that below we can cache the original + # value of the form_data object. This is necessary since the viz + # objects modify the form_data object. If the modified version were + # to be cached here, it will lead to a cache miss when clients + # attempt to retrieve the value of the completed async query. + original_form_data = copy.deepcopy(form_data) - async_query_manager.update_job( - job_metadata, async_query_manager.STATUS_ERROR, errors=errors - ) - raise ex + viz_obj = get_viz( + datasource_type=cast(str, datasource_type), + datasource_id=datasource_id, + form_data=form_data, + force=force, + ) + # run query & cache results + payload = viz_obj.get_payload() + if viz_obj.has_error(payload): + raise SupersetVizException(errors=payload["errors"]) + + # Cache the original form_data value for async retrieval + cache_value = { + "form_data": original_form_data, + "response_type": response_type, + } + cache_key = generate_cache_key(cache_value, cache_key_prefix) + set_and_log_cache(cache_manager.cache, cache_key, cache_value) + result_url = f"/superset/explore_json/data/{cache_key}" + async_query_manager.update_job( + job_metadata, + async_query_manager.STATUS_DONE, + result_url=result_url, + ) + except SoftTimeLimitExceeded as ex: + logger.warning( + "A timeout occurred while loading explore json, error: %s", ex + ) + raise ex + except Exception as ex: + if isinstance(ex, SupersetVizException): + errors = ex.errors # pylint: disable=no-member + else: + error = ex.message if hasattr(ex, "message") else str(ex) # type: ignore # pylint: disable=no-member + errors = [error] + + async_query_manager.update_job( + job_metadata, async_query_manager.STATUS_ERROR, errors=errors + ) + raise ex diff --git a/superset/utils/core.py b/superset/utils/core.py index 336ab4e208..aeb45051b6 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -1453,23 +1453,27 @@ def get_user_id() -> Optional[int]: @contextmanager -def override_user(user: Optional[User]) -> Iterator[Any]: +def override_user(user: Optional[User], force: bool = True) -> Iterator[Any]: """ - Temporarily override the current user (if defined) per `flask.g`. + Temporarily override the current user per `flask.g` with the specified user. Sometimes, often in the context of async Celery tasks, it is useful to switch the current user (which may be undefined) to different one, execute some SQLAlchemy - tasks and then revert back to the original one. + tasks et al. and then revert back to the original one. :param user: The override user + :param force: Whether to override the current user if set """ # pylint: disable=assigning-non-slot if hasattr(g, "user"): - current = g.user - g.user = user - yield - g.user = current + if force or g.user is None: + current = g.user + g.user = user + yield + g.user = current + else: + yield else: g.user = user yield diff --git a/tests/integration_tests/access_tests.py b/tests/integration_tests/access_tests.py index 2e1e897a4f..5ab03055d9 100644 --- a/tests/integration_tests/access_tests.py +++ b/tests/integration_tests/access_tests.py @@ -562,34 +562,34 @@ def test_get_username( assert get_username() == username -@pytest.mark.parametrize( - "username", - [ - None, - "alpha", - "gamma", - ], -) +@pytest.mark.parametrize("username", [None, "alpha", "gamma"]) +@pytest.mark.parametrize("force", [False, True]) def test_override_user( app_context: AppContext, mocker: MockFixture, username: str, + force: bool, ) -> None: mock_g = mocker.patch("superset.utils.core.g", spec={}) admin = security_manager.find_user(username="admin") user = security_manager.find_user(username) - assert not hasattr(mock_g, "user") - - with override_user(user): + with override_user(user, force): assert mock_g.user == user assert not hasattr(mock_g, "user") + mock_g.user = None + + with override_user(user, force): + assert mock_g.user == user + + assert mock_g.user is None + mock_g.user = admin - with override_user(user): - assert mock_g.user == user + with override_user(user, force): + assert mock_g.user == user if force else admin assert mock_g.user == admin diff --git a/tests/integration_tests/tasks/async_queries_tests.py b/tests/integration_tests/tasks/async_queries_tests.py index 5a51c06601..20d0f39eea 100644 --- a/tests/integration_tests/tasks/async_queries_tests.py +++ b/tests/integration_tests/tasks/async_queries_tests.py @@ -28,7 +28,6 @@ from superset.exceptions import SupersetException from superset.extensions import async_query_manager, security_manager from superset.tasks import async_queries from superset.tasks.async_queries import ( - ensure_user_is_set, load_chart_data_into_cache, load_explore_json_into_cache, ) @@ -58,12 +57,7 @@ class TestAsyncQueries(SupersetTestCase): "errors": [], } - with mock.patch.object( - async_queries, "ensure_user_is_set" - ) as ensure_user_is_set: - load_chart_data_into_cache(job_metadata, query_context) - - ensure_user_is_set.assert_called_once_with(user.id) + load_chart_data_into_cache(job_metadata, query_context) mock_set_form_data.assert_called_once_with(query_context) mock_update_job.assert_called_once_with( job_metadata, "done", result_url=mock.ANY @@ -85,11 +79,7 @@ class TestAsyncQueries(SupersetTestCase): "errors": [], } with pytest.raises(ChartDataQueryFailedError): - with mock.patch.object( - async_queries, "ensure_user_is_set" - ) as ensure_user_is_set: - load_chart_data_into_cache(job_metadata, query_context) - ensure_user_is_set.assert_called_once_with(user.id) + load_chart_data_into_cache(job_metadata, query_context) mock_run_command.assert_called_once_with(cache=True) errors = [{"message": "Error: foo"}] @@ -115,11 +105,11 @@ class TestAsyncQueries(SupersetTestCase): with pytest.raises(SoftTimeLimitExceeded): with mock.patch.object( async_queries, - "ensure_user_is_set", - ) as ensure_user_is_set: - ensure_user_is_set.side_effect = SoftTimeLimitExceeded() + "set_form_data", + ) as set_form_data: + set_form_data.side_effect = SoftTimeLimitExceeded() load_chart_data_into_cache(job_metadata, form_data) - ensure_user_is_set.assert_called_once_with(user.id, "error", errors=errors) + set_form_data.assert_called_once_with(form_data, "error", errors=errors) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @mock.patch.object(async_query_manager, "update_job") @@ -145,12 +135,7 @@ class TestAsyncQueries(SupersetTestCase): "errors": [], } - with mock.patch.object( - async_queries, "ensure_user_is_set" - ) as ensure_user_is_set: - load_explore_json_into_cache(job_metadata, form_data) - - ensure_user_is_set.assert_called_once_with(user.id) + load_explore_json_into_cache(job_metadata, form_data) mock_update_job.assert_called_once_with( job_metadata, "done", result_url=mock.ANY ) @@ -172,11 +157,7 @@ class TestAsyncQueries(SupersetTestCase): } with pytest.raises(SupersetException): - with mock.patch.object( - async_queries, "ensure_user_is_set" - ) as ensure_user_is_set: - load_explore_json_into_cache(job_metadata, form_data) - ensure_user_is_set.assert_called_once_with(user.id) + load_explore_json_into_cache(job_metadata, form_data) mock_set_form_data.assert_called_once_with(form_data) errors = ["The dataset associated with this chart no longer exists"] @@ -202,49 +183,8 @@ class TestAsyncQueries(SupersetTestCase): with pytest.raises(SoftTimeLimitExceeded): with mock.patch.object( async_queries, - "ensure_user_is_set", - ) as ensure_user_is_set: - ensure_user_is_set.side_effect = SoftTimeLimitExceeded() + "set_form_data", + ) as set_form_data: + set_form_data.side_effect = SoftTimeLimitExceeded() load_explore_json_into_cache(job_metadata, form_data) - ensure_user_is_set.assert_called_once_with(user.id, "error", errors=errors) - - def test_ensure_user_is_set(self): - g_user_is_set = hasattr(g, "user") - original_g_user = g.user if g_user_is_set else None - - if g_user_is_set: - del g.user - - self.assertFalse(hasattr(g, "user")) - ensure_user_is_set(1) - self.assertTrue(hasattr(g, "user")) - self.assertFalse(g.user.is_anonymous) - self.assertEqual(1, get_user_id()) - - del g.user - - self.assertFalse(hasattr(g, "user")) - ensure_user_is_set(None) - self.assertTrue(hasattr(g, "user")) - self.assertTrue(g.user.is_anonymous) - self.assertEqual(None, get_user_id()) - - del g.user - - g.user = security_manager.get_user_by_id(2) - self.assertEqual(2, get_user_id()) - - ensure_user_is_set(1) - self.assertTrue(hasattr(g, "user")) - self.assertFalse(g.user.is_anonymous) - self.assertEqual(2, get_user_id()) - - ensure_user_is_set(None) - self.assertTrue(hasattr(g, "user")) - self.assertFalse(g.user.is_anonymous) - self.assertEqual(2, get_user_id()) - - if g_user_is_set: - g.user = original_g_user - else: - del g.user + set_form_data.assert_called_once_with(form_data, "error", errors=errors)