refactor: Deprecate ensure_user_is_set in favor of override_user (#20502)

Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
John Bodley 2022-07-05 10:57:40 -07:00 committed by GitHub
parent ad308fbde2
commit 94b3d2f0f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 120 additions and 172 deletions

View File

@ -33,6 +33,7 @@ from superset.extensions import (
security_manager,
)
from superset.utils.cache import generate_cache_key, set_and_log_cache
from superset.utils.core import override_user
from superset.views.utils import get_datasource_info, get_viz
if TYPE_CHECKING:
@ -44,16 +45,6 @@ query_timeout = current_app.config[
] # TODO: new config key
def ensure_user_is_set(user_id: Optional[int]) -> None:
user_is_not_set = not (hasattr(g, "user") and g.user is not None)
if user_is_not_set and user_id is not None:
# pylint: disable=assigning-non-slot
g.user = security_manager.get_user_by_id(user_id)
elif user_is_not_set:
# pylint: disable=assigning-non-slot
g.user = security_manager.get_anonymous_user()
def set_form_data(form_data: Dict[str, Any]) -> None:
# pylint: disable=assigning-non-slot
g.form_data = form_data
@ -76,30 +67,35 @@ def load_chart_data_into_cache(
# pylint: disable=import-outside-toplevel
from superset.charts.data.commands.get_data_command import ChartDataCommand
try:
ensure_user_is_set(job_metadata.get("user_id"))
set_form_data(form_data)
query_context = _create_query_context_from_form(form_data)
command = ChartDataCommand(query_context)
result = command.run(cache=True)
cache_key = result["cache_key"]
result_url = f"/api/v1/chart/data/{cache_key}"
async_query_manager.update_job(
job_metadata,
async_query_manager.STATUS_DONE,
result_url=result_url,
)
except SoftTimeLimitExceeded as ex:
logger.warning("A timeout occurred while loading chart data, error: %s", ex)
raise ex
except Exception as ex:
# TODO: QueryContext should support SIP-40 style errors
error = ex.message if hasattr(ex, "message") else str(ex) # type: ignore # pylint: disable=no-member
errors = [{"message": error}]
async_query_manager.update_job(
job_metadata, async_query_manager.STATUS_ERROR, errors=errors
)
raise ex
user = (
security_manager.get_user_by_id(job_metadata.get("user_id"))
or security_manager.get_anonymous_user()
)
with override_user(user, force=False):
try:
set_form_data(form_data)
query_context = _create_query_context_from_form(form_data)
command = ChartDataCommand(query_context)
result = command.run(cache=True)
cache_key = result["cache_key"]
result_url = f"/api/v1/chart/data/{cache_key}"
async_query_manager.update_job(
job_metadata,
async_query_manager.STATUS_DONE,
result_url=result_url,
)
except SoftTimeLimitExceeded as ex:
logger.warning("A timeout occurred while loading chart data, error: %s", ex)
raise ex
except Exception as ex:
# TODO: QueryContext should support SIP-40 style errors
error = ex.message if hasattr(ex, "message") else str(ex) # type: ignore # pylint: disable=no-member
errors = [{"message": error}]
async_query_manager.update_job(
job_metadata, async_query_manager.STATUS_ERROR, errors=errors
)
raise ex
@celery_app.task(name="load_explore_json_into_cache", soft_time_limit=query_timeout)
@ -110,53 +106,61 @@ def load_explore_json_into_cache( # pylint: disable=too-many-locals
force: bool = False,
) -> None:
cache_key_prefix = "ejr-" # ejr: explore_json request
try:
ensure_user_is_set(job_metadata.get("user_id"))
set_form_data(form_data)
datasource_id, datasource_type = get_datasource_info(None, None, form_data)
# Perform a deep copy here so that below we can cache the original
# value of the form_data object. This is necessary since the viz
# objects modify the form_data object. If the modified version were
# to be cached here, it will lead to a cache miss when clients
# attempt to retrieve the value of the completed async query.
original_form_data = copy.deepcopy(form_data)
user = (
security_manager.get_user_by_id(job_metadata.get("user_id"))
or security_manager.get_anonymous_user()
)
viz_obj = get_viz(
datasource_type=cast(str, datasource_type),
datasource_id=datasource_id,
form_data=form_data,
force=force,
)
# run query & cache results
payload = viz_obj.get_payload()
if viz_obj.has_error(payload):
raise SupersetVizException(errors=payload["errors"])
with override_user(user, force=False):
try:
set_form_data(form_data)
datasource_id, datasource_type = get_datasource_info(None, None, form_data)
# Cache the original form_data value for async retrieval
cache_value = {
"form_data": original_form_data,
"response_type": response_type,
}
cache_key = generate_cache_key(cache_value, cache_key_prefix)
set_and_log_cache(cache_manager.cache, cache_key, cache_value)
result_url = f"/superset/explore_json/data/{cache_key}"
async_query_manager.update_job(
job_metadata,
async_query_manager.STATUS_DONE,
result_url=result_url,
)
except SoftTimeLimitExceeded as ex:
logger.warning("A timeout occurred while loading explore json, error: %s", ex)
raise ex
except Exception as ex:
if isinstance(ex, SupersetVizException):
errors = ex.errors # pylint: disable=no-member
else:
error = ex.message if hasattr(ex, "message") else str(ex) # type: ignore # pylint: disable=no-member
errors = [error]
# Perform a deep copy here so that below we can cache the original
# value of the form_data object. This is necessary since the viz
# objects modify the form_data object. If the modified version were
# to be cached here, it will lead to a cache miss when clients
# attempt to retrieve the value of the completed async query.
original_form_data = copy.deepcopy(form_data)
async_query_manager.update_job(
job_metadata, async_query_manager.STATUS_ERROR, errors=errors
)
raise ex
viz_obj = get_viz(
datasource_type=cast(str, datasource_type),
datasource_id=datasource_id,
form_data=form_data,
force=force,
)
# run query & cache results
payload = viz_obj.get_payload()
if viz_obj.has_error(payload):
raise SupersetVizException(errors=payload["errors"])
# Cache the original form_data value for async retrieval
cache_value = {
"form_data": original_form_data,
"response_type": response_type,
}
cache_key = generate_cache_key(cache_value, cache_key_prefix)
set_and_log_cache(cache_manager.cache, cache_key, cache_value)
result_url = f"/superset/explore_json/data/{cache_key}"
async_query_manager.update_job(
job_metadata,
async_query_manager.STATUS_DONE,
result_url=result_url,
)
except SoftTimeLimitExceeded as ex:
logger.warning(
"A timeout occurred while loading explore json, error: %s", ex
)
raise ex
except Exception as ex:
if isinstance(ex, SupersetVizException):
errors = ex.errors # pylint: disable=no-member
else:
error = ex.message if hasattr(ex, "message") else str(ex) # type: ignore # pylint: disable=no-member
errors = [error]
async_query_manager.update_job(
job_metadata, async_query_manager.STATUS_ERROR, errors=errors
)
raise ex

View File

@ -1453,23 +1453,27 @@ def get_user_id() -> Optional[int]:
@contextmanager
def override_user(user: Optional[User]) -> Iterator[Any]:
def override_user(user: Optional[User], force: bool = True) -> Iterator[Any]:
"""
Temporarily override the current user (if defined) per `flask.g`.
Temporarily override the current user per `flask.g` with the specified user.
Sometimes, often in the context of async Celery tasks, it is useful to switch the
current user (which may be undefined) to different one, execute some SQLAlchemy
tasks and then revert back to the original one.
tasks et al. and then revert back to the original one.
:param user: The override user
:param force: Whether to override the current user if set
"""
# pylint: disable=assigning-non-slot
if hasattr(g, "user"):
current = g.user
g.user = user
yield
g.user = current
if force or g.user is None:
current = g.user
g.user = user
yield
g.user = current
else:
yield
else:
g.user = user
yield

View File

@ -562,34 +562,34 @@ def test_get_username(
assert get_username() == username
@pytest.mark.parametrize(
"username",
[
None,
"alpha",
"gamma",
],
)
@pytest.mark.parametrize("username", [None, "alpha", "gamma"])
@pytest.mark.parametrize("force", [False, True])
def test_override_user(
app_context: AppContext,
mocker: MockFixture,
username: str,
force: bool,
) -> None:
mock_g = mocker.patch("superset.utils.core.g", spec={})
admin = security_manager.find_user(username="admin")
user = security_manager.find_user(username)
assert not hasattr(mock_g, "user")
with override_user(user):
with override_user(user, force):
assert mock_g.user == user
assert not hasattr(mock_g, "user")
mock_g.user = None
with override_user(user, force):
assert mock_g.user == user
assert mock_g.user is None
mock_g.user = admin
with override_user(user):
assert mock_g.user == user
with override_user(user, force):
assert mock_g.user == user if force else admin
assert mock_g.user == admin

View File

@ -28,7 +28,6 @@ from superset.exceptions import SupersetException
from superset.extensions import async_query_manager, security_manager
from superset.tasks import async_queries
from superset.tasks.async_queries import (
ensure_user_is_set,
load_chart_data_into_cache,
load_explore_json_into_cache,
)
@ -58,12 +57,7 @@ class TestAsyncQueries(SupersetTestCase):
"errors": [],
}
with mock.patch.object(
async_queries, "ensure_user_is_set"
) as ensure_user_is_set:
load_chart_data_into_cache(job_metadata, query_context)
ensure_user_is_set.assert_called_once_with(user.id)
load_chart_data_into_cache(job_metadata, query_context)
mock_set_form_data.assert_called_once_with(query_context)
mock_update_job.assert_called_once_with(
job_metadata, "done", result_url=mock.ANY
@ -85,11 +79,7 @@ class TestAsyncQueries(SupersetTestCase):
"errors": [],
}
with pytest.raises(ChartDataQueryFailedError):
with mock.patch.object(
async_queries, "ensure_user_is_set"
) as ensure_user_is_set:
load_chart_data_into_cache(job_metadata, query_context)
ensure_user_is_set.assert_called_once_with(user.id)
load_chart_data_into_cache(job_metadata, query_context)
mock_run_command.assert_called_once_with(cache=True)
errors = [{"message": "Error: foo"}]
@ -115,11 +105,11 @@ class TestAsyncQueries(SupersetTestCase):
with pytest.raises(SoftTimeLimitExceeded):
with mock.patch.object(
async_queries,
"ensure_user_is_set",
) as ensure_user_is_set:
ensure_user_is_set.side_effect = SoftTimeLimitExceeded()
"set_form_data",
) as set_form_data:
set_form_data.side_effect = SoftTimeLimitExceeded()
load_chart_data_into_cache(job_metadata, form_data)
ensure_user_is_set.assert_called_once_with(user.id, "error", errors=errors)
set_form_data.assert_called_once_with(form_data, "error", errors=errors)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@mock.patch.object(async_query_manager, "update_job")
@ -145,12 +135,7 @@ class TestAsyncQueries(SupersetTestCase):
"errors": [],
}
with mock.patch.object(
async_queries, "ensure_user_is_set"
) as ensure_user_is_set:
load_explore_json_into_cache(job_metadata, form_data)
ensure_user_is_set.assert_called_once_with(user.id)
load_explore_json_into_cache(job_metadata, form_data)
mock_update_job.assert_called_once_with(
job_metadata, "done", result_url=mock.ANY
)
@ -172,11 +157,7 @@ class TestAsyncQueries(SupersetTestCase):
}
with pytest.raises(SupersetException):
with mock.patch.object(
async_queries, "ensure_user_is_set"
) as ensure_user_is_set:
load_explore_json_into_cache(job_metadata, form_data)
ensure_user_is_set.assert_called_once_with(user.id)
load_explore_json_into_cache(job_metadata, form_data)
mock_set_form_data.assert_called_once_with(form_data)
errors = ["The dataset associated with this chart no longer exists"]
@ -202,49 +183,8 @@ class TestAsyncQueries(SupersetTestCase):
with pytest.raises(SoftTimeLimitExceeded):
with mock.patch.object(
async_queries,
"ensure_user_is_set",
) as ensure_user_is_set:
ensure_user_is_set.side_effect = SoftTimeLimitExceeded()
"set_form_data",
) as set_form_data:
set_form_data.side_effect = SoftTimeLimitExceeded()
load_explore_json_into_cache(job_metadata, form_data)
ensure_user_is_set.assert_called_once_with(user.id, "error", errors=errors)
def test_ensure_user_is_set(self):
g_user_is_set = hasattr(g, "user")
original_g_user = g.user if g_user_is_set else None
if g_user_is_set:
del g.user
self.assertFalse(hasattr(g, "user"))
ensure_user_is_set(1)
self.assertTrue(hasattr(g, "user"))
self.assertFalse(g.user.is_anonymous)
self.assertEqual(1, get_user_id())
del g.user
self.assertFalse(hasattr(g, "user"))
ensure_user_is_set(None)
self.assertTrue(hasattr(g, "user"))
self.assertTrue(g.user.is_anonymous)
self.assertEqual(None, get_user_id())
del g.user
g.user = security_manager.get_user_by_id(2)
self.assertEqual(2, get_user_id())
ensure_user_is_set(1)
self.assertTrue(hasattr(g, "user"))
self.assertFalse(g.user.is_anonymous)
self.assertEqual(2, get_user_id())
ensure_user_is_set(None)
self.assertTrue(hasattr(g, "user"))
self.assertFalse(g.user.is_anonymous)
self.assertEqual(2, get_user_id())
if g_user_is_set:
g.user = original_g_user
else:
del g.user
set_form_data.assert_called_once_with(form_data, "error", errors=errors)