diff --git a/superset/async_events/async_query_manager.py b/superset/async_events/async_query_manager.py index 94941541fb..32cf247cf3 100644 --- a/superset/async_events/async_query_manager.py +++ b/superset/async_events/async_query_manager.py @@ -191,9 +191,14 @@ class AsyncQueryManager: force: Optional[bool] = False, user_id: Optional[int] = None, ) -> dict[str, Any]: + # pylint: disable=import-outside-toplevel + from superset import security_manager + job_metadata = self.init_job(channel_id, user_id) self._load_explore_json_into_cache_job.delay( - job_metadata, + {**job_metadata, "guest_token": guest_user.guest_token} + if (guest_user := security_manager.get_current_guest_user_if_guest()) + else job_metadata, form_data, response_type, force, @@ -201,10 +206,25 @@ class AsyncQueryManager: return job_metadata def submit_chart_data_job( - self, channel_id: str, form_data: dict[str, Any], user_id: Optional[int] + self, + channel_id: str, + form_data: dict[str, Any], + user_id: Optional[int] = None, ) -> dict[str, Any]: + # pylint: disable=import-outside-toplevel + from superset import security_manager + + # if it's guest user, we want to pass the guest token to the celery task + # chart data cache key is calculated based on the current user + # this way we can keep the cache key consistent between sync and async command + # so that it can be looked up consistently job_metadata = self.init_job(channel_id, user_id) - self._load_chart_data_into_cache_job.delay(job_metadata, form_data) + self._load_chart_data_into_cache_job.delay( + {**job_metadata, "guest_token": guest_user.guest_token} + if (guest_user := security_manager.get_current_guest_user_if_guest()) + else job_metadata, + form_data, + ) return job_metadata def read_events( diff --git a/superset/common/query_context_processor.py b/superset/common/query_context_processor.py index 5b1414d53b..d8b5bea4bb 100644 --- a/superset/common/query_context_processor.py +++ b/superset/common/query_context_processor.py @@ -600,7 +600,15 @@ class QueryContextProcessor: set_and_log_cache( cache_manager.cache, cache_key, - {"data": self._query_context.cache_values}, + { + "data": { + # setting form_data into query context cache value as well + # so that it can be used to reconstruct form_data field + # for query context object when reading from cache + "form_data": self._query_context.form_data, + **self._query_context.cache_values, + }, + }, self.get_cache_timeout(), ) return_value["cache_key"] = cache_key # type: ignore diff --git a/superset/tasks/async_queries.py b/superset/tasks/async_queries.py index 61970ca1f3..b804847cd8 100644 --- a/superset/tasks/async_queries.py +++ b/superset/tasks/async_queries.py @@ -22,6 +22,7 @@ from typing import Any, cast, TYPE_CHECKING from celery.exceptions import SoftTimeLimitExceeded from flask import current_app, g +from flask_appbuilder.security.sqla.models import User from marshmallow import ValidationError from superset.charts.schemas import ChartDataQueryContextSchema @@ -58,6 +59,20 @@ def _create_query_context_from_form(form_data: dict[str, Any]) -> QueryContext: raise error +def _load_user_from_job_metadata(job_metadata: dict[str, Any]) -> User: + if user_id := job_metadata.get("user_id"): + # logged in user + user = security_manager.get_user_by_id(user_id) + elif guest_token := job_metadata.get("guest_token"): + # embedded guest user + user = security_manager.get_guest_user_from_token(guest_token) + del job_metadata["guest_token"] + else: + # default to anonymous user if no user is found + user = security_manager.get_anonymous_user() + return user + + @celery_app.task(name="load_chart_data_into_cache", soft_time_limit=query_timeout) def load_chart_data_into_cache( job_metadata: dict[str, Any], @@ -66,12 +81,7 @@ def load_chart_data_into_cache( # pylint: disable=import-outside-toplevel from superset.commands.chart.data.get_data_command import ChartDataCommand - user = ( - security_manager.get_user_by_id(job_metadata.get("user_id")) - or security_manager.get_anonymous_user() - ) - - with override_user(user, force=False): + with override_user(_load_user_from_job_metadata(job_metadata), force=False): try: set_form_data(form_data) query_context = _create_query_context_from_form(form_data) @@ -106,12 +116,7 @@ def load_explore_json_into_cache( # pylint: disable=too-many-locals ) -> None: cache_key_prefix = "ejr-" # ejr: explore_json request - user = ( - security_manager.get_user_by_id(job_metadata.get("user_id")) - or security_manager.get_anonymous_user() - ) - - with override_user(user, force=False): + with override_user(_load_user_from_job_metadata(job_metadata), force=False): try: set_form_data(form_data) datasource_id, datasource_type = get_datasource_info(None, None, form_data) @@ -140,7 +145,13 @@ def load_explore_json_into_cache( # pylint: disable=too-many-locals "response_type": response_type, } cache_key = generate_cache_key(cache_value, cache_key_prefix) - set_and_log_cache(cache_manager.cache, cache_key, cache_value) + cache_instance = cache_manager.cache + cache_timeout = ( + cache_instance.cache.default_timeout if cache_instance.cache else None + ) + set_and_log_cache( + cache_instance, cache_key, cache_value, cache_timeout=cache_timeout + ) result_url = f"/superset/explore_json/data/{cache_key}" async_query_manager.update_job( job_metadata, diff --git a/tests/integration_tests/query_context_tests.py b/tests/integration_tests/query_context_tests.py index 8c2082d1c4..30cd160d7e 100644 --- a/tests/integration_tests/query_context_tests.py +++ b/tests/integration_tests/query_context_tests.py @@ -121,6 +121,7 @@ class TestQueryContext(SupersetTestCase): cached = cache_manager.cache.get(cache_key) assert cached is not None + assert "form_data" in cached["data"] rehydrated_qc = ChartDataQueryContextSchema().load(cached["data"]) rehydrated_qo = rehydrated_qc.queries[0] diff --git a/tests/unit_tests/async_events/async_query_manager_tests.py b/tests/unit_tests/async_events/async_query_manager_tests.py index b4ae06dfc3..85ea114201 100644 --- a/tests/unit_tests/async_events/async_query_manager_tests.py +++ b/tests/unit_tests/async_events/async_query_manager_tests.py @@ -14,12 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from unittest import mock +from unittest.mock import ANY, Mock -from unittest.mock import Mock - +from flask import g from jwt import encode from pytest import fixture, raises +from superset import security_manager from superset.async_events.async_query_manager import ( AsyncQueryManager, AsyncQueryTokenException, @@ -38,6 +40,12 @@ def async_query_manager(): return query_manager +def set_current_as_guest_user(): + g.user = security_manager.get_guest_user_from_token( + {"user": {}, "resources": [{"type": "dashboard", "id": "some-uuid"}]} + ) + + def test_parse_channel_id_from_request(async_query_manager): encoded_token = encode( {"channel": "test_channel_id"}, JWT_TOKEN_SECRET, algorithm="HS256" @@ -65,3 +73,70 @@ def test_parse_channel_id_from_request_bad_jwt(async_query_manager): with raises(AsyncQueryTokenException): async_query_manager.parse_channel_id_from_request(request) + + +@mock.patch("superset.is_feature_enabled") +def test_submit_chart_data_job_as_guest_user( + is_feature_enabled_mock, async_query_manager +): + is_feature_enabled_mock.return_value = True + set_current_as_guest_user() + job_mock = Mock() + async_query_manager._load_chart_data_into_cache_job = job_mock + job_meta = async_query_manager.submit_chart_data_job( + channel_id="test_channel_id", + form_data={}, + ) + + job_mock.delay.assert_called_once_with( + { + "channel_id": "test_channel_id", + "errors": [], + "guest_token": { + "resources": [{"id": "some-uuid", "type": "dashboard"}], + "user": {}, + }, + "job_id": ANY, + "result_url": None, + "status": "pending", + "user_id": None, + }, + {}, + ) + + assert "guest_token" not in job_meta + + +@mock.patch("superset.is_feature_enabled") +def test_submit_explore_json_job_as_guest_user( + is_feature_enabled_mock, async_query_manager +): + is_feature_enabled_mock.return_value = True + set_current_as_guest_user() + job_mock = Mock() + async_query_manager._load_explore_json_into_cache_job = job_mock + job_meta = async_query_manager.submit_explore_json_job( + channel_id="test_channel_id", + form_data={}, + response_type="json", + ) + + job_mock.delay.assert_called_once_with( + { + "channel_id": "test_channel_id", + "errors": [], + "guest_token": { + "resources": [{"id": "some-uuid", "type": "dashboard"}], + "user": {}, + }, + "job_id": ANY, + "result_url": None, + "status": "pending", + "user_id": None, + }, + {}, + "json", + False, + ) + + assert "guest_token" not in job_meta