From 4e3d4f6daf01749b8f28e0770e138db5ed8fae91 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt <33317356+villebro@users.noreply.github.com> Date: Thu, 16 Sep 2021 19:33:41 +0300 Subject: [PATCH] feat: add global max row limit (#16683) * feat: add global max limit * fix lint and tests * leave SAMPLES_ROW_LIMIT unchanged * fix sample rowcount test * replace max global limit with existing sql max row limit * fix test * make max_limit optional in util * improve comments --- superset/common/query_actions.py | 4 -- superset/common/query_context.py | 2 +- superset/common/query_object.py | 17 ++++- superset/config.py | 8 +-- superset/utils/core.py | 22 +++++++ superset/utils/sqllab_execution_context.py | 5 +- superset/views/core.py | 5 +- superset/viz.py | 7 ++- tests/integration_tests/charts/api_tests.py | 62 +++++++++++++++++-- .../integration_tests/charts/schema_tests.py | 17 ++++- .../integration_tests/query_context_tests.py | 1 + 11 files changed, 123 insertions(+), 27 deletions(-) diff --git a/superset/common/query_actions.py b/superset/common/query_actions.py index 440ab7ce65..d0058bd68d 100644 --- a/superset/common/query_actions.py +++ b/superset/common/query_actions.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. import copy -import math from typing import Any, Callable, cast, Dict, List, Optional, TYPE_CHECKING from flask_babel import _ @@ -131,14 +130,11 @@ def _get_samples( query_context: "QueryContext", query_obj: "QueryObject", force_cached: bool = False ) -> Dict[str, Any]: datasource = _get_datasource(query_context, query_obj) - row_limit = query_obj.row_limit or math.inf query_obj = copy.copy(query_obj) query_obj.is_timeseries = False query_obj.orderby = [] query_obj.metrics = [] query_obj.post_processing = [] - query_obj.row_limit = min(row_limit, config["SAMPLES_ROW_LIMIT"]) - query_obj.row_offset = 0 query_obj.columns = [o.column_name for o in datasource.columns] return _get_full(query_context, query_obj, force_cached) diff --git a/superset/common/query_context.py b/superset/common/query_context.py index 2cf04e2645..17ba7c4823 100644 --- a/superset/common/query_context.py +++ b/superset/common/query_context.py @@ -100,11 +100,11 @@ class QueryContext: self.datasource = ConnectorRegistry.get_datasource( str(datasource["type"]), int(datasource["id"]), db.session ) - self.queries = [QueryObject(**query_obj) for query_obj in queries] self.force = force self.custom_cache_timeout = custom_cache_timeout self.result_type = result_type or ChartDataResultType.FULL self.result_format = result_format or ChartDataResultFormat.JSON + self.queries = [QueryObject(self, **query_obj) for query_obj in queries] self.cache_values = { "datasource": datasource, "queries": queries, diff --git a/superset/common/query_object.py b/superset/common/query_object.py index c8f0814673..abc94ea11d 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -16,7 +16,7 @@ # under the License. import logging from datetime import datetime, timedelta -from typing import Any, Dict, List, NamedTuple, Optional +from typing import Any, Dict, List, NamedTuple, Optional, TYPE_CHECKING from flask_babel import gettext as _ from pandas import DataFrame @@ -28,6 +28,7 @@ from superset.exceptions import QueryObjectValidationError from superset.typing import Metric, OrderBy from superset.utils import pandas_postprocessing from superset.utils.core import ( + apply_max_row_limit, ChartDataResultType, DatasourceDict, DTTM_ALIAS, @@ -41,6 +42,10 @@ from superset.utils.date_parser import get_since_until, parse_human_timedelta from superset.utils.hashing import md5_sha_from_dict from superset.views.utils import get_time_range_endpoints +if TYPE_CHECKING: + from superset.common.query_context import QueryContext # pragma: no cover + + config = app.config logger = logging.getLogger(__name__) @@ -103,6 +108,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes def __init__( # pylint: disable=too-many-arguments,too-many-locals self, + query_context: "QueryContext", annotation_layers: Optional[List[Dict[str, Any]]] = None, applied_time_extras: Optional[Dict[str, str]] = None, apply_fetch_values_predicate: bool = False, @@ -146,7 +152,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes self.datasource = ConnectorRegistry.get_datasource( str(datasource["type"]), int(datasource["id"]), db.session ) - self.result_type = result_type + self.result_type = result_type or query_context.result_type self.apply_fetch_values_predicate = apply_fetch_values_predicate or False self.annotation_layers = [ layer @@ -186,7 +192,12 @@ class QueryObject: # pylint: disable=too-many-instance-attributes for x in metrics ] - self.row_limit = config["ROW_LIMIT"] if row_limit is None else row_limit + default_row_limit = ( + config["SAMPLES_ROW_LIMIT"] + if self.result_type == ChartDataResultType.SAMPLES + else config["ROW_LIMIT"] + ) + self.row_limit = apply_max_row_limit(row_limit or default_row_limit) self.row_offset = row_offset or 0 self.filter = filters or [] self.series_limit = series_limit diff --git a/superset/config.py b/superset/config.py index 8fea2efab2..c7809690e1 100644 --- a/superset/config.py +++ b/superset/config.py @@ -121,9 +121,9 @@ BUILD_NUMBER = None # default viz used in chart explorer DEFAULT_VIZ_TYPE = "table" +# default row limit when requesting chart data ROW_LIMIT = 50000 -VIZ_ROW_LIMIT = 10000 -# max rows retreieved when requesting samples from datasource in explore view +# default row limit when requesting samples from datasource in explore view SAMPLES_ROW_LIMIT = 1000 # max rows retrieved by filter select auto complete FILTER_SELECT_ROW_LIMIT = 10000 @@ -671,9 +671,7 @@ QUERY_LOGGER = None # Set this API key to enable Mapbox visualizations MAPBOX_API_KEY = os.environ.get("MAPBOX_API_KEY", "") -# Maximum number of rows returned from a database -# in async mode, no more than SQL_MAX_ROW will be returned and stored -# in the results backend. This also becomes the limit when exporting CSVs +# Maximum number of rows returned for any analytical database query SQL_MAX_ROW = 100000 # Maximum number of rows displayed in SQL Lab UI diff --git a/superset/utils/core.py b/superset/utils/core.py index a0647b6dc5..7146a3b57e 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -1762,3 +1762,25 @@ def parse_boolean_string(bool_str: Optional[str]) -> bool: return bool(strtobool(bool_str.lower())) except ValueError: return False + + +def apply_max_row_limit(limit: int, max_limit: Optional[int] = None,) -> int: + """ + Override row limit if max global limit is defined + + :param limit: requested row limit + :param max_limit: Maximum allowed row limit + :return: Capped row limit + + >>> apply_max_row_limit(100000, 10) + 10 + >>> apply_max_row_limit(10, 100000) + 10 + >>> apply_max_row_limit(0, 10000) + 10000 + """ + if max_limit is None: + max_limit = current_app.config["SQL_MAX_ROW"] + if limit != 0: + return min(max_limit, limit) + return max_limit diff --git a/superset/utils/sqllab_execution_context.py b/superset/utils/sqllab_execution_context.py index 6c5532b348..09ae33d54d 100644 --- a/superset/utils/sqllab_execution_context.py +++ b/superset/utils/sqllab_execution_context.py @@ -23,10 +23,11 @@ from typing import Any, cast, Dict, Optional, TYPE_CHECKING from flask import g -from superset import app, is_feature_enabled +from superset import is_feature_enabled from superset.models.sql_lab import Query from superset.sql_parse import CtasMethod from superset.utils import core as utils +from superset.utils.core import apply_max_row_limit from superset.utils.dates import now_as_float from superset.views.utils import get_cta_schema_name @@ -102,7 +103,7 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes @staticmethod def _get_limit_param(query_params: Dict[str, Any]) -> int: - limit: int = query_params.get("queryLimit") or app.config["SQL_MAX_ROW"] + limit = apply_max_row_limit(query_params.get("queryLimit") or 0) if limit < 0: logger.warning( "Invalid limit of %i specified. Defaulting to max limit.", limit diff --git a/superset/views/core.py b/superset/views/core.py index 4ed7e89cdb..a07f7d509b 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -104,7 +104,7 @@ from superset.typing import FlaskResponse from superset.utils import core as utils, csv from superset.utils.async_query_manager import AsyncQueryTokenException from superset.utils.cache import etag_cache -from superset.utils.core import ReservedUrlParameters +from superset.utils.core import apply_max_row_limit, ReservedUrlParameters from superset.utils.dates import now_as_float from superset.utils.decorators import check_dashboard_access from superset.utils.sqllab_execution_context import SqlJsonExecutionContext @@ -897,8 +897,9 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods return json_error_response(DATASOURCE_MISSING_ERR) datasource.raise_for_access() + row_limit = apply_max_row_limit(config["FILTER_SELECT_ROW_LIMIT"]) payload = json.dumps( - datasource.values_for_column(column, config["FILTER_SELECT_ROW_LIMIT"]), + datasource.values_for_column(column, row_limit), default=utils.json_int_dttm_ser, ignore_nan=True, ) diff --git a/superset/viz.py b/superset/viz.py index cff9eda8d8..3ab2e2b564 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -69,6 +69,7 @@ from superset.typing import Metric, QueryObjectDict, VizData, VizPayload from superset.utils import core as utils, csv from superset.utils.cache import set_and_log_cache from superset.utils.core import ( + apply_max_row_limit, DTTM_ALIAS, ExtraFiltersReasonType, JS_MAX_INTEGER, @@ -324,7 +325,10 @@ class BaseViz: # pylint: disable=too-many-public-methods ) limit = int(self.form_data.get("limit") or 0) timeseries_limit_metric = self.form_data.get("timeseries_limit_metric") + + # apply row limit to query row_limit = int(self.form_data.get("row_limit") or config["ROW_LIMIT"]) + row_limit = apply_max_row_limit(row_limit) # default order direction order_desc = self.form_data.get("order_desc", True) @@ -1687,9 +1691,6 @@ class HistogramViz(BaseViz): def query_obj(self) -> QueryObjectDict: """Returns the query object for this visualization""" query_obj = super().query_obj() - query_obj["row_limit"] = self.form_data.get( - "row_limit", int(config["VIZ_ROW_LIMIT"]) - ) numeric_columns = self.form_data.get("all_columns_x") if numeric_columns is None: raise QueryObjectValidationError( diff --git a/tests/integration_tests/charts/api_tests.py b/tests/integration_tests/charts/api_tests.py index 05ddb0832c..3647442eba 100644 --- a/tests/integration_tests/charts/api_tests.py +++ b/tests/integration_tests/charts/api_tests.py @@ -18,7 +18,7 @@ """Unit tests for Superset""" import json import unittest -from datetime import datetime, timedelta +from datetime import datetime from io import BytesIO from typing import Optional from unittest import mock @@ -1203,6 +1203,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin): self.login(username="admin") request_payload = get_query_context("birth_names") del request_payload["queries"][0]["row_limit"] + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") response_payload = json.loads(rv.data.decode("utf-8")) result = response_payload["result"][0] @@ -1210,11 +1211,46 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @mock.patch( - "superset.common.query_actions.config", {**app.config, "SAMPLES_ROW_LIMIT": 5}, + "superset.utils.core.current_app.config", {**app.config, "SQL_MAX_ROW": 10}, ) - def test_chart_data_default_sample_limit(self): + def test_chart_data_sql_max_row_limit(self): """ - Chart data API: Ensure sample response row count doesn't exceed default limit + Chart data API: Ensure row count doesn't exceed max global row limit + """ + self.login(username="admin") + request_payload = get_query_context("birth_names") + request_payload["queries"][0]["row_limit"] = 10000000 + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + response_payload = json.loads(rv.data.decode("utf-8")) + result = response_payload["result"][0] + self.assertEqual(result["rowcount"], 10) + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + @mock.patch( + "superset.common.query_object.config", {**app.config, "SAMPLES_ROW_LIMIT": 5}, + ) + def test_chart_data_sample_default_limit(self): + """ + Chart data API: Ensure sample response row count defaults to config defaults + """ + self.login(username="admin") + request_payload = get_query_context("birth_names") + request_payload["result_type"] = utils.ChartDataResultType.SAMPLES + del request_payload["queries"][0]["row_limit"] + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + response_payload = json.loads(rv.data.decode("utf-8")) + result = response_payload["result"][0] + self.assertEqual(result["rowcount"], 5) + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + @mock.patch( + "superset.common.query_actions.config", + {**app.config, "SAMPLES_ROW_LIMIT": 5, "SQL_MAX_ROW": 15}, + ) + def test_chart_data_sample_custom_limit(self): + """ + Chart data API: Ensure requested sample response row count is between + default and SQL max row limit """ self.login(username="admin") request_payload = get_query_context("birth_names") @@ -1223,6 +1259,24 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin): rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") response_payload = json.loads(rv.data.decode("utf-8")) result = response_payload["result"][0] + self.assertEqual(result["rowcount"], 10) + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + @mock.patch( + "superset.utils.core.current_app.config", {**app.config, "SQL_MAX_ROW": 5}, + ) + def test_chart_data_sql_max_row_sample_limit(self): + """ + Chart data API: Ensure requested sample response row count doesn't + exceed SQL max row limit + """ + self.login(username="admin") + request_payload = get_query_context("birth_names") + request_payload["result_type"] = utils.ChartDataResultType.SAMPLES + request_payload["queries"][0]["row_limit"] = 10000000 + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + response_payload = json.loads(rv.data.decode("utf-8")) + result = response_payload["result"][0] self.assertEqual(result["rowcount"], 5) def test_chart_data_incorrect_result_type(self): diff --git a/tests/integration_tests/charts/schema_tests.py b/tests/integration_tests/charts/schema_tests.py index e34b7d71fb..977cf72957 100644 --- a/tests/integration_tests/charts/schema_tests.py +++ b/tests/integration_tests/charts/schema_tests.py @@ -16,17 +16,25 @@ # under the License. # isort:skip_file """Unit tests for Superset""" -from typing import Any, Dict, Tuple +from unittest import mock + +import pytest from marshmallow import ValidationError from tests.integration_tests.test_app import app from superset.charts.schemas import ChartDataQueryContextSchema -from superset.common.query_context import QueryContext from tests.integration_tests.base_tests import SupersetTestCase +from tests.integration_tests.fixtures.birth_names_dashboard import ( + load_birth_names_dashboard_with_slices, +) from tests.integration_tests.fixtures.query_context import get_query_context class TestSchema(SupersetTestCase): + @mock.patch( + "superset.common.query_object.config", {**app.config, "ROW_LIMIT": 5000}, + ) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_query_context_limit_and_offset(self): self.login(username="admin") payload = get_query_context("birth_names") @@ -36,7 +44,7 @@ class TestSchema(SupersetTestCase): payload["queries"][0].pop("row_offset", None) query_context = ChartDataQueryContextSchema().load(payload) query_object = query_context.queries[0] - self.assertEqual(query_object.row_limit, app.config["ROW_LIMIT"]) + self.assertEqual(query_object.row_limit, 5000) self.assertEqual(query_object.row_offset, 0) # Valid limit and offset @@ -55,12 +63,14 @@ class TestSchema(SupersetTestCase): self.assertIn("row_limit", context.exception.messages["queries"][0]) self.assertIn("row_offset", context.exception.messages["queries"][0]) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_query_context_null_timegrain(self): self.login(username="admin") payload = get_query_context("birth_names") payload["queries"][0]["extras"]["time_grain_sqla"] = None _ = ChartDataQueryContextSchema().load(payload) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_query_context_series_limit(self): self.login(username="admin") payload = get_query_context("birth_names") @@ -82,6 +92,7 @@ class TestSchema(SupersetTestCase): } _ = ChartDataQueryContextSchema().load(payload) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_query_context_null_post_processing_op(self): self.login(username="admin") payload = get_query_context("birth_names") diff --git a/tests/integration_tests/query_context_tests.py b/tests/integration_tests/query_context_tests.py index b6d7f97e2d..ecc69b7b8d 100644 --- a/tests/integration_tests/query_context_tests.py +++ b/tests/integration_tests/query_context_tests.py @@ -90,6 +90,7 @@ class TestQueryContext(SupersetTestCase): self.assertEqual(post_proc["operation"], payload_post_proc["operation"]) self.assertEqual(post_proc["options"], payload_post_proc["options"]) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_cache(self): table_name = "birth_names" table = self.get_table(name=table_name)