From 14260f984334c0adedf813cd821f3fc92d3a2bae Mon Sep 17 00:00:00 2001 From: Ville Brofeldt <33317356+villebro@users.noreply.github.com> Date: Fri, 10 Jul 2020 17:06:05 +0300 Subject: [PATCH] feat: add contribution operation and fix cache_key bug (#10286) * feat: add contribution operation and fix cache_key_bug * Add contribution schema --- superset/charts/schemas.py | 17 +++++++++++++- superset/common/query_object.py | 8 ++++--- superset/utils/core.py | 9 +++++++ superset/utils/pandas_postprocessing.py | 31 +++++++++++++++++++++++-- tests/charts/schema_tests.py | 10 ++++++++ tests/pandas_postprocessing_tests.py | 29 ++++++++++++++++++++++- tests/query_context_tests.py | 27 +++++++++++++++++++++ 7 files changed, 124 insertions(+), 7 deletions(-) diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index 8ab4859d6e..4f2e3f0200 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -395,6 +395,19 @@ class ChartDataSortOptionsSchema(ChartDataPostProcessingOperationOptionsSchema): aggregates = ChartDataAggregateConfigField() +class ChartDataContributionOptionsSchema(ChartDataPostProcessingOperationOptionsSchema): + """ + Contribution operation config. + """ + + orientation = fields.String( + description="Should cell values be calculated across the row or column.", + required=True, + validate=validate.OneOf(choices=("row", "column",)), + example="row", + ) + + class ChartDataPivotOptionsSchema(ChartDataPostProcessingOperationOptionsSchema): """ Pivot operation config. @@ -500,6 +513,7 @@ class ChartDataPostProcessingOperationSchema(Schema): validate=validate.OneOf( choices=( "aggregate", + "contribution", "cum", "geodetic_parse", "geohash_decode", @@ -637,7 +651,7 @@ class ChartDataQueryObjectSchema(Schema): "`ChartDataAdhocMetricSchema` for the structure of ad-hoc metrics.", ) post_processing = fields.List( - fields.Nested(ChartDataPostProcessingOperationSchema), + fields.Nested(ChartDataPostProcessingOperationSchema, allow_none=True), description="Post processing operations to be applied to the result set. " "Operations are applied to the result set in sequential order.", ) @@ -812,6 +826,7 @@ CHART_DATA_SCHEMAS = ( # by Marshmallow<3, this is not currently possible. ChartDataAdhocMetricSchema, ChartDataAggregateOptionsSchema, + ChartDataContributionOptionsSchema, ChartDataPivotOptionsSchema, ChartDataRollingOptionsSchema, ChartDataSelectOptionsSchema, diff --git a/superset/common/query_object.py b/superset/common/query_object.py index 8de21659d0..a2676b960e 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -94,7 +94,7 @@ class QueryObject: extras: Optional[Dict[str, Any]] = None, columns: Optional[List[str]] = None, orderby: Optional[List[List[str]]] = None, - post_processing: Optional[List[Dict[str, Any]]] = None, + post_processing: Optional[List[Optional[Dict[str, Any]]]] = None, **kwargs: Any, ): metrics = metrics or [] @@ -114,7 +114,9 @@ class QueryObject: self.is_timeseries = is_timeseries self.time_range = time_range self.time_shift = utils.parse_human_timedelta(time_shift) - self.post_processing = post_processing or [] + self.post_processing = [ + post_proc for post_proc in post_processing or [] if post_proc + ] if not is_sip_38: self.groupby = groupby or [] @@ -224,9 +226,9 @@ class QueryObject: del cache_dict[k] if self.time_range: cache_dict["time_range"] = self.time_range - json_data = self.json_dumps(cache_dict, sort_keys=True) if self.post_processing: cache_dict["post_processing"] = self.post_processing + json_data = self.json_dumps(cache_dict, sort_keys=True) return hashlib.md5(json_data.encode("utf-8")).hexdigest() def json_dumps(self, obj: Any, sort_keys: bool = False) -> str: diff --git a/superset/utils/core.py b/superset/utils/core.py index 9edee1c933..c464d78d60 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -1476,3 +1476,12 @@ class TemporalType(str, Enum): TEXT = "TEXT" TIME = "TIME" TIMESTAMP = "TIMESTAMP" + + +class PostProcessingContributionOrientation(str, Enum): + """ + Calculate cell contibution to row/column total + """ + + ROW = "row" + COLUMN = "column" diff --git a/superset/utils/pandas_postprocessing.py b/superset/utils/pandas_postprocessing.py index b6939775f4..12b49bc9e4 100644 --- a/superset/utils/pandas_postprocessing.py +++ b/superset/utils/pandas_postprocessing.py @@ -15,15 +15,16 @@ # specific language governing permissions and limitations # under the License. from functools import partial -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union import geohash as geohash_lib import numpy as np from flask_babel import gettext as _ from geopy.point import Point -from pandas import DataFrame, NamedAgg +from pandas import DataFrame, NamedAgg, Series from superset.exceptions import QueryObjectValidationError +from superset.utils.core import DTTM_ALIAS, PostProcessingContributionOrientation WHITELIST_NUMPY_FUNCTIONS = ( "average", @@ -517,3 +518,29 @@ def geodetic_parse( return _append_columns(df, geodetic_df, columns) except ValueError: raise QueryObjectValidationError(_("Invalid geodetic string")) + + +def contribution( + df: DataFrame, orientation: PostProcessingContributionOrientation +) -> DataFrame: + """ + Calculate cell contibution to row/column total. + + :param df: DataFrame containing all-numeric data (temporal column ignored) + :param orientation: calculate by dividing cell with row/column total + :return: DataFrame with contributions, with temporal column at beginning if present + """ + temporal_series: Optional[Series] = None + contribution_df = df.copy() + if DTTM_ALIAS in df.columns: + temporal_series = cast(Series, contribution_df.pop(DTTM_ALIAS)) + + if orientation == PostProcessingContributionOrientation.ROW: + contribution_dft = contribution_df.T + contribution_df = (contribution_dft / contribution_dft.sum()).T + else: + contribution_df = contribution_df / contribution_df.sum() + + if temporal_series is not None: + contribution_df.insert(0, DTTM_ALIAS, temporal_series) + return contribution_df diff --git a/tests/charts/schema_tests.py b/tests/charts/schema_tests.py index 354ed823c4..ecb2c97f00 100644 --- a/tests/charts/schema_tests.py +++ b/tests/charts/schema_tests.py @@ -69,3 +69,13 @@ class TestSchema(SupersetTestCase): payload["queries"][0]["extras"]["time_grain_sqla"] = None _ = ChartDataQueryContextSchema().load(payload) + + def test_query_context_null_post_processing_op(self): + self.login(username="admin") + table_name = "birth_names" + table = self.get_table_by_name(table_name) + payload = get_query_context(table.name, table.id, table.type) + + payload["queries"][0]["post_processing"] = [None] + query_context = ChartDataQueryContextSchema().load(payload) + self.assertEqual(query_context.queries[0].post_processing, []) diff --git a/tests/pandas_postprocessing_tests.py b/tests/pandas_postprocessing_tests.py index 87d2cc1821..ea708349ea 100644 --- a/tests/pandas_postprocessing_tests.py +++ b/tests/pandas_postprocessing_tests.py @@ -15,13 +15,15 @@ # specific language governing permissions and limitations # under the License. # isort:skip_file +from datetime import datetime import math from typing import Any, List, Optional -from pandas import Series +from pandas import DataFrame, Series from superset.exceptions import QueryObjectValidationError from superset.utils import pandas_postprocessing as proc +from superset.utils.core import DTTM_ALIAS, PostProcessingContributionOrientation from .base_tests import SupersetTestCase from .fixtures.dataframes import categories_df, lonlat_df, timeseries_df @@ -481,3 +483,28 @@ class TestPostProcessing(SupersetTestCase): self.assertListEqual( series_to_list(post_df["latitude"]), series_to_list(lonlat_df["latitude"]), ) + + def test_contribution(self): + df = DataFrame( + { + DTTM_ALIAS: [ + datetime(2020, 7, 16, 14, 49), + datetime(2020, 7, 16, 14, 50), + ], + "a": [1, 3], + "b": [1, 9], + } + ) + + # cell contribution across row + row_df = proc.contribution(df, PostProcessingContributionOrientation.ROW) + self.assertListEqual(df.columns.tolist(), [DTTM_ALIAS, "a", "b"]) + self.assertListEqual(series_to_list(row_df["a"]), [0.5, 0.25]) + self.assertListEqual(series_to_list(row_df["b"]), [0.5, 0.75]) + + # cell contribution across column without temporal column + df.pop(DTTM_ALIAS) + column_df = proc.contribution(df, PostProcessingContributionOrientation.COLUMN) + self.assertListEqual(df.columns.tolist(), ["a", "b"]) + self.assertListEqual(series_to_list(column_df["a"]), [0.25, 0.75]) + self.assertListEqual(series_to_list(column_df["b"]), [0.1, 0.9]) diff --git a/tests/query_context_tests.py b/tests/query_context_tests.py index 4b625b5a06..f816bcd592 100644 --- a/tests/query_context_tests.py +++ b/tests/query_context_tests.py @@ -99,6 +99,33 @@ class TestQueryContext(SupersetTestCase): # the new cache_key should be different due to updated datasource self.assertNotEqual(cache_key_original, cache_key_new) + def test_cache_key_changes_when_post_processing_is_updated(self): + self.login(username="admin") + table_name = "birth_names" + table = self.get_table_by_name(table_name) + payload = get_query_context( + table.name, table.id, table.type, add_postprocessing_operations=True + ) + + # construct baseline cache_key from query_context with post processing operation + query_context = QueryContext(**payload) + query_object = query_context.queries[0] + cache_key_original = query_context.cache_key(query_object) + + # ensure added None post_processing operation doesn't change cache_key + payload["queries"][0]["post_processing"].append(None) + query_context = QueryContext(**payload) + query_object = query_context.queries[0] + cache_key_with_null = query_context.cache_key(query_object) + self.assertEqual(cache_key_original, cache_key_with_null) + + # ensure query without post processing operation is different + payload["queries"][0].pop("post_processing") + query_context = QueryContext(**payload) + query_object = query_context.queries[0] + cache_key_without_post_processing = query_context.cache_key(query_object) + self.assertNotEqual(cache_key_original, cache_key_without_post_processing) + def test_query_context_time_range_endpoints(self): """ Ensure that time_range_endpoints are populated automatically when missing