From 9ce6b7de839b28308d0f84e5d9abeb87d33a33a7 Mon Sep 17 00:00:00 2001 From: ofekisr <35701650+ofekisr@users.noreply.github.com> Date: Fri, 12 Nov 2021 14:44:21 +0200 Subject: [PATCH] refactor ChartDataCommand - separate loading query_context form cache into different module (#17405) --- superset/charts/api.py | 44 ++----------------- superset/charts/commands/data.py | 10 ----- superset/charts/data/api.py | 21 +++++---- .../charts/data/query_context_cache_loader.py | 30 +++++++++++++ tests/integration_tests/charts/api_tests.py | 18 ++++---- 5 files changed, 55 insertions(+), 68 deletions(-) create mode 100644 superset/charts/data/query_context_cache_loader.py diff --git a/superset/charts/api.py b/superset/charts/api.py index f44a901615..312ad9a2b6 100644 --- a/superset/charts/api.py +++ b/superset/charts/api.py @@ -18,11 +18,10 @@ import json import logging from datetime import datetime from io import BytesIO -from typing import Any, Dict, Optional +from typing import Any, Optional from zipfile import ZipFile -import simplejson -from flask import g, make_response, redirect, request, Response, send_file, url_for +from flask import g, redirect, request, Response, send_file, url_for from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.hooks import before_request from flask_appbuilder.models.sqla.interface import SQLAInterface @@ -49,7 +48,6 @@ from superset.charts.commands.importers.dispatcher import ImportChartsCommand from superset.charts.commands.update import UpdateChartCommand from superset.charts.dao import ChartDAO from superset.charts.filters import ChartAllTextFilter, ChartFavoriteFilter, ChartFilter -from superset.charts.post_processing import apply_post_process from superset.charts.schemas import ( CHART_SCHEMAS, ChartPostSchema, @@ -63,12 +61,10 @@ from superset.charts.schemas import ( ) from superset.commands.importers.exceptions import NoValidFilesFoundError from superset.commands.importers.v1.utils import get_contents_from_bundle -from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod -from superset.extensions import event_logger, security_manager +from superset.extensions import event_logger from superset.models.slice import Slice from superset.tasks.thumbnails import cache_chart_thumbnail -from superset.utils.core import json_int_dttm_ser from superset.utils.screenshots import ChartScreenshot from superset.utils.urls import get_url_path from superset.views.base_api import ( @@ -76,7 +72,6 @@ from superset.views.base_api import ( RelatedFieldFilter, statsd_metrics, ) -from superset.views.core import CsvResponse, generate_download_headers from superset.views.filters import FilterRelatedOwners logger = logging.getLogger(__name__) @@ -483,39 +478,6 @@ class ChartRestApi(BaseSupersetModelRestApi): except ChartBulkDeleteFailedError as ex: return self.response_422(message=str(ex)) - def send_chart_response( - self, result: Dict[Any, Any], form_data: Optional[Dict[str, Any]] = None, - ) -> Response: - result_type = result["query_context"].result_type - result_format = result["query_context"].result_format - - # Post-process the data so it matches the data presented in the chart. - # This is needed for sending reports based on text charts that do the - # post-processing of data, eg, the pivot table. - if result_type == ChartDataResultType.POST_PROCESSED: - result = apply_post_process(result, form_data) - - if result_format == ChartDataResultFormat.CSV: - # Verify user has permission to export CSV file - if not security_manager.can_access("can_csv", "Superset"): - return self.response_403() - - # return the first result - data = result["queries"][0]["data"] - return CsvResponse(data, headers=generate_download_headers("csv")) - - if result_format == ChartDataResultFormat.JSON: - response_data = simplejson.dumps( - {"result": result["queries"]}, - default=json_int_dttm_ser, - ignore_nan=True, - ) - resp = make_response(response_data, 200) - resp.headers["Content-Type"] = "application/json; charset=utf-8" - return resp - - return self.response_400(message=f"Unsupported result_format: {result_format}") - @expose("//cache_screenshot/", methods=["GET"]) @protect() @rison(screenshot_query_schema) diff --git a/superset/charts/commands/data.py b/superset/charts/commands/data.py index 619244c239..ec63362a5c 100644 --- a/superset/charts/commands/data.py +++ b/superset/charts/commands/data.py @@ -20,7 +20,6 @@ from typing import Any, Dict, Optional from flask import Request from marshmallow import ValidationError -from superset import cache from superset.charts.commands.exceptions import ( ChartDataCacheLoadError, ChartDataQueryFailedError, @@ -90,12 +89,3 @@ class ChartDataCommand(BaseCommand): def validate_async_request(self, request: Request) -> None: jwt_data = async_query_manager.parse_jwt_from_request(request) self._async_channel_id = jwt_data["channel"] - - def load_query_context_from_cache( # pylint: disable=no-self-use - self, cache_key: str - ) -> Dict[str, Any]: - cache_value = cache.get(cache_key) - if not cache_value: - raise ChartDataCacheLoadError("Cached data not found") - - return cache_value["data"] diff --git a/superset/charts/data/api.py b/superset/charts/data/api.py index c68760edbd..37703339e7 100644 --- a/superset/charts/data/api.py +++ b/superset/charts/data/api.py @@ -33,6 +33,7 @@ from superset.charts.commands.exceptions import ( ChartDataCacheLoadError, ChartDataQueryFailedError, ) +from superset.charts.data.query_context_cache_loader import QueryContextCacheLoader from superset.charts.post_processing import apply_post_process from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType from superset.exceptions import QueryObjectValidationError @@ -151,7 +152,7 @@ class ChartDataRestApi(ChartRestApi): except (TypeError, json.decoder.JSONDecodeError): form_data = {} - return self.get_data_response(command, form_data=form_data) + return self._get_data_response(command, form_data=form_data) @expose("/data", methods=["POST"]) @protect() @@ -232,7 +233,7 @@ class ChartDataRestApi(ChartRestApi): ): return self._run_async(command) - return self.get_data_response(command) + return self._get_data_response(command) @expose("/data/", methods=["GET"]) @protect() @@ -276,7 +277,7 @@ class ChartDataRestApi(ChartRestApi): """ command = ChartDataCommand() try: - cached_data = command.load_query_context_from_cache(cache_key) + cached_data = self._load_query_context_form_from_cache(cache_key) command.set_query_context(cached_data) command.validate() except ChartDataCacheLoadError: @@ -286,7 +287,7 @@ class ChartDataRestApi(ChartRestApi): message=_("Request is incorrect: %(error)s", error=error.messages) ) - return self.get_data_response(command, True) + return self._get_data_response(command, True) def _run_async(self, command: ChartDataCommand) -> Response: """ @@ -302,7 +303,7 @@ class ChartDataRestApi(ChartRestApi): # If the chart query has already been cached, return it immediately. if already_cached_result: - return self.send_chart_response(result) + return self._send_chart_response(result) # Otherwise, kick off a background job to run the chart query. # Clients will either poll or be notified of query completion, @@ -316,7 +317,7 @@ class ChartDataRestApi(ChartRestApi): result = command.run_async(g.user.get_id()) return self.response(202, **result) - def send_chart_response( + def _send_chart_response( self, result: Dict[Any, Any], form_data: Optional[Dict[str, Any]] = None, ) -> Response: result_type = result["query_context"].result_type @@ -349,7 +350,7 @@ class ChartDataRestApi(ChartRestApi): return self.response_400(message=f"Unsupported result_format: {result_format}") - def get_data_response( + def _get_data_response( self, command: ChartDataCommand, force_cached: bool = False, @@ -362,4 +363,8 @@ class ChartDataRestApi(ChartRestApi): except ChartDataQueryFailedError as exc: return self.response_400(message=exc.message) - return self.send_chart_response(result, form_data) + return self._send_chart_response(result, form_data) + + # pylint: disable=invalid-name, no-self-use + def _load_query_context_form_from_cache(self, cache_key: str) -> Dict[str, Any]: + return QueryContextCacheLoader.load(cache_key) diff --git a/superset/charts/data/query_context_cache_loader.py b/superset/charts/data/query_context_cache_loader.py new file mode 100644 index 0000000000..b5ff3bdae8 --- /dev/null +++ b/superset/charts/data/query_context_cache_loader.py @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Any, Dict + +from superset import cache +from superset.charts.commands.exceptions import ChartDataCacheLoadError + + +class QueryContextCacheLoader: # pylint: disable=too-few-public-methods + @staticmethod + def load(cache_key: str) -> Dict[str, Any]: + cache_value = cache.get(cache_key) + if not cache_value: + raise ChartDataCacheLoadError("Cached data not found") + + return cache_value["data"] diff --git a/tests/integration_tests/charts/api_tests.py b/tests/integration_tests/charts/api_tests.py index 696b52154b..dddd16f201 100644 --- a/tests/integration_tests/charts/api_tests.py +++ b/tests/integration_tests/charts/api_tests.py @@ -1620,15 +1620,15 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @with_feature_flags(GLOBAL_ASYNC_QUERIES=True) - @mock.patch.object(ChartDataCommand, "load_query_context_from_cache") - def test_chart_data_cache(self, load_qc_mock): + @mock.patch("superset.charts.data.api.QueryContextCacheLoader") + def test_chart_data_cache(self, cache_loader): """ Chart data cache API: Test chart data async cache request """ async_query_manager.init_app(app) self.login(username="admin") query_context = get_query_context("birth_names") - load_qc_mock.return_value = query_context + cache_loader.load.return_value = query_context orig_run = ChartDataCommand.run def mock_run(self, **kwargs): @@ -1647,16 +1647,16 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin): self.assertEqual(data["result"][0]["rowcount"], expected_row_count) @with_feature_flags(GLOBAL_ASYNC_QUERIES=True) - @mock.patch.object(ChartDataCommand, "load_query_context_from_cache") + @mock.patch("superset.charts.data.api.QueryContextCacheLoader") @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - def test_chart_data_cache_run_failed(self, load_qc_mock): + def test_chart_data_cache_run_failed(self, cache_loader): """ Chart data cache API: Test chart data async cache request with run failure """ async_query_manager.init_app(app) self.login(username="admin") query_context = get_query_context("birth_names") - load_qc_mock.return_value = query_context + cache_loader.load.return_value = query_context rv = self.get_assert_metric( f"{CHART_DATA_URI}/test-cache-key", "data_from_cache" ) @@ -1666,15 +1666,15 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin): self.assertEqual(data["message"], "Error loading data from cache") @with_feature_flags(GLOBAL_ASYNC_QUERIES=True) - @mock.patch.object(ChartDataCommand, "load_query_context_from_cache") + @mock.patch("superset.charts.data.api.QueryContextCacheLoader") @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") - def test_chart_data_cache_no_login(self, load_qc_mock): + def test_chart_data_cache_no_login(self, cache_loader): """ Chart data cache API: Test chart data async cache request (no login) """ async_query_manager.init_app(app) query_context = get_query_context("birth_names") - load_qc_mock.return_value = query_context + cache_loader.load.return_value = query_context orig_run = ChartDataCommand.run def mock_run(self, **kwargs):