mirror of https://github.com/apache/superset.git
feat: the samples endpoint supports filters and pagination (#20683)
This commit is contained in:
parent
39545352d2
commit
f011abae2b
|
@ -129,7 +129,7 @@ describe('Test datatable', () => {
|
|||
});
|
||||
it('Datapane loads view samples', () => {
|
||||
cy.intercept(
|
||||
'api/v1/explore/samples?force=false&datasource_type=table&datasource_id=*',
|
||||
'datasource/samples?force=false&datasource_type=table&datasource_id=*',
|
||||
).as('Samples');
|
||||
cy.contains('Samples')
|
||||
.click()
|
||||
|
|
|
@ -602,10 +602,11 @@ export const getDatasourceSamples = async (
|
|||
datasourceType,
|
||||
datasourceId,
|
||||
force,
|
||||
jsonPayload,
|
||||
) => {
|
||||
const endpoint = `/api/v1/explore/samples?force=${force}&datasource_type=${datasourceType}&datasource_id=${datasourceId}`;
|
||||
const endpoint = `/datasource/samples?force=${force}&datasource_type=${datasourceType}&datasource_id=${datasourceId}`;
|
||||
try {
|
||||
const response = await SupersetClient.get({ endpoint });
|
||||
const response = await SupersetClient.post({ endpoint, jsonPayload });
|
||||
return response.json.result;
|
||||
} catch (err) {
|
||||
const clientError = await getClientErrorObject(err);
|
||||
|
|
|
@ -29,8 +29,8 @@ import { SamplesPane } from '../components';
|
|||
import { createSamplesPaneProps } from './fixture';
|
||||
|
||||
describe('SamplesPane', () => {
|
||||
fetchMock.get(
|
||||
'end:/api/v1/explore/samples?force=false&datasource_type=table&datasource_id=34',
|
||||
fetchMock.post(
|
||||
'end:/datasource/samples?force=false&datasource_type=table&datasource_id=34',
|
||||
{
|
||||
result: {
|
||||
data: [],
|
||||
|
@ -40,8 +40,8 @@ describe('SamplesPane', () => {
|
|||
},
|
||||
);
|
||||
|
||||
fetchMock.get(
|
||||
'end:/api/v1/explore/samples?force=true&datasource_type=table&datasource_id=35',
|
||||
fetchMock.post(
|
||||
'end:/datasource/samples?force=true&datasource_type=table&datasource_id=35',
|
||||
{
|
||||
result: {
|
||||
data: [
|
||||
|
@ -54,8 +54,8 @@ describe('SamplesPane', () => {
|
|||
},
|
||||
);
|
||||
|
||||
fetchMock.get(
|
||||
'end:/api/v1/explore/samples?force=false&datasource_type=table&datasource_id=36',
|
||||
fetchMock.post(
|
||||
'end:/datasource/samples?force=false&datasource_type=table&datasource_id=36',
|
||||
400,
|
||||
);
|
||||
|
||||
|
|
|
@ -21,9 +21,8 @@ from io import BytesIO
|
|||
from typing import Any
|
||||
from zipfile import is_zipfile, ZipFile
|
||||
|
||||
import simplejson
|
||||
import yaml
|
||||
from flask import make_response, request, Response, send_file
|
||||
from flask import request, Response, send_file
|
||||
from flask_appbuilder.api import expose, protect, rison, safe
|
||||
from flask_appbuilder.models.sqla.interface import SQLAInterface
|
||||
from flask_babel import ngettext
|
||||
|
@ -46,13 +45,11 @@ from superset.datasets.commands.exceptions import (
|
|||
DatasetInvalidError,
|
||||
DatasetNotFoundError,
|
||||
DatasetRefreshFailedError,
|
||||
DatasetSamplesFailedError,
|
||||
DatasetUpdateFailedError,
|
||||
)
|
||||
from superset.datasets.commands.export import ExportDatasetsCommand
|
||||
from superset.datasets.commands.importers.dispatcher import ImportDatasetsCommand
|
||||
from superset.datasets.commands.refresh import RefreshDatasetCommand
|
||||
from superset.datasets.commands.samples import SamplesDatasetCommand
|
||||
from superset.datasets.commands.update import UpdateDatasetCommand
|
||||
from superset.datasets.dao import DatasetDAO
|
||||
from superset.datasets.filters import DatasetCertifiedFilter, DatasetIsNullOrEmptyFilter
|
||||
|
@ -63,7 +60,7 @@ from superset.datasets.schemas import (
|
|||
get_delete_ids_schema,
|
||||
get_export_ids_schema,
|
||||
)
|
||||
from superset.utils.core import json_int_dttm_ser, parse_boolean_string
|
||||
from superset.utils.core import parse_boolean_string
|
||||
from superset.views.base import DatasourceFilter, generate_download_headers
|
||||
from superset.views.base_api import (
|
||||
BaseSupersetModelRestApi,
|
||||
|
@ -93,7 +90,6 @@ class DatasetRestApi(BaseSupersetModelRestApi):
|
|||
"bulk_delete",
|
||||
"refresh",
|
||||
"related_objects",
|
||||
"samples",
|
||||
}
|
||||
list_columns = [
|
||||
"id",
|
||||
|
@ -775,65 +771,3 @@ class DatasetRestApi(BaseSupersetModelRestApi):
|
|||
)
|
||||
command.run()
|
||||
return self.response(200, message="OK")
|
||||
|
||||
@expose("/<pk>/samples")
|
||||
@protect()
|
||||
@safe
|
||||
@statsd_metrics
|
||||
@event_logger.log_this_with_context(
|
||||
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.samples",
|
||||
log_to_statsd=False,
|
||||
)
|
||||
def samples(self, pk: int) -> Response:
|
||||
"""get samples from a Dataset
|
||||
---
|
||||
get:
|
||||
description: >-
|
||||
get samples from a Dataset
|
||||
parameters:
|
||||
- in: path
|
||||
schema:
|
||||
type: integer
|
||||
name: pk
|
||||
- in: query
|
||||
schema:
|
||||
type: boolean
|
||||
name: force
|
||||
responses:
|
||||
200:
|
||||
description: Dataset samples
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
result:
|
||||
$ref: '#/components/schemas/ChartDataResponseResult'
|
||||
401:
|
||||
$ref: '#/components/responses/401'
|
||||
403:
|
||||
$ref: '#/components/responses/403'
|
||||
404:
|
||||
$ref: '#/components/responses/404'
|
||||
422:
|
||||
$ref: '#/components/responses/422'
|
||||
500:
|
||||
$ref: '#/components/responses/500'
|
||||
"""
|
||||
try:
|
||||
force = parse_boolean_string(request.args.get("force"))
|
||||
rv = SamplesDatasetCommand(pk, force).run()
|
||||
response_data = simplejson.dumps(
|
||||
{"result": rv},
|
||||
default=json_int_dttm_ser,
|
||||
ignore_nan=True,
|
||||
)
|
||||
resp = make_response(response_data, 200)
|
||||
resp.headers["Content-Type"] = "application/json; charset=utf-8"
|
||||
return resp
|
||||
except DatasetNotFoundError:
|
||||
return self.response_404()
|
||||
except DatasetForbiddenError:
|
||||
return self.response_403()
|
||||
except DatasetSamplesFailedError as ex:
|
||||
return self.response_400(message=str(ex))
|
||||
|
|
|
@ -1,80 +0,0 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from superset import security_manager
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.common.chart_data import ChartDataResultType
|
||||
from superset.common.query_context_factory import QueryContextFactory
|
||||
from superset.common.utils.query_cache_manager import QueryCacheManager
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.constants import CacheRegion
|
||||
from superset.datasets.commands.exceptions import (
|
||||
DatasetForbiddenError,
|
||||
DatasetNotFoundError,
|
||||
DatasetSamplesFailedError,
|
||||
)
|
||||
from superset.datasets.dao import DatasetDAO
|
||||
from superset.exceptions import SupersetSecurityException
|
||||
from superset.utils.core import QueryStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SamplesDatasetCommand(BaseCommand):
|
||||
def __init__(self, model_id: int, force: bool):
|
||||
self._model_id = model_id
|
||||
self._force = force
|
||||
self._model: Optional[SqlaTable] = None
|
||||
|
||||
def run(self) -> Dict[str, Any]:
|
||||
self.validate()
|
||||
if not self._model:
|
||||
raise DatasetNotFoundError()
|
||||
|
||||
qc_instance = QueryContextFactory().create(
|
||||
datasource={
|
||||
"type": self._model.type,
|
||||
"id": self._model.id,
|
||||
},
|
||||
queries=[{}],
|
||||
result_type=ChartDataResultType.SAMPLES,
|
||||
force=self._force,
|
||||
)
|
||||
results = qc_instance.get_payload()
|
||||
try:
|
||||
sample_data = results["queries"][0]
|
||||
error_msg = sample_data.get("error")
|
||||
if sample_data.get("status") == QueryStatus.FAILED and error_msg:
|
||||
cache_key = sample_data.get("cache_key")
|
||||
QueryCacheManager.delete(cache_key, region=CacheRegion.DATA)
|
||||
raise DatasetSamplesFailedError(error_msg)
|
||||
return sample_data
|
||||
except (IndexError, KeyError) as exc:
|
||||
raise DatasetSamplesFailedError from exc
|
||||
|
||||
def validate(self) -> None:
|
||||
# Validate/populate model exists
|
||||
self._model = DatasetDAO.find_by_id(self._model_id)
|
||||
if not self._model:
|
||||
raise DatasetNotFoundError()
|
||||
# Check ownership
|
||||
try:
|
||||
security_manager.raise_for_ownership(self._model)
|
||||
except SupersetSecurityException as ex:
|
||||
raise DatasetForbiddenError() from ex
|
|
@ -16,22 +16,14 @@
|
|||
# under the License.
|
||||
import logging
|
||||
|
||||
import simplejson
|
||||
from flask import g, make_response, request, Response
|
||||
from flask import g, request, Response
|
||||
from flask_appbuilder.api import BaseApi, expose, protect, safe
|
||||
|
||||
from superset.charts.commands.exceptions import ChartNotFoundError
|
||||
from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod
|
||||
from superset.dao.exceptions import DatasourceNotFound
|
||||
from superset.explore.commands.get import GetExploreCommand
|
||||
from superset.explore.commands.parameters import CommandParameters
|
||||
from superset.explore.commands.samples import SamplesDatasourceCommand
|
||||
from superset.explore.exceptions import (
|
||||
DatasetAccessDeniedError,
|
||||
DatasourceForbiddenError,
|
||||
DatasourceSamplesFailedError,
|
||||
WrongEndpointError,
|
||||
)
|
||||
from superset.explore.exceptions import DatasetAccessDeniedError, WrongEndpointError
|
||||
from superset.explore.permalink.exceptions import ExplorePermalinkGetFailedError
|
||||
from superset.explore.schemas import ExploreContextSchema
|
||||
from superset.extensions import event_logger
|
||||
|
@ -39,16 +31,13 @@ from superset.temporary_cache.commands.exceptions import (
|
|||
TemporaryCacheAccessDeniedError,
|
||||
TemporaryCacheResourceNotFoundError,
|
||||
)
|
||||
from superset.utils.core import json_int_dttm_ser, parse_boolean_string
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExploreRestApi(BaseApi):
|
||||
method_permission_name = MODEL_API_RW_METHOD_PERMISSION_MAP
|
||||
include_route_methods = {RouteMethod.GET} | {
|
||||
"samples",
|
||||
}
|
||||
include_route_methods = {RouteMethod.GET}
|
||||
allow_browser_login = True
|
||||
class_permission_name = "Explore"
|
||||
resource_name = "explore"
|
||||
|
@ -146,70 +135,3 @@ class ExploreRestApi(BaseApi):
|
|||
return self.response(403, message=str(ex))
|
||||
except TemporaryCacheResourceNotFoundError as ex:
|
||||
return self.response(404, message=str(ex))
|
||||
|
||||
@expose("/samples", methods=["GET"])
|
||||
@protect()
|
||||
@safe
|
||||
@event_logger.log_this_with_context(
|
||||
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.samples",
|
||||
log_to_statsd=False,
|
||||
)
|
||||
def samples(self) -> Response:
|
||||
"""get samples from a Datasource
|
||||
---
|
||||
get:
|
||||
description: >-
|
||||
get samples from a Datasource
|
||||
parameters:
|
||||
- in: path
|
||||
schema:
|
||||
type: integer
|
||||
name: pk
|
||||
- in: query
|
||||
schema:
|
||||
type: boolean
|
||||
name: force
|
||||
responses:
|
||||
200:
|
||||
description: Datasource samples
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
result:
|
||||
$ref: '#/components/schemas/ChartDataResponseResult'
|
||||
401:
|
||||
$ref: '#/components/responses/401'
|
||||
403:
|
||||
$ref: '#/components/responses/403'
|
||||
404:
|
||||
$ref: '#/components/responses/404'
|
||||
422:
|
||||
$ref: '#/components/responses/422'
|
||||
500:
|
||||
$ref: '#/components/responses/500'
|
||||
"""
|
||||
try:
|
||||
force = parse_boolean_string(request.args.get("force"))
|
||||
rv = SamplesDatasourceCommand(
|
||||
user=g.user,
|
||||
datasource_type=request.args.get("datasource_type", type=str),
|
||||
datasource_id=request.args.get("datasource_id", type=int),
|
||||
force=force,
|
||||
).run()
|
||||
|
||||
response_data = simplejson.dumps(
|
||||
{"result": rv},
|
||||
default=json_int_dttm_ser,
|
||||
ignore_nan=True,
|
||||
)
|
||||
resp = make_response(response_data, 200)
|
||||
resp.headers["Content-Type"] = "application/json; charset=utf-8"
|
||||
return resp
|
||||
except DatasourceNotFound:
|
||||
return self.response_404()
|
||||
except DatasourceForbiddenError:
|
||||
return self.response_403()
|
||||
except DatasourceSamplesFailedError as ex:
|
||||
return self.response_400(message=str(ex))
|
||||
|
|
|
@ -1,93 +0,0 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
|
||||
from superset import db, security_manager
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.common.chart_data import ChartDataResultType
|
||||
from superset.common.query_context_factory import QueryContextFactory
|
||||
from superset.common.utils.query_cache_manager import QueryCacheManager
|
||||
from superset.constants import CacheRegion
|
||||
from superset.dao.exceptions import DatasourceNotFound
|
||||
from superset.datasource.dao import Datasource, DatasourceDAO
|
||||
from superset.exceptions import SupersetSecurityException
|
||||
from superset.explore.exceptions import (
|
||||
DatasourceForbiddenError,
|
||||
DatasourceSamplesFailedError,
|
||||
)
|
||||
from superset.utils.core import DatasourceType, QueryStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SamplesDatasourceCommand(BaseCommand):
|
||||
def __init__(
|
||||
self,
|
||||
user: User,
|
||||
datasource_id: Optional[int],
|
||||
datasource_type: Optional[str],
|
||||
force: bool,
|
||||
):
|
||||
self._actor = user
|
||||
self._datasource_id = datasource_id
|
||||
self._datasource_type = datasource_type
|
||||
self._force = force
|
||||
self._model: Optional[Datasource] = None
|
||||
|
||||
def run(self) -> Dict[str, Any]:
|
||||
self.validate()
|
||||
if not self._model:
|
||||
raise DatasourceNotFound()
|
||||
|
||||
qc_instance = QueryContextFactory().create(
|
||||
datasource={
|
||||
"type": self._model.type,
|
||||
"id": self._model.id,
|
||||
},
|
||||
queries=[{}],
|
||||
result_type=ChartDataResultType.SAMPLES,
|
||||
force=self._force,
|
||||
)
|
||||
results = qc_instance.get_payload()
|
||||
try:
|
||||
sample_data = results["queries"][0]
|
||||
error_msg = sample_data.get("error")
|
||||
if sample_data.get("status") == QueryStatus.FAILED and error_msg:
|
||||
cache_key = sample_data.get("cache_key")
|
||||
QueryCacheManager.delete(cache_key, region=CacheRegion.DATA)
|
||||
raise DatasourceSamplesFailedError(error_msg)
|
||||
return sample_data
|
||||
except (IndexError, KeyError) as exc:
|
||||
raise DatasourceSamplesFailedError from exc
|
||||
|
||||
def validate(self) -> None:
|
||||
# Validate/populate model exists
|
||||
if self._datasource_type and self._datasource_id:
|
||||
self._model = DatasourceDAO.get_datasource(
|
||||
session=db.session,
|
||||
datasource_type=DatasourceType(self._datasource_type),
|
||||
datasource_id=self._datasource_id,
|
||||
)
|
||||
|
||||
# Check ownership
|
||||
try:
|
||||
security_manager.raise_for_ownership(self._model)
|
||||
except SupersetSecurityException as ex:
|
||||
raise DatasourceForbiddenError() from ex
|
|
@ -14,11 +14,15 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import Any
|
||||
from typing import Any, Dict
|
||||
|
||||
from marshmallow import fields, post_load, Schema
|
||||
from marshmallow import fields, post_load, pre_load, Schema, validate
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from superset import app
|
||||
from superset.charts.schemas import ChartDataFilterSchema
|
||||
from superset.utils.core import DatasourceType
|
||||
|
||||
|
||||
class ExternalMetadataParams(TypedDict):
|
||||
datasource_type: str
|
||||
|
@ -54,3 +58,27 @@ class ExternalMetadataSchema(Schema):
|
|||
schema_name=data.get("schema_name", ""),
|
||||
table_name=data["table_name"],
|
||||
)
|
||||
|
||||
|
||||
class SamplesPayloadSchema(Schema):
|
||||
filters = fields.List(fields.Nested(ChartDataFilterSchema), required=False)
|
||||
|
||||
@pre_load
|
||||
# pylint: disable=no-self-use, unused-argument
|
||||
def handle_none(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
|
||||
if data is None:
|
||||
return {}
|
||||
return data
|
||||
|
||||
|
||||
class SamplesRequestSchema(Schema):
|
||||
datasource_type = fields.String(
|
||||
validate=validate.OneOf([e.value for e in DatasourceType]), required=True
|
||||
)
|
||||
datasource_id = fields.Integer(required=True)
|
||||
force = fields.Boolean(load_default=False)
|
||||
page = fields.Integer(load_default=1)
|
||||
per_page = fields.Integer(
|
||||
validate=validate.Range(min=1, max=app.config.get("SAMPLES_ROW_LIMIT", 1000)),
|
||||
load_default=app.config.get("SAMPLES_ROW_LIMIT", 1000),
|
||||
)
|
||||
|
|
|
@ -0,0 +1,115 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from superset import app, db
|
||||
from superset.common.chart_data import ChartDataResultType
|
||||
from superset.common.query_context_factory import QueryContextFactory
|
||||
from superset.common.utils.query_cache_manager import QueryCacheManager
|
||||
from superset.constants import CacheRegion
|
||||
from superset.datasets.commands.exceptions import DatasetSamplesFailedError
|
||||
from superset.datasource.dao import DatasourceDAO
|
||||
from superset.utils.core import QueryStatus
|
||||
from superset.views.datasource.schemas import SamplesPayloadSchema
|
||||
|
||||
|
||||
def get_limit_clause(page: Optional[int], per_page: Optional[int]) -> Dict[str, int]:
|
||||
samples_row_limit = app.config.get("SAMPLES_ROW_LIMIT", 1000)
|
||||
limit = samples_row_limit
|
||||
offset = 0
|
||||
|
||||
if isinstance(page, int) and isinstance(per_page, int):
|
||||
limit = int(per_page)
|
||||
if limit < 0 or limit > samples_row_limit:
|
||||
# reset limit value if input is invalid
|
||||
limit = samples_row_limit
|
||||
|
||||
offset = max((int(page) - 1) * limit, 0)
|
||||
|
||||
return {"row_offset": offset, "row_limit": limit}
|
||||
|
||||
|
||||
def get_samples( # pylint: disable=too-many-arguments,too-many-locals
|
||||
datasource_type: str,
|
||||
datasource_id: int,
|
||||
force: bool = False,
|
||||
page: int = 1,
|
||||
per_page: int = 1000,
|
||||
payload: Optional[SamplesPayloadSchema] = None,
|
||||
) -> Dict[str, Any]:
|
||||
datasource = DatasourceDAO.get_datasource(
|
||||
session=db.session,
|
||||
datasource_type=datasource_type,
|
||||
datasource_id=datasource_id,
|
||||
)
|
||||
|
||||
limit_clause = get_limit_clause(page, per_page)
|
||||
|
||||
# todo(yongjie): Constructing count(*) and samples in the same query_context,
|
||||
# then remove query_type==SAMPLES
|
||||
# constructing samples query
|
||||
samples_instance = QueryContextFactory().create(
|
||||
datasource={
|
||||
"type": datasource.type,
|
||||
"id": datasource.id,
|
||||
},
|
||||
queries=[{**payload, **limit_clause} if payload else limit_clause],
|
||||
result_type=ChartDataResultType.SAMPLES,
|
||||
force=force,
|
||||
)
|
||||
|
||||
# constructing count(*) query
|
||||
count_star_metric = {
|
||||
"metrics": [
|
||||
{
|
||||
"expressionType": "SQL",
|
||||
"sqlExpression": "COUNT(*)",
|
||||
"label": "COUNT(*)",
|
||||
}
|
||||
]
|
||||
}
|
||||
count_star_instance = QueryContextFactory().create(
|
||||
datasource={
|
||||
"type": datasource.type,
|
||||
"id": datasource.id,
|
||||
},
|
||||
queries=[{**payload, **count_star_metric} if payload else count_star_metric],
|
||||
result_type=ChartDataResultType.FULL,
|
||||
force=force,
|
||||
)
|
||||
samples_results = samples_instance.get_payload()
|
||||
count_star_results = count_star_instance.get_payload()
|
||||
|
||||
try:
|
||||
sample_data = samples_results["queries"][0]
|
||||
count_star_data = count_star_results["queries"][0]
|
||||
failed_status = (
|
||||
sample_data.get("status") == QueryStatus.FAILED
|
||||
or count_star_data.get("status") == QueryStatus.FAILED
|
||||
)
|
||||
error_msg = sample_data.get("error") or count_star_data.get("error")
|
||||
if failed_status and error_msg:
|
||||
cache_key = sample_data.get("cache_key")
|
||||
QueryCacheManager.delete(cache_key, region=CacheRegion.DATA)
|
||||
raise DatasetSamplesFailedError(error_msg)
|
||||
|
||||
sample_data["page"] = page
|
||||
sample_data["per_page"] = per_page
|
||||
sample_data["total_count"] = count_star_data["data"][0]["COUNT(*)"]
|
||||
return sample_data
|
||||
except (IndexError, KeyError) as exc:
|
||||
raise DatasetSamplesFailedError from exc
|
|
@ -50,7 +50,10 @@ from superset.views.datasource.schemas import (
|
|||
ExternalMetadataParams,
|
||||
ExternalMetadataSchema,
|
||||
get_external_metadata_schema,
|
||||
SamplesPayloadSchema,
|
||||
SamplesRequestSchema,
|
||||
)
|
||||
from superset.views.datasource.utils import get_samples
|
||||
from superset.views.utils import sanitize_datasource_data
|
||||
|
||||
|
||||
|
@ -179,3 +182,24 @@ class Datasource(BaseSupersetView):
|
|||
except (NoResultFound, NoSuchTableError) as ex:
|
||||
raise DatasetNotFoundError() from ex
|
||||
return self.json_response(external_metadata)
|
||||
|
||||
@expose("/samples", methods=["POST"])
|
||||
@has_access_api
|
||||
@api
|
||||
@handle_api_exception
|
||||
def samples(self) -> FlaskResponse:
|
||||
try:
|
||||
params = SamplesRequestSchema().load(request.args)
|
||||
payload = SamplesPayloadSchema().load(request.json)
|
||||
except ValidationError as err:
|
||||
return json_error_response(err.messages, status=400)
|
||||
|
||||
rv = get_samples(
|
||||
datasource_type=params["datasource_type"],
|
||||
datasource_id=params["datasource_id"],
|
||||
force=params["force"],
|
||||
page=params["page"],
|
||||
per_page=params["per_page"],
|
||||
payload=payload,
|
||||
)
|
||||
return self.json_response({"result": rv})
|
||||
|
|
|
@ -206,3 +206,107 @@ def with_feature_flags(**mock_feature_flags):
|
|||
return functools.update_wrapper(wrapper, test_fn)
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def virtual_dataset():
|
||||
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
|
||||
|
||||
dataset = SqlaTable(
|
||||
table_name="virtual_dataset",
|
||||
sql=(
|
||||
"SELECT 0 as col1, 'a' as col2, 1.0 as col3, NULL as col4, '2000-01-01 00:00:00' as col5 "
|
||||
"UNION ALL "
|
||||
"SELECT 1, 'b', 1.1, NULL, '2000-01-02 00:00:00' "
|
||||
"UNION ALL "
|
||||
"SELECT 2 as col1, 'c' as col2, 1.2, NULL, '2000-01-03 00:00:00' "
|
||||
"UNION ALL "
|
||||
"SELECT 3 as col1, 'd' as col2, 1.3, NULL, '2000-01-04 00:00:00' "
|
||||
"UNION ALL "
|
||||
"SELECT 4 as col1, 'e' as col2, 1.4, NULL, '2000-01-05 00:00:00' "
|
||||
"UNION ALL "
|
||||
"SELECT 5 as col1, 'f' as col2, 1.5, NULL, '2000-01-06 00:00:00' "
|
||||
"UNION ALL "
|
||||
"SELECT 6 as col1, 'g' as col2, 1.6, NULL, '2000-01-07 00:00:00' "
|
||||
"UNION ALL "
|
||||
"SELECT 7 as col1, 'h' as col2, 1.7, NULL, '2000-01-08 00:00:00' "
|
||||
"UNION ALL "
|
||||
"SELECT 8 as col1, 'i' as col2, 1.8, NULL, '2000-01-09 00:00:00' "
|
||||
"UNION ALL "
|
||||
"SELECT 9 as col1, 'j' as col2, 1.9, NULL, '2000-01-10 00:00:00' "
|
||||
),
|
||||
database=get_example_database(),
|
||||
)
|
||||
TableColumn(column_name="col1", type="INTEGER", table=dataset)
|
||||
TableColumn(column_name="col2", type="VARCHAR(255)", table=dataset)
|
||||
TableColumn(column_name="col3", type="DECIMAL(4,2)", table=dataset)
|
||||
TableColumn(column_name="col4", type="VARCHAR(255)", table=dataset)
|
||||
# Different database dialect datetime type is not consistent, so temporarily use varchar
|
||||
TableColumn(column_name="col5", type="VARCHAR(255)", table=dataset)
|
||||
|
||||
SqlMetric(metric_name="count", expression="count(*)", table=dataset)
|
||||
db.session.merge(dataset)
|
||||
|
||||
yield dataset
|
||||
|
||||
db.session.delete(dataset)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def physical_dataset():
|
||||
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
|
||||
|
||||
example_database = get_example_database()
|
||||
engine = example_database.get_sqla_engine()
|
||||
# sqlite can only execute one statement at a time
|
||||
engine.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS physical_dataset(
|
||||
col1 INTEGER,
|
||||
col2 VARCHAR(255),
|
||||
col3 DECIMAL(4,2),
|
||||
col4 VARCHAR(255),
|
||||
col5 VARCHAR(255)
|
||||
);
|
||||
"""
|
||||
)
|
||||
engine.execute(
|
||||
"""
|
||||
INSERT INTO physical_dataset values
|
||||
(0, 'a', 1.0, NULL, '2000-01-01 00:00:00'),
|
||||
(1, 'b', 1.1, NULL, '2000-01-02 00:00:00'),
|
||||
(2, 'c', 1.2, NULL, '2000-01-03 00:00:00'),
|
||||
(3, 'd', 1.3, NULL, '2000-01-04 00:00:00'),
|
||||
(4, 'e', 1.4, NULL, '2000-01-05 00:00:00'),
|
||||
(5, 'f', 1.5, NULL, '2000-01-06 00:00:00'),
|
||||
(6, 'g', 1.6, NULL, '2000-01-07 00:00:00'),
|
||||
(7, 'h', 1.7, NULL, '2000-01-08 00:00:00'),
|
||||
(8, 'i', 1.8, NULL, '2000-01-09 00:00:00'),
|
||||
(9, 'j', 1.9, NULL, '2000-01-10 00:00:00');
|
||||
"""
|
||||
)
|
||||
|
||||
dataset = SqlaTable(
|
||||
table_name="physical_dataset",
|
||||
database=example_database,
|
||||
)
|
||||
TableColumn(column_name="col1", type="INTEGER", table=dataset)
|
||||
TableColumn(column_name="col2", type="VARCHAR(255)", table=dataset)
|
||||
TableColumn(column_name="col3", type="DECIMAL(4,2)", table=dataset)
|
||||
TableColumn(column_name="col4", type="VARCHAR(255)", table=dataset)
|
||||
TableColumn(column_name="col5", type="VARCHAR(255)", table=dataset)
|
||||
SqlMetric(metric_name="count", expression="count(*)", table=dataset)
|
||||
db.session.merge(dataset)
|
||||
if example_database.backend == "sqlite":
|
||||
db.session.commit()
|
||||
|
||||
yield dataset
|
||||
|
||||
engine.execute(
|
||||
"""
|
||||
DROP TABLE physical_dataset;
|
||||
"""
|
||||
)
|
||||
db.session.delete(dataset)
|
||||
db.session.commit()
|
||||
|
|
|
@ -27,9 +27,7 @@ import pytest
|
|||
import yaml
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from superset.common.utils.query_cache_manager import QueryCacheManager
|
||||
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
|
||||
from superset.constants import CacheRegion
|
||||
from superset.dao.exceptions import (
|
||||
DAOCreateFailedError,
|
||||
DAODeleteFailedError,
|
||||
|
@ -2085,102 +2083,3 @@ class TestDatasetApi(SupersetTestCase):
|
|||
|
||||
db.session.delete(table_w_certification)
|
||||
db.session.commit()
|
||||
|
||||
@pytest.mark.usefixtures("create_datasets")
|
||||
def test_get_dataset_samples(self):
|
||||
"""
|
||||
Dataset API: Test get dataset samples
|
||||
"""
|
||||
if backend() == "sqlite":
|
||||
return
|
||||
|
||||
dataset = self.get_fixture_datasets()[0]
|
||||
|
||||
self.login(username="admin")
|
||||
uri = f"api/v1/dataset/{dataset.id}/samples"
|
||||
|
||||
# 1. should cache data
|
||||
# feeds data
|
||||
self.client.get(uri)
|
||||
# get from cache
|
||||
rv = self.client.get(uri)
|
||||
rv_data = json.loads(rv.data)
|
||||
assert rv.status_code == 200
|
||||
assert "result" in rv_data
|
||||
assert rv_data["result"]["cached_dttm"] is not None
|
||||
cache_key1 = rv_data["result"]["cache_key"]
|
||||
assert QueryCacheManager.has(cache_key1, region=CacheRegion.DATA)
|
||||
|
||||
# 2. should through cache
|
||||
uri2 = f"api/v1/dataset/{dataset.id}/samples?force=true"
|
||||
# feeds data
|
||||
self.client.get(uri2)
|
||||
# force query
|
||||
rv2 = self.client.get(uri2)
|
||||
rv_data2 = json.loads(rv2.data)
|
||||
assert rv_data2["result"]["cached_dttm"] is None
|
||||
cache_key2 = rv_data2["result"]["cache_key"]
|
||||
assert QueryCacheManager.has(cache_key2, region=CacheRegion.DATA)
|
||||
|
||||
# 3. data precision
|
||||
assert "colnames" in rv_data2["result"]
|
||||
assert "coltypes" in rv_data2["result"]
|
||||
assert "data" in rv_data2["result"]
|
||||
|
||||
eager_samples = dataset.database.get_df(
|
||||
f"select * from {dataset.table_name}"
|
||||
f' limit {self.app.config["SAMPLES_ROW_LIMIT"]}'
|
||||
).to_dict(orient="records")
|
||||
assert eager_samples == rv_data2["result"]["data"]
|
||||
|
||||
@pytest.mark.usefixtures("create_datasets")
|
||||
def test_get_dataset_samples_with_failed_cc(self):
|
||||
if backend() == "sqlite":
|
||||
return
|
||||
|
||||
dataset = self.get_fixture_datasets()[0]
|
||||
|
||||
self.login(username="admin")
|
||||
failed_column = TableColumn(
|
||||
column_name="DUMMY CC",
|
||||
type="VARCHAR(255)",
|
||||
table=dataset,
|
||||
expression="INCORRECT SQL",
|
||||
)
|
||||
uri = f"api/v1/dataset/{dataset.id}/samples"
|
||||
dataset.columns.append(failed_column)
|
||||
rv = self.client.get(uri)
|
||||
assert rv.status_code == 400
|
||||
rv_data = json.loads(rv.data)
|
||||
assert "message" in rv_data
|
||||
if dataset.database.db_engine_spec.engine_name == "PostgreSQL":
|
||||
assert "INCORRECT SQL" in rv_data.get("message")
|
||||
|
||||
def test_get_dataset_samples_on_virtual_dataset(self):
|
||||
if backend() == "sqlite":
|
||||
return
|
||||
|
||||
virtual_dataset = SqlaTable(
|
||||
table_name="virtual_dataset",
|
||||
sql=("SELECT 'foo' as foo, 'bar' as bar"),
|
||||
database=get_example_database(),
|
||||
)
|
||||
TableColumn(column_name="foo", type="VARCHAR(255)", table=virtual_dataset)
|
||||
TableColumn(column_name="bar", type="VARCHAR(255)", table=virtual_dataset)
|
||||
SqlMetric(metric_name="count", expression="count(*)", table=virtual_dataset)
|
||||
|
||||
self.login(username="admin")
|
||||
uri = f"api/v1/dataset/{virtual_dataset.id}/samples"
|
||||
rv = self.client.get(uri)
|
||||
assert rv.status_code == 200
|
||||
rv_data = json.loads(rv.data)
|
||||
cache_key = rv_data["result"]["cache_key"]
|
||||
assert QueryCacheManager.has(cache_key, region=CacheRegion.DATA)
|
||||
|
||||
# remove original column in dataset
|
||||
virtual_dataset.sql = "SELECT 'foo' as foo"
|
||||
rv = self.client.get(uri)
|
||||
assert rv.status_code == 400
|
||||
|
||||
db.session.delete(virtual_dataset)
|
||||
db.session.commit()
|
||||
|
|
|
@ -23,13 +23,15 @@ import prison
|
|||
import pytest
|
||||
|
||||
from superset import app, db
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.common.utils.query_cache_manager import QueryCacheManager
|
||||
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
|
||||
from superset.constants import CacheRegion
|
||||
from superset.dao.exceptions import DatasourceNotFound, DatasourceTypeNotSupportedError
|
||||
from superset.datasets.commands.exceptions import DatasetNotFoundError
|
||||
from superset.exceptions import SupersetGenericDBErrorException
|
||||
from superset.models.core import Database
|
||||
from superset.utils.core import DatasourceType, get_example_default_schema
|
||||
from superset.utils.database import get_example_database
|
||||
from superset.utils.core import backend, get_example_default_schema
|
||||
from superset.utils.database import get_example_database, get_main_database
|
||||
from tests.integration_tests.base_tests import db_insert_temp_object, SupersetTestCase
|
||||
from tests.integration_tests.fixtures.birth_names_dashboard import (
|
||||
load_birth_names_dashboard_with_slices,
|
||||
|
@ -416,3 +418,189 @@ class TestDatasource(SupersetTestCase):
|
|||
self.login(username="admin")
|
||||
resp = self.get_json_resp("/datasource/get/druid/500000/", raise_on_error=False)
|
||||
self.assertEqual(resp.get("error"), "'druid' is not a valid DatasourceType")
|
||||
|
||||
|
||||
def test_get_samples(test_client, login_as_admin, virtual_dataset):
|
||||
"""
|
||||
Dataset API: Test get dataset samples
|
||||
"""
|
||||
# 1. should cache data
|
||||
uri = (
|
||||
f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table"
|
||||
)
|
||||
# feeds data
|
||||
test_client.post(uri)
|
||||
# get from cache
|
||||
rv = test_client.post(uri)
|
||||
rv_data = json.loads(rv.data)
|
||||
assert rv.status_code == 200
|
||||
assert len(rv_data["result"]["data"]) == 10
|
||||
assert QueryCacheManager.has(
|
||||
rv_data["result"]["cache_key"],
|
||||
region=CacheRegion.DATA,
|
||||
)
|
||||
assert rv_data["result"]["is_cached"]
|
||||
|
||||
# 2. should read through cache data
|
||||
uri2 = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&force=true"
|
||||
# feeds data
|
||||
test_client.post(uri2)
|
||||
# force query
|
||||
rv2 = test_client.post(uri2)
|
||||
rv_data2 = json.loads(rv2.data)
|
||||
assert rv2.status_code == 200
|
||||
assert len(rv_data2["result"]["data"]) == 10
|
||||
assert QueryCacheManager.has(
|
||||
rv_data2["result"]["cache_key"],
|
||||
region=CacheRegion.DATA,
|
||||
)
|
||||
assert not rv_data2["result"]["is_cached"]
|
||||
|
||||
# 3. data precision
|
||||
assert "colnames" in rv_data2["result"]
|
||||
assert "coltypes" in rv_data2["result"]
|
||||
assert "data" in rv_data2["result"]
|
||||
|
||||
eager_samples = virtual_dataset.database.get_df(
|
||||
f"select * from ({virtual_dataset.sql}) as tbl"
|
||||
f' limit {app.config["SAMPLES_ROW_LIMIT"]}'
|
||||
)
|
||||
# the col3 is Decimal
|
||||
eager_samples["col3"] = eager_samples["col3"].apply(float)
|
||||
eager_samples = eager_samples.to_dict(orient="records")
|
||||
assert eager_samples == rv_data2["result"]["data"]
|
||||
|
||||
|
||||
def test_get_samples_with_incorrect_cc(test_client, login_as_admin, virtual_dataset):
|
||||
TableColumn(
|
||||
column_name="DUMMY CC",
|
||||
type="VARCHAR(255)",
|
||||
table=virtual_dataset,
|
||||
expression="INCORRECT SQL",
|
||||
)
|
||||
db.session.merge(virtual_dataset)
|
||||
|
||||
uri = (
|
||||
f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table"
|
||||
)
|
||||
rv = test_client.post(uri)
|
||||
assert rv.status_code == 422
|
||||
|
||||
rv_data = json.loads(rv.data)
|
||||
assert "error" in rv_data
|
||||
if virtual_dataset.database.db_engine_spec.engine_name == "PostgreSQL":
|
||||
assert "INCORRECT SQL" in rv_data.get("error")
|
||||
|
||||
|
||||
def test_get_samples_on_physical_dataset(test_client, login_as_admin, physical_dataset):
|
||||
uri = (
|
||||
f"/datasource/samples?datasource_id={physical_dataset.id}&datasource_type=table"
|
||||
)
|
||||
rv = test_client.post(uri)
|
||||
assert rv.status_code == 200
|
||||
rv_data = json.loads(rv.data)
|
||||
assert QueryCacheManager.has(
|
||||
rv_data["result"]["cache_key"], region=CacheRegion.DATA
|
||||
)
|
||||
assert len(rv_data["result"]["data"]) == 10
|
||||
|
||||
|
||||
def test_get_samples_with_filters(test_client, login_as_admin, virtual_dataset):
|
||||
uri = (
|
||||
f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table"
|
||||
)
|
||||
rv = test_client.post(uri, json=None)
|
||||
assert rv.status_code == 200
|
||||
|
||||
rv = test_client.post(uri, json={})
|
||||
assert rv.status_code == 200
|
||||
|
||||
rv = test_client.post(uri, json={"foo": "bar"})
|
||||
assert rv.status_code == 400
|
||||
|
||||
rv = test_client.post(
|
||||
uri, json={"filters": [{"col": "col1", "op": "INVALID", "val": 0}]}
|
||||
)
|
||||
assert rv.status_code == 400
|
||||
|
||||
rv = test_client.post(
|
||||
uri,
|
||||
json={
|
||||
"filters": [
|
||||
{"col": "col2", "op": "==", "val": "a"},
|
||||
{"col": "col1", "op": "==", "val": 0},
|
||||
]
|
||||
},
|
||||
)
|
||||
assert rv.status_code == 200
|
||||
rv_data = json.loads(rv.data)
|
||||
assert rv_data["result"]["colnames"] == ["col1", "col2", "col3", "col4", "col5"]
|
||||
assert rv_data["result"]["rowcount"] == 1
|
||||
|
||||
# empty results
|
||||
rv = test_client.post(
|
||||
uri,
|
||||
json={
|
||||
"filters": [
|
||||
{"col": "col2", "op": "==", "val": "x"},
|
||||
]
|
||||
},
|
||||
)
|
||||
assert rv.status_code == 200
|
||||
rv_data = json.loads(rv.data)
|
||||
assert rv_data["result"]["colnames"] == []
|
||||
assert rv_data["result"]["rowcount"] == 0
|
||||
|
||||
|
||||
def test_get_samples_pagination(test_client, login_as_admin, virtual_dataset):
|
||||
# 1. default page, per_page and total_count
|
||||
uri = (
|
||||
f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table"
|
||||
)
|
||||
rv = test_client.post(uri)
|
||||
rv_data = json.loads(rv.data)
|
||||
assert rv_data["result"]["page"] == 1
|
||||
assert rv_data["result"]["per_page"] == app.config["SAMPLES_ROW_LIMIT"]
|
||||
assert rv_data["result"]["total_count"] == 10
|
||||
|
||||
# 2. incorrect per_page
|
||||
per_pages = (app.config["SAMPLES_ROW_LIMIT"] + 1, 0, "xx")
|
||||
for per_page in per_pages:
|
||||
uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page={per_page}"
|
||||
rv = test_client.post(uri)
|
||||
assert rv.status_code == 400
|
||||
|
||||
# 3. incorrect page or datasource_type
|
||||
uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&page=xx"
|
||||
rv = test_client.post(uri)
|
||||
assert rv.status_code == 400
|
||||
|
||||
uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=xx"
|
||||
rv = test_client.post(uri)
|
||||
assert rv.status_code == 400
|
||||
|
||||
# 4. turning pages
|
||||
uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page=2&page=1"
|
||||
rv = test_client.post(uri)
|
||||
rv_data = json.loads(rv.data)
|
||||
assert rv_data["result"]["page"] == 1
|
||||
assert rv_data["result"]["per_page"] == 2
|
||||
assert rv_data["result"]["total_count"] == 10
|
||||
assert [row["col1"] for row in rv_data["result"]["data"]] == [0, 1]
|
||||
|
||||
uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page=2&page=2"
|
||||
rv = test_client.post(uri)
|
||||
rv_data = json.loads(rv.data)
|
||||
assert rv_data["result"]["page"] == 2
|
||||
assert rv_data["result"]["per_page"] == 2
|
||||
assert rv_data["result"]["total_count"] == 10
|
||||
assert [row["col1"] for row in rv_data["result"]["data"]] == [2, 3]
|
||||
|
||||
# 5. Exceeding the maximum pages
|
||||
uri = f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table&per_page=2&page=6"
|
||||
rv = test_client.post(uri)
|
||||
rv_data = json.loads(rv.data)
|
||||
assert rv_data["result"]["page"] == 6
|
||||
assert rv_data["result"]["per_page"] == 2
|
||||
assert rv_data["result"]["total_count"] == 10
|
||||
assert [row["col1"] for row in rv_data["result"]["data"]] == []
|
||||
|
|
Loading…
Reference in New Issue