feat: add support for generic series limit (#16660)

* feat: add support for generic series limit

* refine series_columns logic

* update docs

* bump superset-ui

* add note to UPDATING.md

* remove default value for timeseries_limit
This commit is contained in:
Ville Brofeldt 2021-09-16 12:09:08 +03:00 committed by GitHub
parent 21f98ddc21
commit 836b5e2c86
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 537 additions and 460 deletions

View File

@ -26,6 +26,7 @@ assists people when migrating to a new version.
### Breaking Changes
- [16660](https://github.com/apache/incubator-superset/pull/16660): The `columns` Jinja parameter has been renamed `table_columns` to make the `columns` query object parameter available in the Jinja context.
- [16711](https://github.com/apache/incubator-superset/pull/16711): The `url_param` Jinja function will now by default escape the result. For instance, the value `O'Brien` will now be changed to `O''Brien`. To disable this behavior, call `url_param` with `escape_result` set to `False`: `url_param("my_key", "my default", escape_result=False)`.
### Potential Downtime

View File

@ -16,14 +16,15 @@ To enable templating, the `ENABLE_TEMPLATE_PROCESSING` feature flag needs to be
in Custom SQL in the filter and metric controls in Explore. By default, the following variables are
made available in the Jinja context:
- `columns`: columns available in the dataset
- `columns`: columns which to group by in the query
- `filter`: filters applied in the query
- `from_dttm`: start `datetime` value from the selected time range (`None` if undefined)
- `to_dttm`: end `datetime` value from the selected time range (`None` if undefined)
- `groupby`: columns which to group by in the query
- `groupby`: columns which to group by in the query (deprecated)
- `metrics`: aggregate expressions in the query
- `row_limit`: row limit of the query
- `row_offset`: row offset of the query
- `table_columns`: columns available in the dataset
- `time_column`: temporal column of the query (`None` if undefined)
- `time_grain`: selected time grain (`None` if undefined)

File diff suppressed because it is too large Load Diff

View File

@ -68,35 +68,35 @@
"@emotion/cache": "^11.4.0",
"@emotion/react": "^11.4.1",
"@emotion/styled": "^11.3.0",
"@superset-ui/chart-controls": "^0.18.2",
"@superset-ui/core": "^0.18.2",
"@superset-ui/legacy-plugin-chart-calendar": "^0.18.2",
"@superset-ui/legacy-plugin-chart-chord": "^0.18.2",
"@superset-ui/legacy-plugin-chart-country-map": "^0.18.2",
"@superset-ui/legacy-plugin-chart-event-flow": "^0.18.2",
"@superset-ui/legacy-plugin-chart-force-directed": "^0.18.2",
"@superset-ui/legacy-plugin-chart-heatmap": "^0.18.2",
"@superset-ui/legacy-plugin-chart-histogram": "^0.18.2",
"@superset-ui/legacy-plugin-chart-horizon": "^0.18.2",
"@superset-ui/legacy-plugin-chart-map-box": "^0.18.2",
"@superset-ui/legacy-plugin-chart-paired-t-test": "^0.18.2",
"@superset-ui/legacy-plugin-chart-parallel-coordinates": "^0.18.2",
"@superset-ui/legacy-plugin-chart-partition": "^0.18.2",
"@superset-ui/legacy-plugin-chart-pivot-table": "^0.18.2",
"@superset-ui/legacy-plugin-chart-rose": "^0.18.2",
"@superset-ui/legacy-plugin-chart-sankey": "^0.18.2",
"@superset-ui/legacy-plugin-chart-sankey-loop": "^0.18.2",
"@superset-ui/legacy-plugin-chart-sunburst": "^0.18.2",
"@superset-ui/legacy-plugin-chart-treemap": "^0.18.2",
"@superset-ui/legacy-plugin-chart-world-map": "^0.18.2",
"@superset-ui/legacy-preset-chart-big-number": "^0.18.2",
"@superset-ui/chart-controls": "^0.18.4",
"@superset-ui/core": "^0.18.4",
"@superset-ui/legacy-plugin-chart-calendar": "^0.18.4",
"@superset-ui/legacy-plugin-chart-chord": "^0.18.4",
"@superset-ui/legacy-plugin-chart-country-map": "^0.18.4",
"@superset-ui/legacy-plugin-chart-event-flow": "^0.18.4",
"@superset-ui/legacy-plugin-chart-force-directed": "^0.18.4",
"@superset-ui/legacy-plugin-chart-heatmap": "^0.18.4",
"@superset-ui/legacy-plugin-chart-histogram": "^0.18.4",
"@superset-ui/legacy-plugin-chart-horizon": "^0.18.4",
"@superset-ui/legacy-plugin-chart-map-box": "^0.18.4",
"@superset-ui/legacy-plugin-chart-paired-t-test": "^0.18.4",
"@superset-ui/legacy-plugin-chart-parallel-coordinates": "^0.18.4",
"@superset-ui/legacy-plugin-chart-partition": "^0.18.4",
"@superset-ui/legacy-plugin-chart-pivot-table": "^0.18.4",
"@superset-ui/legacy-plugin-chart-rose": "^0.18.4",
"@superset-ui/legacy-plugin-chart-sankey": "^0.18.4",
"@superset-ui/legacy-plugin-chart-sankey-loop": "^0.18.4",
"@superset-ui/legacy-plugin-chart-sunburst": "^0.18.4",
"@superset-ui/legacy-plugin-chart-treemap": "^0.18.4",
"@superset-ui/legacy-plugin-chart-world-map": "^0.18.4",
"@superset-ui/legacy-preset-chart-big-number": "^0.18.4",
"@superset-ui/legacy-preset-chart-deckgl": "^0.4.12",
"@superset-ui/legacy-preset-chart-nvd3": "^0.18.2",
"@superset-ui/plugin-chart-echarts": "^0.18.2",
"@superset-ui/plugin-chart-pivot-table": "^0.18.2",
"@superset-ui/plugin-chart-table": "^0.18.2",
"@superset-ui/plugin-chart-word-cloud": "^0.18.2",
"@superset-ui/preset-chart-xy": "^0.18.2",
"@superset-ui/legacy-preset-chart-nvd3": "^0.18.4",
"@superset-ui/plugin-chart-echarts": "^0.18.4",
"@superset-ui/plugin-chart-pivot-table": "^0.18.4",
"@superset-ui/plugin-chart-table": "^0.18.4",
"@superset-ui/plugin-chart-word-cloud": "^0.18.4",
"@superset-ui/preset-chart-xy": "^0.18.4",
"@vx/responsive": "^0.0.195",
"abortcontroller-polyfill": "^1.1.9",
"antd": "^4.9.4",

View File

@ -639,18 +639,15 @@ class ChartDataPivotOptionsSchema(ChartDataPostProcessingOperationOptionsSchema)
index = (
fields.List(
fields.String(
allow_none=False,
description="Columns to group by on the table index (=rows)",
),
fields.String(allow_none=False),
description="Columns to group by on the table index (=rows)",
minLength=1,
required=True,
),
)
columns = fields.List(
fields.String(
allow_none=False, description="Columns to group by on the table columns",
),
fields.String(allow_none=False),
description="Columns to group by on the table columns",
)
metric_fill_value = fields.Number(
description="Value to replace missing values with in aggregate calculations.",
@ -964,7 +961,9 @@ class ChartDataQueryObjectSchema(Schema):
deprecated=True,
)
groupby = fields.List(
fields.String(description="Columns by which to group the query.",),
fields.String(),
description="Columns by which to group the query. "
"This field is deprecated, use `columns` instead.",
allow_none=True,
)
metrics = fields.List(
@ -1012,12 +1011,33 @@ class ChartDataQueryObjectSchema(Schema):
is_timeseries = fields.Boolean(
description="Is the `query_object` a timeseries.", allow_none=True,
)
series_columns = fields.List(
fields.String(),
description="Columns to use when limiting series count. "
"All columns must be present in the `columns` property. "
"Requires `series_limit` and `series_limit_metric` to be set.",
allow_none=True,
)
series_limit = fields.Integer(
description="Maximum number of series. "
"Requires `series` and `series_limit_metric` to be set.",
allow_none=True,
)
series_limit_metric = fields.Raw(
description="Metric used to limit timeseries queries by. "
"Requires `series` and `series_limit` to be set.",
allow_none=True,
)
timeseries_limit = fields.Integer(
description="Maximum row count for timeseries queries. Default: `0`",
description="Maximum row count for timeseries queries. "
"This field is deprecated, use `series_limit` instead."
"Default: `0`",
allow_none=True,
)
timeseries_limit_metric = fields.Raw(
description="Metric used to limit timeseries queries by.", allow_none=True,
description="Metric used to limit timeseries queries by. "
"This field is deprecated, use `series_limit_metric` instead.",
allow_none=True,
)
row_limit = fields.Integer(
description='Maximum row count (0=disabled). Default: `config["ROW_LIMIT"]`',

View File

@ -135,7 +135,6 @@ def _get_samples(
query_obj = copy.copy(query_obj)
query_obj.is_timeseries = False
query_obj.orderby = []
query_obj.groupby = []
query_obj.metrics = []
query_obj.post_processing = []
query_obj.row_limit = min(row_limit, config["SAMPLES_ROW_LIMIT"])

View File

@ -451,7 +451,6 @@ class QueryContext:
invalid_columns = [
col
for col in query_obj.columns
+ query_obj.groupby
+ get_column_names_from_metrics(query_obj.metrics or [])
if col not in self.datasource.column_names and col != DTTM_ALIAS
]

View File

@ -55,6 +55,9 @@ class DeprecatedField(NamedTuple):
DEPRECATED_FIELDS = (
DeprecatedField(old_name="granularity_sqla", new_name="granularity"),
DeprecatedField(old_name="groupby", new_name="columns"),
DeprecatedField(old_name="timeseries_limit", new_name="series_limit"),
DeprecatedField(old_name="timeseries_limit_metric", new_name="series_limit_metric"),
)
DEPRECATED_EXTRAS_FIELDS = (
@ -74,63 +77,68 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
annotation_layers: List[Dict[str, Any]]
applied_time_extras: Dict[str, str]
apply_fetch_values_predicate: bool
granularity: Optional[str]
columns: List[str]
datasource: Optional[BaseDatasource]
extras: Dict[str, Any]
filter: List[QueryObjectFilterClause]
from_dttm: Optional[datetime]
to_dttm: Optional[datetime]
granularity: Optional[str]
inner_from_dttm: Optional[datetime]
inner_to_dttm: Optional[datetime]
is_rowcount: bool
is_timeseries: bool
time_shift: Optional[timedelta]
groupby: List[str]
order_desc: bool
orderby: List[OrderBy]
metrics: Optional[List[Metric]]
result_type: Optional[ChartDataResultType]
row_limit: int
row_offset: int
filter: List[QueryObjectFilterClause]
timeseries_limit: int
timeseries_limit_metric: Optional[Metric]
order_desc: bool
extras: Dict[str, Any]
columns: List[str]
orderby: List[OrderBy]
post_processing: List[Dict[str, Any]]
datasource: Optional[BaseDatasource]
result_type: Optional[ChartDataResultType]
is_rowcount: bool
series_columns: List[str]
series_limit: int
series_limit_metric: Optional[Metric]
time_offsets: List[str]
time_shift: Optional[timedelta]
to_dttm: Optional[datetime]
post_processing: List[Dict[str, Any]]
def __init__( # pylint: disable=too-many-arguments,too-many-locals
self,
datasource: Optional[DatasourceDict] = None,
result_type: Optional[ChartDataResultType] = None,
annotation_layers: Optional[List[Dict[str, Any]]] = None,
applied_time_extras: Optional[Dict[str, str]] = None,
apply_fetch_values_predicate: bool = False,
granularity: Optional[str] = None,
metrics: Optional[List[Metric]] = None,
groupby: Optional[List[str]] = None,
filters: Optional[List[QueryObjectFilterClause]] = None,
time_range: Optional[str] = None,
time_shift: Optional[str] = None,
is_timeseries: Optional[bool] = None,
timeseries_limit: int = 0,
row_limit: Optional[int] = None,
row_offset: Optional[int] = None,
timeseries_limit_metric: Optional[Metric] = None,
order_desc: bool = True,
extras: Optional[Dict[str, Any]] = None,
columns: Optional[List[str]] = None,
datasource: Optional[DatasourceDict] = None,
extras: Optional[Dict[str, Any]] = None,
filters: Optional[List[QueryObjectFilterClause]] = None,
granularity: Optional[str] = None,
is_rowcount: bool = False,
is_timeseries: Optional[bool] = None,
metrics: Optional[List[Metric]] = None,
order_desc: bool = True,
orderby: Optional[List[OrderBy]] = None,
post_processing: Optional[List[Optional[Dict[str, Any]]]] = None,
is_rowcount: bool = False,
result_type: Optional[ChartDataResultType] = None,
row_limit: Optional[int] = None,
row_offset: Optional[int] = None,
series_columns: Optional[List[str]] = None,
series_limit: int = 0,
series_limit_metric: Optional[Metric] = None,
time_range: Optional[str] = None,
time_shift: Optional[str] = None,
**kwargs: Any,
):
columns = columns or []
groupby = groupby or []
extras = extras or {}
annotation_layers = annotation_layers or []
self.time_offsets = kwargs.get("time_offsets", [])
self.inner_from_dttm = kwargs.get("inner_from_dttm")
self.inner_to_dttm = kwargs.get("inner_to_dttm")
if series_columns:
self.series_columns = series_columns
elif is_timeseries and metrics:
self.series_columns = columns
else:
self.series_columns = []
self.is_rowcount = is_rowcount
self.datasource = None
@ -161,9 +169,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
# is_timeseries is True if time column is in either columns or groupby
# (both are dimensions)
self.is_timeseries = (
is_timeseries
if is_timeseries is not None
else DTTM_ALIAS in columns + groupby
is_timeseries if is_timeseries is not None else DTTM_ALIAS in columns
)
self.time_range = time_range
self.time_shift = parse_human_timedelta(time_shift)
@ -183,8 +189,8 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
self.row_limit = config["ROW_LIMIT"] if row_limit is None else row_limit
self.row_offset = row_offset or 0
self.filter = filters or []
self.timeseries_limit = timeseries_limit
self.timeseries_limit_metric = timeseries_limit_metric
self.series_limit = series_limit
self.series_limit_metric = series_limit_metric
self.order_desc = order_desc
self.extras = extras
@ -194,9 +200,12 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
)
self.columns = columns
self.groupby = groupby or []
self.orderby = orderby or []
self._rename_deprecated_fields(kwargs)
self._move_deprecated_extra_fields(kwargs)
def _rename_deprecated_fields(self, kwargs: Dict[str, Any]) -> None:
# rename deprecated fields
for field in DEPRECATED_FIELDS:
if field.old_name in kwargs:
@ -216,6 +225,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
)
setattr(self, field.new_name, value)
def _move_deprecated_extra_fields(self, kwargs: Dict[str, Any]) -> None:
# move deprecated extras fields to extras
for field in DEPRECATED_EXTRAS_FIELDS:
if field.old_name in kwargs:
@ -254,6 +264,14 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
"""Validate query object"""
error: Optional[QueryObjectValidationError] = None
all_labels = self.metric_names + self.column_names
missing_series = [col for col in self.series_columns if col not in self.columns]
if missing_series:
_(
"The following entries in `series_columns` are missing "
"in `columns`: %(columns)s. ",
columns=", ".join(f'"{x}"' for x in missing_series),
)
if len(set(all_labels)) < len(all_labels):
dup_labels = find_duplicates(all_labels)
error = QueryObjectValidationError(
@ -270,24 +288,24 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
def to_dict(self) -> Dict[str, Any]:
query_object_dict = {
"apply_fetch_values_predicate": self.apply_fetch_values_predicate,
"granularity": self.granularity,
"groupby": self.groupby,
"columns": self.columns,
"extras": self.extras,
"filter": self.filter,
"from_dttm": self.from_dttm,
"to_dttm": self.to_dttm,
"granularity": self.granularity,
"inner_from_dttm": self.inner_from_dttm,
"inner_to_dttm": self.inner_to_dttm,
"is_rowcount": self.is_rowcount,
"is_timeseries": self.is_timeseries,
"metrics": self.metrics,
"order_desc": self.order_desc,
"orderby": self.orderby,
"row_limit": self.row_limit,
"row_offset": self.row_offset,
"filter": self.filter,
"timeseries_limit": self.timeseries_limit,
"timeseries_limit_metric": self.timeseries_limit_metric,
"order_desc": self.order_desc,
"extras": self.extras,
"columns": self.columns,
"orderby": self.orderby,
"series_columns": self.series_columns,
"series_limit": self.series_limit,
"series_limit_metric": self.series_limit_metric,
"to_dttm": self.to_dttm,
}
return query_object_dict

View File

@ -19,7 +19,7 @@ import dataclasses
import json
import logging
import re
from collections import defaultdict, OrderedDict
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import (
@ -931,27 +931,30 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements
self,
metrics: Optional[List[Metric]] = None,
granularity: Optional[str] = None,
from_dttm: Optional[datetime] = None,
to_dttm: Optional[datetime] = None,
apply_fetch_values_predicate: bool = False,
columns: Optional[List[str]] = None,
groupby: Optional[List[str]] = None,
extras: Optional[Dict[str, Any]] = None,
filter: Optional[ # pylint: disable=redefined-builtin
List[QueryObjectFilterClause]
] = None,
is_timeseries: bool = True,
timeseries_limit: int = 15,
timeseries_limit_metric: Optional[Metric] = None,
row_limit: Optional[int] = None,
row_offset: Optional[int] = None,
from_dttm: Optional[datetime] = None,
granularity: Optional[str] = None,
groupby: Optional[List[str]] = None,
inner_from_dttm: Optional[datetime] = None,
inner_to_dttm: Optional[datetime] = None,
orderby: Optional[List[OrderBy]] = None,
extras: Optional[Dict[str, Any]] = None,
order_desc: bool = True,
is_rowcount: bool = False,
apply_fetch_values_predicate: bool = False,
is_timeseries: bool = True,
metrics: Optional[List[Metric]] = None,
orderby: Optional[List[OrderBy]] = None,
order_desc: bool = True,
to_dttm: Optional[datetime] = None,
series_columns: Optional[List[str]] = None,
series_limit: Optional[int] = None,
series_limit_metric: Optional[Metric] = None,
row_limit: Optional[int] = None,
row_offset: Optional[int] = None,
timeseries_limit: Optional[int] = None,
timeseries_limit_metric: Optional[Metric] = None,
) -> SqlaQuery:
"""Querying any sqla table from this common interface"""
if granularity not in self.dttm_cols and granularity is not None:
@ -961,6 +964,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
time_grain = extras.get("time_grain_sqla")
template_kwargs = {
"columns": columns,
"from_dttm": from_dttm.isoformat() if from_dttm else None,
"groupby": groupby,
"metrics": metrics,
@ -969,9 +973,14 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
"time_column": granularity,
"time_grain": time_grain,
"to_dttm": to_dttm.isoformat() if to_dttm else None,
"table_columns": [col.column_name for col in self.columns],
"filter": filter,
"columns": [col.column_name for col in self.columns],
}
series_columns = series_columns or []
# deprecated, to be removed in 2.0
if is_timeseries and timeseries_limit:
series_limit = timeseries_limit
series_limit_metric = series_limit_metric or timeseries_limit_metric
template_kwargs.update(self.template_params_dict)
extra_cache_keys: List[Any] = []
template_kwargs["extra_cache_keys"] = extra_cache_keys
@ -984,8 +993,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
need_groupby = bool(metrics is not None or groupby)
metrics = metrics or []
# Database spec supports join-free timeslot grouping
time_groupby_inline = db_engine_spec.time_groupby_inline
# For backward compatibility
if granularity not in self.dttm_cols and granularity is not None:
granularity = self.main_dttm_col
columns_by_name: Dict[str, TableColumn] = {
col.column_name: col for col in self.columns
@ -1057,7 +1067,8 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
)
select_exprs: List[Union[Column, Label]] = []
groupby_exprs_sans_timestamp = OrderedDict()
groupby_all_columns = {}
groupby_series_columns = {}
# filter out the pseudo column __timestamp from columns
columns = columns or []
@ -1078,7 +1089,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
else:
outer = literal_column(f"({selected})")
outer = self.make_sqla_column_compatible(outer, selected)
groupby_exprs_sans_timestamp[outer.name] = outer
groupby_all_columns[outer.name] = outer
if not series_columns or outer.name in series_columns:
groupby_series_columns[outer.name] = outer
select_exprs.append(outer)
elif columns:
for selected in columns:
@ -1090,7 +1103,6 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
metrics_exprs = []
time_range_endpoints = extras.get("time_range_endpoints")
groupby_exprs_with_timestamp = OrderedDict(groupby_exprs_sans_timestamp.items())
if granularity:
if granularity not in columns_by_name or not dttm_col:
@ -1106,7 +1118,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
timestamp = dttm_col.get_timestamp_expression(time_grain)
# always put timestamp as the first column
select_exprs.insert(0, timestamp)
groupby_exprs_with_timestamp[timestamp.name] = timestamp
groupby_all_columns[timestamp.name] = timestamp
# Use main dttm column to support index with secondary dttm columns.
if (
@ -1142,8 +1154,8 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
tbl = self.get_from_clause(template_processor)
if groupby_exprs_with_timestamp:
qry = qry.group_by(*groupby_exprs_with_timestamp.values())
if groupby_all_columns:
qry = qry.group_by(*groupby_all_columns.values())
where_clause_and = []
having_clause_and = []
@ -1289,13 +1301,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
if row_offset:
qry = qry.offset(row_offset)
if (
is_timeseries
and timeseries_limit
and not time_groupby_inline
and groupby_exprs_sans_timestamp
and dttm_col
):
if db_engine_spec.allows_subqueries and series_limit and groupby_series_columns:
if db_engine_spec.allows_joins:
# some sql dialects require for order by expressions
# to also be in the select clause -- others, e.g. vertica,
@ -1305,32 +1311,37 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
)
inner_groupby_exprs = []
inner_select_exprs = []
for gby_name, gby_obj in groupby_exprs_sans_timestamp.items():
for gby_name, gby_obj in groupby_series_columns.items():
inner = self.make_sqla_column_compatible(gby_obj, gby_name + "__")
inner_groupby_exprs.append(inner)
inner_select_exprs.append(inner)
inner_select_exprs += [inner_main_metric_expr]
subq = select(inner_select_exprs).select_from(tbl)
inner_time_filter = dttm_col.get_time_filter(
inner_from_dttm or from_dttm,
inner_to_dttm or to_dttm,
time_range_endpoints,
)
subq = subq.where(and_(*(where_clause_and + [inner_time_filter])))
inner_time_filter = []
if dttm_col and not db_engine_spec.time_groupby_inline:
inner_time_filter = [
dttm_col.get_time_filter(
inner_from_dttm or from_dttm,
inner_to_dttm or to_dttm,
time_range_endpoints,
)
]
subq = subq.where(and_(*(where_clause_and + inner_time_filter)))
subq = subq.group_by(*inner_groupby_exprs)
ob = inner_main_metric_expr
if timeseries_limit_metric:
ob = self._get_timeseries_orderby(
timeseries_limit_metric, metrics_by_name, columns_by_name
if series_limit_metric:
ob = self._get_series_orderby(
series_limit_metric, metrics_by_name, columns_by_name
)
direction = desc if order_desc else asc
subq = subq.order_by(direction(ob))
subq = subq.limit(timeseries_limit)
subq = subq.limit(series_limit)
on_clause = []
for gby_name, gby_obj in groupby_exprs_sans_timestamp.items():
for gby_name, gby_obj in groupby_series_columns.items():
# in this case the column name, not the alias, needs to be
# conditionally mutated, as it refers to the column alias in
# the inner query
@ -1339,13 +1350,11 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
tbl = tbl.join(subq.alias(), and_(*on_clause))
else:
if timeseries_limit_metric:
if series_limit_metric:
orderby = [
(
self._get_timeseries_orderby(
timeseries_limit_metric,
metrics_by_name,
columns_by_name,
self._get_series_orderby(
series_limit_metric, metrics_by_name, columns_by_name,
),
False,
)
@ -1354,7 +1363,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
# run prequery to get top groups
prequery_obj = {
"is_timeseries": False,
"row_limit": timeseries_limit,
"row_limit": series_limit,
"metrics": metrics,
"granularity": granularity,
"groupby": groupby,
@ -1372,10 +1381,10 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
dimensions = [
c
for c in result.df.columns
if c not in metrics and c in groupby_exprs_sans_timestamp
if c not in metrics and c in groupby_series_columns
]
top_groups = self._get_top_groups(
result.df, dimensions, groupby_exprs_sans_timestamp
result.df, dimensions, groupby_series_columns
)
qry = qry.where(top_groups)
@ -1398,31 +1407,29 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
prequeries=prequeries,
)
def _get_timeseries_orderby(
def _get_series_orderby(
self,
timeseries_limit_metric: Metric,
series_limit_metric: Metric,
metrics_by_name: Dict[str, SqlMetric],
columns_by_name: Dict[str, TableColumn],
) -> Column:
if utils.is_adhoc_metric(timeseries_limit_metric):
assert isinstance(timeseries_limit_metric, dict)
ob = self.adhoc_metric_to_sqla(timeseries_limit_metric, columns_by_name)
if utils.is_adhoc_metric(series_limit_metric):
assert isinstance(series_limit_metric, dict)
ob = self.adhoc_metric_to_sqla(series_limit_metric, columns_by_name)
elif (
isinstance(timeseries_limit_metric, str)
and timeseries_limit_metric in metrics_by_name
isinstance(series_limit_metric, str)
and series_limit_metric in metrics_by_name
):
ob = metrics_by_name[timeseries_limit_metric].get_sqla_col()
ob = metrics_by_name[series_limit_metric].get_sqla_col()
else:
raise QueryObjectValidationError(
_("Metric '%(metric)s' does not exist", metric=timeseries_limit_metric)
_("Metric '%(metric)s' does not exist", metric=series_limit_metric)
)
return ob
def _get_top_groups( # pylint: disable=no-self-use
self,
df: pd.DataFrame,
dimensions: List[str],
groupby_exprs: "OrderedDict[str, Any]",
@staticmethod
def _get_top_groups(
df: pd.DataFrame, dimensions: List[str], groupby_exprs: Dict[str, Any],
) -> ColumnElement:
groups = []
for _unused, row in df.iterrows():

View File

@ -272,6 +272,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
GenericDataType.BOOLEAN,
),
)
# Does database support join-free timeslot grouping
time_groupby_inline = False
limit_method = LimitMethod.FORCE_LIMIT
time_secondary_columns = False

View File

@ -35,7 +35,7 @@ import humanize
import prison
import pytest
import yaml
from sqlalchemy import and_, or_
from sqlalchemy import and_
from sqlalchemy.sql import func
from tests.integration_tests.fixtures.world_bank_dashboard import (
@ -1956,3 +1956,24 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
"verbose_name",
"dtype",
]
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_series_limit(self):
"""
Chart data API: Query total rows
"""
SERIES_LIMIT = 5
self.login(username="admin")
request_payload = get_query_context("birth_names")
request_payload["queries"][0]["columns"] = ["state", "name"]
request_payload["queries"][0]["series_columns"] = ["name"]
request_payload["queries"][0]["series_limit"] = SERIES_LIMIT
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
response_payload = json.loads(rv.data.decode("utf-8"))
data = response_payload["result"][0]["data"]
unique_names = set(row["name"] for row in data)
self.maxDiff = None
self.assertEqual(len(unique_names), SERIES_LIMIT)
self.assertEqual(
set(column for column in data[0].keys()), {"state", "name", "sum__num"}
)

View File

@ -29,7 +29,7 @@ query_birth_names = {
),
"time_grain_sqla": "P1D",
},
"groupby": ["name"],
"columns": ["name"],
"metrics": [{"label": "sum__num"}],
"orderby": [("sum__num", False)],
"row_limit": 100,

View File

@ -432,6 +432,7 @@ class TestSqlaTableModel(SupersetTestCase):
from_dttm=None,
to_dttm=None,
extras=dict(time_grain_sqla="P1Y"),
series_limit=15 if inner_join and is_timeseries else None,
)
qr = tbl.query(query_obj)
self.assertEqual(qr.status, QueryStatus.SUCCESS)

View File

@ -69,7 +69,7 @@ class TestQueryContext(SupersetTestCase):
# check basic properies
self.assertEqual(query.extras, payload_query["extras"])
self.assertEqual(query.filter, payload_query["filters"])
self.assertEqual(query.groupby, payload_query["groupby"])
self.assertEqual(query.columns, payload_query["columns"])
# metrics are mutated during creation
for metric_idx, metric in enumerate(query.metrics):
@ -277,12 +277,20 @@ class TestQueryContext(SupersetTestCase):
"""
self.login(username="admin")
payload = get_query_context("birth_names")
columns = payload["queries"][0]["columns"]
payload["queries"][0]["groupby"] = columns
payload["queries"][0]["timeseries_limit"] = 99
payload["queries"][0]["timeseries_limit_metric"] = "sum__num"
del payload["queries"][0]["columns"]
payload["queries"][0]["granularity_sqla"] = "timecol"
payload["queries"][0]["having_filters"] = [{"col": "a", "op": "==", "val": "b"}]
query_context = ChartDataQueryContextSchema().load(payload)
self.assertEqual(len(query_context.queries), 1)
query_object = query_context.queries[0]
self.assertEqual(query_object.granularity, "timecol")
self.assertEqual(query_object.columns, columns)
self.assertEqual(query_object.series_limit, 99)
self.assertEqual(query_object.series_limit_metric, "sum__num")
self.assertIn("having_druid", query_object.extras)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")