mirror of https://github.com/apache/superset.git
feat: supports mulitple filters in samples endpoint (#21008)
This commit is contained in:
parent
e214e1ace6
commit
802b69f97b
|
@ -38,3 +38,4 @@ class ChartDataResultType(str, Enum):
|
|||
SAMPLES = "samples"
|
||||
TIMEGRAINS = "timegrains"
|
||||
POST_PROCESSED = "post_processed"
|
||||
DRILL_DETAIL = "drill_detail"
|
||||
|
|
|
@ -162,6 +162,27 @@ def _get_samples(
|
|||
return _get_full(query_context, query_obj, force_cached)
|
||||
|
||||
|
||||
def _get_drill_detail(
|
||||
query_context: QueryContext, query_obj: QueryObject, force_cached: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
# todo(yongjie): Remove this function,
|
||||
# when determining whether samples should be applied to the time filter.
|
||||
datasource = _get_datasource(query_context, query_obj)
|
||||
query_obj = copy.copy(query_obj)
|
||||
query_obj.is_timeseries = False
|
||||
query_obj.orderby = []
|
||||
query_obj.metrics = None
|
||||
query_obj.post_processing = []
|
||||
qry_obj_cols = []
|
||||
for o in datasource.columns:
|
||||
if isinstance(o, dict):
|
||||
qry_obj_cols.append(o.get("column_name"))
|
||||
else:
|
||||
qry_obj_cols.append(o.column_name)
|
||||
query_obj.columns = qry_obj_cols
|
||||
return _get_full(query_context, query_obj, force_cached)
|
||||
|
||||
|
||||
def _get_results(
|
||||
query_context: QueryContext, query_obj: QueryObject, force_cached: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
|
@ -182,6 +203,7 @@ _result_type_functions: Dict[
|
|||
# and post-process it later where we have the chart context, since
|
||||
# post-processing is unique to each visualization type
|
||||
ChartDataResultType.POST_PROCESSED: _get_full,
|
||||
ChartDataResultType.DRILL_DETAIL: _get_drill_detail,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ from marshmallow import fields, post_load, pre_load, Schema, validate
|
|||
from typing_extensions import TypedDict
|
||||
|
||||
from superset import app
|
||||
from superset.charts.schemas import ChartDataFilterSchema
|
||||
from superset.charts.schemas import ChartDataExtrasSchema, ChartDataFilterSchema
|
||||
from superset.utils.core import DatasourceType
|
||||
|
||||
|
||||
|
@ -62,6 +62,17 @@ class ExternalMetadataSchema(Schema):
|
|||
|
||||
class SamplesPayloadSchema(Schema):
|
||||
filters = fields.List(fields.Nested(ChartDataFilterSchema), required=False)
|
||||
granularity = fields.String(
|
||||
allow_none=True,
|
||||
)
|
||||
time_range = fields.String(
|
||||
allow_none=True,
|
||||
)
|
||||
extras = fields.Nested(
|
||||
ChartDataExtrasSchema,
|
||||
description="Extra parameters to add to the query.",
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
@pre_load
|
||||
# pylint: disable=no-self-use, unused-argument
|
||||
|
|
|
@ -60,17 +60,30 @@ def get_samples( # pylint: disable=too-many-arguments,too-many-locals
|
|||
limit_clause = get_limit_clause(page, per_page)
|
||||
|
||||
# todo(yongjie): Constructing count(*) and samples in the same query_context,
|
||||
# then remove query_type==SAMPLES
|
||||
# constructing samples query
|
||||
samples_instance = QueryContextFactory().create(
|
||||
datasource={
|
||||
"type": datasource.type,
|
||||
"id": datasource.id,
|
||||
},
|
||||
queries=[{**payload, **limit_clause} if payload else limit_clause],
|
||||
result_type=ChartDataResultType.SAMPLES,
|
||||
force=force,
|
||||
)
|
||||
if payload is None:
|
||||
# constructing samples query
|
||||
samples_instance = QueryContextFactory().create(
|
||||
datasource={
|
||||
"type": datasource.type,
|
||||
"id": datasource.id,
|
||||
},
|
||||
queries=[limit_clause],
|
||||
result_type=ChartDataResultType.SAMPLES,
|
||||
force=force,
|
||||
)
|
||||
else:
|
||||
# constructing drill detail query
|
||||
# When query_type == 'samples' the `time filter` will be removed,
|
||||
# so it is not applicable drill detail query
|
||||
samples_instance = QueryContextFactory().create(
|
||||
datasource={
|
||||
"type": datasource.type,
|
||||
"id": datasource.id,
|
||||
},
|
||||
queries=[{**payload, **limit_clause}],
|
||||
result_type=ChartDataResultType.DRILL_DETAIL,
|
||||
force=force,
|
||||
)
|
||||
|
||||
# constructing count(*) query
|
||||
count_star_metric = {
|
||||
|
|
|
@ -314,7 +314,7 @@ def physical_dataset():
|
|||
col2 VARCHAR(255),
|
||||
col3 DECIMAL(4,2),
|
||||
col4 VARCHAR(255),
|
||||
col5 VARCHAR(255)
|
||||
col5 TIMESTAMP
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
@ -342,11 +342,10 @@ def physical_dataset():
|
|||
TableColumn(column_name="col2", type="VARCHAR(255)", table=dataset)
|
||||
TableColumn(column_name="col3", type="DECIMAL(4,2)", table=dataset)
|
||||
TableColumn(column_name="col4", type="VARCHAR(255)", table=dataset)
|
||||
TableColumn(column_name="col5", type="VARCHAR(255)", table=dataset)
|
||||
TableColumn(column_name="col5", type="TIMESTAMP", is_dttm=True, table=dataset)
|
||||
SqlMetric(metric_name="count", expression="count(*)", table=dataset)
|
||||
db.session.merge(dataset)
|
||||
if example_database.backend == "sqlite":
|
||||
db.session.commit()
|
||||
db.session.commit()
|
||||
|
||||
yield dataset
|
||||
|
||||
|
@ -355,5 +354,7 @@ def physical_dataset():
|
|||
DROP TABLE physical_dataset;
|
||||
"""
|
||||
)
|
||||
db.session.delete(dataset)
|
||||
dataset = db.session.query(SqlaTable).filter_by(table_name="physical_dataset").all()
|
||||
for ds in dataset:
|
||||
db.session.delete(ds)
|
||||
db.session.commit()
|
||||
|
|
|
@ -432,14 +432,13 @@ def test_get_samples(test_client, login_as_admin, virtual_dataset):
|
|||
test_client.post(uri)
|
||||
# get from cache
|
||||
rv = test_client.post(uri)
|
||||
rv_data = json.loads(rv.data)
|
||||
assert rv.status_code == 200
|
||||
assert len(rv_data["result"]["data"]) == 10
|
||||
assert len(rv.json["result"]["data"]) == 10
|
||||
assert QueryCacheManager.has(
|
||||
rv_data["result"]["cache_key"],
|
||||
rv.json["result"]["cache_key"],
|
||||
region=CacheRegion.DATA,
|
||||
)
|
||||
assert rv_data["result"]["is_cached"]
|
||||
assert rv.json["result"]["is_cached"]
|
||||
|
||||
# 2. should read through cache data
|
||||
uri2 = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&force=true"
|
||||
|
@ -447,19 +446,18 @@ def test_get_samples(test_client, login_as_admin, virtual_dataset):
|
|||
test_client.post(uri2)
|
||||
# force query
|
||||
rv2 = test_client.post(uri2)
|
||||
rv_data2 = json.loads(rv2.data)
|
||||
assert rv2.status_code == 200
|
||||
assert len(rv_data2["result"]["data"]) == 10
|
||||
assert len(rv2.json["result"]["data"]) == 10
|
||||
assert QueryCacheManager.has(
|
||||
rv_data2["result"]["cache_key"],
|
||||
rv2.json["result"]["cache_key"],
|
||||
region=CacheRegion.DATA,
|
||||
)
|
||||
assert not rv_data2["result"]["is_cached"]
|
||||
assert not rv2.json["result"]["is_cached"]
|
||||
|
||||
# 3. data precision
|
||||
assert "colnames" in rv_data2["result"]
|
||||
assert "coltypes" in rv_data2["result"]
|
||||
assert "data" in rv_data2["result"]
|
||||
assert "colnames" in rv2.json["result"]
|
||||
assert "coltypes" in rv2.json["result"]
|
||||
assert "data" in rv2.json["result"]
|
||||
|
||||
eager_samples = virtual_dataset.database.get_df(
|
||||
f"select * from ({virtual_dataset.sql}) as tbl"
|
||||
|
@ -468,7 +466,7 @@ def test_get_samples(test_client, login_as_admin, virtual_dataset):
|
|||
# the col3 is Decimal
|
||||
eager_samples["col3"] = eager_samples["col3"].apply(float)
|
||||
eager_samples = eager_samples.to_dict(orient="records")
|
||||
assert eager_samples == rv_data2["result"]["data"]
|
||||
assert eager_samples == rv2.json["result"]["data"]
|
||||
|
||||
|
||||
def test_get_samples_with_incorrect_cc(test_client, login_as_admin, virtual_dataset):
|
||||
|
@ -486,10 +484,9 @@ def test_get_samples_with_incorrect_cc(test_client, login_as_admin, virtual_data
|
|||
rv = test_client.post(uri)
|
||||
assert rv.status_code == 422
|
||||
|
||||
rv_data = json.loads(rv.data)
|
||||
assert "error" in rv_data
|
||||
assert "error" in rv.json
|
||||
if virtual_dataset.database.db_engine_spec.engine_name == "PostgreSQL":
|
||||
assert "INCORRECT SQL" in rv_data.get("error")
|
||||
assert "INCORRECT SQL" in rv.json.get("error")
|
||||
|
||||
|
||||
def test_get_samples_on_physical_dataset(test_client, login_as_admin, physical_dataset):
|
||||
|
@ -498,11 +495,10 @@ def test_get_samples_on_physical_dataset(test_client, login_as_admin, physical_d
|
|||
)
|
||||
rv = test_client.post(uri)
|
||||
assert rv.status_code == 200
|
||||
rv_data = json.loads(rv.data)
|
||||
assert QueryCacheManager.has(
|
||||
rv_data["result"]["cache_key"], region=CacheRegion.DATA
|
||||
rv.json["result"]["cache_key"], region=CacheRegion.DATA
|
||||
)
|
||||
assert len(rv_data["result"]["data"]) == 10
|
||||
assert len(rv.json["result"]["data"]) == 10
|
||||
|
||||
|
||||
def test_get_samples_with_filters(test_client, login_as_admin, virtual_dataset):
|
||||
|
@ -533,9 +529,8 @@ def test_get_samples_with_filters(test_client, login_as_admin, virtual_dataset):
|
|||
},
|
||||
)
|
||||
assert rv.status_code == 200
|
||||
rv_data = json.loads(rv.data)
|
||||
assert rv_data["result"]["colnames"] == ["col1", "col2", "col3", "col4", "col5"]
|
||||
assert rv_data["result"]["rowcount"] == 1
|
||||
assert rv.json["result"]["colnames"] == ["col1", "col2", "col3", "col4", "col5"]
|
||||
assert rv.json["result"]["rowcount"] == 1
|
||||
|
||||
# empty results
|
||||
rv = test_client.post(
|
||||
|
@ -547,9 +542,64 @@ def test_get_samples_with_filters(test_client, login_as_admin, virtual_dataset):
|
|||
},
|
||||
)
|
||||
assert rv.status_code == 200
|
||||
rv_data = json.loads(rv.data)
|
||||
assert rv_data["result"]["colnames"] == []
|
||||
assert rv_data["result"]["rowcount"] == 0
|
||||
assert rv.json["result"]["colnames"] == []
|
||||
assert rv.json["result"]["rowcount"] == 0
|
||||
|
||||
|
||||
def test_get_samples_with_time_filter(test_client, login_as_admin, physical_dataset):
|
||||
uri = (
|
||||
f"/datasource/samples?datasource_id={physical_dataset.id}&datasource_type=table"
|
||||
)
|
||||
payload = {
|
||||
"granularity": "col5",
|
||||
"time_range": "2000-01-02 : 2000-01-04",
|
||||
}
|
||||
rv = test_client.post(uri, json=payload)
|
||||
assert len(rv.json["result"]["data"]) == 2
|
||||
if physical_dataset.database.backend != "sqlite":
|
||||
assert [row["col5"] for row in rv.json["result"]["data"]] == [
|
||||
946771200000.0, # 2000-01-02 00:00:00
|
||||
946857600000.0, # 2000-01-03 00:00:00
|
||||
]
|
||||
assert rv.json["result"]["page"] == 1
|
||||
assert rv.json["result"]["per_page"] == app.config["SAMPLES_ROW_LIMIT"]
|
||||
assert rv.json["result"]["total_count"] == 2
|
||||
|
||||
|
||||
def test_get_samples_with_multiple_filters(
|
||||
test_client, login_as_admin, physical_dataset
|
||||
):
|
||||
# 1. empty response
|
||||
uri = (
|
||||
f"/datasource/samples?datasource_id={physical_dataset.id}&datasource_type=table"
|
||||
)
|
||||
payload = {
|
||||
"granularity": "col5",
|
||||
"time_range": "2000-01-02 : 2000-01-04",
|
||||
"filters": [
|
||||
{"col": "col4", "op": "IS NOT NULL"},
|
||||
],
|
||||
}
|
||||
rv = test_client.post(uri, json=payload)
|
||||
assert len(rv.json["result"]["data"]) == 0
|
||||
|
||||
# 2. adhoc filters, time filters, and custom where
|
||||
payload = {
|
||||
"granularity": "col5",
|
||||
"time_range": "2000-01-02 : 2000-01-04",
|
||||
"filters": [
|
||||
{"col": "col2", "op": "==", "val": "c"},
|
||||
],
|
||||
"extras": {"where": "col3 = 1.2 and col4 is null"},
|
||||
}
|
||||
rv = test_client.post(uri, json=payload)
|
||||
assert len(rv.json["result"]["data"]) == 1
|
||||
assert rv.json["result"]["total_count"] == 1
|
||||
assert "2000-01-02" in rv.json["result"]["query"]
|
||||
assert "2000-01-04" in rv.json["result"]["query"]
|
||||
assert "col3 = 1.2" in rv.json["result"]["query"]
|
||||
assert "col4 is null" in rv.json["result"]["query"]
|
||||
assert "col2 = 'c'" in rv.json["result"]["query"]
|
||||
|
||||
|
||||
def test_get_samples_pagination(test_client, login_as_admin, virtual_dataset):
|
||||
|
@ -558,10 +608,9 @@ def test_get_samples_pagination(test_client, login_as_admin, virtual_dataset):
|
|||
f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table"
|
||||
)
|
||||
rv = test_client.post(uri)
|
||||
rv_data = json.loads(rv.data)
|
||||
assert rv_data["result"]["page"] == 1
|
||||
assert rv_data["result"]["per_page"] == app.config["SAMPLES_ROW_LIMIT"]
|
||||
assert rv_data["result"]["total_count"] == 10
|
||||
assert rv.json["result"]["page"] == 1
|
||||
assert rv.json["result"]["per_page"] == app.config["SAMPLES_ROW_LIMIT"]
|
||||
assert rv.json["result"]["total_count"] == 10
|
||||
|
||||
# 2. incorrect per_page
|
||||
per_pages = (app.config["SAMPLES_ROW_LIMIT"] + 1, 0, "xx")
|
||||
|
@ -582,25 +631,22 @@ def test_get_samples_pagination(test_client, login_as_admin, virtual_dataset):
|
|||
# 4. turning pages
|
||||
uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page=2&page=1"
|
||||
rv = test_client.post(uri)
|
||||
rv_data = json.loads(rv.data)
|
||||
assert rv_data["result"]["page"] == 1
|
||||
assert rv_data["result"]["per_page"] == 2
|
||||
assert rv_data["result"]["total_count"] == 10
|
||||
assert [row["col1"] for row in rv_data["result"]["data"]] == [0, 1]
|
||||
assert rv.json["result"]["page"] == 1
|
||||
assert rv.json["result"]["per_page"] == 2
|
||||
assert rv.json["result"]["total_count"] == 10
|
||||
assert [row["col1"] for row in rv.json["result"]["data"]] == [0, 1]
|
||||
|
||||
uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page=2&page=2"
|
||||
rv = test_client.post(uri)
|
||||
rv_data = json.loads(rv.data)
|
||||
assert rv_data["result"]["page"] == 2
|
||||
assert rv_data["result"]["per_page"] == 2
|
||||
assert rv_data["result"]["total_count"] == 10
|
||||
assert [row["col1"] for row in rv_data["result"]["data"]] == [2, 3]
|
||||
assert rv.json["result"]["page"] == 2
|
||||
assert rv.json["result"]["per_page"] == 2
|
||||
assert rv.json["result"]["total_count"] == 10
|
||||
assert [row["col1"] for row in rv.json["result"]["data"]] == [2, 3]
|
||||
|
||||
# 5. Exceeding the maximum pages
|
||||
uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page=2&page=6"
|
||||
rv = test_client.post(uri)
|
||||
rv_data = json.loads(rv.data)
|
||||
assert rv_data["result"]["page"] == 6
|
||||
assert rv_data["result"]["per_page"] == 2
|
||||
assert rv_data["result"]["total_count"] == 10
|
||||
assert [row["col1"] for row in rv_data["result"]["data"]] == []
|
||||
assert rv.json["result"]["page"] == 6
|
||||
assert rv.json["result"]["per_page"] == 2
|
||||
assert rv.json["result"]["total_count"] == 10
|
||||
assert [row["col1"] for row in rv.json["result"]["data"]] == []
|
||||
|
|
Loading…
Reference in New Issue