mirror of https://github.com/apache/superset.git
feat: add contribution operation and fix cache_key bug (#10286)
* feat: add contribution operation and fix cache_key_bug * Add contribution schema
This commit is contained in:
parent
7d4d2e7469
commit
14260f9843
|
@ -395,6 +395,19 @@ class ChartDataSortOptionsSchema(ChartDataPostProcessingOperationOptionsSchema):
|
|||
aggregates = ChartDataAggregateConfigField()
|
||||
|
||||
|
||||
class ChartDataContributionOptionsSchema(ChartDataPostProcessingOperationOptionsSchema):
|
||||
"""
|
||||
Contribution operation config.
|
||||
"""
|
||||
|
||||
orientation = fields.String(
|
||||
description="Should cell values be calculated across the row or column.",
|
||||
required=True,
|
||||
validate=validate.OneOf(choices=("row", "column",)),
|
||||
example="row",
|
||||
)
|
||||
|
||||
|
||||
class ChartDataPivotOptionsSchema(ChartDataPostProcessingOperationOptionsSchema):
|
||||
"""
|
||||
Pivot operation config.
|
||||
|
@ -500,6 +513,7 @@ class ChartDataPostProcessingOperationSchema(Schema):
|
|||
validate=validate.OneOf(
|
||||
choices=(
|
||||
"aggregate",
|
||||
"contribution",
|
||||
"cum",
|
||||
"geodetic_parse",
|
||||
"geohash_decode",
|
||||
|
@ -637,7 +651,7 @@ class ChartDataQueryObjectSchema(Schema):
|
|||
"`ChartDataAdhocMetricSchema` for the structure of ad-hoc metrics.",
|
||||
)
|
||||
post_processing = fields.List(
|
||||
fields.Nested(ChartDataPostProcessingOperationSchema),
|
||||
fields.Nested(ChartDataPostProcessingOperationSchema, allow_none=True),
|
||||
description="Post processing operations to be applied to the result set. "
|
||||
"Operations are applied to the result set in sequential order.",
|
||||
)
|
||||
|
@ -812,6 +826,7 @@ CHART_DATA_SCHEMAS = (
|
|||
# by Marshmallow<3, this is not currently possible.
|
||||
ChartDataAdhocMetricSchema,
|
||||
ChartDataAggregateOptionsSchema,
|
||||
ChartDataContributionOptionsSchema,
|
||||
ChartDataPivotOptionsSchema,
|
||||
ChartDataRollingOptionsSchema,
|
||||
ChartDataSelectOptionsSchema,
|
||||
|
|
|
@ -94,7 +94,7 @@ class QueryObject:
|
|||
extras: Optional[Dict[str, Any]] = None,
|
||||
columns: Optional[List[str]] = None,
|
||||
orderby: Optional[List[List[str]]] = None,
|
||||
post_processing: Optional[List[Dict[str, Any]]] = None,
|
||||
post_processing: Optional[List[Optional[Dict[str, Any]]]] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
metrics = metrics or []
|
||||
|
@ -114,7 +114,9 @@ class QueryObject:
|
|||
self.is_timeseries = is_timeseries
|
||||
self.time_range = time_range
|
||||
self.time_shift = utils.parse_human_timedelta(time_shift)
|
||||
self.post_processing = post_processing or []
|
||||
self.post_processing = [
|
||||
post_proc for post_proc in post_processing or [] if post_proc
|
||||
]
|
||||
if not is_sip_38:
|
||||
self.groupby = groupby or []
|
||||
|
||||
|
@ -224,9 +226,9 @@ class QueryObject:
|
|||
del cache_dict[k]
|
||||
if self.time_range:
|
||||
cache_dict["time_range"] = self.time_range
|
||||
json_data = self.json_dumps(cache_dict, sort_keys=True)
|
||||
if self.post_processing:
|
||||
cache_dict["post_processing"] = self.post_processing
|
||||
json_data = self.json_dumps(cache_dict, sort_keys=True)
|
||||
return hashlib.md5(json_data.encode("utf-8")).hexdigest()
|
||||
|
||||
def json_dumps(self, obj: Any, sort_keys: bool = False) -> str:
|
||||
|
|
|
@ -1476,3 +1476,12 @@ class TemporalType(str, Enum):
|
|||
TEXT = "TEXT"
|
||||
TIME = "TIME"
|
||||
TIMESTAMP = "TIMESTAMP"
|
||||
|
||||
|
||||
class PostProcessingContributionOrientation(str, Enum):
|
||||
"""
|
||||
Calculate cell contibution to row/column total
|
||||
"""
|
||||
|
||||
ROW = "row"
|
||||
COLUMN = "column"
|
||||
|
|
|
@ -15,15 +15,16 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import geohash as geohash_lib
|
||||
import numpy as np
|
||||
from flask_babel import gettext as _
|
||||
from geopy.point import Point
|
||||
from pandas import DataFrame, NamedAgg
|
||||
from pandas import DataFrame, NamedAgg, Series
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.utils.core import DTTM_ALIAS, PostProcessingContributionOrientation
|
||||
|
||||
WHITELIST_NUMPY_FUNCTIONS = (
|
||||
"average",
|
||||
|
@ -517,3 +518,29 @@ def geodetic_parse(
|
|||
return _append_columns(df, geodetic_df, columns)
|
||||
except ValueError:
|
||||
raise QueryObjectValidationError(_("Invalid geodetic string"))
|
||||
|
||||
|
||||
def contribution(
|
||||
df: DataFrame, orientation: PostProcessingContributionOrientation
|
||||
) -> DataFrame:
|
||||
"""
|
||||
Calculate cell contibution to row/column total.
|
||||
|
||||
:param df: DataFrame containing all-numeric data (temporal column ignored)
|
||||
:param orientation: calculate by dividing cell with row/column total
|
||||
:return: DataFrame with contributions, with temporal column at beginning if present
|
||||
"""
|
||||
temporal_series: Optional[Series] = None
|
||||
contribution_df = df.copy()
|
||||
if DTTM_ALIAS in df.columns:
|
||||
temporal_series = cast(Series, contribution_df.pop(DTTM_ALIAS))
|
||||
|
||||
if orientation == PostProcessingContributionOrientation.ROW:
|
||||
contribution_dft = contribution_df.T
|
||||
contribution_df = (contribution_dft / contribution_dft.sum()).T
|
||||
else:
|
||||
contribution_df = contribution_df / contribution_df.sum()
|
||||
|
||||
if temporal_series is not None:
|
||||
contribution_df.insert(0, DTTM_ALIAS, temporal_series)
|
||||
return contribution_df
|
||||
|
|
|
@ -69,3 +69,13 @@ class TestSchema(SupersetTestCase):
|
|||
|
||||
payload["queries"][0]["extras"]["time_grain_sqla"] = None
|
||||
_ = ChartDataQueryContextSchema().load(payload)
|
||||
|
||||
def test_query_context_null_post_processing_op(self):
|
||||
self.login(username="admin")
|
||||
table_name = "birth_names"
|
||||
table = self.get_table_by_name(table_name)
|
||||
payload = get_query_context(table.name, table.id, table.type)
|
||||
|
||||
payload["queries"][0]["post_processing"] = [None]
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
self.assertEqual(query_context.queries[0].post_processing, [])
|
||||
|
|
|
@ -15,13 +15,15 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# isort:skip_file
|
||||
from datetime import datetime
|
||||
import math
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from pandas import Series
|
||||
from pandas import DataFrame, Series
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.utils import pandas_postprocessing as proc
|
||||
from superset.utils.core import DTTM_ALIAS, PostProcessingContributionOrientation
|
||||
|
||||
from .base_tests import SupersetTestCase
|
||||
from .fixtures.dataframes import categories_df, lonlat_df, timeseries_df
|
||||
|
@ -481,3 +483,28 @@ class TestPostProcessing(SupersetTestCase):
|
|||
self.assertListEqual(
|
||||
series_to_list(post_df["latitude"]), series_to_list(lonlat_df["latitude"]),
|
||||
)
|
||||
|
||||
def test_contribution(self):
|
||||
df = DataFrame(
|
||||
{
|
||||
DTTM_ALIAS: [
|
||||
datetime(2020, 7, 16, 14, 49),
|
||||
datetime(2020, 7, 16, 14, 50),
|
||||
],
|
||||
"a": [1, 3],
|
||||
"b": [1, 9],
|
||||
}
|
||||
)
|
||||
|
||||
# cell contribution across row
|
||||
row_df = proc.contribution(df, PostProcessingContributionOrientation.ROW)
|
||||
self.assertListEqual(df.columns.tolist(), [DTTM_ALIAS, "a", "b"])
|
||||
self.assertListEqual(series_to_list(row_df["a"]), [0.5, 0.25])
|
||||
self.assertListEqual(series_to_list(row_df["b"]), [0.5, 0.75])
|
||||
|
||||
# cell contribution across column without temporal column
|
||||
df.pop(DTTM_ALIAS)
|
||||
column_df = proc.contribution(df, PostProcessingContributionOrientation.COLUMN)
|
||||
self.assertListEqual(df.columns.tolist(), ["a", "b"])
|
||||
self.assertListEqual(series_to_list(column_df["a"]), [0.25, 0.75])
|
||||
self.assertListEqual(series_to_list(column_df["b"]), [0.1, 0.9])
|
||||
|
|
|
@ -99,6 +99,33 @@ class TestQueryContext(SupersetTestCase):
|
|||
# the new cache_key should be different due to updated datasource
|
||||
self.assertNotEqual(cache_key_original, cache_key_new)
|
||||
|
||||
def test_cache_key_changes_when_post_processing_is_updated(self):
|
||||
self.login(username="admin")
|
||||
table_name = "birth_names"
|
||||
table = self.get_table_by_name(table_name)
|
||||
payload = get_query_context(
|
||||
table.name, table.id, table.type, add_postprocessing_operations=True
|
||||
)
|
||||
|
||||
# construct baseline cache_key from query_context with post processing operation
|
||||
query_context = QueryContext(**payload)
|
||||
query_object = query_context.queries[0]
|
||||
cache_key_original = query_context.cache_key(query_object)
|
||||
|
||||
# ensure added None post_processing operation doesn't change cache_key
|
||||
payload["queries"][0]["post_processing"].append(None)
|
||||
query_context = QueryContext(**payload)
|
||||
query_object = query_context.queries[0]
|
||||
cache_key_with_null = query_context.cache_key(query_object)
|
||||
self.assertEqual(cache_key_original, cache_key_with_null)
|
||||
|
||||
# ensure query without post processing operation is different
|
||||
payload["queries"][0].pop("post_processing")
|
||||
query_context = QueryContext(**payload)
|
||||
query_object = query_context.queries[0]
|
||||
cache_key_without_post_processing = query_context.cache_key(query_object)
|
||||
self.assertNotEqual(cache_key_original, cache_key_without_post_processing)
|
||||
|
||||
def test_query_context_time_range_endpoints(self):
|
||||
"""
|
||||
Ensure that time_range_endpoints are populated automatically when missing
|
||||
|
|
Loading…
Reference in New Issue