feat: add global max row limit (#16683)

* feat: add global max limit

* fix lint and tests

* leave SAMPLES_ROW_LIMIT unchanged

* fix sample rowcount test

* replace max global limit with existing sql max row limit

* fix test

* make max_limit optional in util

* improve comments
This commit is contained in:
Ville Brofeldt 2021-09-16 19:33:41 +03:00 committed by GitHub
parent 633f29f3e9
commit 4e3d4f6daf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 123 additions and 27 deletions

View File

@ -15,7 +15,6 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import copy import copy
import math
from typing import Any, Callable, cast, Dict, List, Optional, TYPE_CHECKING from typing import Any, Callable, cast, Dict, List, Optional, TYPE_CHECKING
from flask_babel import _ from flask_babel import _
@ -131,14 +130,11 @@ def _get_samples(
query_context: "QueryContext", query_obj: "QueryObject", force_cached: bool = False query_context: "QueryContext", query_obj: "QueryObject", force_cached: bool = False
) -> Dict[str, Any]: ) -> Dict[str, Any]:
datasource = _get_datasource(query_context, query_obj) datasource = _get_datasource(query_context, query_obj)
row_limit = query_obj.row_limit or math.inf
query_obj = copy.copy(query_obj) query_obj = copy.copy(query_obj)
query_obj.is_timeseries = False query_obj.is_timeseries = False
query_obj.orderby = [] query_obj.orderby = []
query_obj.metrics = [] query_obj.metrics = []
query_obj.post_processing = [] query_obj.post_processing = []
query_obj.row_limit = min(row_limit, config["SAMPLES_ROW_LIMIT"])
query_obj.row_offset = 0
query_obj.columns = [o.column_name for o in datasource.columns] query_obj.columns = [o.column_name for o in datasource.columns]
return _get_full(query_context, query_obj, force_cached) return _get_full(query_context, query_obj, force_cached)

View File

@ -100,11 +100,11 @@ class QueryContext:
self.datasource = ConnectorRegistry.get_datasource( self.datasource = ConnectorRegistry.get_datasource(
str(datasource["type"]), int(datasource["id"]), db.session str(datasource["type"]), int(datasource["id"]), db.session
) )
self.queries = [QueryObject(**query_obj) for query_obj in queries]
self.force = force self.force = force
self.custom_cache_timeout = custom_cache_timeout self.custom_cache_timeout = custom_cache_timeout
self.result_type = result_type or ChartDataResultType.FULL self.result_type = result_type or ChartDataResultType.FULL
self.result_format = result_format or ChartDataResultFormat.JSON self.result_format = result_format or ChartDataResultFormat.JSON
self.queries = [QueryObject(self, **query_obj) for query_obj in queries]
self.cache_values = { self.cache_values = {
"datasource": datasource, "datasource": datasource,
"queries": queries, "queries": queries,

View File

@ -16,7 +16,7 @@
# under the License. # under the License.
import logging import logging
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Dict, List, NamedTuple, Optional from typing import Any, Dict, List, NamedTuple, Optional, TYPE_CHECKING
from flask_babel import gettext as _ from flask_babel import gettext as _
from pandas import DataFrame from pandas import DataFrame
@ -28,6 +28,7 @@ from superset.exceptions import QueryObjectValidationError
from superset.typing import Metric, OrderBy from superset.typing import Metric, OrderBy
from superset.utils import pandas_postprocessing from superset.utils import pandas_postprocessing
from superset.utils.core import ( from superset.utils.core import (
apply_max_row_limit,
ChartDataResultType, ChartDataResultType,
DatasourceDict, DatasourceDict,
DTTM_ALIAS, DTTM_ALIAS,
@ -41,6 +42,10 @@ from superset.utils.date_parser import get_since_until, parse_human_timedelta
from superset.utils.hashing import md5_sha_from_dict from superset.utils.hashing import md5_sha_from_dict
from superset.views.utils import get_time_range_endpoints from superset.views.utils import get_time_range_endpoints
if TYPE_CHECKING:
from superset.common.query_context import QueryContext # pragma: no cover
config = app.config config = app.config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -103,6 +108,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
def __init__( # pylint: disable=too-many-arguments,too-many-locals def __init__( # pylint: disable=too-many-arguments,too-many-locals
self, self,
query_context: "QueryContext",
annotation_layers: Optional[List[Dict[str, Any]]] = None, annotation_layers: Optional[List[Dict[str, Any]]] = None,
applied_time_extras: Optional[Dict[str, str]] = None, applied_time_extras: Optional[Dict[str, str]] = None,
apply_fetch_values_predicate: bool = False, apply_fetch_values_predicate: bool = False,
@ -146,7 +152,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
self.datasource = ConnectorRegistry.get_datasource( self.datasource = ConnectorRegistry.get_datasource(
str(datasource["type"]), int(datasource["id"]), db.session str(datasource["type"]), int(datasource["id"]), db.session
) )
self.result_type = result_type self.result_type = result_type or query_context.result_type
self.apply_fetch_values_predicate = apply_fetch_values_predicate or False self.apply_fetch_values_predicate = apply_fetch_values_predicate or False
self.annotation_layers = [ self.annotation_layers = [
layer layer
@ -186,7 +192,12 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
for x in metrics for x in metrics
] ]
self.row_limit = config["ROW_LIMIT"] if row_limit is None else row_limit default_row_limit = (
config["SAMPLES_ROW_LIMIT"]
if self.result_type == ChartDataResultType.SAMPLES
else config["ROW_LIMIT"]
)
self.row_limit = apply_max_row_limit(row_limit or default_row_limit)
self.row_offset = row_offset or 0 self.row_offset = row_offset or 0
self.filter = filters or [] self.filter = filters or []
self.series_limit = series_limit self.series_limit = series_limit

View File

@ -121,9 +121,9 @@ BUILD_NUMBER = None
# default viz used in chart explorer # default viz used in chart explorer
DEFAULT_VIZ_TYPE = "table" DEFAULT_VIZ_TYPE = "table"
# default row limit when requesting chart data
ROW_LIMIT = 50000 ROW_LIMIT = 50000
VIZ_ROW_LIMIT = 10000 # default row limit when requesting samples from datasource in explore view
# max rows retreieved when requesting samples from datasource in explore view
SAMPLES_ROW_LIMIT = 1000 SAMPLES_ROW_LIMIT = 1000
# max rows retrieved by filter select auto complete # max rows retrieved by filter select auto complete
FILTER_SELECT_ROW_LIMIT = 10000 FILTER_SELECT_ROW_LIMIT = 10000
@ -671,9 +671,7 @@ QUERY_LOGGER = None
# Set this API key to enable Mapbox visualizations # Set this API key to enable Mapbox visualizations
MAPBOX_API_KEY = os.environ.get("MAPBOX_API_KEY", "") MAPBOX_API_KEY = os.environ.get("MAPBOX_API_KEY", "")
# Maximum number of rows returned from a database # Maximum number of rows returned for any analytical database query
# in async mode, no more than SQL_MAX_ROW will be returned and stored
# in the results backend. This also becomes the limit when exporting CSVs
SQL_MAX_ROW = 100000 SQL_MAX_ROW = 100000
# Maximum number of rows displayed in SQL Lab UI # Maximum number of rows displayed in SQL Lab UI

View File

@ -1762,3 +1762,25 @@ def parse_boolean_string(bool_str: Optional[str]) -> bool:
return bool(strtobool(bool_str.lower())) return bool(strtobool(bool_str.lower()))
except ValueError: except ValueError:
return False return False
def apply_max_row_limit(limit: int, max_limit: Optional[int] = None,) -> int:
"""
Override row limit if max global limit is defined
:param limit: requested row limit
:param max_limit: Maximum allowed row limit
:return: Capped row limit
>>> apply_max_row_limit(100000, 10)
10
>>> apply_max_row_limit(10, 100000)
10
>>> apply_max_row_limit(0, 10000)
10000
"""
if max_limit is None:
max_limit = current_app.config["SQL_MAX_ROW"]
if limit != 0:
return min(max_limit, limit)
return max_limit

View File

@ -23,10 +23,11 @@ from typing import Any, cast, Dict, Optional, TYPE_CHECKING
from flask import g from flask import g
from superset import app, is_feature_enabled from superset import is_feature_enabled
from superset.models.sql_lab import Query from superset.models.sql_lab import Query
from superset.sql_parse import CtasMethod from superset.sql_parse import CtasMethod
from superset.utils import core as utils from superset.utils import core as utils
from superset.utils.core import apply_max_row_limit
from superset.utils.dates import now_as_float from superset.utils.dates import now_as_float
from superset.views.utils import get_cta_schema_name from superset.views.utils import get_cta_schema_name
@ -102,7 +103,7 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes
@staticmethod @staticmethod
def _get_limit_param(query_params: Dict[str, Any]) -> int: def _get_limit_param(query_params: Dict[str, Any]) -> int:
limit: int = query_params.get("queryLimit") or app.config["SQL_MAX_ROW"] limit = apply_max_row_limit(query_params.get("queryLimit") or 0)
if limit < 0: if limit < 0:
logger.warning( logger.warning(
"Invalid limit of %i specified. Defaulting to max limit.", limit "Invalid limit of %i specified. Defaulting to max limit.", limit

View File

@ -104,7 +104,7 @@ from superset.typing import FlaskResponse
from superset.utils import core as utils, csv from superset.utils import core as utils, csv
from superset.utils.async_query_manager import AsyncQueryTokenException from superset.utils.async_query_manager import AsyncQueryTokenException
from superset.utils.cache import etag_cache from superset.utils.cache import etag_cache
from superset.utils.core import ReservedUrlParameters from superset.utils.core import apply_max_row_limit, ReservedUrlParameters
from superset.utils.dates import now_as_float from superset.utils.dates import now_as_float
from superset.utils.decorators import check_dashboard_access from superset.utils.decorators import check_dashboard_access
from superset.utils.sqllab_execution_context import SqlJsonExecutionContext from superset.utils.sqllab_execution_context import SqlJsonExecutionContext
@ -897,8 +897,9 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
return json_error_response(DATASOURCE_MISSING_ERR) return json_error_response(DATASOURCE_MISSING_ERR)
datasource.raise_for_access() datasource.raise_for_access()
row_limit = apply_max_row_limit(config["FILTER_SELECT_ROW_LIMIT"])
payload = json.dumps( payload = json.dumps(
datasource.values_for_column(column, config["FILTER_SELECT_ROW_LIMIT"]), datasource.values_for_column(column, row_limit),
default=utils.json_int_dttm_ser, default=utils.json_int_dttm_ser,
ignore_nan=True, ignore_nan=True,
) )

View File

@ -69,6 +69,7 @@ from superset.typing import Metric, QueryObjectDict, VizData, VizPayload
from superset.utils import core as utils, csv from superset.utils import core as utils, csv
from superset.utils.cache import set_and_log_cache from superset.utils.cache import set_and_log_cache
from superset.utils.core import ( from superset.utils.core import (
apply_max_row_limit,
DTTM_ALIAS, DTTM_ALIAS,
ExtraFiltersReasonType, ExtraFiltersReasonType,
JS_MAX_INTEGER, JS_MAX_INTEGER,
@ -324,7 +325,10 @@ class BaseViz: # pylint: disable=too-many-public-methods
) )
limit = int(self.form_data.get("limit") or 0) limit = int(self.form_data.get("limit") or 0)
timeseries_limit_metric = self.form_data.get("timeseries_limit_metric") timeseries_limit_metric = self.form_data.get("timeseries_limit_metric")
# apply row limit to query
row_limit = int(self.form_data.get("row_limit") or config["ROW_LIMIT"]) row_limit = int(self.form_data.get("row_limit") or config["ROW_LIMIT"])
row_limit = apply_max_row_limit(row_limit)
# default order direction # default order direction
order_desc = self.form_data.get("order_desc", True) order_desc = self.form_data.get("order_desc", True)
@ -1687,9 +1691,6 @@ class HistogramViz(BaseViz):
def query_obj(self) -> QueryObjectDict: def query_obj(self) -> QueryObjectDict:
"""Returns the query object for this visualization""" """Returns the query object for this visualization"""
query_obj = super().query_obj() query_obj = super().query_obj()
query_obj["row_limit"] = self.form_data.get(
"row_limit", int(config["VIZ_ROW_LIMIT"])
)
numeric_columns = self.form_data.get("all_columns_x") numeric_columns = self.form_data.get("all_columns_x")
if numeric_columns is None: if numeric_columns is None:
raise QueryObjectValidationError( raise QueryObjectValidationError(

View File

@ -18,7 +18,7 @@
"""Unit tests for Superset""" """Unit tests for Superset"""
import json import json
import unittest import unittest
from datetime import datetime, timedelta from datetime import datetime
from io import BytesIO from io import BytesIO
from typing import Optional from typing import Optional
from unittest import mock from unittest import mock
@ -1203,6 +1203,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
self.login(username="admin") self.login(username="admin")
request_payload = get_query_context("birth_names") request_payload = get_query_context("birth_names")
del request_payload["queries"][0]["row_limit"] del request_payload["queries"][0]["row_limit"]
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8")) response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0] result = response_payload["result"][0]
@ -1210,11 +1211,46 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@mock.patch( @mock.patch(
"superset.common.query_actions.config", {**app.config, "SAMPLES_ROW_LIMIT": 5}, "superset.utils.core.current_app.config", {**app.config, "SQL_MAX_ROW": 10},
) )
def test_chart_data_default_sample_limit(self): def test_chart_data_sql_max_row_limit(self):
""" """
Chart data API: Ensure sample response row count doesn't exceed default limit Chart data API: Ensure row count doesn't exceed max global row limit
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["queries"][0]["row_limit"] = 10000000
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
self.assertEqual(result["rowcount"], 10)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@mock.patch(
"superset.common.query_object.config", {**app.config, "SAMPLES_ROW_LIMIT": 5},
)
def test_chart_data_sample_default_limit(self):
"""
Chart data API: Ensure sample response row count defaults to config defaults
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["result_type"] = utils.ChartDataResultType.SAMPLES
del request_payload["queries"][0]["row_limit"]
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
self.assertEqual(result["rowcount"], 5)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@mock.patch(
"superset.common.query_actions.config",
{**app.config, "SAMPLES_ROW_LIMIT": 5, "SQL_MAX_ROW": 15},
)
def test_chart_data_sample_custom_limit(self):
"""
Chart data API: Ensure requested sample response row count is between
default and SQL max row limit
""" """
self.login(username="admin") self.login(username="admin")
request_payload = get_query_context("birth_names") request_payload = get_query_context("birth_names")
@ -1223,6 +1259,24 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8")) response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0] result = response_payload["result"][0]
self.assertEqual(result["rowcount"], 10)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@mock.patch(
"superset.utils.core.current_app.config", {**app.config, "SQL_MAX_ROW": 5},
)
def test_chart_data_sql_max_row_sample_limit(self):
"""
Chart data API: Ensure requested sample response row count doesn't
exceed SQL max row limit
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["result_type"] = utils.ChartDataResultType.SAMPLES
request_payload["queries"][0]["row_limit"] = 10000000
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
self.assertEqual(result["rowcount"], 5) self.assertEqual(result["rowcount"], 5)
def test_chart_data_incorrect_result_type(self): def test_chart_data_incorrect_result_type(self):

View File

@ -16,17 +16,25 @@
# under the License. # under the License.
# isort:skip_file # isort:skip_file
"""Unit tests for Superset""" """Unit tests for Superset"""
from typing import Any, Dict, Tuple from unittest import mock
import pytest
from marshmallow import ValidationError from marshmallow import ValidationError
from tests.integration_tests.test_app import app from tests.integration_tests.test_app import app
from superset.charts.schemas import ChartDataQueryContextSchema from superset.charts.schemas import ChartDataQueryContextSchema
from superset.common.query_context import QueryContext
from tests.integration_tests.base_tests import SupersetTestCase from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
)
from tests.integration_tests.fixtures.query_context import get_query_context from tests.integration_tests.fixtures.query_context import get_query_context
class TestSchema(SupersetTestCase): class TestSchema(SupersetTestCase):
@mock.patch(
"superset.common.query_object.config", {**app.config, "ROW_LIMIT": 5000},
)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_query_context_limit_and_offset(self): def test_query_context_limit_and_offset(self):
self.login(username="admin") self.login(username="admin")
payload = get_query_context("birth_names") payload = get_query_context("birth_names")
@ -36,7 +44,7 @@ class TestSchema(SupersetTestCase):
payload["queries"][0].pop("row_offset", None) payload["queries"][0].pop("row_offset", None)
query_context = ChartDataQueryContextSchema().load(payload) query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0] query_object = query_context.queries[0]
self.assertEqual(query_object.row_limit, app.config["ROW_LIMIT"]) self.assertEqual(query_object.row_limit, 5000)
self.assertEqual(query_object.row_offset, 0) self.assertEqual(query_object.row_offset, 0)
# Valid limit and offset # Valid limit and offset
@ -55,12 +63,14 @@ class TestSchema(SupersetTestCase):
self.assertIn("row_limit", context.exception.messages["queries"][0]) self.assertIn("row_limit", context.exception.messages["queries"][0])
self.assertIn("row_offset", context.exception.messages["queries"][0]) self.assertIn("row_offset", context.exception.messages["queries"][0])
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_query_context_null_timegrain(self): def test_query_context_null_timegrain(self):
self.login(username="admin") self.login(username="admin")
payload = get_query_context("birth_names") payload = get_query_context("birth_names")
payload["queries"][0]["extras"]["time_grain_sqla"] = None payload["queries"][0]["extras"]["time_grain_sqla"] = None
_ = ChartDataQueryContextSchema().load(payload) _ = ChartDataQueryContextSchema().load(payload)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_query_context_series_limit(self): def test_query_context_series_limit(self):
self.login(username="admin") self.login(username="admin")
payload = get_query_context("birth_names") payload = get_query_context("birth_names")
@ -82,6 +92,7 @@ class TestSchema(SupersetTestCase):
} }
_ = ChartDataQueryContextSchema().load(payload) _ = ChartDataQueryContextSchema().load(payload)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_query_context_null_post_processing_op(self): def test_query_context_null_post_processing_op(self):
self.login(username="admin") self.login(username="admin")
payload = get_query_context("birth_names") payload = get_query_context("birth_names")

View File

@ -90,6 +90,7 @@ class TestQueryContext(SupersetTestCase):
self.assertEqual(post_proc["operation"], payload_post_proc["operation"]) self.assertEqual(post_proc["operation"], payload_post_proc["operation"])
self.assertEqual(post_proc["options"], payload_post_proc["options"]) self.assertEqual(post_proc["options"], payload_post_proc["options"])
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_cache(self): def test_cache(self):
table_name = "birth_names" table_name = "birth_names"
table = self.get_table(name=table_name) table = self.get_table(name=table_name)