From c01d42fd9828c0833a6dfbeaeca3d89e3f16411f Mon Sep 17 00:00:00 2001 From: Yongjie Zhao Date: Sat, 31 Jul 2021 09:02:04 +0100 Subject: [PATCH] fix: eliminate cartesian product columns in pivot operator (#15975) * fix: eliminate cartesian product columns in pivot operator * wip * wip * minor tip --- superset/utils/pandas_postprocessing.py | 15 ++++++++++++ .../pandas_postprocessing_tests.py | 23 ++++++++++++++++++- 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/superset/utils/pandas_postprocessing.py b/superset/utils/pandas_postprocessing.py index 75daba5881..0d8105c756 100644 --- a/superset/utils/pandas_postprocessing.py +++ b/superset/utils/pandas_postprocessing.py @@ -264,6 +264,15 @@ def pivot( # pylint: disable=too-many-arguments # Remove once/if support is added. aggfunc = {na.column: na.aggfunc for na in aggregate_funcs.values()} + # When dropna = False, the pivot_table function will calculate cartesian-product + # for MultiIndex. + # https://github.com/apache/superset/issues/15956 + # https://github.com/pandas-dev/pandas/issues/18030 + series_set = set() + if not drop_missing_columns and columns: + for row in df[columns].itertuples(): + metrics_and_series = tuple(aggfunc.keys()) + tuple(row[1:]) + series_set.add(str(metrics_and_series)) df = df.pivot_table( values=aggfunc.keys(), index=index, @@ -275,6 +284,12 @@ def pivot( # pylint: disable=too-many-arguments margins_name=marginal_distribution_name, ) + if not drop_missing_columns and len(series_set) > 0 and not df.empty: + for col in df.columns: + series = str(col) + if series not in series_set: + df = df.drop(col, axis=PandasAxis.COLUMN) + if combine_value_with_metric: df = df.stack(0).unstack() diff --git a/tests/integration_tests/pandas_postprocessing_tests.py b/tests/integration_tests/pandas_postprocessing_tests.py index 5cb7d55113..57bcc1fb92 100644 --- a/tests/integration_tests/pandas_postprocessing_tests.py +++ b/tests/integration_tests/pandas_postprocessing_tests.py @@ -20,7 +20,8 @@ from importlib.util import find_spec import math from typing import Any, List, Optional -from pandas import DataFrame, Series, Timestamp +import numpy as np +from pandas import DataFrame, Series, Timestamp, to_datetime import pytest from superset.exceptions import QueryObjectValidationError @@ -256,6 +257,26 @@ class TestPostProcessing(SupersetTestCase): aggregates={"idx_nulls": {}}, ) + def test_pivot_eliminate_cartesian_product_columns(self): + mock_df = DataFrame( + { + "dttm": to_datetime(["2019-01-01", "2019-01-01"]), + "a": [0, 1], + "b": [0, 1], + "metric": [9, np.NAN], + } + ) + + df = proc.pivot( + df=mock_df, + index=["dttm"], + columns=["a", "b"], + aggregates={"metric": {"operator": "mean"}}, + drop_missing_columns=False, + ) + self.assertEqual(list(df.columns), ["dttm", "0, 0", "1, 1"]) + self.assertTrue(np.isnan(df["1, 1"][0])) + def test_aggregate(self): aggregates = { "asc sum": {"column": "asc_idx", "operator": "sum"},