chore: Add OpenAPI docs to /api/v1/chart/data EP (#9556)

* Add OpenAPI docs to /api/v1/chart/data EP

* Add missing fields to QueryObject, fix linting and unit test errors

* Fix unit test errors

* abc

* Fix errors uncovered by schema validation and add unit test for invalid payload

* Add schema for response

* Remove unnecessary pylint disable
This commit is contained in:
Ville Brofeldt 2020-04-17 16:44:16 +03:00 committed by GitHub
parent 427d2a05e5
commit 8e439b1115
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 695 additions and 148 deletions

View File

@ -45,7 +45,7 @@ combine_as_imports = true
include_trailing_comma = true
line_length = 88
known_first_party = superset
known_third_party =alembic,backoff,bleach,celery,click,colorama,contextlib2,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parsedatetime,pathlib2,polyline,prison,pyarrow,pyhive,pytz,retry,selenium,setuptools,simplejson,sphinx_rtd_theme,sqlalchemy,sqlalchemy_utils,sqlparse,werkzeug,wtforms,wtforms_json,yaml
known_third_party =alembic,apispec,backoff,bleach,celery,click,colorama,contextlib2,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parsedatetime,pathlib2,polyline,prison,pyarrow,pyhive,pytz,retry,selenium,setuptools,simplejson,sphinx_rtd_theme,sqlalchemy,sqlalchemy_utils,sqlparse,werkzeug,wtforms,wtforms_json,yaml
multi_line_output = 3
order_by_type = false

View File

@ -18,10 +18,11 @@ import logging
from typing import Any, Dict
import simplejson
from apispec import APISpec
from flask import g, make_response, redirect, request, Response, url_for
from flask_appbuilder.api import expose, protect, rison, safe
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_babel import ngettext
from flask_babel import gettext as _, ngettext
from werkzeug.wrappers import Response as WerkzeugResponse
from werkzeug.wsgi import FileWrapper
@ -41,12 +42,13 @@ from superset.charts.commands.exceptions import (
from superset.charts.commands.update import UpdateChartCommand
from superset.charts.filters import ChartFilter, ChartNameOrDescriptionFilter
from superset.charts.schemas import (
CHART_DATA_SCHEMAS,
ChartDataQueryContextSchema,
ChartPostSchema,
ChartPutSchema,
get_delete_ids_schema,
thumbnail_query_schema,
)
from superset.common.query_context import QueryContext
from superset.constants import RouteMethod
from superset.exceptions import SupersetSecurityException
from superset.extensions import event_logger, security_manager
@ -381,74 +383,21 @@ class ChartRestApi(BaseSupersetModelRestApi):
Takes a query context constructed in the client and returns payload data
response for the given query.
requestBody:
description: Query context schema
description: >-
A query context consists of a datasource from which to fetch data
and one or many query objects.
required: true
content:
application/json:
schema:
type: object
properties:
datasource:
type: object
description: The datasource where the query will run
properties:
id:
type: integer
type:
type: string
queries:
type: array
items:
type: object
properties:
granularity:
type: string
groupby:
type: array
items:
type: string
metrics:
type: array
items:
type: object
filters:
type: array
items:
type: string
row_limit:
type: integer
$ref: "#/components/schemas/ChartDataQueryContextSchema"
responses:
200:
description: Query result
content:
application/json:
schema:
type: array
items:
type: object
properties:
cache_key:
type: string
cached_dttm:
type: string
cache_timeout:
type: integer
error:
type: string
is_cached:
type: boolean
query:
type: string
status:
type: string
stacktrace:
type: string
rowcount:
type: integer
data:
type: array
items:
type: object
$ref: "#/components/schemas/ChartDataResponseSchema"
400:
$ref: '#/components/responses/400'
500:
@ -457,7 +406,11 @@ class ChartRestApi(BaseSupersetModelRestApi):
if not request.is_json:
return self.response_400(message="Request is not JSON")
try:
query_context = QueryContext(**request.json)
query_context, errors = ChartDataQueryContextSchema().load(request.json)
if errors:
return self.response_400(
message=_("Request is incorrect: %(error)s", error=errors)
)
except KeyError:
return self.response_400(message="Request is incorrect")
try:
@ -466,7 +419,7 @@ class ChartRestApi(BaseSupersetModelRestApi):
return self.response_401()
payload_json = query_context.get_payload()
response_data = simplejson.dumps(
payload_json, default=json_int_dttm_ser, ignore_nan=True
{"result": payload_json}, default=json_int_dttm_ser, ignore_nan=True
)
resp = make_response(response_data, 200)
resp.headers["Content-Type"] = "application/json; charset=utf-8"
@ -533,3 +486,10 @@ class ChartRestApi(BaseSupersetModelRestApi):
return Response(
FileWrapper(screenshot), mimetype="image/png", direct_passthrough=True
)
def add_apispec_components(self, api_spec: APISpec) -> None:
for chart_type in CHART_DATA_SCHEMAS:
api_spec.components.schema(
chart_type.__name__, schema=chart_type,
)
super().add_apispec_components(api_spec)

View File

@ -14,11 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Union
from typing import Any, Dict, Union
from marshmallow import fields, Schema, ValidationError
from marshmallow import fields, post_load, Schema, ValidationError
from marshmallow.validate import Length
from superset.common.query_context import QueryContext
from superset.exceptions import SupersetException
from superset.utils import core as utils
@ -59,3 +60,510 @@ class ChartPutSchema(Schema):
datasource_id = fields.Integer(allow_none=True)
datasource_type = fields.String(allow_none=True)
dashboards = fields.List(fields.Integer())
class ChartDataColumnSchema(Schema):
column_name = fields.String(
description="The name of the target column", example="mycol",
)
type = fields.String(description="Type of target column", example="BIGINT",)
class ChartDataAdhocMetricSchema(Schema):
"""
Ad-hoc metrics are used to define metrics outside the datasource.
"""
expressionType = fields.String(
description="Simple or SQL metric",
required=True,
enum=["SIMPLE", "SQL"],
example="SQL",
)
aggregate = fields.String(
description="Aggregation operator. Only required for simple expression types.",
required=False,
enum=["AVG", "COUNT", "COUNT_DISTINCT", "MAX", "MIN", "SUM"],
)
column = fields.Nested(ChartDataColumnSchema)
sqlExpression = fields.String(
description="The metric as defined by a SQL aggregate expression. "
"Only required for SQL expression type.",
required=False,
example="SUM(weight * observations) / SUM(weight)",
)
label = fields.String(
description="Label for the metric. Is automatically generated unless "
"hasCustomLabel is true, in which case label must be defined.",
required=False,
example="Weighted observations",
)
hasCustomLabel = fields.Boolean(
description="When false, the label will be automatically generated based on "
"the aggregate expression. When true, a custom label has to be "
"specified.",
required=False,
example=True,
)
optionName = fields.String(
description="Unique identifier. Can be any string value, as long as all "
"metrics have a unique identifier. If undefined, a random name "
"will be generated.",
required=False,
example="metric_aec60732-fac0-4b17-b736-93f1a5c93e30",
)
class ChartDataAggregateConfigField(fields.Dict):
def __init__(self) -> None:
super().__init__(
description="The keys are the name of the aggregate column to be created, "
"and the values specify the details of how to apply the "
"aggregation. If an operator requires additional options, "
"these can be passed here to be unpacked in the operator call. The "
"following numpy operators are supported: average, argmin, argmax, cumsum, "
"cumprod, max, mean, median, nansum, nanmin, nanmax, nanmean, nanmedian, "
"min, percentile, prod, product, std, sum, var. Any options required by "
"the operator can be passed to the `options` object.\n"
"\n"
"In the example, a new column `first_quantile` is created based on values "
"in the column `my_col` using the `percentile` operator with "
"the `q=0.25` parameter.",
example={
"first_quantile": {
"operator": "percentile",
"column": "my_col",
"options": {"q": 0.25},
}
},
)
class ChartDataPostProcessingOperationOptionsSchema(Schema):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
class ChartDataAggregateOptionsSchema(ChartDataPostProcessingOperationOptionsSchema):
"""
Aggregate operation config.
"""
groupby = (
fields.List(
fields.String(
allow_none=False, description="Columns by which to group by",
),
minLength=1,
required=True,
),
)
aggregates = ChartDataAggregateConfigField()
class ChartDataRollingOptionsSchema(ChartDataPostProcessingOperationOptionsSchema):
"""
Rolling operation config.
"""
columns = (
fields.Dict(
description="columns on which to perform rolling, mapping source column to "
"target column. For instance, `{'y': 'y'}` will replace the "
"column `y` with the rolling value in `y`, while `{'y': 'y2'}` "
"will add a column `y2` based on rolling values calculated "
"from `y`, leaving the original column `y` unchanged.",
example={"weekly_rolling_sales": "sales"},
),
)
rolling_type = fields.String(
description="Type of rolling window. Any numpy function will work.",
enum=[
"average",
"argmin",
"argmax",
"cumsum",
"cumprod",
"max",
"mean",
"median",
"nansum",
"nanmin",
"nanmax",
"nanmean",
"nanmedian",
"min",
"percentile",
"prod",
"product",
"std",
"sum",
"var",
],
required=True,
example="percentile",
)
window = fields.Integer(
description="Size of the rolling window in days.", required=True, example=7,
)
rolling_type_options = fields.Dict(
desctiption="Optional options to pass to rolling method. Needed for "
"e.g. quantile operation.",
required=False,
example={},
)
center = fields.Boolean(
description="Should the label be at the center of the window. Default: `false`",
required=False,
example=False,
)
win_type = fields.String(
description="Type of window function. See "
"[SciPy window functions](https://docs.scipy.org/doc/scipy/reference"
"/signal.windows.html#module-scipy.signal.windows) "
"for more details. Some window functions require passing "
"additional parameters to `rolling_type_options`. For instance, "
"to use `gaussian`, the parameter `std` needs to be provided.",
required=False,
enum=[
"boxcar",
"triang",
"blackman",
"hamming",
"bartlett",
"parzen",
"bohman",
"blackmanharris",
"nuttall",
"barthann",
"kaiser",
"gaussian",
"general_gaussian",
"slepian",
"exponential",
],
)
min_periods = fields.Integer(
description="The minimum amount of periods required for a row to be included "
"in the result set.",
required=False,
example=7,
)
class ChartDataSelectOptionsSchema(ChartDataPostProcessingOperationOptionsSchema):
"""
Sort operation config.
"""
columns = fields.List(
fields.String(),
description="Columns which to select from the input data, in the desired "
"order. If columns are renamed, the old column name should be "
"referenced here.",
example=["country", "gender", "age"],
)
rename = fields.List(
fields.Dict(),
description="columns which to rename, mapping source column to target column. "
"For instance, `{'y': 'y2'}` will rename the column `y` to `y2`.",
example=[{"age": "average_age"}],
)
class ChartDataSortOptionsSchema(ChartDataPostProcessingOperationOptionsSchema):
"""
Sort operation config.
"""
columns = fields.Dict(
description="columns by by which to sort. The key specifies the column name, "
"value specifies if sorting in ascending order.",
example={"country": True, "gender": False},
required=True,
)
aggregates = ChartDataAggregateConfigField()
class ChartDataPivotOptionsSchema(ChartDataPostProcessingOperationOptionsSchema):
"""
Pivot operation config.
"""
index = (
fields.List(
fields.String(
allow_none=False,
description="Columns to group by on the table index (=rows)",
),
minLength=1,
required=True,
),
)
columns = fields.List(
fields.String(
allow_none=False, description="Columns to group by on the table columns",
),
minLength=1,
required=True,
)
metric_fill_value = fields.Number(
required=False,
description="Value to replace missing values with in aggregate calculations.",
)
column_fill_value = fields.String(
required=False, description="Value to replace missing pivot columns names with."
)
drop_missing_columns = fields.Boolean(
description="Do not include columns whose entries are all missing "
"(default: `true`).",
required=False,
)
marginal_distributions = fields.Boolean(
description="Add totals for row/column. (default: `false`)", required=False,
)
marginal_distribution_name = fields.String(
description="Name of marginal distribution row/column. (default: `All`)",
required=False,
)
aggregates = ChartDataAggregateConfigField()
class ChartDataPostProcessingOperationSchema(Schema):
operation = fields.String(
description="Post processing operation type",
required=True,
enum=["aggregate", "pivot", "rolling", "select", "sort"],
example="aggregate",
)
options = fields.Nested(
ChartDataPostProcessingOperationOptionsSchema,
description="Options specifying how to perform the operation. Please refer "
"to the respective post processing operation option schemas. "
"For example, `ChartDataPostProcessingOperationOptions` specifies "
"the required options for the pivot operation.",
example={
"groupby": ["country", "gender"],
"aggregates": {
"age_q1": {
"operator": "percentile",
"column": "age",
"options": {"q": 0.25},
},
"age_mean": {"operator": "mean", "column": "age",},
},
},
)
class ChartDataFilterSchema(Schema):
col = fields.String(
description="The column to filter.", required=True, example="country"
)
op = fields.String( # pylint: disable=invalid-name
description="The comparison operator.",
enum=[filter_op.value for filter_op in utils.FilterOperationType],
required=True,
example="IN",
)
val = fields.Raw(
description="The value or values to compare against. Can be a string, "
"integer, decimal or list, depending on the operator.",
example=["China", "France", "Japan"],
)
class ChartDataExtrasSchema(Schema):
time_range_endpoints = fields.List(
fields.String(enum=["INCLUSIVE", "EXCLUSIVE"]),
description="A list with two values, stating if start/end should be "
"inclusive/exclusive.",
required=False,
)
relative_start = fields.String(
description="Start time for relative time deltas. "
'Default: `config["DEFAULT_RELATIVE_START_TIME"]`',
enum=["today", "now"],
required=False,
)
relative_end = fields.String(
description="End time for relative time deltas. "
'Default: `config["DEFAULT_RELATIVE_START_TIME"]`',
enum=["today", "now"],
required=False,
)
class ChartDataQueryObjectSchema(Schema):
filters = fields.List(fields.Nested(ChartDataFilterSchema), required=False)
granularity = fields.String(
description="To what level of granularity should the temporal column be "
"aggregated. Supports "
"[ISO 8601](https://en.wikipedia.org/wiki/ISO_8601#Durations) "
"durations.",
enum=[
"PT1S",
"PT1M",
"PT5M",
"PT10M",
"PT15M",
"PT0.5H",
"PT1H",
"P1D",
"P1W",
"P1M",
"P0.25Y",
"P1Y",
],
required=False,
example="P1D",
)
groupby = fields.List(
fields.String(description="Columns by which to group the query.",),
)
metrics = fields.List(
fields.Raw(),
description="Aggregate expressions. Metrics can be passed as both "
"references to datasource metrics (strings), or ad-hoc metrics"
"which are defined only within the query object. See "
"`ChartDataAdhocMetricSchema` for the structure of ad-hoc metrics.",
)
post_processing = fields.List(
fields.Nested(ChartDataPostProcessingOperationSchema),
description="Post processing operations to be applied to the result set. "
"Operations are applied to the result set in sequential order.",
required=False,
)
time_range = fields.String(
description="A time rage, either expressed as a colon separated string "
"`since : until`. Valid formats for `since` and `until` are: \n"
"- ISO 8601\n"
"- X days/years/hours/day/year/weeks\n"
"- X days/years/hours/day/year/weeks ago\n"
"- X days/years/hours/day/year/weeks from now\n"
"\n"
"Additionally, the following freeform can be used:\n"
"\n"
"- Last day\n"
"- Last week\n"
"- Last month\n"
"- Last quarter\n"
"- Last year\n"
"- No filter\n"
"- Last X seconds/minutes/hours/days/weeks/months/years\n"
"- Next X seconds/minutes/hours/days/weeks/months/years\n",
required=False,
example="Last week",
)
time_shift = fields.String(
description="A human-readable date/time string. "
"Please refer to [parsdatetime](https://github.com/bear/parsedatetime) "
"documentation for details on valid values.",
required=False,
)
is_timeseries = fields.Boolean(
description="Is the `query_object` a timeseries.", required=False
)
timeseries_limit = fields.Integer(
description="Maximum row count for timeseries queries. Default: `0`",
required=False,
)
row_limit = fields.Integer(
description='Maximum row count. Default: `config["ROW_LIMIT"]`', required=False,
)
order_desc = fields.Boolean(
description="Reverse order. Default: `false`", required=False
)
extras = fields.Dict(description=" Default: `{}`", required=False)
columns = fields.List(fields.String(), description="", required=False,)
orderby = fields.List(
fields.List(fields.Raw()),
description="Expects a list of lists where the first element is the column "
"name which to sort by, and the second element is a boolean ",
required=False,
example=[["my_col_1", False], ["my_col_2", True]],
)
class ChartDataDatasourceSchema(Schema):
description = "Chart datasource"
id = fields.Integer(description="Datasource id", required=True,)
type = fields.String(description="Datasource type", enum=["druid", "sql"])
class ChartDataQueryContextSchema(Schema):
datasource = fields.Nested(ChartDataDatasourceSchema)
queries = fields.List(fields.Nested(ChartDataQueryObjectSchema))
# pylint: disable=no-self-use
@post_load
def make_query_context(self, data: Dict[str, Any]) -> QueryContext:
query_context = QueryContext(**data)
return query_context
# pylint: enable=no-self-use
class ChartDataResponseResult(Schema):
cache_key = fields.String(
description="Unique cache key for query object", required=True, allow_none=True,
)
cached_dttm = fields.String(
description="Cache timestamp", required=True, allow_none=True,
)
cache_timeout = fields.Integer(
description="Cache timeout in following order: custom timeout, datasource "
"timeout, default config timeout.",
required=True,
allow_none=True,
)
error = fields.String(description="Error", allow_none=True,)
is_cached = fields.Boolean(
description="Is the result cached", required=True, allow_none=None,
)
query = fields.String(
description="The executed query statement", required=True, allow_none=False,
)
status = fields.String(
description="Status of the query",
enum=[
"stopped",
"failed",
"pending",
"running",
"scheduled",
"success",
"timed_out",
],
allow_none=False,
)
stacktrace = fields.String(
desciption="Stacktrace if there was an error", allow_none=True,
)
rowcount = fields.Integer(
description="Amount of rows in result set", allow_none=False,
)
data = fields.List(fields.Dict(), description="A list with results")
class ChartDataResponseSchema(Schema):
result = fields.List(
fields.Nested(ChartDataResponseResult),
description="A list of results for each corresponding query in the request.",
)
CHART_DATA_SCHEMAS = (
ChartDataQueryContextSchema,
ChartDataResponseSchema,
# TODO: These should optimally be included in the QueryContext schema as an `anyOf`
# in ChartDataPostPricessingOperation.options, but since `anyOf` is not
# by Marshmallow<3, this is not currently possible.
ChartDataAdhocMetricSchema,
ChartDataAggregateOptionsSchema,
ChartDataPivotOptionsSchema,
ChartDataRollingOptionsSchema,
ChartDataSelectOptionsSchema,
ChartDataSortOptionsSchema,
)

View File

@ -47,9 +47,9 @@ class QueryObject:
is_timeseries: bool
time_shift: Optional[timedelta]
groupby: List[str]
metrics: List[Union[Dict, str]]
metrics: List[Union[Dict[str, Any], str]]
row_limit: int
filter: List[str]
filter: List[Dict[str, Any]]
timeseries_limit: int
timeseries_limit_metric: Optional[Dict]
order_desc: bool
@ -61,9 +61,9 @@ class QueryObject:
def __init__(
self,
granularity: str,
metrics: List[Union[Dict, str]],
metrics: List[Union[Dict[str, Any], str]],
groupby: Optional[List[str]] = None,
filters: Optional[List[str]] = None,
filters: Optional[List[Dict[str, Any]]] = None,
time_range: Optional[str] = None,
time_shift: Optional[str] = None,
is_timeseries: bool = False,
@ -75,14 +75,17 @@ class QueryObject:
columns: Optional[List[str]] = None,
orderby: Optional[List[List]] = None,
post_processing: Optional[List[Dict[str, Any]]] = None,
relative_start: str = app.config["DEFAULT_RELATIVE_START_TIME"],
relative_end: str = app.config["DEFAULT_RELATIVE_END_TIME"],
):
extras = extras or {}
is_sip_38 = is_feature_enabled("SIP_38_VIZ_REARCHITECTURE")
self.granularity = granularity
self.from_dttm, self.to_dttm = utils.get_since_until(
relative_start=relative_start,
relative_end=relative_end,
relative_start=extras.get(
"relative_start", app.config["DEFAULT_RELATIVE_START_TIME"]
),
relative_end=extras.get(
"relative_end", app.config["DEFAULT_RELATIVE_END_TIME"]
),
time_range=time_range,
time_shift=time_shift,
)
@ -106,7 +109,7 @@ class QueryObject:
self.timeseries_limit = timeseries_limit
self.timeseries_limit_metric = timeseries_limit_metric
self.order_desc = order_desc
self.extras = extras or {}
self.extras = extras
if app.config["SIP_15_ENABLED"] and "time_range_endpoints" not in self.extras:
self.extras["time_range_endpoints"] = get_time_range_endpoints(form_data={})

View File

@ -25,6 +25,7 @@ from sqlalchemy.orm import foreign, Query, relationship
from superset.constants import NULL_STRING
from superset.models.helpers import AuditMixinNullable, ImportMixin, QueryResult
from superset.models.slice import Slice
from superset.typing import FilterValue, FilterValues
from superset.utils import core as utils
METRIC_FORM_DATA_PARAMS = [
@ -301,28 +302,33 @@ class BaseDatasource(
@staticmethod
def filter_values_handler(
values, target_column_is_numeric=False, is_list_target=False
):
def handle_single_value(v):
values: Optional[FilterValues],
target_column_is_numeric: bool = False,
is_list_target: bool = False,
) -> Optional[FilterValues]:
if values is None:
return None
def handle_single_value(value: Optional[FilterValue]) -> Optional[FilterValue]:
# backward compatibility with previous <select> components
if isinstance(v, str):
v = v.strip("\t\n'\"")
if isinstance(value, str):
value = value.strip("\t\n'\"")
if target_column_is_numeric:
# For backwards compatibility and edge cases
# where a column data type might have changed
v = utils.string_to_num(v)
if v == NULL_STRING:
value = utils.cast_to_num(value)
if value == NULL_STRING:
return None
elif v == "<empty string>":
elif value == "<empty string>":
return ""
return v
return value
if isinstance(values, (list, tuple)):
values = [handle_single_value(v) for v in values]
values = [handle_single_value(v) for v in values] # type: ignore
else:
values = handle_single_value(values)
if is_list_target and not isinstance(values, (tuple, list)):
values = [values]
values = [values] # type: ignore
elif not is_list_target and isinstance(values, (tuple, list)):
if values:
values = values[0]

View File

@ -24,7 +24,7 @@ from copy import deepcopy
from datetime import datetime, timedelta
from distutils.version import LooseVersion
from multiprocessing.pool import ThreadPool
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
from typing import cast, Dict, Iterable, List, Optional, Set, Tuple, Union
import pandas as pd
import sqlalchemy as sa
@ -54,6 +54,7 @@ from superset.constants import NULL_STRING
from superset.exceptions import SupersetException
from superset.models.core import Database
from superset.models.helpers import AuditMixinNullable, ImportMixin, QueryResult
from superset.typing import FilterValues
from superset.utils import core as utils, import_datasource
try:
@ -80,7 +81,12 @@ except ImportError:
pass
try:
from superset.utils.core import DimSelector, DTTM_ALIAS, flasher
from superset.utils.core import (
DimSelector,
DTTM_ALIAS,
FilterOperationType,
flasher,
)
except ImportError:
pass
@ -1483,13 +1489,20 @@ class DruidDatasource(Model, BaseDatasource):
"""Given Superset filter data structure, returns pydruid Filter(s)"""
filters = None
for flt in raw_filters:
col = flt.get("col")
op = flt.get("op")
eq = flt.get("val")
col: Optional[str] = flt.get("col")
op: Optional[str] = flt["op"].upper() if "op" in flt else None
eq: Optional[FilterValues] = flt.get("val")
if (
not col
or not op
or (eq is None and op not in ("IS NULL", "IS NOT NULL"))
or (
eq is None
and op
not in (
FilterOperationType.IS_NULL.value,
FilterOperationType.IS_NOT_NULL.value,
)
)
):
continue
@ -1503,7 +1516,10 @@ class DruidDatasource(Model, BaseDatasource):
cond = None
is_numeric_col = col in num_cols
is_list_target = op in ("in", "not in")
is_list_target = op in (
FilterOperationType.IN.value,
FilterOperationType.NOT_IN.value,
)
eq = cls.filter_values_handler(
eq,
is_list_target=is_list_target,
@ -1512,15 +1528,16 @@ class DruidDatasource(Model, BaseDatasource):
# For these two ops, could have used Dimension,
# but it doesn't support extraction functions
if op == "==":
if op == FilterOperationType.EQUALS.value:
cond = Filter(
dimension=col, value=eq, extraction_function=extraction_fn
)
elif op == "!=":
elif op == FilterOperationType.NOT_EQUALS.value:
cond = ~Filter(
dimension=col, value=eq, extraction_function=extraction_fn
)
elif op in ("in", "not in"):
elif is_list_target:
eq = cast(list, eq)
fields = []
# ignore the filter if it has no value
if not len(eq):
@ -1540,9 +1557,9 @@ class DruidDatasource(Model, BaseDatasource):
for s in eq:
fields.append(Dimension(col) == s)
cond = Filter(type="or", fields=fields)
if op == "not in":
if op == FilterOperationType.NOT_IN.value:
cond = ~cond
elif op == "regex":
elif op == FilterOperationType.REGEX.value:
cond = Filter(
extraction_function=extraction_fn,
type="regex",
@ -1552,7 +1569,7 @@ class DruidDatasource(Model, BaseDatasource):
# For the ops below, could have used pydruid's Bound,
# but it doesn't support extraction functions
elif op == ">=":
elif op == FilterOperationType.GREATER_THAN_OR_EQUALS.value:
cond = Bound(
extraction_function=extraction_fn,
dimension=col,
@ -1562,7 +1579,7 @@ class DruidDatasource(Model, BaseDatasource):
upper=None,
ordering=cls._get_ordering(is_numeric_col),
)
elif op == "<=":
elif op == FilterOperationType.LESS_THAN_OR_EQUALS.value:
cond = Bound(
extraction_function=extraction_fn,
dimension=col,
@ -1572,7 +1589,7 @@ class DruidDatasource(Model, BaseDatasource):
upper=eq,
ordering=cls._get_ordering(is_numeric_col),
)
elif op == ">":
elif op == FilterOperationType.GREATER_THAN.value:
cond = Bound(
extraction_function=extraction_fn,
lowerStrict=True,
@ -1582,7 +1599,7 @@ class DruidDatasource(Model, BaseDatasource):
upper=None,
ordering=cls._get_ordering(is_numeric_col),
)
elif op == "<":
elif op == FilterOperationType.LESS_THAN.value:
cond = Bound(
extraction_function=extraction_fn,
upperStrict=True,
@ -1592,9 +1609,9 @@ class DruidDatasource(Model, BaseDatasource):
upper=eq,
ordering=cls._get_ordering(is_numeric_col),
)
elif op == "IS NULL":
elif op == FilterOperationType.IS_NULL.value:
cond = Filter(dimension=col, value="")
elif op == "IS NOT NULL":
elif op == FilterOperationType.IS_NOT_NULL.value:
cond = ~Filter(dimension=col, value="")
if filters:
@ -1610,21 +1627,25 @@ class DruidDatasource(Model, BaseDatasource):
def _get_having_obj(self, col: str, op: str, eq: str) -> "Having":
cond = None
if op == "==":
if op == FilterOperationType.EQUALS.value:
if col in self.column_names:
cond = DimSelector(dimension=col, value=eq)
else:
cond = Aggregation(col) == eq
elif op == ">":
elif op == FilterOperationType.GREATER_THAN.value:
cond = Aggregation(col) > eq
elif op == "<":
elif op == FilterOperationType.LESS_THAN.value:
cond = Aggregation(col) < eq
return cond
def get_having_filters(self, raw_filters: List[Dict]) -> "Having":
filters = None
reversed_op_map = {"!=": "==", ">=": "<", "<=": ">"}
reversed_op_map = {
FilterOperationType.NOT_EQUALS.value: FilterOperationType.EQUALS.value,
FilterOperationType.GREATER_THAN_OR_EQUALS.value: FilterOperationType.LESS_THAN.value,
FilterOperationType.LESS_THAN_OR_EQUALS.value: FilterOperationType.GREATER_THAN.value,
}
for flt in raw_filters:
if not all(f in flt for f in ["col", "op", "val"]):
@ -1633,7 +1654,11 @@ class DruidDatasource(Model, BaseDatasource):
op = flt["op"]
eq = flt["val"]
cond = None
if op in ["==", ">", "<"]:
if op in [
FilterOperationType.EQUALS.value,
FilterOperationType.GREATER_THAN.value,
FilterOperationType.LESS_THAN.value,
]:
cond = self._get_having_obj(col, op, eq)
elif op in reversed_op_map:
cond = ~self._get_having_obj(col, reversed_op_map[op], eq)

View File

@ -843,43 +843,53 @@ class SqlaTable(Model, BaseDatasource):
if not all([flt.get(s) for s in ["col", "op"]]):
continue
col = flt["col"]
op = flt["op"]
op = flt["op"].upper()
col_obj = cols.get(col)
if col_obj:
is_list_target = op in ("in", "not in")
is_list_target = op in (
utils.FilterOperationType.IN.value,
utils.FilterOperationType.NOT_IN.value,
)
eq = self.filter_values_handler(
flt.get("val"),
values=flt.get("val"),
target_column_is_numeric=col_obj.is_numeric,
is_list_target=is_list_target,
)
if op in ("in", "not in"):
if op in (
utils.FilterOperationType.IN.value,
utils.FilterOperationType.NOT_IN.value,
):
cond = col_obj.get_sqla_col().in_(eq)
if NULL_STRING in eq:
cond = or_(cond, col_obj.get_sqla_col() == None)
if op == "not in":
if isinstance(eq, str) and NULL_STRING in eq:
cond = or_(cond, col_obj.get_sqla_col() is None)
if op == utils.FilterOperationType.NOT_IN.value:
cond = ~cond
where_clause_and.append(cond)
else:
if col_obj.is_numeric:
eq = utils.string_to_num(flt["val"])
if op == "==":
eq = utils.cast_to_num(flt["val"])
if op == utils.FilterOperationType.EQUALS.value:
where_clause_and.append(col_obj.get_sqla_col() == eq)
elif op == "!=":
elif op == utils.FilterOperationType.NOT_EQUALS.value:
where_clause_and.append(col_obj.get_sqla_col() != eq)
elif op == ">":
elif op == utils.FilterOperationType.GREATER_THAN.value:
where_clause_and.append(col_obj.get_sqla_col() > eq)
elif op == "<":
elif op == utils.FilterOperationType.LESS_THAN.value:
where_clause_and.append(col_obj.get_sqla_col() < eq)
elif op == ">=":
elif op == utils.FilterOperationType.GREATER_THAN_OR_EQUALS.value:
where_clause_and.append(col_obj.get_sqla_col() >= eq)
elif op == "<=":
elif op == utils.FilterOperationType.LESS_THAN_OR_EQUALS.value:
where_clause_and.append(col_obj.get_sqla_col() <= eq)
elif op == "LIKE":
elif op == utils.FilterOperationType.LIKE.value:
where_clause_and.append(col_obj.get_sqla_col().like(eq))
elif op == "IS NULL":
where_clause_and.append(col_obj.get_sqla_col() == None)
elif op == "IS NOT NULL":
where_clause_and.append(col_obj.get_sqla_col() != None)
elif op == utils.FilterOperationType.IS_NULL.value:
where_clause_and.append(col_obj.get_sqla_col() is None)
elif op == utils.FilterOperationType.IS_NOT_NULL.value:
where_clause_and.append(col_obj.get_sqla_col() is None)
else:
raise Exception(
_("Invalid filter operation type: %(op)s", op=op)
)
where_clause_and += self._get_sqla_row_level_filters(template_processor)
if extras:

View File

@ -189,7 +189,7 @@ def load_birth_names(only_metadata: bool = False, force: bool = False) -> None:
"expressionType": "SIMPLE",
"filterOptionName": "2745eae5",
"comparator": ["other"],
"operator": "not in",
"operator": "NOT IN",
"subject": "state",
}
],

View File

@ -249,7 +249,7 @@ def load_world_bank_health_n_pop( # pylint: disable=too-many-locals
"AMA",
"PLW",
],
"operator": "not in",
"operator": "NOT IN",
"subject": "country_code",
}
],

View File

@ -25,4 +25,6 @@ DbapiDescriptionRow = Tuple[
]
DbapiDescription = Union[List[DbapiDescriptionRow], Tuple[DbapiDescriptionRow, ...]]
DbapiResult = List[Union[List[Any], Tuple[Any, ...]]]
FilterValue = Union[float, int, str]
FilterValues = Union[FilterValue, List[FilterValue], Tuple[FilterValue]]
VizData = Optional[Union[List[Any], Dict[Any, Any]]]

View File

@ -198,28 +198,30 @@ def parse_js_uri_path_item(
return unquote_plus(item) if unquote and item else item
def string_to_num(s: str):
"""Converts a string to an int/float
def cast_to_num(value: Union[float, int, str]) -> Optional[Union[float, int]]:
"""Casts a value to an int/float
Returns ``None`` if it can't be converted
>>> string_to_num('5')
>>> cast_to_num('5')
5
>>> string_to_num('5.2')
>>> cast_to_num('5.2')
5.2
>>> string_to_num(10)
>>> cast_to_num(10)
10
>>> string_to_num(10.1)
>>> cast_to_num(10.1)
10.1
>>> string_to_num('this is not a string') is None
>>> cast_to_num('this is not a string') is None
True
:param value: value to be converted to numeric representation
:returns: value cast to `int` if value is all digits, `float` if `value` is
decimal value and `None`` if it can't be converted
"""
if isinstance(s, (int, float)):
return s
if s.isdigit():
return int(s)
if isinstance(value, (int, float)):
return value
if value.isdigit():
return int(value)
try:
return float(s)
return float(value)
except ValueError:
return None
@ -1346,3 +1348,22 @@ class DbColumnType(Enum):
NUMERIC = 0
STRING = 1
TEMPORAL = 2
class FilterOperationType(str, Enum):
"""
Filter operation type
"""
EQUALS = "=="
NOT_EQUALS = "!="
GREATER_THAN = ">"
LESS_THAN = "<"
GREATER_THAN_OR_EQUALS = ">="
LESS_THAN_OR_EQUALS = "<="
LIKE = "LIKE"
IS_NULL = "IS NULL"
IS_NOT_NULL = "IS NOT NULL"
IN = "IN"
NOT_IN = "NOT IN"
REGEX = "REGEX"

View File

@ -96,7 +96,7 @@ def _get_aggregate_funcs(
aggregators. Currently only numpy aggregators are supported.
:param df: DataFrame on which to perform aggregate operation.
:param aggregates: Mapping from column name to aggregat config.
:param aggregates: Mapping from column name to aggregate config.
:return: Mapping from metric name to function that takes a single input argument.
"""
agg_funcs: Dict[str, NamedAgg] = {}
@ -276,12 +276,13 @@ def rolling( # pylint: disable=too-many-arguments
on rolling values calculated from `y`, leaving the original column `y`
unchanged.
:param rolling_type: Type of rolling window. Any numpy function will work.
:param window: Size of the window.
:param rolling_type_options: Optional options to pass to rolling method. Needed
for e.g. quantile operation.
:param center: Should the label be at the center of the window.
:param win_type: Type of window function.
:param window: Size of the window.
:param min_periods:
:param min_periods: The minimum amount of periods required for a row to be included
in the result set.
:return: DataFrame with the rolling columns
:raises ChartDataValidationError: If the request in incorrect
"""
@ -332,7 +333,7 @@ def select(
:param df: DataFrame on which the rolling period will be based.
:param columns: Columns which to select from the DataFrame, in the desired order.
If columns are renamed, the new column name should be referenced
If columns are renamed, the old column name should be referenced
here.
:param rename: columns which to rename, mapping source column to target column.
For instance, `{'y': 'y2'}` will rename the column `y` to

View File

@ -657,7 +657,18 @@ class ChartApiTests(SupersetTestCase, ApiOwnersTestCaseMixin):
rv = self.client.post(uri, json=query_context)
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(data[0]["rowcount"], 100)
self.assertEqual(data["result"][0]["rowcount"], 100)
def test_invalid_chart_data(self):
"""
Query API: Test chart data query
"""
self.login(username="admin")
query_context = self._get_query_context()
query_context["datasource"] = "abc"
uri = "api/v1/chart/data"
rv = self.client.post(uri, json=query_context)
self.assertEqual(rv.status_code, 400)
def test_query_exec_not_allowed(self):
"""