feat: add contribution operation and fix cache_key bug (#10286)

* feat: add contribution operation and fix cache_key_bug

* Add contribution schema
This commit is contained in:
Ville Brofeldt 2020-07-10 17:06:05 +03:00 committed by GitHub
parent 7d4d2e7469
commit 14260f9843
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 124 additions and 7 deletions

View File

@ -395,6 +395,19 @@ class ChartDataSortOptionsSchema(ChartDataPostProcessingOperationOptionsSchema):
aggregates = ChartDataAggregateConfigField()
class ChartDataContributionOptionsSchema(ChartDataPostProcessingOperationOptionsSchema):
"""
Contribution operation config.
"""
orientation = fields.String(
description="Should cell values be calculated across the row or column.",
required=True,
validate=validate.OneOf(choices=("row", "column",)),
example="row",
)
class ChartDataPivotOptionsSchema(ChartDataPostProcessingOperationOptionsSchema):
"""
Pivot operation config.
@ -500,6 +513,7 @@ class ChartDataPostProcessingOperationSchema(Schema):
validate=validate.OneOf(
choices=(
"aggregate",
"contribution",
"cum",
"geodetic_parse",
"geohash_decode",
@ -637,7 +651,7 @@ class ChartDataQueryObjectSchema(Schema):
"`ChartDataAdhocMetricSchema` for the structure of ad-hoc metrics.",
)
post_processing = fields.List(
fields.Nested(ChartDataPostProcessingOperationSchema),
fields.Nested(ChartDataPostProcessingOperationSchema, allow_none=True),
description="Post processing operations to be applied to the result set. "
"Operations are applied to the result set in sequential order.",
)
@ -812,6 +826,7 @@ CHART_DATA_SCHEMAS = (
# by Marshmallow<3, this is not currently possible.
ChartDataAdhocMetricSchema,
ChartDataAggregateOptionsSchema,
ChartDataContributionOptionsSchema,
ChartDataPivotOptionsSchema,
ChartDataRollingOptionsSchema,
ChartDataSelectOptionsSchema,

View File

@ -94,7 +94,7 @@ class QueryObject:
extras: Optional[Dict[str, Any]] = None,
columns: Optional[List[str]] = None,
orderby: Optional[List[List[str]]] = None,
post_processing: Optional[List[Dict[str, Any]]] = None,
post_processing: Optional[List[Optional[Dict[str, Any]]]] = None,
**kwargs: Any,
):
metrics = metrics or []
@ -114,7 +114,9 @@ class QueryObject:
self.is_timeseries = is_timeseries
self.time_range = time_range
self.time_shift = utils.parse_human_timedelta(time_shift)
self.post_processing = post_processing or []
self.post_processing = [
post_proc for post_proc in post_processing or [] if post_proc
]
if not is_sip_38:
self.groupby = groupby or []
@ -224,9 +226,9 @@ class QueryObject:
del cache_dict[k]
if self.time_range:
cache_dict["time_range"] = self.time_range
json_data = self.json_dumps(cache_dict, sort_keys=True)
if self.post_processing:
cache_dict["post_processing"] = self.post_processing
json_data = self.json_dumps(cache_dict, sort_keys=True)
return hashlib.md5(json_data.encode("utf-8")).hexdigest()
def json_dumps(self, obj: Any, sort_keys: bool = False) -> str:

View File

@ -1476,3 +1476,12 @@ class TemporalType(str, Enum):
TEXT = "TEXT"
TIME = "TIME"
TIMESTAMP = "TIMESTAMP"
class PostProcessingContributionOrientation(str, Enum):
"""
Calculate cell contibution to row/column total
"""
ROW = "row"
COLUMN = "column"

View File

@ -15,15 +15,16 @@
# specific language governing permissions and limitations
# under the License.
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
import geohash as geohash_lib
import numpy as np
from flask_babel import gettext as _
from geopy.point import Point
from pandas import DataFrame, NamedAgg
from pandas import DataFrame, NamedAgg, Series
from superset.exceptions import QueryObjectValidationError
from superset.utils.core import DTTM_ALIAS, PostProcessingContributionOrientation
WHITELIST_NUMPY_FUNCTIONS = (
"average",
@ -517,3 +518,29 @@ def geodetic_parse(
return _append_columns(df, geodetic_df, columns)
except ValueError:
raise QueryObjectValidationError(_("Invalid geodetic string"))
def contribution(
df: DataFrame, orientation: PostProcessingContributionOrientation
) -> DataFrame:
"""
Calculate cell contibution to row/column total.
:param df: DataFrame containing all-numeric data (temporal column ignored)
:param orientation: calculate by dividing cell with row/column total
:return: DataFrame with contributions, with temporal column at beginning if present
"""
temporal_series: Optional[Series] = None
contribution_df = df.copy()
if DTTM_ALIAS in df.columns:
temporal_series = cast(Series, contribution_df.pop(DTTM_ALIAS))
if orientation == PostProcessingContributionOrientation.ROW:
contribution_dft = contribution_df.T
contribution_df = (contribution_dft / contribution_dft.sum()).T
else:
contribution_df = contribution_df / contribution_df.sum()
if temporal_series is not None:
contribution_df.insert(0, DTTM_ALIAS, temporal_series)
return contribution_df

View File

@ -69,3 +69,13 @@ class TestSchema(SupersetTestCase):
payload["queries"][0]["extras"]["time_grain_sqla"] = None
_ = ChartDataQueryContextSchema().load(payload)
def test_query_context_null_post_processing_op(self):
self.login(username="admin")
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload["queries"][0]["post_processing"] = [None]
query_context = ChartDataQueryContextSchema().load(payload)
self.assertEqual(query_context.queries[0].post_processing, [])

View File

@ -15,13 +15,15 @@
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
from datetime import datetime
import math
from typing import Any, List, Optional
from pandas import Series
from pandas import DataFrame, Series
from superset.exceptions import QueryObjectValidationError
from superset.utils import pandas_postprocessing as proc
from superset.utils.core import DTTM_ALIAS, PostProcessingContributionOrientation
from .base_tests import SupersetTestCase
from .fixtures.dataframes import categories_df, lonlat_df, timeseries_df
@ -481,3 +483,28 @@ class TestPostProcessing(SupersetTestCase):
self.assertListEqual(
series_to_list(post_df["latitude"]), series_to_list(lonlat_df["latitude"]),
)
def test_contribution(self):
df = DataFrame(
{
DTTM_ALIAS: [
datetime(2020, 7, 16, 14, 49),
datetime(2020, 7, 16, 14, 50),
],
"a": [1, 3],
"b": [1, 9],
}
)
# cell contribution across row
row_df = proc.contribution(df, PostProcessingContributionOrientation.ROW)
self.assertListEqual(df.columns.tolist(), [DTTM_ALIAS, "a", "b"])
self.assertListEqual(series_to_list(row_df["a"]), [0.5, 0.25])
self.assertListEqual(series_to_list(row_df["b"]), [0.5, 0.75])
# cell contribution across column without temporal column
df.pop(DTTM_ALIAS)
column_df = proc.contribution(df, PostProcessingContributionOrientation.COLUMN)
self.assertListEqual(df.columns.tolist(), ["a", "b"])
self.assertListEqual(series_to_list(column_df["a"]), [0.25, 0.75])
self.assertListEqual(series_to_list(column_df["b"]), [0.1, 0.9])

View File

@ -99,6 +99,33 @@ class TestQueryContext(SupersetTestCase):
# the new cache_key should be different due to updated datasource
self.assertNotEqual(cache_key_original, cache_key_new)
def test_cache_key_changes_when_post_processing_is_updated(self):
self.login(username="admin")
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(
table.name, table.id, table.type, add_postprocessing_operations=True
)
# construct baseline cache_key from query_context with post processing operation
query_context = QueryContext(**payload)
query_object = query_context.queries[0]
cache_key_original = query_context.cache_key(query_object)
# ensure added None post_processing operation doesn't change cache_key
payload["queries"][0]["post_processing"].append(None)
query_context = QueryContext(**payload)
query_object = query_context.queries[0]
cache_key_with_null = query_context.cache_key(query_object)
self.assertEqual(cache_key_original, cache_key_with_null)
# ensure query without post processing operation is different
payload["queries"][0].pop("post_processing")
query_context = QueryContext(**payload)
query_object = query_context.queries[0]
cache_key_without_post_processing = query_context.cache_key(query_object)
self.assertNotEqual(cache_key_original, cache_key_without_post_processing)
def test_query_context_time_range_endpoints(self):
"""
Ensure that time_range_endpoints are populated automatically when missing