From 8e439b1115481b6df7f8af616ac683f399d52893 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt <33317356+villebro@users.noreply.github.com> Date: Fri, 17 Apr 2020 16:44:16 +0300 Subject: [PATCH] chore: Add OpenAPI docs to /api/v1/chart/data EP (#9556) * Add OpenAPI docs to /api/v1/chart/data EP * Add missing fields to QueryObject, fix linting and unit test errors * Fix unit test errors * abc * Fix errors uncovered by schema validation and add unit test for invalid payload * Add schema for response * Remove unnecessary pylint disable --- setup.cfg | 2 +- superset/charts/api.py | 84 +--- superset/charts/schemas.py | 512 +++++++++++++++++++++++- superset/common/query_object.py | 21 +- superset/connectors/base/models.py | 28 +- superset/connectors/druid/models.py | 71 ++-- superset/connectors/sqla/models.py | 48 ++- superset/examples/birth_names.py | 2 +- superset/examples/world_bank.py | 2 +- superset/typing.py | 2 + superset/utils/core.py | 49 ++- superset/utils/pandas_postprocessing.py | 9 +- tests/charts/api_tests.py | 13 +- 13 files changed, 695 insertions(+), 148 deletions(-) diff --git a/setup.cfg b/setup.cfg index b535c63a76..566633c192 100644 --- a/setup.cfg +++ b/setup.cfg @@ -45,7 +45,7 @@ combine_as_imports = true include_trailing_comma = true line_length = 88 known_first_party = superset -known_third_party =alembic,backoff,bleach,celery,click,colorama,contextlib2,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parsedatetime,pathlib2,polyline,prison,pyarrow,pyhive,pytz,retry,selenium,setuptools,simplejson,sphinx_rtd_theme,sqlalchemy,sqlalchemy_utils,sqlparse,werkzeug,wtforms,wtforms_json,yaml +known_third_party =alembic,apispec,backoff,bleach,celery,click,colorama,contextlib2,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parsedatetime,pathlib2,polyline,prison,pyarrow,pyhive,pytz,retry,selenium,setuptools,simplejson,sphinx_rtd_theme,sqlalchemy,sqlalchemy_utils,sqlparse,werkzeug,wtforms,wtforms_json,yaml multi_line_output = 3 order_by_type = false diff --git a/superset/charts/api.py b/superset/charts/api.py index be0df118b2..be4f40747b 100644 --- a/superset/charts/api.py +++ b/superset/charts/api.py @@ -18,10 +18,11 @@ import logging from typing import Any, Dict import simplejson +from apispec import APISpec from flask import g, make_response, redirect, request, Response, url_for from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.models.sqla.interface import SQLAInterface -from flask_babel import ngettext +from flask_babel import gettext as _, ngettext from werkzeug.wrappers import Response as WerkzeugResponse from werkzeug.wsgi import FileWrapper @@ -41,12 +42,13 @@ from superset.charts.commands.exceptions import ( from superset.charts.commands.update import UpdateChartCommand from superset.charts.filters import ChartFilter, ChartNameOrDescriptionFilter from superset.charts.schemas import ( + CHART_DATA_SCHEMAS, + ChartDataQueryContextSchema, ChartPostSchema, ChartPutSchema, get_delete_ids_schema, thumbnail_query_schema, ) -from superset.common.query_context import QueryContext from superset.constants import RouteMethod from superset.exceptions import SupersetSecurityException from superset.extensions import event_logger, security_manager @@ -381,74 +383,21 @@ class ChartRestApi(BaseSupersetModelRestApi): Takes a query context constructed in the client and returns payload data response for the given query. requestBody: - description: Query context schema + description: >- + A query context consists of a datasource from which to fetch data + and one or many query objects. required: true content: application/json: schema: - type: object - properties: - datasource: - type: object - description: The datasource where the query will run - properties: - id: - type: integer - type: - type: string - queries: - type: array - items: - type: object - properties: - granularity: - type: string - groupby: - type: array - items: - type: string - metrics: - type: array - items: - type: object - filters: - type: array - items: - type: string - row_limit: - type: integer + $ref: "#/components/schemas/ChartDataQueryContextSchema" responses: 200: description: Query result content: application/json: schema: - type: array - items: - type: object - properties: - cache_key: - type: string - cached_dttm: - type: string - cache_timeout: - type: integer - error: - type: string - is_cached: - type: boolean - query: - type: string - status: - type: string - stacktrace: - type: string - rowcount: - type: integer - data: - type: array - items: - type: object + $ref: "#/components/schemas/ChartDataResponseSchema" 400: $ref: '#/components/responses/400' 500: @@ -457,7 +406,11 @@ class ChartRestApi(BaseSupersetModelRestApi): if not request.is_json: return self.response_400(message="Request is not JSON") try: - query_context = QueryContext(**request.json) + query_context, errors = ChartDataQueryContextSchema().load(request.json) + if errors: + return self.response_400( + message=_("Request is incorrect: %(error)s", error=errors) + ) except KeyError: return self.response_400(message="Request is incorrect") try: @@ -466,7 +419,7 @@ class ChartRestApi(BaseSupersetModelRestApi): return self.response_401() payload_json = query_context.get_payload() response_data = simplejson.dumps( - payload_json, default=json_int_dttm_ser, ignore_nan=True + {"result": payload_json}, default=json_int_dttm_ser, ignore_nan=True ) resp = make_response(response_data, 200) resp.headers["Content-Type"] = "application/json; charset=utf-8" @@ -533,3 +486,10 @@ class ChartRestApi(BaseSupersetModelRestApi): return Response( FileWrapper(screenshot), mimetype="image/png", direct_passthrough=True ) + + def add_apispec_components(self, api_spec: APISpec) -> None: + for chart_type in CHART_DATA_SCHEMAS: + api_spec.components.schema( + chart_type.__name__, schema=chart_type, + ) + super().add_apispec_components(api_spec) diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index bf1b57b321..0a7035c6c0 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -14,11 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Union +from typing import Any, Dict, Union -from marshmallow import fields, Schema, ValidationError +from marshmallow import fields, post_load, Schema, ValidationError from marshmallow.validate import Length +from superset.common.query_context import QueryContext from superset.exceptions import SupersetException from superset.utils import core as utils @@ -59,3 +60,510 @@ class ChartPutSchema(Schema): datasource_id = fields.Integer(allow_none=True) datasource_type = fields.String(allow_none=True) dashboards = fields.List(fields.Integer()) + + +class ChartDataColumnSchema(Schema): + column_name = fields.String( + description="The name of the target column", example="mycol", + ) + type = fields.String(description="Type of target column", example="BIGINT",) + + +class ChartDataAdhocMetricSchema(Schema): + """ + Ad-hoc metrics are used to define metrics outside the datasource. + """ + + expressionType = fields.String( + description="Simple or SQL metric", + required=True, + enum=["SIMPLE", "SQL"], + example="SQL", + ) + aggregate = fields.String( + description="Aggregation operator. Only required for simple expression types.", + required=False, + enum=["AVG", "COUNT", "COUNT_DISTINCT", "MAX", "MIN", "SUM"], + ) + column = fields.Nested(ChartDataColumnSchema) + sqlExpression = fields.String( + description="The metric as defined by a SQL aggregate expression. " + "Only required for SQL expression type.", + required=False, + example="SUM(weight * observations) / SUM(weight)", + ) + label = fields.String( + description="Label for the metric. Is automatically generated unless " + "hasCustomLabel is true, in which case label must be defined.", + required=False, + example="Weighted observations", + ) + hasCustomLabel = fields.Boolean( + description="When false, the label will be automatically generated based on " + "the aggregate expression. When true, a custom label has to be " + "specified.", + required=False, + example=True, + ) + optionName = fields.String( + description="Unique identifier. Can be any string value, as long as all " + "metrics have a unique identifier. If undefined, a random name " + "will be generated.", + required=False, + example="metric_aec60732-fac0-4b17-b736-93f1a5c93e30", + ) + + +class ChartDataAggregateConfigField(fields.Dict): + def __init__(self) -> None: + super().__init__( + description="The keys are the name of the aggregate column to be created, " + "and the values specify the details of how to apply the " + "aggregation. If an operator requires additional options, " + "these can be passed here to be unpacked in the operator call. The " + "following numpy operators are supported: average, argmin, argmax, cumsum, " + "cumprod, max, mean, median, nansum, nanmin, nanmax, nanmean, nanmedian, " + "min, percentile, prod, product, std, sum, var. Any options required by " + "the operator can be passed to the `options` object.\n" + "\n" + "In the example, a new column `first_quantile` is created based on values " + "in the column `my_col` using the `percentile` operator with " + "the `q=0.25` parameter.", + example={ + "first_quantile": { + "operator": "percentile", + "column": "my_col", + "options": {"q": 0.25}, + } + }, + ) + + +class ChartDataPostProcessingOperationOptionsSchema(Schema): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class ChartDataAggregateOptionsSchema(ChartDataPostProcessingOperationOptionsSchema): + """ + Aggregate operation config. + """ + + groupby = ( + fields.List( + fields.String( + allow_none=False, description="Columns by which to group by", + ), + minLength=1, + required=True, + ), + ) + aggregates = ChartDataAggregateConfigField() + + +class ChartDataRollingOptionsSchema(ChartDataPostProcessingOperationOptionsSchema): + """ + Rolling operation config. + """ + + columns = ( + fields.Dict( + description="columns on which to perform rolling, mapping source column to " + "target column. For instance, `{'y': 'y'}` will replace the " + "column `y` with the rolling value in `y`, while `{'y': 'y2'}` " + "will add a column `y2` based on rolling values calculated " + "from `y`, leaving the original column `y` unchanged.", + example={"weekly_rolling_sales": "sales"}, + ), + ) + rolling_type = fields.String( + description="Type of rolling window. Any numpy function will work.", + enum=[ + "average", + "argmin", + "argmax", + "cumsum", + "cumprod", + "max", + "mean", + "median", + "nansum", + "nanmin", + "nanmax", + "nanmean", + "nanmedian", + "min", + "percentile", + "prod", + "product", + "std", + "sum", + "var", + ], + required=True, + example="percentile", + ) + window = fields.Integer( + description="Size of the rolling window in days.", required=True, example=7, + ) + rolling_type_options = fields.Dict( + desctiption="Optional options to pass to rolling method. Needed for " + "e.g. quantile operation.", + required=False, + example={}, + ) + center = fields.Boolean( + description="Should the label be at the center of the window. Default: `false`", + required=False, + example=False, + ) + win_type = fields.String( + description="Type of window function. See " + "[SciPy window functions](https://docs.scipy.org/doc/scipy/reference" + "/signal.windows.html#module-scipy.signal.windows) " + "for more details. Some window functions require passing " + "additional parameters to `rolling_type_options`. For instance, " + "to use `gaussian`, the parameter `std` needs to be provided.", + required=False, + enum=[ + "boxcar", + "triang", + "blackman", + "hamming", + "bartlett", + "parzen", + "bohman", + "blackmanharris", + "nuttall", + "barthann", + "kaiser", + "gaussian", + "general_gaussian", + "slepian", + "exponential", + ], + ) + min_periods = fields.Integer( + description="The minimum amount of periods required for a row to be included " + "in the result set.", + required=False, + example=7, + ) + + +class ChartDataSelectOptionsSchema(ChartDataPostProcessingOperationOptionsSchema): + """ + Sort operation config. + """ + + columns = fields.List( + fields.String(), + description="Columns which to select from the input data, in the desired " + "order. If columns are renamed, the old column name should be " + "referenced here.", + example=["country", "gender", "age"], + ) + rename = fields.List( + fields.Dict(), + description="columns which to rename, mapping source column to target column. " + "For instance, `{'y': 'y2'}` will rename the column `y` to `y2`.", + example=[{"age": "average_age"}], + ) + + +class ChartDataSortOptionsSchema(ChartDataPostProcessingOperationOptionsSchema): + """ + Sort operation config. + """ + + columns = fields.Dict( + description="columns by by which to sort. The key specifies the column name, " + "value specifies if sorting in ascending order.", + example={"country": True, "gender": False}, + required=True, + ) + aggregates = ChartDataAggregateConfigField() + + +class ChartDataPivotOptionsSchema(ChartDataPostProcessingOperationOptionsSchema): + """ + Pivot operation config. + """ + + index = ( + fields.List( + fields.String( + allow_none=False, + description="Columns to group by on the table index (=rows)", + ), + minLength=1, + required=True, + ), + ) + columns = fields.List( + fields.String( + allow_none=False, description="Columns to group by on the table columns", + ), + minLength=1, + required=True, + ) + metric_fill_value = fields.Number( + required=False, + description="Value to replace missing values with in aggregate calculations.", + ) + column_fill_value = fields.String( + required=False, description="Value to replace missing pivot columns names with." + ) + drop_missing_columns = fields.Boolean( + description="Do not include columns whose entries are all missing " + "(default: `true`).", + required=False, + ) + marginal_distributions = fields.Boolean( + description="Add totals for row/column. (default: `false`)", required=False, + ) + marginal_distribution_name = fields.String( + description="Name of marginal distribution row/column. (default: `All`)", + required=False, + ) + aggregates = ChartDataAggregateConfigField() + + +class ChartDataPostProcessingOperationSchema(Schema): + operation = fields.String( + description="Post processing operation type", + required=True, + enum=["aggregate", "pivot", "rolling", "select", "sort"], + example="aggregate", + ) + options = fields.Nested( + ChartDataPostProcessingOperationOptionsSchema, + description="Options specifying how to perform the operation. Please refer " + "to the respective post processing operation option schemas. " + "For example, `ChartDataPostProcessingOperationOptions` specifies " + "the required options for the pivot operation.", + example={ + "groupby": ["country", "gender"], + "aggregates": { + "age_q1": { + "operator": "percentile", + "column": "age", + "options": {"q": 0.25}, + }, + "age_mean": {"operator": "mean", "column": "age",}, + }, + }, + ) + + +class ChartDataFilterSchema(Schema): + col = fields.String( + description="The column to filter.", required=True, example="country" + ) + op = fields.String( # pylint: disable=invalid-name + description="The comparison operator.", + enum=[filter_op.value for filter_op in utils.FilterOperationType], + required=True, + example="IN", + ) + val = fields.Raw( + description="The value or values to compare against. Can be a string, " + "integer, decimal or list, depending on the operator.", + example=["China", "France", "Japan"], + ) + + +class ChartDataExtrasSchema(Schema): + + time_range_endpoints = fields.List( + fields.String(enum=["INCLUSIVE", "EXCLUSIVE"]), + description="A list with two values, stating if start/end should be " + "inclusive/exclusive.", + required=False, + ) + relative_start = fields.String( + description="Start time for relative time deltas. " + 'Default: `config["DEFAULT_RELATIVE_START_TIME"]`', + enum=["today", "now"], + required=False, + ) + relative_end = fields.String( + description="End time for relative time deltas. " + 'Default: `config["DEFAULT_RELATIVE_START_TIME"]`', + enum=["today", "now"], + required=False, + ) + + +class ChartDataQueryObjectSchema(Schema): + filters = fields.List(fields.Nested(ChartDataFilterSchema), required=False) + granularity = fields.String( + description="To what level of granularity should the temporal column be " + "aggregated. Supports " + "[ISO 8601](https://en.wikipedia.org/wiki/ISO_8601#Durations) " + "durations.", + enum=[ + "PT1S", + "PT1M", + "PT5M", + "PT10M", + "PT15M", + "PT0.5H", + "PT1H", + "P1D", + "P1W", + "P1M", + "P0.25Y", + "P1Y", + ], + required=False, + example="P1D", + ) + groupby = fields.List( + fields.String(description="Columns by which to group the query.",), + ) + metrics = fields.List( + fields.Raw(), + description="Aggregate expressions. Metrics can be passed as both " + "references to datasource metrics (strings), or ad-hoc metrics" + "which are defined only within the query object. See " + "`ChartDataAdhocMetricSchema` for the structure of ad-hoc metrics.", + ) + post_processing = fields.List( + fields.Nested(ChartDataPostProcessingOperationSchema), + description="Post processing operations to be applied to the result set. " + "Operations are applied to the result set in sequential order.", + required=False, + ) + time_range = fields.String( + description="A time rage, either expressed as a colon separated string " + "`since : until`. Valid formats for `since` and `until` are: \n" + "- ISO 8601\n" + "- X days/years/hours/day/year/weeks\n" + "- X days/years/hours/day/year/weeks ago\n" + "- X days/years/hours/day/year/weeks from now\n" + "\n" + "Additionally, the following freeform can be used:\n" + "\n" + "- Last day\n" + "- Last week\n" + "- Last month\n" + "- Last quarter\n" + "- Last year\n" + "- No filter\n" + "- Last X seconds/minutes/hours/days/weeks/months/years\n" + "- Next X seconds/minutes/hours/days/weeks/months/years\n", + required=False, + example="Last week", + ) + time_shift = fields.String( + description="A human-readable date/time string. " + "Please refer to [parsdatetime](https://github.com/bear/parsedatetime) " + "documentation for details on valid values.", + required=False, + ) + is_timeseries = fields.Boolean( + description="Is the `query_object` a timeseries.", required=False + ) + timeseries_limit = fields.Integer( + description="Maximum row count for timeseries queries. Default: `0`", + required=False, + ) + row_limit = fields.Integer( + description='Maximum row count. Default: `config["ROW_LIMIT"]`', required=False, + ) + order_desc = fields.Boolean( + description="Reverse order. Default: `false`", required=False + ) + extras = fields.Dict(description=" Default: `{}`", required=False) + columns = fields.List(fields.String(), description="", required=False,) + orderby = fields.List( + fields.List(fields.Raw()), + description="Expects a list of lists where the first element is the column " + "name which to sort by, and the second element is a boolean ", + required=False, + example=[["my_col_1", False], ["my_col_2", True]], + ) + + +class ChartDataDatasourceSchema(Schema): + description = "Chart datasource" + id = fields.Integer(description="Datasource id", required=True,) + type = fields.String(description="Datasource type", enum=["druid", "sql"]) + + +class ChartDataQueryContextSchema(Schema): + datasource = fields.Nested(ChartDataDatasourceSchema) + queries = fields.List(fields.Nested(ChartDataQueryObjectSchema)) + + # pylint: disable=no-self-use + @post_load + def make_query_context(self, data: Dict[str, Any]) -> QueryContext: + query_context = QueryContext(**data) + return query_context + + # pylint: enable=no-self-use + + +class ChartDataResponseResult(Schema): + cache_key = fields.String( + description="Unique cache key for query object", required=True, allow_none=True, + ) + cached_dttm = fields.String( + description="Cache timestamp", required=True, allow_none=True, + ) + cache_timeout = fields.Integer( + description="Cache timeout in following order: custom timeout, datasource " + "timeout, default config timeout.", + required=True, + allow_none=True, + ) + error = fields.String(description="Error", allow_none=True,) + is_cached = fields.Boolean( + description="Is the result cached", required=True, allow_none=None, + ) + query = fields.String( + description="The executed query statement", required=True, allow_none=False, + ) + status = fields.String( + description="Status of the query", + enum=[ + "stopped", + "failed", + "pending", + "running", + "scheduled", + "success", + "timed_out", + ], + allow_none=False, + ) + stacktrace = fields.String( + desciption="Stacktrace if there was an error", allow_none=True, + ) + rowcount = fields.Integer( + description="Amount of rows in result set", allow_none=False, + ) + data = fields.List(fields.Dict(), description="A list with results") + + +class ChartDataResponseSchema(Schema): + result = fields.List( + fields.Nested(ChartDataResponseResult), + description="A list of results for each corresponding query in the request.", + ) + + +CHART_DATA_SCHEMAS = ( + ChartDataQueryContextSchema, + ChartDataResponseSchema, + # TODO: These should optimally be included in the QueryContext schema as an `anyOf` + # in ChartDataPostPricessingOperation.options, but since `anyOf` is not + # by Marshmallow<3, this is not currently possible. + ChartDataAdhocMetricSchema, + ChartDataAggregateOptionsSchema, + ChartDataPivotOptionsSchema, + ChartDataRollingOptionsSchema, + ChartDataSelectOptionsSchema, + ChartDataSortOptionsSchema, +) diff --git a/superset/common/query_object.py b/superset/common/query_object.py index 63158c6751..31a62418bf 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -47,9 +47,9 @@ class QueryObject: is_timeseries: bool time_shift: Optional[timedelta] groupby: List[str] - metrics: List[Union[Dict, str]] + metrics: List[Union[Dict[str, Any], str]] row_limit: int - filter: List[str] + filter: List[Dict[str, Any]] timeseries_limit: int timeseries_limit_metric: Optional[Dict] order_desc: bool @@ -61,9 +61,9 @@ class QueryObject: def __init__( self, granularity: str, - metrics: List[Union[Dict, str]], + metrics: List[Union[Dict[str, Any], str]], groupby: Optional[List[str]] = None, - filters: Optional[List[str]] = None, + filters: Optional[List[Dict[str, Any]]] = None, time_range: Optional[str] = None, time_shift: Optional[str] = None, is_timeseries: bool = False, @@ -75,14 +75,17 @@ class QueryObject: columns: Optional[List[str]] = None, orderby: Optional[List[List]] = None, post_processing: Optional[List[Dict[str, Any]]] = None, - relative_start: str = app.config["DEFAULT_RELATIVE_START_TIME"], - relative_end: str = app.config["DEFAULT_RELATIVE_END_TIME"], ): + extras = extras or {} is_sip_38 = is_feature_enabled("SIP_38_VIZ_REARCHITECTURE") self.granularity = granularity self.from_dttm, self.to_dttm = utils.get_since_until( - relative_start=relative_start, - relative_end=relative_end, + relative_start=extras.get( + "relative_start", app.config["DEFAULT_RELATIVE_START_TIME"] + ), + relative_end=extras.get( + "relative_end", app.config["DEFAULT_RELATIVE_END_TIME"] + ), time_range=time_range, time_shift=time_shift, ) @@ -106,7 +109,7 @@ class QueryObject: self.timeseries_limit = timeseries_limit self.timeseries_limit_metric = timeseries_limit_metric self.order_desc = order_desc - self.extras = extras or {} + self.extras = extras if app.config["SIP_15_ENABLED"] and "time_range_endpoints" not in self.extras: self.extras["time_range_endpoints"] = get_time_range_endpoints(form_data={}) diff --git a/superset/connectors/base/models.py b/superset/connectors/base/models.py index 8e1acc74ad..2b6e0d2630 100644 --- a/superset/connectors/base/models.py +++ b/superset/connectors/base/models.py @@ -25,6 +25,7 @@ from sqlalchemy.orm import foreign, Query, relationship from superset.constants import NULL_STRING from superset.models.helpers import AuditMixinNullable, ImportMixin, QueryResult from superset.models.slice import Slice +from superset.typing import FilterValue, FilterValues from superset.utils import core as utils METRIC_FORM_DATA_PARAMS = [ @@ -301,28 +302,33 @@ class BaseDatasource( @staticmethod def filter_values_handler( - values, target_column_is_numeric=False, is_list_target=False - ): - def handle_single_value(v): + values: Optional[FilterValues], + target_column_is_numeric: bool = False, + is_list_target: bool = False, + ) -> Optional[FilterValues]: + if values is None: + return None + + def handle_single_value(value: Optional[FilterValue]) -> Optional[FilterValue]: # backward compatibility with previous