feat: the samples endpoint supports filters and pagination (#20683)

This commit is contained in:
Yongjie Zhao 2022-07-22 20:14:42 +08:00 committed by GitHub
parent 39545352d2
commit f011abae2b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 479 additions and 437 deletions

View File

@ -129,7 +129,7 @@ describe('Test datatable', () => {
});
it('Datapane loads view samples', () => {
cy.intercept(
'api/v1/explore/samples?force=false&datasource_type=table&datasource_id=*',
'datasource/samples?force=false&datasource_type=table&datasource_id=*',
).as('Samples');
cy.contains('Samples')
.click()

View File

@ -602,10 +602,11 @@ export const getDatasourceSamples = async (
datasourceType,
datasourceId,
force,
jsonPayload,
) => {
const endpoint = `/api/v1/explore/samples?force=${force}&datasource_type=${datasourceType}&datasource_id=${datasourceId}`;
const endpoint = `/datasource/samples?force=${force}&datasource_type=${datasourceType}&datasource_id=${datasourceId}`;
try {
const response = await SupersetClient.get({ endpoint });
const response = await SupersetClient.post({ endpoint, jsonPayload });
return response.json.result;
} catch (err) {
const clientError = await getClientErrorObject(err);

View File

@ -29,8 +29,8 @@ import { SamplesPane } from '../components';
import { createSamplesPaneProps } from './fixture';
describe('SamplesPane', () => {
fetchMock.get(
'end:/api/v1/explore/samples?force=false&datasource_type=table&datasource_id=34',
fetchMock.post(
'end:/datasource/samples?force=false&datasource_type=table&datasource_id=34',
{
result: {
data: [],
@ -40,8 +40,8 @@ describe('SamplesPane', () => {
},
);
fetchMock.get(
'end:/api/v1/explore/samples?force=true&datasource_type=table&datasource_id=35',
fetchMock.post(
'end:/datasource/samples?force=true&datasource_type=table&datasource_id=35',
{
result: {
data: [
@ -54,8 +54,8 @@ describe('SamplesPane', () => {
},
);
fetchMock.get(
'end:/api/v1/explore/samples?force=false&datasource_type=table&datasource_id=36',
fetchMock.post(
'end:/datasource/samples?force=false&datasource_type=table&datasource_id=36',
400,
);

View File

@ -21,9 +21,8 @@ from io import BytesIO
from typing import Any
from zipfile import is_zipfile, ZipFile
import simplejson
import yaml
from flask import make_response, request, Response, send_file
from flask import request, Response, send_file
from flask_appbuilder.api import expose, protect, rison, safe
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_babel import ngettext
@ -46,13 +45,11 @@ from superset.datasets.commands.exceptions import (
DatasetInvalidError,
DatasetNotFoundError,
DatasetRefreshFailedError,
DatasetSamplesFailedError,
DatasetUpdateFailedError,
)
from superset.datasets.commands.export import ExportDatasetsCommand
from superset.datasets.commands.importers.dispatcher import ImportDatasetsCommand
from superset.datasets.commands.refresh import RefreshDatasetCommand
from superset.datasets.commands.samples import SamplesDatasetCommand
from superset.datasets.commands.update import UpdateDatasetCommand
from superset.datasets.dao import DatasetDAO
from superset.datasets.filters import DatasetCertifiedFilter, DatasetIsNullOrEmptyFilter
@ -63,7 +60,7 @@ from superset.datasets.schemas import (
get_delete_ids_schema,
get_export_ids_schema,
)
from superset.utils.core import json_int_dttm_ser, parse_boolean_string
from superset.utils.core import parse_boolean_string
from superset.views.base import DatasourceFilter, generate_download_headers
from superset.views.base_api import (
BaseSupersetModelRestApi,
@ -93,7 +90,6 @@ class DatasetRestApi(BaseSupersetModelRestApi):
"bulk_delete",
"refresh",
"related_objects",
"samples",
}
list_columns = [
"id",
@ -775,65 +771,3 @@ class DatasetRestApi(BaseSupersetModelRestApi):
)
command.run()
return self.response(200, message="OK")
@expose("/<pk>/samples")
@protect()
@safe
@statsd_metrics
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.samples",
log_to_statsd=False,
)
def samples(self, pk: int) -> Response:
"""get samples from a Dataset
---
get:
description: >-
get samples from a Dataset
parameters:
- in: path
schema:
type: integer
name: pk
- in: query
schema:
type: boolean
name: force
responses:
200:
description: Dataset samples
content:
application/json:
schema:
type: object
properties:
result:
$ref: '#/components/schemas/ChartDataResponseResult'
401:
$ref: '#/components/responses/401'
403:
$ref: '#/components/responses/403'
404:
$ref: '#/components/responses/404'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
try:
force = parse_boolean_string(request.args.get("force"))
rv = SamplesDatasetCommand(pk, force).run()
response_data = simplejson.dumps(
{"result": rv},
default=json_int_dttm_ser,
ignore_nan=True,
)
resp = make_response(response_data, 200)
resp.headers["Content-Type"] = "application/json; charset=utf-8"
return resp
except DatasetNotFoundError:
return self.response_404()
except DatasetForbiddenError:
return self.response_403()
except DatasetSamplesFailedError as ex:
return self.response_400(message=str(ex))

View File

@ -1,80 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Dict, Optional
from superset import security_manager
from superset.commands.base import BaseCommand
from superset.common.chart_data import ChartDataResultType
from superset.common.query_context_factory import QueryContextFactory
from superset.common.utils.query_cache_manager import QueryCacheManager
from superset.connectors.sqla.models import SqlaTable
from superset.constants import CacheRegion
from superset.datasets.commands.exceptions import (
DatasetForbiddenError,
DatasetNotFoundError,
DatasetSamplesFailedError,
)
from superset.datasets.dao import DatasetDAO
from superset.exceptions import SupersetSecurityException
from superset.utils.core import QueryStatus
logger = logging.getLogger(__name__)
class SamplesDatasetCommand(BaseCommand):
def __init__(self, model_id: int, force: bool):
self._model_id = model_id
self._force = force
self._model: Optional[SqlaTable] = None
def run(self) -> Dict[str, Any]:
self.validate()
if not self._model:
raise DatasetNotFoundError()
qc_instance = QueryContextFactory().create(
datasource={
"type": self._model.type,
"id": self._model.id,
},
queries=[{}],
result_type=ChartDataResultType.SAMPLES,
force=self._force,
)
results = qc_instance.get_payload()
try:
sample_data = results["queries"][0]
error_msg = sample_data.get("error")
if sample_data.get("status") == QueryStatus.FAILED and error_msg:
cache_key = sample_data.get("cache_key")
QueryCacheManager.delete(cache_key, region=CacheRegion.DATA)
raise DatasetSamplesFailedError(error_msg)
return sample_data
except (IndexError, KeyError) as exc:
raise DatasetSamplesFailedError from exc
def validate(self) -> None:
# Validate/populate model exists
self._model = DatasetDAO.find_by_id(self._model_id)
if not self._model:
raise DatasetNotFoundError()
# Check ownership
try:
security_manager.raise_for_ownership(self._model)
except SupersetSecurityException as ex:
raise DatasetForbiddenError() from ex

View File

@ -16,22 +16,14 @@
# under the License.
import logging
import simplejson
from flask import g, make_response, request, Response
from flask import g, request, Response
from flask_appbuilder.api import BaseApi, expose, protect, safe
from superset.charts.commands.exceptions import ChartNotFoundError
from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod
from superset.dao.exceptions import DatasourceNotFound
from superset.explore.commands.get import GetExploreCommand
from superset.explore.commands.parameters import CommandParameters
from superset.explore.commands.samples import SamplesDatasourceCommand
from superset.explore.exceptions import (
DatasetAccessDeniedError,
DatasourceForbiddenError,
DatasourceSamplesFailedError,
WrongEndpointError,
)
from superset.explore.exceptions import DatasetAccessDeniedError, WrongEndpointError
from superset.explore.permalink.exceptions import ExplorePermalinkGetFailedError
from superset.explore.schemas import ExploreContextSchema
from superset.extensions import event_logger
@ -39,16 +31,13 @@ from superset.temporary_cache.commands.exceptions import (
TemporaryCacheAccessDeniedError,
TemporaryCacheResourceNotFoundError,
)
from superset.utils.core import json_int_dttm_ser, parse_boolean_string
logger = logging.getLogger(__name__)
class ExploreRestApi(BaseApi):
method_permission_name = MODEL_API_RW_METHOD_PERMISSION_MAP
include_route_methods = {RouteMethod.GET} | {
"samples",
}
include_route_methods = {RouteMethod.GET}
allow_browser_login = True
class_permission_name = "Explore"
resource_name = "explore"
@ -146,70 +135,3 @@ class ExploreRestApi(BaseApi):
return self.response(403, message=str(ex))
except TemporaryCacheResourceNotFoundError as ex:
return self.response(404, message=str(ex))
@expose("/samples", methods=["GET"])
@protect()
@safe
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.samples",
log_to_statsd=False,
)
def samples(self) -> Response:
"""get samples from a Datasource
---
get:
description: >-
get samples from a Datasource
parameters:
- in: path
schema:
type: integer
name: pk
- in: query
schema:
type: boolean
name: force
responses:
200:
description: Datasource samples
content:
application/json:
schema:
type: object
properties:
result:
$ref: '#/components/schemas/ChartDataResponseResult'
401:
$ref: '#/components/responses/401'
403:
$ref: '#/components/responses/403'
404:
$ref: '#/components/responses/404'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
try:
force = parse_boolean_string(request.args.get("force"))
rv = SamplesDatasourceCommand(
user=g.user,
datasource_type=request.args.get("datasource_type", type=str),
datasource_id=request.args.get("datasource_id", type=int),
force=force,
).run()
response_data = simplejson.dumps(
{"result": rv},
default=json_int_dttm_ser,
ignore_nan=True,
)
resp = make_response(response_data, 200)
resp.headers["Content-Type"] = "application/json; charset=utf-8"
return resp
except DatasourceNotFound:
return self.response_404()
except DatasourceForbiddenError:
return self.response_403()
except DatasourceSamplesFailedError as ex:
return self.response_400(message=str(ex))

View File

@ -1,93 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Dict, Optional
from flask_appbuilder.security.sqla.models import User
from superset import db, security_manager
from superset.commands.base import BaseCommand
from superset.common.chart_data import ChartDataResultType
from superset.common.query_context_factory import QueryContextFactory
from superset.common.utils.query_cache_manager import QueryCacheManager
from superset.constants import CacheRegion
from superset.dao.exceptions import DatasourceNotFound
from superset.datasource.dao import Datasource, DatasourceDAO
from superset.exceptions import SupersetSecurityException
from superset.explore.exceptions import (
DatasourceForbiddenError,
DatasourceSamplesFailedError,
)
from superset.utils.core import DatasourceType, QueryStatus
logger = logging.getLogger(__name__)
class SamplesDatasourceCommand(BaseCommand):
def __init__(
self,
user: User,
datasource_id: Optional[int],
datasource_type: Optional[str],
force: bool,
):
self._actor = user
self._datasource_id = datasource_id
self._datasource_type = datasource_type
self._force = force
self._model: Optional[Datasource] = None
def run(self) -> Dict[str, Any]:
self.validate()
if not self._model:
raise DatasourceNotFound()
qc_instance = QueryContextFactory().create(
datasource={
"type": self._model.type,
"id": self._model.id,
},
queries=[{}],
result_type=ChartDataResultType.SAMPLES,
force=self._force,
)
results = qc_instance.get_payload()
try:
sample_data = results["queries"][0]
error_msg = sample_data.get("error")
if sample_data.get("status") == QueryStatus.FAILED and error_msg:
cache_key = sample_data.get("cache_key")
QueryCacheManager.delete(cache_key, region=CacheRegion.DATA)
raise DatasourceSamplesFailedError(error_msg)
return sample_data
except (IndexError, KeyError) as exc:
raise DatasourceSamplesFailedError from exc
def validate(self) -> None:
# Validate/populate model exists
if self._datasource_type and self._datasource_id:
self._model = DatasourceDAO.get_datasource(
session=db.session,
datasource_type=DatasourceType(self._datasource_type),
datasource_id=self._datasource_id,
)
# Check ownership
try:
security_manager.raise_for_ownership(self._model)
except SupersetSecurityException as ex:
raise DatasourceForbiddenError() from ex

View File

@ -14,11 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any
from typing import Any, Dict
from marshmallow import fields, post_load, Schema
from marshmallow import fields, post_load, pre_load, Schema, validate
from typing_extensions import TypedDict
from superset import app
from superset.charts.schemas import ChartDataFilterSchema
from superset.utils.core import DatasourceType
class ExternalMetadataParams(TypedDict):
datasource_type: str
@ -54,3 +58,27 @@ class ExternalMetadataSchema(Schema):
schema_name=data.get("schema_name", ""),
table_name=data["table_name"],
)
class SamplesPayloadSchema(Schema):
filters = fields.List(fields.Nested(ChartDataFilterSchema), required=False)
@pre_load
# pylint: disable=no-self-use, unused-argument
def handle_none(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
if data is None:
return {}
return data
class SamplesRequestSchema(Schema):
datasource_type = fields.String(
validate=validate.OneOf([e.value for e in DatasourceType]), required=True
)
datasource_id = fields.Integer(required=True)
force = fields.Boolean(load_default=False)
page = fields.Integer(load_default=1)
per_page = fields.Integer(
validate=validate.Range(min=1, max=app.config.get("SAMPLES_ROW_LIMIT", 1000)),
load_default=app.config.get("SAMPLES_ROW_LIMIT", 1000),
)

View File

@ -0,0 +1,115 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict, Optional
from superset import app, db
from superset.common.chart_data import ChartDataResultType
from superset.common.query_context_factory import QueryContextFactory
from superset.common.utils.query_cache_manager import QueryCacheManager
from superset.constants import CacheRegion
from superset.datasets.commands.exceptions import DatasetSamplesFailedError
from superset.datasource.dao import DatasourceDAO
from superset.utils.core import QueryStatus
from superset.views.datasource.schemas import SamplesPayloadSchema
def get_limit_clause(page: Optional[int], per_page: Optional[int]) -> Dict[str, int]:
samples_row_limit = app.config.get("SAMPLES_ROW_LIMIT", 1000)
limit = samples_row_limit
offset = 0
if isinstance(page, int) and isinstance(per_page, int):
limit = int(per_page)
if limit < 0 or limit > samples_row_limit:
# reset limit value if input is invalid
limit = samples_row_limit
offset = max((int(page) - 1) * limit, 0)
return {"row_offset": offset, "row_limit": limit}
def get_samples( # pylint: disable=too-many-arguments,too-many-locals
datasource_type: str,
datasource_id: int,
force: bool = False,
page: int = 1,
per_page: int = 1000,
payload: Optional[SamplesPayloadSchema] = None,
) -> Dict[str, Any]:
datasource = DatasourceDAO.get_datasource(
session=db.session,
datasource_type=datasource_type,
datasource_id=datasource_id,
)
limit_clause = get_limit_clause(page, per_page)
# todo(yongjie): Constructing count(*) and samples in the same query_context,
# then remove query_type==SAMPLES
# constructing samples query
samples_instance = QueryContextFactory().create(
datasource={
"type": datasource.type,
"id": datasource.id,
},
queries=[{**payload, **limit_clause} if payload else limit_clause],
result_type=ChartDataResultType.SAMPLES,
force=force,
)
# constructing count(*) query
count_star_metric = {
"metrics": [
{
"expressionType": "SQL",
"sqlExpression": "COUNT(*)",
"label": "COUNT(*)",
}
]
}
count_star_instance = QueryContextFactory().create(
datasource={
"type": datasource.type,
"id": datasource.id,
},
queries=[{**payload, **count_star_metric} if payload else count_star_metric],
result_type=ChartDataResultType.FULL,
force=force,
)
samples_results = samples_instance.get_payload()
count_star_results = count_star_instance.get_payload()
try:
sample_data = samples_results["queries"][0]
count_star_data = count_star_results["queries"][0]
failed_status = (
sample_data.get("status") == QueryStatus.FAILED
or count_star_data.get("status") == QueryStatus.FAILED
)
error_msg = sample_data.get("error") or count_star_data.get("error")
if failed_status and error_msg:
cache_key = sample_data.get("cache_key")
QueryCacheManager.delete(cache_key, region=CacheRegion.DATA)
raise DatasetSamplesFailedError(error_msg)
sample_data["page"] = page
sample_data["per_page"] = per_page
sample_data["total_count"] = count_star_data["data"][0]["COUNT(*)"]
return sample_data
except (IndexError, KeyError) as exc:
raise DatasetSamplesFailedError from exc

View File

@ -50,7 +50,10 @@ from superset.views.datasource.schemas import (
ExternalMetadataParams,
ExternalMetadataSchema,
get_external_metadata_schema,
SamplesPayloadSchema,
SamplesRequestSchema,
)
from superset.views.datasource.utils import get_samples
from superset.views.utils import sanitize_datasource_data
@ -179,3 +182,24 @@ class Datasource(BaseSupersetView):
except (NoResultFound, NoSuchTableError) as ex:
raise DatasetNotFoundError() from ex
return self.json_response(external_metadata)
@expose("/samples", methods=["POST"])
@has_access_api
@api
@handle_api_exception
def samples(self) -> FlaskResponse:
try:
params = SamplesRequestSchema().load(request.args)
payload = SamplesPayloadSchema().load(request.json)
except ValidationError as err:
return json_error_response(err.messages, status=400)
rv = get_samples(
datasource_type=params["datasource_type"],
datasource_id=params["datasource_id"],
force=params["force"],
page=params["page"],
per_page=params["per_page"],
payload=payload,
)
return self.json_response({"result": rv})

View File

@ -206,3 +206,107 @@ def with_feature_flags(**mock_feature_flags):
return functools.update_wrapper(wrapper, test_fn)
return decorate
@pytest.fixture
def virtual_dataset():
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
dataset = SqlaTable(
table_name="virtual_dataset",
sql=(
"SELECT 0 as col1, 'a' as col2, 1.0 as col3, NULL as col4, '2000-01-01 00:00:00' as col5 "
"UNION ALL "
"SELECT 1, 'b', 1.1, NULL, '2000-01-02 00:00:00' "
"UNION ALL "
"SELECT 2 as col1, 'c' as col2, 1.2, NULL, '2000-01-03 00:00:00' "
"UNION ALL "
"SELECT 3 as col1, 'd' as col2, 1.3, NULL, '2000-01-04 00:00:00' "
"UNION ALL "
"SELECT 4 as col1, 'e' as col2, 1.4, NULL, '2000-01-05 00:00:00' "
"UNION ALL "
"SELECT 5 as col1, 'f' as col2, 1.5, NULL, '2000-01-06 00:00:00' "
"UNION ALL "
"SELECT 6 as col1, 'g' as col2, 1.6, NULL, '2000-01-07 00:00:00' "
"UNION ALL "
"SELECT 7 as col1, 'h' as col2, 1.7, NULL, '2000-01-08 00:00:00' "
"UNION ALL "
"SELECT 8 as col1, 'i' as col2, 1.8, NULL, '2000-01-09 00:00:00' "
"UNION ALL "
"SELECT 9 as col1, 'j' as col2, 1.9, NULL, '2000-01-10 00:00:00' "
),
database=get_example_database(),
)
TableColumn(column_name="col1", type="INTEGER", table=dataset)
TableColumn(column_name="col2", type="VARCHAR(255)", table=dataset)
TableColumn(column_name="col3", type="DECIMAL(4,2)", table=dataset)
TableColumn(column_name="col4", type="VARCHAR(255)", table=dataset)
# Different database dialect datetime type is not consistent, so temporarily use varchar
TableColumn(column_name="col5", type="VARCHAR(255)", table=dataset)
SqlMetric(metric_name="count", expression="count(*)", table=dataset)
db.session.merge(dataset)
yield dataset
db.session.delete(dataset)
db.session.commit()
@pytest.fixture
def physical_dataset():
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
example_database = get_example_database()
engine = example_database.get_sqla_engine()
# sqlite can only execute one statement at a time
engine.execute(
"""
CREATE TABLE IF NOT EXISTS physical_dataset(
col1 INTEGER,
col2 VARCHAR(255),
col3 DECIMAL(4,2),
col4 VARCHAR(255),
col5 VARCHAR(255)
);
"""
)
engine.execute(
"""
INSERT INTO physical_dataset values
(0, 'a', 1.0, NULL, '2000-01-01 00:00:00'),
(1, 'b', 1.1, NULL, '2000-01-02 00:00:00'),
(2, 'c', 1.2, NULL, '2000-01-03 00:00:00'),
(3, 'd', 1.3, NULL, '2000-01-04 00:00:00'),
(4, 'e', 1.4, NULL, '2000-01-05 00:00:00'),
(5, 'f', 1.5, NULL, '2000-01-06 00:00:00'),
(6, 'g', 1.6, NULL, '2000-01-07 00:00:00'),
(7, 'h', 1.7, NULL, '2000-01-08 00:00:00'),
(8, 'i', 1.8, NULL, '2000-01-09 00:00:00'),
(9, 'j', 1.9, NULL, '2000-01-10 00:00:00');
"""
)
dataset = SqlaTable(
table_name="physical_dataset",
database=example_database,
)
TableColumn(column_name="col1", type="INTEGER", table=dataset)
TableColumn(column_name="col2", type="VARCHAR(255)", table=dataset)
TableColumn(column_name="col3", type="DECIMAL(4,2)", table=dataset)
TableColumn(column_name="col4", type="VARCHAR(255)", table=dataset)
TableColumn(column_name="col5", type="VARCHAR(255)", table=dataset)
SqlMetric(metric_name="count", expression="count(*)", table=dataset)
db.session.merge(dataset)
if example_database.backend == "sqlite":
db.session.commit()
yield dataset
engine.execute(
"""
DROP TABLE physical_dataset;
"""
)
db.session.delete(dataset)
db.session.commit()

View File

@ -27,9 +27,7 @@ import pytest
import yaml
from sqlalchemy.sql import func
from superset.common.utils.query_cache_manager import QueryCacheManager
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.constants import CacheRegion
from superset.dao.exceptions import (
DAOCreateFailedError,
DAODeleteFailedError,
@ -2085,102 +2083,3 @@ class TestDatasetApi(SupersetTestCase):
db.session.delete(table_w_certification)
db.session.commit()
@pytest.mark.usefixtures("create_datasets")
def test_get_dataset_samples(self):
"""
Dataset API: Test get dataset samples
"""
if backend() == "sqlite":
return
dataset = self.get_fixture_datasets()[0]
self.login(username="admin")
uri = f"api/v1/dataset/{dataset.id}/samples"
# 1. should cache data
# feeds data
self.client.get(uri)
# get from cache
rv = self.client.get(uri)
rv_data = json.loads(rv.data)
assert rv.status_code == 200
assert "result" in rv_data
assert rv_data["result"]["cached_dttm"] is not None
cache_key1 = rv_data["result"]["cache_key"]
assert QueryCacheManager.has(cache_key1, region=CacheRegion.DATA)
# 2. should through cache
uri2 = f"api/v1/dataset/{dataset.id}/samples?force=true"
# feeds data
self.client.get(uri2)
# force query
rv2 = self.client.get(uri2)
rv_data2 = json.loads(rv2.data)
assert rv_data2["result"]["cached_dttm"] is None
cache_key2 = rv_data2["result"]["cache_key"]
assert QueryCacheManager.has(cache_key2, region=CacheRegion.DATA)
# 3. data precision
assert "colnames" in rv_data2["result"]
assert "coltypes" in rv_data2["result"]
assert "data" in rv_data2["result"]
eager_samples = dataset.database.get_df(
f"select * from {dataset.table_name}"
f' limit {self.app.config["SAMPLES_ROW_LIMIT"]}'
).to_dict(orient="records")
assert eager_samples == rv_data2["result"]["data"]
@pytest.mark.usefixtures("create_datasets")
def test_get_dataset_samples_with_failed_cc(self):
if backend() == "sqlite":
return
dataset = self.get_fixture_datasets()[0]
self.login(username="admin")
failed_column = TableColumn(
column_name="DUMMY CC",
type="VARCHAR(255)",
table=dataset,
expression="INCORRECT SQL",
)
uri = f"api/v1/dataset/{dataset.id}/samples"
dataset.columns.append(failed_column)
rv = self.client.get(uri)
assert rv.status_code == 400
rv_data = json.loads(rv.data)
assert "message" in rv_data
if dataset.database.db_engine_spec.engine_name == "PostgreSQL":
assert "INCORRECT SQL" in rv_data.get("message")
def test_get_dataset_samples_on_virtual_dataset(self):
if backend() == "sqlite":
return
virtual_dataset = SqlaTable(
table_name="virtual_dataset",
sql=("SELECT 'foo' as foo, 'bar' as bar"),
database=get_example_database(),
)
TableColumn(column_name="foo", type="VARCHAR(255)", table=virtual_dataset)
TableColumn(column_name="bar", type="VARCHAR(255)", table=virtual_dataset)
SqlMetric(metric_name="count", expression="count(*)", table=virtual_dataset)
self.login(username="admin")
uri = f"api/v1/dataset/{virtual_dataset.id}/samples"
rv = self.client.get(uri)
assert rv.status_code == 200
rv_data = json.loads(rv.data)
cache_key = rv_data["result"]["cache_key"]
assert QueryCacheManager.has(cache_key, region=CacheRegion.DATA)
# remove original column in dataset
virtual_dataset.sql = "SELECT 'foo' as foo"
rv = self.client.get(uri)
assert rv.status_code == 400
db.session.delete(virtual_dataset)
db.session.commit()

View File

@ -23,13 +23,15 @@ import prison
import pytest
from superset import app, db
from superset.connectors.sqla.models import SqlaTable
from superset.common.utils.query_cache_manager import QueryCacheManager
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.constants import CacheRegion
from superset.dao.exceptions import DatasourceNotFound, DatasourceTypeNotSupportedError
from superset.datasets.commands.exceptions import DatasetNotFoundError
from superset.exceptions import SupersetGenericDBErrorException
from superset.models.core import Database
from superset.utils.core import DatasourceType, get_example_default_schema
from superset.utils.database import get_example_database
from superset.utils.core import backend, get_example_default_schema
from superset.utils.database import get_example_database, get_main_database
from tests.integration_tests.base_tests import db_insert_temp_object, SupersetTestCase
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
@ -416,3 +418,189 @@ class TestDatasource(SupersetTestCase):
self.login(username="admin")
resp = self.get_json_resp("/datasource/get/druid/500000/", raise_on_error=False)
self.assertEqual(resp.get("error"), "'druid' is not a valid DatasourceType")
def test_get_samples(test_client, login_as_admin, virtual_dataset):
"""
Dataset API: Test get dataset samples
"""
# 1. should cache data
uri = (
f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table"
)
# feeds data
test_client.post(uri)
# get from cache
rv = test_client.post(uri)
rv_data = json.loads(rv.data)
assert rv.status_code == 200
assert len(rv_data["result"]["data"]) == 10
assert QueryCacheManager.has(
rv_data["result"]["cache_key"],
region=CacheRegion.DATA,
)
assert rv_data["result"]["is_cached"]
# 2. should read through cache data
uri2 = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&force=true"
# feeds data
test_client.post(uri2)
# force query
rv2 = test_client.post(uri2)
rv_data2 = json.loads(rv2.data)
assert rv2.status_code == 200
assert len(rv_data2["result"]["data"]) == 10
assert QueryCacheManager.has(
rv_data2["result"]["cache_key"],
region=CacheRegion.DATA,
)
assert not rv_data2["result"]["is_cached"]
# 3. data precision
assert "colnames" in rv_data2["result"]
assert "coltypes" in rv_data2["result"]
assert "data" in rv_data2["result"]
eager_samples = virtual_dataset.database.get_df(
f"select * from ({virtual_dataset.sql}) as tbl"
f' limit {app.config["SAMPLES_ROW_LIMIT"]}'
)
# the col3 is Decimal
eager_samples["col3"] = eager_samples["col3"].apply(float)
eager_samples = eager_samples.to_dict(orient="records")
assert eager_samples == rv_data2["result"]["data"]
def test_get_samples_with_incorrect_cc(test_client, login_as_admin, virtual_dataset):
TableColumn(
column_name="DUMMY CC",
type="VARCHAR(255)",
table=virtual_dataset,
expression="INCORRECT SQL",
)
db.session.merge(virtual_dataset)
uri = (
f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table"
)
rv = test_client.post(uri)
assert rv.status_code == 422
rv_data = json.loads(rv.data)
assert "error" in rv_data
if virtual_dataset.database.db_engine_spec.engine_name == "PostgreSQL":
assert "INCORRECT SQL" in rv_data.get("error")
def test_get_samples_on_physical_dataset(test_client, login_as_admin, physical_dataset):
uri = (
f"/datasource/samples?datasource_id={physical_dataset.id}&datasource_type=table"
)
rv = test_client.post(uri)
assert rv.status_code == 200
rv_data = json.loads(rv.data)
assert QueryCacheManager.has(
rv_data["result"]["cache_key"], region=CacheRegion.DATA
)
assert len(rv_data["result"]["data"]) == 10
def test_get_samples_with_filters(test_client, login_as_admin, virtual_dataset):
uri = (
f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table"
)
rv = test_client.post(uri, json=None)
assert rv.status_code == 200
rv = test_client.post(uri, json={})
assert rv.status_code == 200
rv = test_client.post(uri, json={"foo": "bar"})
assert rv.status_code == 400
rv = test_client.post(
uri, json={"filters": [{"col": "col1", "op": "INVALID", "val": 0}]}
)
assert rv.status_code == 400
rv = test_client.post(
uri,
json={
"filters": [
{"col": "col2", "op": "==", "val": "a"},
{"col": "col1", "op": "==", "val": 0},
]
},
)
assert rv.status_code == 200
rv_data = json.loads(rv.data)
assert rv_data["result"]["colnames"] == ["col1", "col2", "col3", "col4", "col5"]
assert rv_data["result"]["rowcount"] == 1
# empty results
rv = test_client.post(
uri,
json={
"filters": [
{"col": "col2", "op": "==", "val": "x"},
]
},
)
assert rv.status_code == 200
rv_data = json.loads(rv.data)
assert rv_data["result"]["colnames"] == []
assert rv_data["result"]["rowcount"] == 0
def test_get_samples_pagination(test_client, login_as_admin, virtual_dataset):
# 1. default page, per_page and total_count
uri = (
f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table"
)
rv = test_client.post(uri)
rv_data = json.loads(rv.data)
assert rv_data["result"]["page"] == 1
assert rv_data["result"]["per_page"] == app.config["SAMPLES_ROW_LIMIT"]
assert rv_data["result"]["total_count"] == 10
# 2. incorrect per_page
per_pages = (app.config["SAMPLES_ROW_LIMIT"] + 1, 0, "xx")
for per_page in per_pages:
uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page={per_page}"
rv = test_client.post(uri)
assert rv.status_code == 400
# 3. incorrect page or datasource_type
uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&page=xx"
rv = test_client.post(uri)
assert rv.status_code == 400
uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=xx"
rv = test_client.post(uri)
assert rv.status_code == 400
# 4. turning pages
uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page=2&page=1"
rv = test_client.post(uri)
rv_data = json.loads(rv.data)
assert rv_data["result"]["page"] == 1
assert rv_data["result"]["per_page"] == 2
assert rv_data["result"]["total_count"] == 10
assert [row["col1"] for row in rv_data["result"]["data"]] == [0, 1]
uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page=2&page=2"
rv = test_client.post(uri)
rv_data = json.loads(rv.data)
assert rv_data["result"]["page"] == 2
assert rv_data["result"]["per_page"] == 2
assert rv_data["result"]["total_count"] == 10
assert [row["col1"] for row in rv_data["result"]["data"]] == [2, 3]
# 5. Exceeding the maximum pages
uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page=2&page=6"
rv = test_client.post(uri)
rv_data = json.loads(rv.data)
assert rv_data["result"]["page"] == 6
assert rv_data["result"]["per_page"] == 2
assert rv_data["result"]["total_count"] == 10
assert [row["col1"] for row in rv_data["result"]["data"]] == []