feat(embedded+async queries): support async queries to work with embedded guest user (#26332)

This commit is contained in:
Zef Lin 2024-01-08 17:11:45 -08:00 committed by GitHub
parent 4c2e818cd3
commit efdeb9df05
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 134 additions and 19 deletions

View File

@ -191,9 +191,14 @@ class AsyncQueryManager:
force: Optional[bool] = False,
user_id: Optional[int] = None,
) -> dict[str, Any]:
# pylint: disable=import-outside-toplevel
from superset import security_manager
job_metadata = self.init_job(channel_id, user_id)
self._load_explore_json_into_cache_job.delay(
job_metadata,
{**job_metadata, "guest_token": guest_user.guest_token}
if (guest_user := security_manager.get_current_guest_user_if_guest())
else job_metadata,
form_data,
response_type,
force,
@ -201,10 +206,25 @@ class AsyncQueryManager:
return job_metadata
def submit_chart_data_job(
self, channel_id: str, form_data: dict[str, Any], user_id: Optional[int]
self,
channel_id: str,
form_data: dict[str, Any],
user_id: Optional[int] = None,
) -> dict[str, Any]:
# pylint: disable=import-outside-toplevel
from superset import security_manager
# if it's guest user, we want to pass the guest token to the celery task
# chart data cache key is calculated based on the current user
# this way we can keep the cache key consistent between sync and async command
# so that it can be looked up consistently
job_metadata = self.init_job(channel_id, user_id)
self._load_chart_data_into_cache_job.delay(job_metadata, form_data)
self._load_chart_data_into_cache_job.delay(
{**job_metadata, "guest_token": guest_user.guest_token}
if (guest_user := security_manager.get_current_guest_user_if_guest())
else job_metadata,
form_data,
)
return job_metadata
def read_events(

View File

@ -600,7 +600,15 @@ class QueryContextProcessor:
set_and_log_cache(
cache_manager.cache,
cache_key,
{"data": self._query_context.cache_values},
{
"data": {
# setting form_data into query context cache value as well
# so that it can be used to reconstruct form_data field
# for query context object when reading from cache
"form_data": self._query_context.form_data,
**self._query_context.cache_values,
},
},
self.get_cache_timeout(),
)
return_value["cache_key"] = cache_key # type: ignore

View File

@ -22,6 +22,7 @@ from typing import Any, cast, TYPE_CHECKING
from celery.exceptions import SoftTimeLimitExceeded
from flask import current_app, g
from flask_appbuilder.security.sqla.models import User
from marshmallow import ValidationError
from superset.charts.schemas import ChartDataQueryContextSchema
@ -58,6 +59,20 @@ def _create_query_context_from_form(form_data: dict[str, Any]) -> QueryContext:
raise error
def _load_user_from_job_metadata(job_metadata: dict[str, Any]) -> User:
if user_id := job_metadata.get("user_id"):
# logged in user
user = security_manager.get_user_by_id(user_id)
elif guest_token := job_metadata.get("guest_token"):
# embedded guest user
user = security_manager.get_guest_user_from_token(guest_token)
del job_metadata["guest_token"]
else:
# default to anonymous user if no user is found
user = security_manager.get_anonymous_user()
return user
@celery_app.task(name="load_chart_data_into_cache", soft_time_limit=query_timeout)
def load_chart_data_into_cache(
job_metadata: dict[str, Any],
@ -66,12 +81,7 @@ def load_chart_data_into_cache(
# pylint: disable=import-outside-toplevel
from superset.commands.chart.data.get_data_command import ChartDataCommand
user = (
security_manager.get_user_by_id(job_metadata.get("user_id"))
or security_manager.get_anonymous_user()
)
with override_user(user, force=False):
with override_user(_load_user_from_job_metadata(job_metadata), force=False):
try:
set_form_data(form_data)
query_context = _create_query_context_from_form(form_data)
@ -106,12 +116,7 @@ def load_explore_json_into_cache( # pylint: disable=too-many-locals
) -> None:
cache_key_prefix = "ejr-" # ejr: explore_json request
user = (
security_manager.get_user_by_id(job_metadata.get("user_id"))
or security_manager.get_anonymous_user()
)
with override_user(user, force=False):
with override_user(_load_user_from_job_metadata(job_metadata), force=False):
try:
set_form_data(form_data)
datasource_id, datasource_type = get_datasource_info(None, None, form_data)
@ -140,7 +145,13 @@ def load_explore_json_into_cache( # pylint: disable=too-many-locals
"response_type": response_type,
}
cache_key = generate_cache_key(cache_value, cache_key_prefix)
set_and_log_cache(cache_manager.cache, cache_key, cache_value)
cache_instance = cache_manager.cache
cache_timeout = (
cache_instance.cache.default_timeout if cache_instance.cache else None
)
set_and_log_cache(
cache_instance, cache_key, cache_value, cache_timeout=cache_timeout
)
result_url = f"/superset/explore_json/data/{cache_key}"
async_query_manager.update_job(
job_metadata,

View File

@ -121,6 +121,7 @@ class TestQueryContext(SupersetTestCase):
cached = cache_manager.cache.get(cache_key)
assert cached is not None
assert "form_data" in cached["data"]
rehydrated_qc = ChartDataQueryContextSchema().load(cached["data"])
rehydrated_qo = rehydrated_qc.queries[0]

View File

@ -14,12 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest import mock
from unittest.mock import ANY, Mock
from unittest.mock import Mock
from flask import g
from jwt import encode
from pytest import fixture, raises
from superset import security_manager
from superset.async_events.async_query_manager import (
AsyncQueryManager,
AsyncQueryTokenException,
@ -38,6 +40,12 @@ def async_query_manager():
return query_manager
def set_current_as_guest_user():
g.user = security_manager.get_guest_user_from_token(
{"user": {}, "resources": [{"type": "dashboard", "id": "some-uuid"}]}
)
def test_parse_channel_id_from_request(async_query_manager):
encoded_token = encode(
{"channel": "test_channel_id"}, JWT_TOKEN_SECRET, algorithm="HS256"
@ -65,3 +73,70 @@ def test_parse_channel_id_from_request_bad_jwt(async_query_manager):
with raises(AsyncQueryTokenException):
async_query_manager.parse_channel_id_from_request(request)
@mock.patch("superset.is_feature_enabled")
def test_submit_chart_data_job_as_guest_user(
is_feature_enabled_mock, async_query_manager
):
is_feature_enabled_mock.return_value = True
set_current_as_guest_user()
job_mock = Mock()
async_query_manager._load_chart_data_into_cache_job = job_mock
job_meta = async_query_manager.submit_chart_data_job(
channel_id="test_channel_id",
form_data={},
)
job_mock.delay.assert_called_once_with(
{
"channel_id": "test_channel_id",
"errors": [],
"guest_token": {
"resources": [{"id": "some-uuid", "type": "dashboard"}],
"user": {},
},
"job_id": ANY,
"result_url": None,
"status": "pending",
"user_id": None,
},
{},
)
assert "guest_token" not in job_meta
@mock.patch("superset.is_feature_enabled")
def test_submit_explore_json_job_as_guest_user(
is_feature_enabled_mock, async_query_manager
):
is_feature_enabled_mock.return_value = True
set_current_as_guest_user()
job_mock = Mock()
async_query_manager._load_explore_json_into_cache_job = job_mock
job_meta = async_query_manager.submit_explore_json_job(
channel_id="test_channel_id",
form_data={},
response_type="json",
)
job_mock.delay.assert_called_once_with(
{
"channel_id": "test_channel_id",
"errors": [],
"guest_token": {
"resources": [{"id": "some-uuid", "type": "dashboard"}],
"user": {},
},
"job_id": ANY,
"result_url": None,
"status": "pending",
"user_id": None,
},
{},
"json",
False,
)
assert "guest_token" not in job_meta