fix: eliminate cartesian product columns in pivot operator (#15975)

* fix: eliminate cartesian product columns in pivot operator

* wip

* wip

* minor tip
This commit is contained in:
Yongjie Zhao 2021-07-31 09:02:04 +01:00 committed by GitHub
parent b73d7baedf
commit c01d42fd98
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 37 additions and 1 deletions

View File

@ -264,6 +264,15 @@ def pivot( # pylint: disable=too-many-arguments
# Remove once/if support is added. # Remove once/if support is added.
aggfunc = {na.column: na.aggfunc for na in aggregate_funcs.values()} aggfunc = {na.column: na.aggfunc for na in aggregate_funcs.values()}
# When dropna = False, the pivot_table function will calculate cartesian-product
# for MultiIndex.
# https://github.com/apache/superset/issues/15956
# https://github.com/pandas-dev/pandas/issues/18030
series_set = set()
if not drop_missing_columns and columns:
for row in df[columns].itertuples():
metrics_and_series = tuple(aggfunc.keys()) + tuple(row[1:])
series_set.add(str(metrics_and_series))
df = df.pivot_table( df = df.pivot_table(
values=aggfunc.keys(), values=aggfunc.keys(),
index=index, index=index,
@ -275,6 +284,12 @@ def pivot( # pylint: disable=too-many-arguments
margins_name=marginal_distribution_name, margins_name=marginal_distribution_name,
) )
if not drop_missing_columns and len(series_set) > 0 and not df.empty:
for col in df.columns:
series = str(col)
if series not in series_set:
df = df.drop(col, axis=PandasAxis.COLUMN)
if combine_value_with_metric: if combine_value_with_metric:
df = df.stack(0).unstack() df = df.stack(0).unstack()

View File

@ -20,7 +20,8 @@ from importlib.util import find_spec
import math import math
from typing import Any, List, Optional from typing import Any, List, Optional
from pandas import DataFrame, Series, Timestamp import numpy as np
from pandas import DataFrame, Series, Timestamp, to_datetime
import pytest import pytest
from superset.exceptions import QueryObjectValidationError from superset.exceptions import QueryObjectValidationError
@ -256,6 +257,26 @@ class TestPostProcessing(SupersetTestCase):
aggregates={"idx_nulls": {}}, aggregates={"idx_nulls": {}},
) )
def test_pivot_eliminate_cartesian_product_columns(self):
mock_df = DataFrame(
{
"dttm": to_datetime(["2019-01-01", "2019-01-01"]),
"a": [0, 1],
"b": [0, 1],
"metric": [9, np.NAN],
}
)
df = proc.pivot(
df=mock_df,
index=["dttm"],
columns=["a", "b"],
aggregates={"metric": {"operator": "mean"}},
drop_missing_columns=False,
)
self.assertEqual(list(df.columns), ["dttm", "0, 0", "1, 1"])
self.assertTrue(np.isnan(df["1, 1"][0]))
def test_aggregate(self): def test_aggregate(self):
aggregates = { aggregates = {
"asc sum": {"column": "asc_idx", "operator": "sum"}, "asc sum": {"column": "asc_idx", "operator": "sum"},