feat: add global max row limit (#16683)

* feat: add global max limit

* fix lint and tests

* leave SAMPLES_ROW_LIMIT unchanged

* fix sample rowcount test

* replace max global limit with existing sql max row limit

* fix test

* make max_limit optional in util

* improve comments
This commit is contained in:
Ville Brofeldt 2021-09-16 19:33:41 +03:00 committed by GitHub
parent 633f29f3e9
commit 4e3d4f6daf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 123 additions and 27 deletions

View File

@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
import copy
import math
from typing import Any, Callable, cast, Dict, List, Optional, TYPE_CHECKING
from flask_babel import _
@ -131,14 +130,11 @@ def _get_samples(
query_context: "QueryContext", query_obj: "QueryObject", force_cached: bool = False
) -> Dict[str, Any]:
datasource = _get_datasource(query_context, query_obj)
row_limit = query_obj.row_limit or math.inf
query_obj = copy.copy(query_obj)
query_obj.is_timeseries = False
query_obj.orderby = []
query_obj.metrics = []
query_obj.post_processing = []
query_obj.row_limit = min(row_limit, config["SAMPLES_ROW_LIMIT"])
query_obj.row_offset = 0
query_obj.columns = [o.column_name for o in datasource.columns]
return _get_full(query_context, query_obj, force_cached)

View File

@ -100,11 +100,11 @@ class QueryContext:
self.datasource = ConnectorRegistry.get_datasource(
str(datasource["type"]), int(datasource["id"]), db.session
)
self.queries = [QueryObject(**query_obj) for query_obj in queries]
self.force = force
self.custom_cache_timeout = custom_cache_timeout
self.result_type = result_type or ChartDataResultType.FULL
self.result_format = result_format or ChartDataResultFormat.JSON
self.queries = [QueryObject(self, **query_obj) for query_obj in queries]
self.cache_values = {
"datasource": datasource,
"queries": queries,

View File

@ -16,7 +16,7 @@
# under the License.
import logging
from datetime import datetime, timedelta
from typing import Any, Dict, List, NamedTuple, Optional
from typing import Any, Dict, List, NamedTuple, Optional, TYPE_CHECKING
from flask_babel import gettext as _
from pandas import DataFrame
@ -28,6 +28,7 @@ from superset.exceptions import QueryObjectValidationError
from superset.typing import Metric, OrderBy
from superset.utils import pandas_postprocessing
from superset.utils.core import (
apply_max_row_limit,
ChartDataResultType,
DatasourceDict,
DTTM_ALIAS,
@ -41,6 +42,10 @@ from superset.utils.date_parser import get_since_until, parse_human_timedelta
from superset.utils.hashing import md5_sha_from_dict
from superset.views.utils import get_time_range_endpoints
if TYPE_CHECKING:
from superset.common.query_context import QueryContext # pragma: no cover
config = app.config
logger = logging.getLogger(__name__)
@ -103,6 +108,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
def __init__( # pylint: disable=too-many-arguments,too-many-locals
self,
query_context: "QueryContext",
annotation_layers: Optional[List[Dict[str, Any]]] = None,
applied_time_extras: Optional[Dict[str, str]] = None,
apply_fetch_values_predicate: bool = False,
@ -146,7 +152,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
self.datasource = ConnectorRegistry.get_datasource(
str(datasource["type"]), int(datasource["id"]), db.session
)
self.result_type = result_type
self.result_type = result_type or query_context.result_type
self.apply_fetch_values_predicate = apply_fetch_values_predicate or False
self.annotation_layers = [
layer
@ -186,7 +192,12 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
for x in metrics
]
self.row_limit = config["ROW_LIMIT"] if row_limit is None else row_limit
default_row_limit = (
config["SAMPLES_ROW_LIMIT"]
if self.result_type == ChartDataResultType.SAMPLES
else config["ROW_LIMIT"]
)
self.row_limit = apply_max_row_limit(row_limit or default_row_limit)
self.row_offset = row_offset or 0
self.filter = filters or []
self.series_limit = series_limit

View File

@ -121,9 +121,9 @@ BUILD_NUMBER = None
# default viz used in chart explorer
DEFAULT_VIZ_TYPE = "table"
# default row limit when requesting chart data
ROW_LIMIT = 50000
VIZ_ROW_LIMIT = 10000
# max rows retreieved when requesting samples from datasource in explore view
# default row limit when requesting samples from datasource in explore view
SAMPLES_ROW_LIMIT = 1000
# max rows retrieved by filter select auto complete
FILTER_SELECT_ROW_LIMIT = 10000
@ -671,9 +671,7 @@ QUERY_LOGGER = None
# Set this API key to enable Mapbox visualizations
MAPBOX_API_KEY = os.environ.get("MAPBOX_API_KEY", "")
# Maximum number of rows returned from a database
# in async mode, no more than SQL_MAX_ROW will be returned and stored
# in the results backend. This also becomes the limit when exporting CSVs
# Maximum number of rows returned for any analytical database query
SQL_MAX_ROW = 100000
# Maximum number of rows displayed in SQL Lab UI

View File

@ -1762,3 +1762,25 @@ def parse_boolean_string(bool_str: Optional[str]) -> bool:
return bool(strtobool(bool_str.lower()))
except ValueError:
return False
def apply_max_row_limit(limit: int, max_limit: Optional[int] = None,) -> int:
"""
Override row limit if max global limit is defined
:param limit: requested row limit
:param max_limit: Maximum allowed row limit
:return: Capped row limit
>>> apply_max_row_limit(100000, 10)
10
>>> apply_max_row_limit(10, 100000)
10
>>> apply_max_row_limit(0, 10000)
10000
"""
if max_limit is None:
max_limit = current_app.config["SQL_MAX_ROW"]
if limit != 0:
return min(max_limit, limit)
return max_limit

View File

@ -23,10 +23,11 @@ from typing import Any, cast, Dict, Optional, TYPE_CHECKING
from flask import g
from superset import app, is_feature_enabled
from superset import is_feature_enabled
from superset.models.sql_lab import Query
from superset.sql_parse import CtasMethod
from superset.utils import core as utils
from superset.utils.core import apply_max_row_limit
from superset.utils.dates import now_as_float
from superset.views.utils import get_cta_schema_name
@ -102,7 +103,7 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes
@staticmethod
def _get_limit_param(query_params: Dict[str, Any]) -> int:
limit: int = query_params.get("queryLimit") or app.config["SQL_MAX_ROW"]
limit = apply_max_row_limit(query_params.get("queryLimit") or 0)
if limit < 0:
logger.warning(
"Invalid limit of %i specified. Defaulting to max limit.", limit

View File

@ -104,7 +104,7 @@ from superset.typing import FlaskResponse
from superset.utils import core as utils, csv
from superset.utils.async_query_manager import AsyncQueryTokenException
from superset.utils.cache import etag_cache
from superset.utils.core import ReservedUrlParameters
from superset.utils.core import apply_max_row_limit, ReservedUrlParameters
from superset.utils.dates import now_as_float
from superset.utils.decorators import check_dashboard_access
from superset.utils.sqllab_execution_context import SqlJsonExecutionContext
@ -897,8 +897,9 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
return json_error_response(DATASOURCE_MISSING_ERR)
datasource.raise_for_access()
row_limit = apply_max_row_limit(config["FILTER_SELECT_ROW_LIMIT"])
payload = json.dumps(
datasource.values_for_column(column, config["FILTER_SELECT_ROW_LIMIT"]),
datasource.values_for_column(column, row_limit),
default=utils.json_int_dttm_ser,
ignore_nan=True,
)

View File

@ -69,6 +69,7 @@ from superset.typing import Metric, QueryObjectDict, VizData, VizPayload
from superset.utils import core as utils, csv
from superset.utils.cache import set_and_log_cache
from superset.utils.core import (
apply_max_row_limit,
DTTM_ALIAS,
ExtraFiltersReasonType,
JS_MAX_INTEGER,
@ -324,7 +325,10 @@ class BaseViz: # pylint: disable=too-many-public-methods
)
limit = int(self.form_data.get("limit") or 0)
timeseries_limit_metric = self.form_data.get("timeseries_limit_metric")
# apply row limit to query
row_limit = int(self.form_data.get("row_limit") or config["ROW_LIMIT"])
row_limit = apply_max_row_limit(row_limit)
# default order direction
order_desc = self.form_data.get("order_desc", True)
@ -1687,9 +1691,6 @@ class HistogramViz(BaseViz):
def query_obj(self) -> QueryObjectDict:
"""Returns the query object for this visualization"""
query_obj = super().query_obj()
query_obj["row_limit"] = self.form_data.get(
"row_limit", int(config["VIZ_ROW_LIMIT"])
)
numeric_columns = self.form_data.get("all_columns_x")
if numeric_columns is None:
raise QueryObjectValidationError(

View File

@ -18,7 +18,7 @@
"""Unit tests for Superset"""
import json
import unittest
from datetime import datetime, timedelta
from datetime import datetime
from io import BytesIO
from typing import Optional
from unittest import mock
@ -1203,6 +1203,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
self.login(username="admin")
request_payload = get_query_context("birth_names")
del request_payload["queries"][0]["row_limit"]
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
@ -1210,11 +1211,46 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@mock.patch(
"superset.common.query_actions.config", {**app.config, "SAMPLES_ROW_LIMIT": 5},
"superset.utils.core.current_app.config", {**app.config, "SQL_MAX_ROW": 10},
)
def test_chart_data_default_sample_limit(self):
def test_chart_data_sql_max_row_limit(self):
"""
Chart data API: Ensure sample response row count doesn't exceed default limit
Chart data API: Ensure row count doesn't exceed max global row limit
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["queries"][0]["row_limit"] = 10000000
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
self.assertEqual(result["rowcount"], 10)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@mock.patch(
"superset.common.query_object.config", {**app.config, "SAMPLES_ROW_LIMIT": 5},
)
def test_chart_data_sample_default_limit(self):
"""
Chart data API: Ensure sample response row count defaults to config defaults
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["result_type"] = utils.ChartDataResultType.SAMPLES
del request_payload["queries"][0]["row_limit"]
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
self.assertEqual(result["rowcount"], 5)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@mock.patch(
"superset.common.query_actions.config",
{**app.config, "SAMPLES_ROW_LIMIT": 5, "SQL_MAX_ROW": 15},
)
def test_chart_data_sample_custom_limit(self):
"""
Chart data API: Ensure requested sample response row count is between
default and SQL max row limit
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
@ -1223,6 +1259,24 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
self.assertEqual(result["rowcount"], 10)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
@mock.patch(
"superset.utils.core.current_app.config", {**app.config, "SQL_MAX_ROW": 5},
)
def test_chart_data_sql_max_row_sample_limit(self):
"""
Chart data API: Ensure requested sample response row count doesn't
exceed SQL max row limit
"""
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["result_type"] = utils.ChartDataResultType.SAMPLES
request_payload["queries"][0]["row_limit"] = 10000000
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
self.assertEqual(result["rowcount"], 5)
def test_chart_data_incorrect_result_type(self):

View File

@ -16,17 +16,25 @@
# under the License.
# isort:skip_file
"""Unit tests for Superset"""
from typing import Any, Dict, Tuple
from unittest import mock
import pytest
from marshmallow import ValidationError
from tests.integration_tests.test_app import app
from superset.charts.schemas import ChartDataQueryContextSchema
from superset.common.query_context import QueryContext
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
)
from tests.integration_tests.fixtures.query_context import get_query_context
class TestSchema(SupersetTestCase):
@mock.patch(
"superset.common.query_object.config", {**app.config, "ROW_LIMIT": 5000},
)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_query_context_limit_and_offset(self):
self.login(username="admin")
payload = get_query_context("birth_names")
@ -36,7 +44,7 @@ class TestSchema(SupersetTestCase):
payload["queries"][0].pop("row_offset", None)
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
self.assertEqual(query_object.row_limit, app.config["ROW_LIMIT"])
self.assertEqual(query_object.row_limit, 5000)
self.assertEqual(query_object.row_offset, 0)
# Valid limit and offset
@ -55,12 +63,14 @@ class TestSchema(SupersetTestCase):
self.assertIn("row_limit", context.exception.messages["queries"][0])
self.assertIn("row_offset", context.exception.messages["queries"][0])
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_query_context_null_timegrain(self):
self.login(username="admin")
payload = get_query_context("birth_names")
payload["queries"][0]["extras"]["time_grain_sqla"] = None
_ = ChartDataQueryContextSchema().load(payload)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_query_context_series_limit(self):
self.login(username="admin")
payload = get_query_context("birth_names")
@ -82,6 +92,7 @@ class TestSchema(SupersetTestCase):
}
_ = ChartDataQueryContextSchema().load(payload)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_query_context_null_post_processing_op(self):
self.login(username="admin")
payload = get_query_context("birth_names")

View File

@ -90,6 +90,7 @@ class TestQueryContext(SupersetTestCase):
self.assertEqual(post_proc["operation"], payload_post_proc["operation"])
self.assertEqual(post_proc["options"], payload_post_proc["options"])
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_cache(self):
table_name = "birth_names"
table = self.get_table(name=table_name)