mirror of https://github.com/apache/superset.git
feat: add global max row limit (#16683)
* feat: add global max limit * fix lint and tests * leave SAMPLES_ROW_LIMIT unchanged * fix sample rowcount test * replace max global limit with existing sql max row limit * fix test * make max_limit optional in util * improve comments
This commit is contained in:
parent
633f29f3e9
commit
4e3d4f6daf
|
@ -15,7 +15,6 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import copy
|
||||
import math
|
||||
from typing import Any, Callable, cast, Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
from flask_babel import _
|
||||
|
@ -131,14 +130,11 @@ def _get_samples(
|
|||
query_context: "QueryContext", query_obj: "QueryObject", force_cached: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
datasource = _get_datasource(query_context, query_obj)
|
||||
row_limit = query_obj.row_limit or math.inf
|
||||
query_obj = copy.copy(query_obj)
|
||||
query_obj.is_timeseries = False
|
||||
query_obj.orderby = []
|
||||
query_obj.metrics = []
|
||||
query_obj.post_processing = []
|
||||
query_obj.row_limit = min(row_limit, config["SAMPLES_ROW_LIMIT"])
|
||||
query_obj.row_offset = 0
|
||||
query_obj.columns = [o.column_name for o in datasource.columns]
|
||||
return _get_full(query_context, query_obj, force_cached)
|
||||
|
||||
|
|
|
@ -100,11 +100,11 @@ class QueryContext:
|
|||
self.datasource = ConnectorRegistry.get_datasource(
|
||||
str(datasource["type"]), int(datasource["id"]), db.session
|
||||
)
|
||||
self.queries = [QueryObject(**query_obj) for query_obj in queries]
|
||||
self.force = force
|
||||
self.custom_cache_timeout = custom_cache_timeout
|
||||
self.result_type = result_type or ChartDataResultType.FULL
|
||||
self.result_format = result_format or ChartDataResultFormat.JSON
|
||||
self.queries = [QueryObject(self, **query_obj) for query_obj in queries]
|
||||
self.cache_values = {
|
||||
"datasource": datasource,
|
||||
"queries": queries,
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
# under the License.
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, NamedTuple, Optional
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, TYPE_CHECKING
|
||||
|
||||
from flask_babel import gettext as _
|
||||
from pandas import DataFrame
|
||||
|
@ -28,6 +28,7 @@ from superset.exceptions import QueryObjectValidationError
|
|||
from superset.typing import Metric, OrderBy
|
||||
from superset.utils import pandas_postprocessing
|
||||
from superset.utils.core import (
|
||||
apply_max_row_limit,
|
||||
ChartDataResultType,
|
||||
DatasourceDict,
|
||||
DTTM_ALIAS,
|
||||
|
@ -41,6 +42,10 @@ from superset.utils.date_parser import get_since_until, parse_human_timedelta
|
|||
from superset.utils.hashing import md5_sha_from_dict
|
||||
from superset.views.utils import get_time_range_endpoints
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.common.query_context import QueryContext # pragma: no cover
|
||||
|
||||
|
||||
config = app.config
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -103,6 +108,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
|
|||
|
||||
def __init__( # pylint: disable=too-many-arguments,too-many-locals
|
||||
self,
|
||||
query_context: "QueryContext",
|
||||
annotation_layers: Optional[List[Dict[str, Any]]] = None,
|
||||
applied_time_extras: Optional[Dict[str, str]] = None,
|
||||
apply_fetch_values_predicate: bool = False,
|
||||
|
@ -146,7 +152,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
|
|||
self.datasource = ConnectorRegistry.get_datasource(
|
||||
str(datasource["type"]), int(datasource["id"]), db.session
|
||||
)
|
||||
self.result_type = result_type
|
||||
self.result_type = result_type or query_context.result_type
|
||||
self.apply_fetch_values_predicate = apply_fetch_values_predicate or False
|
||||
self.annotation_layers = [
|
||||
layer
|
||||
|
@ -186,7 +192,12 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
|
|||
for x in metrics
|
||||
]
|
||||
|
||||
self.row_limit = config["ROW_LIMIT"] if row_limit is None else row_limit
|
||||
default_row_limit = (
|
||||
config["SAMPLES_ROW_LIMIT"]
|
||||
if self.result_type == ChartDataResultType.SAMPLES
|
||||
else config["ROW_LIMIT"]
|
||||
)
|
||||
self.row_limit = apply_max_row_limit(row_limit or default_row_limit)
|
||||
self.row_offset = row_offset or 0
|
||||
self.filter = filters or []
|
||||
self.series_limit = series_limit
|
||||
|
|
|
@ -121,9 +121,9 @@ BUILD_NUMBER = None
|
|||
# default viz used in chart explorer
|
||||
DEFAULT_VIZ_TYPE = "table"
|
||||
|
||||
# default row limit when requesting chart data
|
||||
ROW_LIMIT = 50000
|
||||
VIZ_ROW_LIMIT = 10000
|
||||
# max rows retreieved when requesting samples from datasource in explore view
|
||||
# default row limit when requesting samples from datasource in explore view
|
||||
SAMPLES_ROW_LIMIT = 1000
|
||||
# max rows retrieved by filter select auto complete
|
||||
FILTER_SELECT_ROW_LIMIT = 10000
|
||||
|
@ -671,9 +671,7 @@ QUERY_LOGGER = None
|
|||
# Set this API key to enable Mapbox visualizations
|
||||
MAPBOX_API_KEY = os.environ.get("MAPBOX_API_KEY", "")
|
||||
|
||||
# Maximum number of rows returned from a database
|
||||
# in async mode, no more than SQL_MAX_ROW will be returned and stored
|
||||
# in the results backend. This also becomes the limit when exporting CSVs
|
||||
# Maximum number of rows returned for any analytical database query
|
||||
SQL_MAX_ROW = 100000
|
||||
|
||||
# Maximum number of rows displayed in SQL Lab UI
|
||||
|
|
|
@ -1762,3 +1762,25 @@ def parse_boolean_string(bool_str: Optional[str]) -> bool:
|
|||
return bool(strtobool(bool_str.lower()))
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def apply_max_row_limit(limit: int, max_limit: Optional[int] = None,) -> int:
|
||||
"""
|
||||
Override row limit if max global limit is defined
|
||||
|
||||
:param limit: requested row limit
|
||||
:param max_limit: Maximum allowed row limit
|
||||
:return: Capped row limit
|
||||
|
||||
>>> apply_max_row_limit(100000, 10)
|
||||
10
|
||||
>>> apply_max_row_limit(10, 100000)
|
||||
10
|
||||
>>> apply_max_row_limit(0, 10000)
|
||||
10000
|
||||
"""
|
||||
if max_limit is None:
|
||||
max_limit = current_app.config["SQL_MAX_ROW"]
|
||||
if limit != 0:
|
||||
return min(max_limit, limit)
|
||||
return max_limit
|
||||
|
|
|
@ -23,10 +23,11 @@ from typing import Any, cast, Dict, Optional, TYPE_CHECKING
|
|||
|
||||
from flask import g
|
||||
|
||||
from superset import app, is_feature_enabled
|
||||
from superset import is_feature_enabled
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.sql_parse import CtasMethod
|
||||
from superset.utils import core as utils
|
||||
from superset.utils.core import apply_max_row_limit
|
||||
from superset.utils.dates import now_as_float
|
||||
from superset.views.utils import get_cta_schema_name
|
||||
|
||||
|
@ -102,7 +103,7 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes
|
|||
|
||||
@staticmethod
|
||||
def _get_limit_param(query_params: Dict[str, Any]) -> int:
|
||||
limit: int = query_params.get("queryLimit") or app.config["SQL_MAX_ROW"]
|
||||
limit = apply_max_row_limit(query_params.get("queryLimit") or 0)
|
||||
if limit < 0:
|
||||
logger.warning(
|
||||
"Invalid limit of %i specified. Defaulting to max limit.", limit
|
||||
|
|
|
@ -104,7 +104,7 @@ from superset.typing import FlaskResponse
|
|||
from superset.utils import core as utils, csv
|
||||
from superset.utils.async_query_manager import AsyncQueryTokenException
|
||||
from superset.utils.cache import etag_cache
|
||||
from superset.utils.core import ReservedUrlParameters
|
||||
from superset.utils.core import apply_max_row_limit, ReservedUrlParameters
|
||||
from superset.utils.dates import now_as_float
|
||||
from superset.utils.decorators import check_dashboard_access
|
||||
from superset.utils.sqllab_execution_context import SqlJsonExecutionContext
|
||||
|
@ -897,8 +897,9 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
return json_error_response(DATASOURCE_MISSING_ERR)
|
||||
|
||||
datasource.raise_for_access()
|
||||
row_limit = apply_max_row_limit(config["FILTER_SELECT_ROW_LIMIT"])
|
||||
payload = json.dumps(
|
||||
datasource.values_for_column(column, config["FILTER_SELECT_ROW_LIMIT"]),
|
||||
datasource.values_for_column(column, row_limit),
|
||||
default=utils.json_int_dttm_ser,
|
||||
ignore_nan=True,
|
||||
)
|
||||
|
|
|
@ -69,6 +69,7 @@ from superset.typing import Metric, QueryObjectDict, VizData, VizPayload
|
|||
from superset.utils import core as utils, csv
|
||||
from superset.utils.cache import set_and_log_cache
|
||||
from superset.utils.core import (
|
||||
apply_max_row_limit,
|
||||
DTTM_ALIAS,
|
||||
ExtraFiltersReasonType,
|
||||
JS_MAX_INTEGER,
|
||||
|
@ -324,7 +325,10 @@ class BaseViz: # pylint: disable=too-many-public-methods
|
|||
)
|
||||
limit = int(self.form_data.get("limit") or 0)
|
||||
timeseries_limit_metric = self.form_data.get("timeseries_limit_metric")
|
||||
|
||||
# apply row limit to query
|
||||
row_limit = int(self.form_data.get("row_limit") or config["ROW_LIMIT"])
|
||||
row_limit = apply_max_row_limit(row_limit)
|
||||
|
||||
# default order direction
|
||||
order_desc = self.form_data.get("order_desc", True)
|
||||
|
@ -1687,9 +1691,6 @@ class HistogramViz(BaseViz):
|
|||
def query_obj(self) -> QueryObjectDict:
|
||||
"""Returns the query object for this visualization"""
|
||||
query_obj = super().query_obj()
|
||||
query_obj["row_limit"] = self.form_data.get(
|
||||
"row_limit", int(config["VIZ_ROW_LIMIT"])
|
||||
)
|
||||
numeric_columns = self.form_data.get("all_columns_x")
|
||||
if numeric_columns is None:
|
||||
raise QueryObjectValidationError(
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
"""Unit tests for Superset"""
|
||||
import json
|
||||
import unittest
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
from unittest import mock
|
||||
|
@ -1203,6 +1203,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
|
|||
self.login(username="admin")
|
||||
request_payload = get_query_context("birth_names")
|
||||
del request_payload["queries"][0]["row_limit"]
|
||||
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
response_payload = json.loads(rv.data.decode("utf-8"))
|
||||
result = response_payload["result"][0]
|
||||
|
@ -1210,11 +1211,46 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
|
|||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
@mock.patch(
|
||||
"superset.common.query_actions.config", {**app.config, "SAMPLES_ROW_LIMIT": 5},
|
||||
"superset.utils.core.current_app.config", {**app.config, "SQL_MAX_ROW": 10},
|
||||
)
|
||||
def test_chart_data_default_sample_limit(self):
|
||||
def test_chart_data_sql_max_row_limit(self):
|
||||
"""
|
||||
Chart data API: Ensure sample response row count doesn't exceed default limit
|
||||
Chart data API: Ensure row count doesn't exceed max global row limit
|
||||
"""
|
||||
self.login(username="admin")
|
||||
request_payload = get_query_context("birth_names")
|
||||
request_payload["queries"][0]["row_limit"] = 10000000
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
response_payload = json.loads(rv.data.decode("utf-8"))
|
||||
result = response_payload["result"][0]
|
||||
self.assertEqual(result["rowcount"], 10)
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
@mock.patch(
|
||||
"superset.common.query_object.config", {**app.config, "SAMPLES_ROW_LIMIT": 5},
|
||||
)
|
||||
def test_chart_data_sample_default_limit(self):
|
||||
"""
|
||||
Chart data API: Ensure sample response row count defaults to config defaults
|
||||
"""
|
||||
self.login(username="admin")
|
||||
request_payload = get_query_context("birth_names")
|
||||
request_payload["result_type"] = utils.ChartDataResultType.SAMPLES
|
||||
del request_payload["queries"][0]["row_limit"]
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
response_payload = json.loads(rv.data.decode("utf-8"))
|
||||
result = response_payload["result"][0]
|
||||
self.assertEqual(result["rowcount"], 5)
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
@mock.patch(
|
||||
"superset.common.query_actions.config",
|
||||
{**app.config, "SAMPLES_ROW_LIMIT": 5, "SQL_MAX_ROW": 15},
|
||||
)
|
||||
def test_chart_data_sample_custom_limit(self):
|
||||
"""
|
||||
Chart data API: Ensure requested sample response row count is between
|
||||
default and SQL max row limit
|
||||
"""
|
||||
self.login(username="admin")
|
||||
request_payload = get_query_context("birth_names")
|
||||
|
@ -1223,6 +1259,24 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
|
|||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
response_payload = json.loads(rv.data.decode("utf-8"))
|
||||
result = response_payload["result"][0]
|
||||
self.assertEqual(result["rowcount"], 10)
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
@mock.patch(
|
||||
"superset.utils.core.current_app.config", {**app.config, "SQL_MAX_ROW": 5},
|
||||
)
|
||||
def test_chart_data_sql_max_row_sample_limit(self):
|
||||
"""
|
||||
Chart data API: Ensure requested sample response row count doesn't
|
||||
exceed SQL max row limit
|
||||
"""
|
||||
self.login(username="admin")
|
||||
request_payload = get_query_context("birth_names")
|
||||
request_payload["result_type"] = utils.ChartDataResultType.SAMPLES
|
||||
request_payload["queries"][0]["row_limit"] = 10000000
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
response_payload = json.loads(rv.data.decode("utf-8"))
|
||||
result = response_payload["result"][0]
|
||||
self.assertEqual(result["rowcount"], 5)
|
||||
|
||||
def test_chart_data_incorrect_result_type(self):
|
||||
|
|
|
@ -16,17 +16,25 @@
|
|||
# under the License.
|
||||
# isort:skip_file
|
||||
"""Unit tests for Superset"""
|
||||
from typing import Any, Dict, Tuple
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from marshmallow import ValidationError
|
||||
from tests.integration_tests.test_app import app
|
||||
from superset.charts.schemas import ChartDataQueryContextSchema
|
||||
from superset.common.query_context import QueryContext
|
||||
from tests.integration_tests.base_tests import SupersetTestCase
|
||||
from tests.integration_tests.fixtures.birth_names_dashboard import (
|
||||
load_birth_names_dashboard_with_slices,
|
||||
)
|
||||
from tests.integration_tests.fixtures.query_context import get_query_context
|
||||
|
||||
|
||||
class TestSchema(SupersetTestCase):
|
||||
@mock.patch(
|
||||
"superset.common.query_object.config", {**app.config, "ROW_LIMIT": 5000},
|
||||
)
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_query_context_limit_and_offset(self):
|
||||
self.login(username="admin")
|
||||
payload = get_query_context("birth_names")
|
||||
|
@ -36,7 +44,7 @@ class TestSchema(SupersetTestCase):
|
|||
payload["queries"][0].pop("row_offset", None)
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
query_object = query_context.queries[0]
|
||||
self.assertEqual(query_object.row_limit, app.config["ROW_LIMIT"])
|
||||
self.assertEqual(query_object.row_limit, 5000)
|
||||
self.assertEqual(query_object.row_offset, 0)
|
||||
|
||||
# Valid limit and offset
|
||||
|
@ -55,12 +63,14 @@ class TestSchema(SupersetTestCase):
|
|||
self.assertIn("row_limit", context.exception.messages["queries"][0])
|
||||
self.assertIn("row_offset", context.exception.messages["queries"][0])
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_query_context_null_timegrain(self):
|
||||
self.login(username="admin")
|
||||
payload = get_query_context("birth_names")
|
||||
payload["queries"][0]["extras"]["time_grain_sqla"] = None
|
||||
_ = ChartDataQueryContextSchema().load(payload)
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_query_context_series_limit(self):
|
||||
self.login(username="admin")
|
||||
payload = get_query_context("birth_names")
|
||||
|
@ -82,6 +92,7 @@ class TestSchema(SupersetTestCase):
|
|||
}
|
||||
_ = ChartDataQueryContextSchema().load(payload)
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_query_context_null_post_processing_op(self):
|
||||
self.login(username="admin")
|
||||
payload = get_query_context("birth_names")
|
||||
|
|
|
@ -90,6 +90,7 @@ class TestQueryContext(SupersetTestCase):
|
|||
self.assertEqual(post_proc["operation"], payload_post_proc["operation"])
|
||||
self.assertEqual(post_proc["options"], payload_post_proc["options"])
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_cache(self):
|
||||
table_name = "birth_names"
|
||||
table = self.get_table(name=table_name)
|
||||
|
|
Loading…
Reference in New Issue