From baeacc3c560dbd2ac9543912ca5559e112118d68 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt <33317356+villebro@users.noreply.github.com> Date: Wed, 8 Jul 2020 13:35:53 +0300 Subject: [PATCH] feat(chart-data-api): make pivoted columns flattenable (#10255) * feat(chart-data-api): make pivoted columns flattenable * Linting + improve tests --- superset/charts/schemas.py | 2 - superset/utils/pandas_postprocessing.py | 42 +++++++-- tests/pandas_postprocessing_tests.py | 111 +++++++++++++++++++++--- 3 files changed, 134 insertions(+), 21 deletions(-) diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index 60cec21b2b..8ab4859d6e 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -414,8 +414,6 @@ class ChartDataPivotOptionsSchema(ChartDataPostProcessingOperationOptionsSchema) fields.String( allow_none=False, description="Columns to group by on the table columns", ), - minLength=1, - required=True, ) metric_fill_value = fields.Number( description="Value to replace missing values with in aggregate calculations.", diff --git a/superset/utils/pandas_postprocessing.py b/superset/utils/pandas_postprocessing.py index e62b393895..b6939775f4 100644 --- a/superset/utils/pandas_postprocessing.py +++ b/superset/utils/pandas_postprocessing.py @@ -72,13 +72,38 @@ WHITELIST_CUMULATIVE_FUNCTIONS = ( ) +def _flatten_column_after_pivot( + column: Union[str, Tuple[str, ...]], aggregates: Dict[str, Dict[str, Any]] +) -> str: + """ + Function for flattening column names into a single string. This step is necessary + to be able to properly serialize a DataFrame. If the column is a string, return + element unchanged. For multi-element columns, join column elements with a comma, + with the exception of pivots made with a single aggregate, in which case the + aggregate column name is omitted. + + :param column: single element from `DataFrame.columns` + :param aggregates: aggregates + :return: + """ + if isinstance(column, str): + return column + if len(column) == 1: + return column[0] + if len(aggregates) == 1 and len(column) > 1: + # drop aggregate for single aggregate pivots with multiple groupings + # from column name (aggregates always come first in column name) + column = column[1:] + return ", ".join(column) + + def validate_column_args(*argnames: str) -> Callable[..., Any]: def wrapper(func: Callable[..., Any]) -> Callable[..., Any]: def wrapped(df: DataFrame, **options: Any) -> Any: columns = df.columns.tolist() for name in argnames: if name in options and not all( - elem in columns for elem in options[name] + elem in columns for elem in options.get(name) or [] ): raise QueryObjectValidationError( _("Referenced columns not available in DataFrame.") @@ -154,14 +179,15 @@ def _append_columns( def pivot( # pylint: disable=too-many-arguments df: DataFrame, index: List[str], - columns: List[str], aggregates: Dict[str, Dict[str, Any]], + columns: Optional[List[str]] = None, metric_fill_value: Optional[Any] = None, column_fill_value: Optional[str] = None, drop_missing_columns: Optional[bool] = True, combine_value_with_metric: bool = False, marginal_distributions: Optional[bool] = None, marginal_distribution_name: Optional[str] = None, + flatten_columns: bool = True, ) -> DataFrame: """ Perform a pivot operation on a DataFrame. @@ -179,6 +205,7 @@ def pivot( # pylint: disable=too-many-arguments :param marginal_distributions: Add totals for row/column. Default to False :param marginal_distribution_name: Name of row/column with marginal distribution. Default to 'All'. + :param flatten_columns: Convert column names to strings :return: A pivot table :raises ChartDataValidationError: If the request in incorrect """ @@ -186,10 +213,6 @@ def pivot( # pylint: disable=too-many-arguments raise QueryObjectValidationError( _("Pivot operation requires at least one index") ) - if not columns: - raise QueryObjectValidationError( - _("Pivot operation requires at least one column") - ) if not aggregates: raise QueryObjectValidationError( _("Pivot operation must include at least one aggregate") @@ -218,6 +241,13 @@ def pivot( # pylint: disable=too-many-arguments if combine_value_with_metric: df = df.stack(0).unstack() + # Make index regular column + if flatten_columns: + df.columns = [ + _flatten_column_after_pivot(col, aggregates) for col in df.columns + ] + # return index as regular column + df.reset_index(level=0, inplace=True) return df diff --git a/tests/pandas_postprocessing_tests.py b/tests/pandas_postprocessing_tests.py index 839b227170..87d2cc1821 100644 --- a/tests/pandas_postprocessing_tests.py +++ b/tests/pandas_postprocessing_tests.py @@ -26,6 +26,12 @@ from superset.utils import pandas_postprocessing as proc from .base_tests import SupersetTestCase from .fixtures.dataframes import categories_df, lonlat_df, timeseries_df +AGGREGATES_SINGLE = {"idx_nulls": {"operator": "sum"}} +AGGREGATES_MULTIPLE = { + "idx_nulls": {"operator": "sum"}, + "asc_idx": {"operator": "mean"}, +} + def series_to_list(series: Series) -> List[Any]: """ @@ -57,33 +63,99 @@ def round_floats( class TestPostProcessing(SupersetTestCase): - def test_pivot(self): - aggregates = {"idx_nulls": {"operator": "sum"}} + def test_flatten_column_after_pivot(self): + """ + Test pivot column flattening function + """ + # single aggregate cases + self.assertEqual( + proc._flatten_column_after_pivot( + aggregates=AGGREGATES_SINGLE, column="idx_nulls", + ), + "idx_nulls", + ) + self.assertEqual( + proc._flatten_column_after_pivot( + aggregates=AGGREGATES_SINGLE, column=("idx_nulls", "col1"), + ), + "col1", + ) + self.assertEqual( + proc._flatten_column_after_pivot( + aggregates=AGGREGATES_SINGLE, column=("idx_nulls", "col1", "col2"), + ), + "col1, col2", + ) - # regular pivot + # Multiple aggregate cases + self.assertEqual( + proc._flatten_column_after_pivot( + aggregates=AGGREGATES_MULTIPLE, column=("idx_nulls", "asc_idx", "col1"), + ), + "idx_nulls, asc_idx, col1", + ) + self.assertEqual( + proc._flatten_column_after_pivot( + aggregates=AGGREGATES_MULTIPLE, + column=("idx_nulls", "asc_idx", "col1", "col2"), + ), + "idx_nulls, asc_idx, col1, col2", + ) + + def test_pivot_without_columns(self): + """ + Make sure pivot without columns returns correct DataFrame + """ + df = proc.pivot(df=categories_df, index=["name"], aggregates=AGGREGATES_SINGLE,) + self.assertListEqual( + df.columns.tolist(), ["name", "idx_nulls"], + ) + self.assertEqual(len(df), 101) + self.assertEqual(df.sum()[1], 1050) + + def test_pivot_with_single_column(self): + """ + Make sure pivot with single column returns correct DataFrame + """ df = proc.pivot( df=categories_df, index=["name"], columns=["category"], - aggregates=aggregates, + aggregates=AGGREGATES_SINGLE, ) self.assertListEqual( - df.columns.tolist(), - [("idx_nulls", "cat0"), ("idx_nulls", "cat1"), ("idx_nulls", "cat2")], + df.columns.tolist(), ["name", "cat0", "cat1", "cat2"], ) self.assertEqual(len(df), 101) - self.assertEqual(df.sum()[0], 315) + self.assertEqual(df.sum()[1], 315) - # regular pivot df = proc.pivot( df=categories_df, index=["dept"], columns=["category"], - aggregates=aggregates, + aggregates=AGGREGATES_SINGLE, + ) + self.assertListEqual( + df.columns.tolist(), ["dept", "cat0", "cat1", "cat2"], ) self.assertEqual(len(df), 5) - # fill value + def test_pivot_with_multiple_columns(self): + """ + Make sure pivot with multiple columns returns correct DataFrame + """ + df = proc.pivot( + df=categories_df, + index=["name"], + columns=["category", "dept"], + aggregates=AGGREGATES_SINGLE, + ) + self.assertEqual(len(df.columns), 1 + 3 * 5) # index + possible permutations + + def test_pivot_fill_values(self): + """ + Make sure pivot with fill values returns correct DataFrame + """ df = proc.pivot( df=categories_df, index=["name"], @@ -91,7 +163,20 @@ class TestPostProcessing(SupersetTestCase): metric_fill_value=1, aggregates={"idx_nulls": {"operator": "sum"}}, ) - self.assertEqual(df.sum()[0], 382) + self.assertEqual(df.sum()[1], 382) + + def test_pivot_exceptions(self): + """ + Make sure pivot raises correct Exceptions + """ + # Missing index + self.assertRaises( + TypeError, + proc.pivot, + df=categories_df, + columns=["dept"], + aggregates=AGGREGATES_SINGLE, + ) # invalid index reference self.assertRaises( @@ -100,7 +185,7 @@ class TestPostProcessing(SupersetTestCase): df=categories_df, index=["abc"], columns=["dept"], - aggregates=aggregates, + aggregates=AGGREGATES_SINGLE, ) # invalid column reference @@ -110,7 +195,7 @@ class TestPostProcessing(SupersetTestCase): df=categories_df, index=["dept"], columns=["abc"], - aggregates=aggregates, + aggregates=AGGREGATES_SINGLE, ) # invalid aggregate options