From 327a2817d35416f74902c889487d0aa6ff1969b8 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt <33317356+villebro@users.noreply.github.com> Date: Fri, 4 Dec 2020 14:40:31 +0200 Subject: [PATCH] feat: add event and interval annotation support to chart data ep (#11665) * feat: add event and interval annotation support to chart data ep * add tests + refactor fixtures * use chart dao --- superset-frontend/src/chart/chartAction.js | 5 +- superset/charts/schemas.py | 5 +- superset/common/query_context.py | 100 ++++++++++++++++++-- superset/common/query_object.py | 30 +++++- superset/utils/core.py | 7 ++ superset/viz.py | 17 ++-- tests/annotation_layers/api_tests.py | 79 ++-------------- tests/annotation_layers/fixtures.py | 101 +++++++++++++++++++++ tests/charts/api_tests.py | 52 ++++++++++- tests/fixtures/query_context.py | 73 +++++++++++++++ 10 files changed, 372 insertions(+), 97 deletions(-) create mode 100644 tests/annotation_layers/fixtures.py diff --git a/superset-frontend/src/chart/chartAction.js b/superset-frontend/src/chart/chartAction.js index 16dd3f4167..f6329a80a5 100644 --- a/superset-frontend/src/chart/chartAction.js +++ b/superset-frontend/src/chart/chartAction.js @@ -412,7 +412,10 @@ export function exploreJSON( }); }); - const annotationLayers = formData.annotation_layers || []; + // only retrieve annotations when calling the legacy API + const annotationLayers = shouldUseLegacyApi(formData) + ? formData.annotation_layers || [] + : []; const isDashboardRequest = dashboardId > 0; return Promise.all([ diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index ca1497a80e..5e346fad37 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -23,6 +23,7 @@ from marshmallow.validate import Length, Range from superset.common.query_context import QueryContext from superset.utils import schema as utils from superset.utils.core import ( + AnnotationType, FilterOperator, PostProcessingBoxplotWhiskerType, PostProcessingContributionOrientation, @@ -783,9 +784,7 @@ class ChartDataExtrasSchema(Schema): class AnnotationLayerSchema(Schema): annotationType = fields.String( description="Type of annotation layer", - validate=validate.OneOf( - choices=("EVENT", "FORMULA", "INTERVAL", "TIME_SERIES",) - ), + validate=validate.OneOf(choices=[ann.value for ann in AnnotationType]), ) color = fields.String(description="Layer color", allow_none=True,) descriptionColumns = fields.List( diff --git a/superset/common/query_context.py b/superset/common/query_context.py index 19e666866d..25113c36ae 100644 --- a/superset/common/query_context.py +++ b/superset/common/query_context.py @@ -25,14 +25,17 @@ import pandas as pd from flask_babel import gettext as _ from superset import app, db, is_feature_enabled +from superset.annotation_layers.dao import AnnotationLayerDAO +from superset.charts.dao import ChartDAO from superset.common.query_object import QueryObject from superset.connectors.base.models import BaseDatasource from superset.connectors.connector_registry import ConnectorRegistry -from superset.exceptions import QueryObjectValidationError +from superset.exceptions import QueryObjectValidationError, SupersetException from superset.extensions import cache_manager, security_manager from superset.stats_logger import BaseStatsLogger from superset.utils import core as utils from superset.utils.core import DTTM_ALIAS +from superset.views.utils import get_viz from superset.viz import set_and_log_cache config = app.config @@ -157,8 +160,7 @@ class QueryContext: query_obj.row_offset = 0 query_obj.columns = [o.column_name for o in self.datasource.columns] payload = self.get_df_payload(query_obj) - # TODO: implement - payload["annotation_data"] = [] + df = payload["df"] status = payload["status"] if status != utils.QueryStatus.FAILED: @@ -220,7 +222,79 @@ class QueryContext: ) return cache_key - def get_df_payload( # pylint: disable=too-many-statements + @staticmethod + def get_native_annotation_data(query_obj: QueryObject) -> Dict[str, Any]: + annotation_data = {} + annotation_layers = [ + layer + for layer in query_obj.annotation_layers + if layer["sourceType"] == "NATIVE" + ] + layer_ids = [layer["value"] for layer in annotation_layers] + layer_objects = { + layer_object.id: layer_object + for layer_object in AnnotationLayerDAO.find_by_ids(layer_ids) + } + + # annotations + for layer in annotation_layers: + layer_id = layer["value"] + layer_name = layer["name"] + columns = [ + "start_dttm", + "end_dttm", + "short_descr", + "long_descr", + "json_metadata", + ] + layer_object = layer_objects[layer_id] + records = [ + {column: getattr(annotation, column) for column in columns} + for annotation in layer_object.annotation + ] + result = {"columns": columns, "records": records} + annotation_data[layer_name] = result + return annotation_data + + @staticmethod + def get_viz_annotation_data( + annotation_layer: Dict[str, Any], force: bool + ) -> Dict[str, Any]: + chart = ChartDAO.find_by_id(annotation_layer["value"]) + form_data = chart.form_data.copy() + if not chart: + raise QueryObjectValidationError("The chart does not exist") + try: + viz_obj = get_viz( + datasource_type=chart.datasource.type, + datasource_id=chart.datasource.id, + form_data=form_data, + force=force, + ) + payload = viz_obj.get_payload() + return payload["data"] + except SupersetException as ex: + raise QueryObjectValidationError(utils.error_msg_from_exception(ex)) + + def get_annotation_data(self, query_obj: QueryObject) -> Dict[str, Any]: + """ + + :param query_obj: + :return: + """ + annotation_data: Dict[str, Any] = self.get_native_annotation_data(query_obj) + for annotation_layer in [ + layer + for layer in query_obj.annotation_layers + if layer["sourceType"] in ("line", "table") + ]: + name = annotation_layer["name"] + annotation_data[name] = self.get_viz_annotation_data( + annotation_layer, self.force + ) + return annotation_data + + def get_df_payload( # pylint: disable=too-many-statements,too-many-locals self, query_obj: QueryObject, **kwargs: Any ) -> Dict[str, Any]: """Handles caching around the df payload retrieval""" @@ -233,6 +307,7 @@ class QueryContext: cache_value = None status = None query = "" + annotation_data = {} error_message = None if cache_key and cache_manager.data_cache and not self.force: cache_value = cache_manager.data_cache.get(cache_key) @@ -241,6 +316,7 @@ class QueryContext: try: df = cache_value["df"] query = cache_value["query"] + annotation_data = cache_value.get("annotation_data", {}) status = utils.QueryStatus.SUCCESS is_loaded = True stats_logger.incr("loaded_from_cache") @@ -272,6 +348,8 @@ class QueryContext: query = query_result["query"] error_message = query_result["error_message"] df = query_result["df"] + annotation_data = self.get_annotation_data(query_obj) + if status != utils.QueryStatus.FAILED: stats_logger.incr("loaded_from_source") if not self.force: @@ -289,18 +367,20 @@ class QueryContext: if is_loaded and cache_key and status != utils.QueryStatus.FAILED: set_and_log_cache( - cache_key, - df, - query, - cached_dttm, - self.cache_timeout, - self.datasource.uid, + cache_key=cache_key, + df=df, + query=query, + annotation_data=annotation_data, + cached_dttm=cached_dttm, + cache_timeout=self.cache_timeout, + datasource_uid=self.datasource.uid, ) return { "cache_key": cache_key, "cached_dttm": cache_value["dttm"] if cache_value is not None else None, "cache_timeout": self.cache_timeout, "df": df, + "annotation_data": annotation_data, "error": error_message, "is_cached": cache_value is not None, "query": query, diff --git a/superset/common/query_object.py b/superset/common/query_object.py index aa2d3147dc..6eb1231f2e 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -106,7 +106,12 @@ class QueryObject: metrics = metrics or [] extras = extras or {} is_sip_38 = is_feature_enabled("SIP_38_VIZ_REARCHITECTURE") - self.annotation_layers = annotation_layers + self.annotation_layers = [ + layer + for layer in annotation_layers + # formula annotations don't affect the payload, hence can be dropped + if layer["annotationType"] != "FORMULA" + ] self.applied_time_extras = applied_time_extras or {} self.granularity = granularity self.from_dttm, self.to_dttm = utils.get_since_until( @@ -236,10 +241,31 @@ class QueryObject: cache_dict["time_range"] = self.time_range if self.post_processing: cache_dict["post_processing"] = self.post_processing + + annotation_fields = [ + "annotationType", + "descriptionColumns", + "intervalEndColumn", + "name", + "overrides", + "sourceType", + "timeColumn", + "titleColumn", + "value", + ] + annotation_layers = [ + {field: layer[field] for field in annotation_fields if field in layer} + for layer in self.annotation_layers + ] + # only add to key if there are annotations present that affect the payload + if annotation_layers: + cache_dict["annotation_layers"] = annotation_layers + json_data = self.json_dumps(cache_dict, sort_keys=True) return hashlib.md5(json_data.encode("utf-8")).hexdigest() - def json_dumps(self, obj: Any, sort_keys: bool = False) -> str: + @staticmethod + def json_dumps(obj: Any, sort_keys: bool = False) -> str: return json.dumps( obj, default=utils.json_int_dttm_ser, ignore_nan=True, sort_keys=sort_keys ) diff --git a/superset/utils/core.py b/superset/utils/core.py index b4e3ef51e1..2d4febeb42 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -1591,6 +1591,13 @@ class ExtraFiltersTimeColumnType(str, Enum): TIME_RANGE = "__time_range" +class AnnotationType(str, Enum): + FORMULA = "FORMULA" + INTERVAL = "INTERVAL" + EVENT = "EVENT" + TIME_SERIES = "TIME_SERIES" + + def is_test() -> bool: return strtobool(os.environ.get("SUPERSET_TESTENV", "false")) diff --git a/superset/viz.py b/superset/viz.py index ab08fe4e70..e63f2c0804 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -104,9 +104,12 @@ def set_and_log_cache( cached_dttm: str, cache_timeout: int, datasource_uid: Optional[str], + annotation_data: Optional[Dict[str, Any]] = None, ) -> None: try: - cache_value = dict(dttm=cached_dttm, df=df, query=query) + cache_value = dict( + dttm=cached_dttm, df=df, query=query, annotation_data=annotation_data or {} + ) stats_logger.incr("set_cache_key") cache_manager.data_cache.set(cache_key, cache_value, timeout=cache_timeout) @@ -587,12 +590,12 @@ class BaseViz: if is_loaded and cache_key and self.status != utils.QueryStatus.FAILED: set_and_log_cache( - cache_key, - df, - self.query, - cached_dttm, - self.cache_timeout, - self.datasource.uid, + cache_key=cache_key, + df=df, + query=self.query, + cached_dttm=cached_dttm, + cache_timeout=self.cache_timeout, + datasource_uid=self.datasource.uid, ) return { "cache_key": self._any_cache_key, diff --git a/tests/annotation_layers/api_tests.py b/tests/annotation_layers/api_tests.py index 2b0df4126b..0ee361bcfb 100644 --- a/tests/annotation_layers/api_tests.py +++ b/tests/annotation_layers/api_tests.py @@ -16,8 +16,6 @@ # under the License. # isort:skip_file """Unit tests for Superset""" -from datetime import datetime -from typing import Optional import json import pytest @@ -29,77 +27,17 @@ from superset import db from superset.models.annotations import Annotation, AnnotationLayer from tests.base_tests import SupersetTestCase - +from tests.annotation_layers.fixtures import ( + create_annotation_layers, + get_end_dttm, + get_start_dttm, +) ANNOTATION_LAYERS_COUNT = 10 ANNOTATIONS_COUNT = 5 class TestAnnotationLayerApi(SupersetTestCase): - def insert_annotation_layer( - self, name: str = "", descr: str = "" - ) -> AnnotationLayer: - annotation_layer = AnnotationLayer(name=name, descr=descr,) - db.session.add(annotation_layer) - db.session.commit() - return annotation_layer - - def insert_annotation( - self, - layer: AnnotationLayer, - short_descr: str, - long_descr: str, - json_metadata: Optional[str] = "", - start_dttm: Optional[datetime] = None, - end_dttm: Optional[datetime] = None, - ) -> Annotation: - annotation = Annotation( - layer=layer, - short_descr=short_descr, - long_descr=long_descr, - json_metadata=json_metadata, - start_dttm=start_dttm, - end_dttm=end_dttm, - ) - db.session.add(annotation) - db.session.commit() - return annotation - - @pytest.fixture() - def create_annotation_layers(self): - """ - Creates ANNOTATION_LAYERS_COUNT-1 layers with no annotations - and a final one with ANNOTATION_COUNT childs - :return: - """ - with self.create_app().app_context(): - annotation_layers = [] - annotations = [] - for cx in range(ANNOTATION_LAYERS_COUNT - 1): - annotation_layers.append( - self.insert_annotation_layer(name=f"name{cx}", descr=f"descr{cx}") - ) - layer_with_annotations = self.insert_annotation_layer( - "layer_with_annotations" - ) - annotation_layers.append(layer_with_annotations) - for cx in range(ANNOTATIONS_COUNT): - annotations.append( - self.insert_annotation( - layer_with_annotations, - short_descr=f"short_descr{cx}", - long_descr=f"long_descr{cx}", - ) - ) - yield annotation_layers - - # rollback changes - for annotation_layer in annotation_layers: - db.session.delete(annotation_layer) - for annotation in annotations: - db.session.delete(annotation) - db.session.commit() - @staticmethod def get_layer_with_annotation() -> AnnotationLayer: return ( @@ -421,9 +359,10 @@ class TestAnnotationLayerApi(SupersetTestCase): """ Annotation API: Test get annotation """ + annotation_id = 1 annotation = ( db.session.query(Annotation) - .filter(Annotation.short_descr == "short_descr1") + .filter(Annotation.short_descr == f"short_descr{annotation_id}") .one_or_none() ) @@ -436,12 +375,12 @@ class TestAnnotationLayerApi(SupersetTestCase): expected_result = { "id": annotation.id, - "end_dttm": None, + "end_dttm": get_end_dttm(annotation_id).isoformat(), "json_metadata": "", "layer": {"id": annotation.layer_id, "name": "layer_with_annotations"}, "long_descr": annotation.long_descr, "short_descr": annotation.short_descr, - "start_dttm": None, + "start_dttm": get_start_dttm(annotation_id).isoformat(), } data = json.loads(rv.data.decode("utf-8")) diff --git a/tests/annotation_layers/fixtures.py b/tests/annotation_layers/fixtures.py new file mode 100644 index 0000000000..d2960acad5 --- /dev/null +++ b/tests/annotation_layers/fixtures.py @@ -0,0 +1,101 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# isort:skip_file +import pytest +from datetime import datetime +from typing import Optional + +from superset import db +from superset.models.annotations import Annotation, AnnotationLayer + +from tests.test_app import app + + +ANNOTATION_LAYERS_COUNT = 10 +ANNOTATIONS_COUNT = 5 + + +def get_start_dttm(annotation_id: int) -> datetime: + return datetime(1990 + annotation_id, 1, 1) + + +def get_end_dttm(annotation_id: int) -> datetime: + return datetime(1990 + annotation_id, 7, 1) + + +def _insert_annotation_layer(name: str = "", descr: str = "") -> AnnotationLayer: + annotation_layer = AnnotationLayer(name=name, descr=descr,) + db.session.add(annotation_layer) + db.session.commit() + return annotation_layer + + +def _insert_annotation( + layer: AnnotationLayer, + short_descr: str, + long_descr: str, + json_metadata: Optional[str] = "", + start_dttm: Optional[datetime] = None, + end_dttm: Optional[datetime] = None, +) -> Annotation: + annotation = Annotation( + layer=layer, + short_descr=short_descr, + long_descr=long_descr, + json_metadata=json_metadata, + start_dttm=start_dttm, + end_dttm=end_dttm, + ) + db.session.add(annotation) + db.session.commit() + return annotation + + +@pytest.fixture() +def create_annotation_layers(): + """ + Creates ANNOTATION_LAYERS_COUNT-1 layers with no annotations + and a final one with ANNOTATION_COUNT childs + :return: + """ + with app.app_context(): + annotation_layers = [] + annotations = [] + for cx in range(ANNOTATION_LAYERS_COUNT - 1): + annotation_layers.append( + _insert_annotation_layer(name=f"name{cx}", descr=f"descr{cx}") + ) + layer_with_annotations = _insert_annotation_layer("layer_with_annotations") + annotation_layers.append(layer_with_annotations) + for cx in range(ANNOTATIONS_COUNT): + annotations.append( + _insert_annotation( + layer_with_annotations, + short_descr=f"short_descr{cx}", + long_descr=f"long_descr{cx}", + start_dttm=get_start_dttm(cx), + end_dttm=get_end_dttm(cx), + ) + ) + yield annotation_layers + + # rollback changes + for annotation_layer in annotation_layers: + db.session.delete(annotation_layer) + for annotation in annotations: + db.session.delete(annotation) + db.session.commit() diff --git a/tests/charts/api_tests.py b/tests/charts/api_tests.py index 511186db2d..7479769437 100644 --- a/tests/charts/api_tests.py +++ b/tests/charts/api_tests.py @@ -30,17 +30,18 @@ import yaml from sqlalchemy import and_ from sqlalchemy.sql import func -from superset.connectors.sqla.models import SqlaTable -from superset.utils.core import get_example_database -from tests.fixtures.unicode_dashboard import load_unicode_dashboard_with_slice from tests.test_app import app +from superset.connectors.sqla.models import SqlaTable +from superset.utils.core import AnnotationType, get_example_database from superset.connectors.connector_registry import ConnectorRegistry from superset.extensions import db, security_manager +from superset.models.annotations import AnnotationLayer from superset.models.core import Database, FavStar, FavStarClassName from superset.models.dashboard import Dashboard from superset.models.reports import ReportSchedule, ReportScheduleType from superset.models.slice import Slice from superset.utils import core as utils + from tests.base_api_tests import ApiOwnersTestCaseMixin from tests.base_tests import SupersetTestCase from tests.fixtures.importexport import ( @@ -50,7 +51,9 @@ from tests.fixtures.importexport import ( dataset_config, dataset_metadata_config, ) -from tests.fixtures.query_context import get_query_context +from tests.fixtures.query_context import get_query_context, ANNOTATION_LAYERS +from tests.fixtures.unicode_dashboard import load_unicode_dashboard_with_slice +from tests.annotation_layers.fixtures import create_annotation_layers CHART_DATA_URI = "api/v1/chart/data" CHARTS_FIXTURE_COUNT = 10 @@ -1383,3 +1386,44 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin): assert response == { "message": {"metadata.yaml": {"type": ["Must be equal to Slice."]}} } + + @pytest.mark.usefixtures("create_annotation_layers") + def test_chart_data_annotations(self): + """ + Chart data API: Test chart data query + """ + self.login(username="admin") + table = self.get_table_by_name("birth_names") + request_payload = get_query_context(table.name, table.id, table.type) + + annotation_layers = [] + request_payload["queries"][0]["annotation_layers"] = annotation_layers + + # formula + annotation_layers.append(ANNOTATION_LAYERS[AnnotationType.FORMULA]) + + # interval + interval_layer = ( + db.session.query(AnnotationLayer) + .filter(AnnotationLayer.name == "name1") + .one() + ) + interval = ANNOTATION_LAYERS[AnnotationType.INTERVAL] + interval["value"] = interval_layer.id + annotation_layers.append(interval) + + # event + event_layer = ( + db.session.query(AnnotationLayer) + .filter(AnnotationLayer.name == "name2") + .one() + ) + event = ANNOTATION_LAYERS[AnnotationType.EVENT] + event["value"] = event_layer.id + annotation_layers.append(event) + + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + self.assertEqual(rv.status_code, 200) + data = json.loads(rv.data.decode("utf-8")) + # response should only contain interval and event data, not formula + self.assertEqual(len(data["result"][0]["annotation_data"]), 2) diff --git a/tests/fixtures/query_context.py b/tests/fixtures/query_context.py index 3881e57641..21bf2b5876 100644 --- a/tests/fixtures/query_context.py +++ b/tests/fixtures/query_context.py @@ -17,6 +17,8 @@ import copy from typing import Any, Dict, List +from superset.utils.core import AnnotationType + QUERY_OBJECTS = { "birth_names": { "extras": {"where": "", "time_range_endpoints": ["inclusive", "exclusive"]}, @@ -37,6 +39,77 @@ QUERY_OBJECTS = { } } +ANNOTATION_LAYERS = { + AnnotationType.FORMULA: { + "annotationType": "FORMULA", + "color": "#ff7f44", + "hideLine": False, + "name": "my formula", + "opacity": "", + "overrides": {"time_range": None}, + "show": True, + "showMarkers": False, + "sourceType": "", + "style": "solid", + "value": "3+x", + "width": 5, + }, + AnnotationType.EVENT: { + "name": "my event", + "annotationType": "EVENT", + "sourceType": "NATIVE", + "color": "#e04355", + "opacity": "", + "style": "solid", + "width": 5, + "showMarkers": False, + "hideLine": False, + "value": 1, + "overrides": {"time_range": None}, + "show": True, + "titleColumn": "", + "descriptionColumns": [], + "timeColumn": "", + "intervalEndColumn": "", + }, + AnnotationType.INTERVAL: { + "name": "my interval", + "annotationType": "INTERVAL", + "sourceType": "NATIVE", + "color": "#e04355", + "opacity": "", + "style": "solid", + "width": 1, + "showMarkers": False, + "hideLine": False, + "value": 1, + "overrides": {"time_range": None}, + "show": True, + "titleColumn": "", + "descriptionColumns": [], + "timeColumn": "", + "intervalEndColumn": "", + }, + AnnotationType.TIME_SERIES: { + "annotationType": "TIME_SERIES", + "color": None, + "descriptionColumns": [], + "hideLine": False, + "intervalEndColumn": "", + "name": "my line", + "opacity": "", + "overrides": {"time_range": None}, + "show": True, + "showMarkers": False, + "sourceType": "line", + "style": "dashed", + "timeColumn": "", + "titleColumn": "", + "value": 837, + "width": 5, + }, +} + POSTPROCESSING_OPERATIONS = { "birth_names": [ {