mirror of https://github.com/apache/superset.git
feat(chart-data-api): make pivoted columns flattenable (#10255)
* feat(chart-data-api): make pivoted columns flattenable * Linting + improve tests
This commit is contained in:
parent
4252770d50
commit
baeacc3c56
|
@ -414,8 +414,6 @@ class ChartDataPivotOptionsSchema(ChartDataPostProcessingOperationOptionsSchema)
|
|||
fields.String(
|
||||
allow_none=False, description="Columns to group by on the table columns",
|
||||
),
|
||||
minLength=1,
|
||||
required=True,
|
||||
)
|
||||
metric_fill_value = fields.Number(
|
||||
description="Value to replace missing values with in aggregate calculations.",
|
||||
|
|
|
@ -72,13 +72,38 @@ WHITELIST_CUMULATIVE_FUNCTIONS = (
|
|||
)
|
||||
|
||||
|
||||
def _flatten_column_after_pivot(
|
||||
column: Union[str, Tuple[str, ...]], aggregates: Dict[str, Dict[str, Any]]
|
||||
) -> str:
|
||||
"""
|
||||
Function for flattening column names into a single string. This step is necessary
|
||||
to be able to properly serialize a DataFrame. If the column is a string, return
|
||||
element unchanged. For multi-element columns, join column elements with a comma,
|
||||
with the exception of pivots made with a single aggregate, in which case the
|
||||
aggregate column name is omitted.
|
||||
|
||||
:param column: single element from `DataFrame.columns`
|
||||
:param aggregates: aggregates
|
||||
:return:
|
||||
"""
|
||||
if isinstance(column, str):
|
||||
return column
|
||||
if len(column) == 1:
|
||||
return column[0]
|
||||
if len(aggregates) == 1 and len(column) > 1:
|
||||
# drop aggregate for single aggregate pivots with multiple groupings
|
||||
# from column name (aggregates always come first in column name)
|
||||
column = column[1:]
|
||||
return ", ".join(column)
|
||||
|
||||
|
||||
def validate_column_args(*argnames: str) -> Callable[..., Any]:
|
||||
def wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
def wrapped(df: DataFrame, **options: Any) -> Any:
|
||||
columns = df.columns.tolist()
|
||||
for name in argnames:
|
||||
if name in options and not all(
|
||||
elem in columns for elem in options[name]
|
||||
elem in columns for elem in options.get(name) or []
|
||||
):
|
||||
raise QueryObjectValidationError(
|
||||
_("Referenced columns not available in DataFrame.")
|
||||
|
@ -154,14 +179,15 @@ def _append_columns(
|
|||
def pivot( # pylint: disable=too-many-arguments
|
||||
df: DataFrame,
|
||||
index: List[str],
|
||||
columns: List[str],
|
||||
aggregates: Dict[str, Dict[str, Any]],
|
||||
columns: Optional[List[str]] = None,
|
||||
metric_fill_value: Optional[Any] = None,
|
||||
column_fill_value: Optional[str] = None,
|
||||
drop_missing_columns: Optional[bool] = True,
|
||||
combine_value_with_metric: bool = False,
|
||||
marginal_distributions: Optional[bool] = None,
|
||||
marginal_distribution_name: Optional[str] = None,
|
||||
flatten_columns: bool = True,
|
||||
) -> DataFrame:
|
||||
"""
|
||||
Perform a pivot operation on a DataFrame.
|
||||
|
@ -179,6 +205,7 @@ def pivot( # pylint: disable=too-many-arguments
|
|||
:param marginal_distributions: Add totals for row/column. Default to False
|
||||
:param marginal_distribution_name: Name of row/column with marginal distribution.
|
||||
Default to 'All'.
|
||||
:param flatten_columns: Convert column names to strings
|
||||
:return: A pivot table
|
||||
:raises ChartDataValidationError: If the request in incorrect
|
||||
"""
|
||||
|
@ -186,10 +213,6 @@ def pivot( # pylint: disable=too-many-arguments
|
|||
raise QueryObjectValidationError(
|
||||
_("Pivot operation requires at least one index")
|
||||
)
|
||||
if not columns:
|
||||
raise QueryObjectValidationError(
|
||||
_("Pivot operation requires at least one column")
|
||||
)
|
||||
if not aggregates:
|
||||
raise QueryObjectValidationError(
|
||||
_("Pivot operation must include at least one aggregate")
|
||||
|
@ -218,6 +241,13 @@ def pivot( # pylint: disable=too-many-arguments
|
|||
if combine_value_with_metric:
|
||||
df = df.stack(0).unstack()
|
||||
|
||||
# Make index regular column
|
||||
if flatten_columns:
|
||||
df.columns = [
|
||||
_flatten_column_after_pivot(col, aggregates) for col in df.columns
|
||||
]
|
||||
# return index as regular column
|
||||
df.reset_index(level=0, inplace=True)
|
||||
return df
|
||||
|
||||
|
||||
|
|
|
@ -26,6 +26,12 @@ from superset.utils import pandas_postprocessing as proc
|
|||
from .base_tests import SupersetTestCase
|
||||
from .fixtures.dataframes import categories_df, lonlat_df, timeseries_df
|
||||
|
||||
AGGREGATES_SINGLE = {"idx_nulls": {"operator": "sum"}}
|
||||
AGGREGATES_MULTIPLE = {
|
||||
"idx_nulls": {"operator": "sum"},
|
||||
"asc_idx": {"operator": "mean"},
|
||||
}
|
||||
|
||||
|
||||
def series_to_list(series: Series) -> List[Any]:
|
||||
"""
|
||||
|
@ -57,33 +63,99 @@ def round_floats(
|
|||
|
||||
|
||||
class TestPostProcessing(SupersetTestCase):
|
||||
def test_pivot(self):
|
||||
aggregates = {"idx_nulls": {"operator": "sum"}}
|
||||
def test_flatten_column_after_pivot(self):
|
||||
"""
|
||||
Test pivot column flattening function
|
||||
"""
|
||||
# single aggregate cases
|
||||
self.assertEqual(
|
||||
proc._flatten_column_after_pivot(
|
||||
aggregates=AGGREGATES_SINGLE, column="idx_nulls",
|
||||
),
|
||||
"idx_nulls",
|
||||
)
|
||||
self.assertEqual(
|
||||
proc._flatten_column_after_pivot(
|
||||
aggregates=AGGREGATES_SINGLE, column=("idx_nulls", "col1"),
|
||||
),
|
||||
"col1",
|
||||
)
|
||||
self.assertEqual(
|
||||
proc._flatten_column_after_pivot(
|
||||
aggregates=AGGREGATES_SINGLE, column=("idx_nulls", "col1", "col2"),
|
||||
),
|
||||
"col1, col2",
|
||||
)
|
||||
|
||||
# regular pivot
|
||||
# Multiple aggregate cases
|
||||
self.assertEqual(
|
||||
proc._flatten_column_after_pivot(
|
||||
aggregates=AGGREGATES_MULTIPLE, column=("idx_nulls", "asc_idx", "col1"),
|
||||
),
|
||||
"idx_nulls, asc_idx, col1",
|
||||
)
|
||||
self.assertEqual(
|
||||
proc._flatten_column_after_pivot(
|
||||
aggregates=AGGREGATES_MULTIPLE,
|
||||
column=("idx_nulls", "asc_idx", "col1", "col2"),
|
||||
),
|
||||
"idx_nulls, asc_idx, col1, col2",
|
||||
)
|
||||
|
||||
def test_pivot_without_columns(self):
|
||||
"""
|
||||
Make sure pivot without columns returns correct DataFrame
|
||||
"""
|
||||
df = proc.pivot(df=categories_df, index=["name"], aggregates=AGGREGATES_SINGLE,)
|
||||
self.assertListEqual(
|
||||
df.columns.tolist(), ["name", "idx_nulls"],
|
||||
)
|
||||
self.assertEqual(len(df), 101)
|
||||
self.assertEqual(df.sum()[1], 1050)
|
||||
|
||||
def test_pivot_with_single_column(self):
|
||||
"""
|
||||
Make sure pivot with single column returns correct DataFrame
|
||||
"""
|
||||
df = proc.pivot(
|
||||
df=categories_df,
|
||||
index=["name"],
|
||||
columns=["category"],
|
||||
aggregates=aggregates,
|
||||
aggregates=AGGREGATES_SINGLE,
|
||||
)
|
||||
self.assertListEqual(
|
||||
df.columns.tolist(),
|
||||
[("idx_nulls", "cat0"), ("idx_nulls", "cat1"), ("idx_nulls", "cat2")],
|
||||
df.columns.tolist(), ["name", "cat0", "cat1", "cat2"],
|
||||
)
|
||||
self.assertEqual(len(df), 101)
|
||||
self.assertEqual(df.sum()[0], 315)
|
||||
self.assertEqual(df.sum()[1], 315)
|
||||
|
||||
# regular pivot
|
||||
df = proc.pivot(
|
||||
df=categories_df,
|
||||
index=["dept"],
|
||||
columns=["category"],
|
||||
aggregates=aggregates,
|
||||
aggregates=AGGREGATES_SINGLE,
|
||||
)
|
||||
self.assertListEqual(
|
||||
df.columns.tolist(), ["dept", "cat0", "cat1", "cat2"],
|
||||
)
|
||||
self.assertEqual(len(df), 5)
|
||||
|
||||
# fill value
|
||||
def test_pivot_with_multiple_columns(self):
|
||||
"""
|
||||
Make sure pivot with multiple columns returns correct DataFrame
|
||||
"""
|
||||
df = proc.pivot(
|
||||
df=categories_df,
|
||||
index=["name"],
|
||||
columns=["category", "dept"],
|
||||
aggregates=AGGREGATES_SINGLE,
|
||||
)
|
||||
self.assertEqual(len(df.columns), 1 + 3 * 5) # index + possible permutations
|
||||
|
||||
def test_pivot_fill_values(self):
|
||||
"""
|
||||
Make sure pivot with fill values returns correct DataFrame
|
||||
"""
|
||||
df = proc.pivot(
|
||||
df=categories_df,
|
||||
index=["name"],
|
||||
|
@ -91,7 +163,20 @@ class TestPostProcessing(SupersetTestCase):
|
|||
metric_fill_value=1,
|
||||
aggregates={"idx_nulls": {"operator": "sum"}},
|
||||
)
|
||||
self.assertEqual(df.sum()[0], 382)
|
||||
self.assertEqual(df.sum()[1], 382)
|
||||
|
||||
def test_pivot_exceptions(self):
|
||||
"""
|
||||
Make sure pivot raises correct Exceptions
|
||||
"""
|
||||
# Missing index
|
||||
self.assertRaises(
|
||||
TypeError,
|
||||
proc.pivot,
|
||||
df=categories_df,
|
||||
columns=["dept"],
|
||||
aggregates=AGGREGATES_SINGLE,
|
||||
)
|
||||
|
||||
# invalid index reference
|
||||
self.assertRaises(
|
||||
|
@ -100,7 +185,7 @@ class TestPostProcessing(SupersetTestCase):
|
|||
df=categories_df,
|
||||
index=["abc"],
|
||||
columns=["dept"],
|
||||
aggregates=aggregates,
|
||||
aggregates=AGGREGATES_SINGLE,
|
||||
)
|
||||
|
||||
# invalid column reference
|
||||
|
@ -110,7 +195,7 @@ class TestPostProcessing(SupersetTestCase):
|
|||
df=categories_df,
|
||||
index=["dept"],
|
||||
columns=["abc"],
|
||||
aggregates=aggregates,
|
||||
aggregates=AGGREGATES_SINGLE,
|
||||
)
|
||||
|
||||
# invalid aggregate options
|
||||
|
|
Loading…
Reference in New Issue