From ac8e54d9094aaba05a87167da611b169464064da Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Tue, 17 Aug 2021 14:41:22 -0700 Subject: [PATCH] fix: improve pivot post-processing (#16289) * fix: improve pivot post-processing * Add tests * Trim space from column name --- superset/charts/post_processing.py | 265 ++++-- .../unit_tests/charts/test_post_processing.py | 764 +++++++++++++++++- 2 files changed, 919 insertions(+), 110 deletions(-) diff --git a/superset/charts/post_processing.py b/superset/charts/post_processing.py index b67d8705e3..4919907ded 100644 --- a/superset/charts/post_processing.py +++ b/superset/charts/post_processing.py @@ -27,60 +27,151 @@ for these chart types. """ from io import StringIO -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Dict, List, Optional, Tuple import pandas as pd from superset.utils.core import DTTM_ALIAS, extract_dataframe_dtypes, get_metric_name -def sql_like_sum(series: pd.Series) -> pd.Series: +def get_column_key(label: Tuple[str, ...], metrics: List[str]) -> Tuple[Any, ...]: """ - A SUM aggregation function that mimics the behavior from SQL. + Sort columns when combining metrics. + + MultiIndex labels have the metric name as the last element in the + tuple. We want to sort these according to the list of passed metrics. """ - return series.sum(min_count=1) + parts: List[Any] = list(label) + metric = parts[-1] + parts[-1] = metrics.index(metric) + return tuple(parts) -def pivot_table(df: pd.DataFrame, form_data: Dict[str, Any]) -> pd.DataFrame: - """ - Pivot table. - """ - if form_data.get("granularity") == "all" and DTTM_ALIAS in df: - del df[DTTM_ALIAS] +def pivot_df( # pylint: disable=too-many-locals, too-many-arguments, too-many-statements, too-many-branches + df: pd.DataFrame, + rows: List[str], + columns: List[str], + metrics: List[str], + aggfunc: str = "Sum", + transpose_pivot: bool = False, + combine_metrics: bool = False, + show_rows_total: bool = False, + show_columns_total: bool = False, + apply_metrics_on_rows: bool = False, +) -> pd.DataFrame: + metric_name = f"Total ({aggfunc})" - metrics = [get_metric_name(m) for m in form_data["metrics"]] - aggfuncs: Dict[str, Union[str, Callable[[Any], Any]]] = {} - for metric in metrics: - aggfunc = form_data.get("pandas_aggfunc") or "sum" - if pd.api.types.is_numeric_dtype(df[metric]): - if aggfunc == "sum": - aggfunc = sql_like_sum - elif aggfunc not in {"min", "max"}: - aggfunc = "max" - aggfuncs[metric] = aggfunc + if transpose_pivot: + rows, columns = columns, rows - groupby = form_data.get("groupby") or [] - columns = form_data.get("columns") or [] - if form_data.get("transpose_pivot"): - groupby, columns = columns, groupby + # to apply the metrics on the rows we pivot the dataframe, apply the + # metrics to the columns, and pivot the dataframe back before + # returning it + if apply_metrics_on_rows: + rows, columns = columns, rows + axis = {"columns": 0, "rows": 1} + else: + axis = {"columns": 1, "rows": 0} - df = df.pivot_table( - index=groupby, - columns=columns, - values=metrics, - aggfunc=aggfuncs, - margins=form_data.get("pivot_margins"), - ) + # pivot data; we'll compute totals and subtotals later + if rows or columns: + df = df.pivot_table( + index=rows, + columns=columns, + values=metrics, + aggfunc=pivot_v2_aggfunc_map[aggfunc], + margins=False, + ) + else: + # if there's no rows nor columns we have a single value; update + # the index with the metric name so it shows up in the table + df.index = pd.Index([*df.index[:-1], metric_name], name="metric") - # Display metrics side by side with each column - if form_data.get("combine_metric"): - df = df.stack(0).unstack().reindex(level=-1, columns=metrics) + # if no rows were passed the metrics will be in the rows, so we + # need to move them back to columns + if columns and not rows: + df = df.stack().to_frame().T + df = df[metrics] + df.index = pd.Index([*df.index[:-1], metric_name], name="metric") - # flatten column names - df.columns = [ - " ".join(str(name) for name in column) if isinstance(column, tuple) else column - for column in df.columns - ] + # combining metrics changes the column hierarchy, moving the metric + # from the top to the bottom, eg: + # + # ('SUM(col)', 'age', 'name') => ('age', 'name', 'SUM(col)') + if combine_metrics and isinstance(df.columns, pd.MultiIndex): + # move metrics to the lowest level + new_order = [*range(1, df.columns.nlevels), 0] + df = df.reorder_levels(new_order, axis=1) + + # sort columns, combining metrics for each group + decorated_columns = [(col, i) for i, col in enumerate(df.columns)] + grouped_columns = sorted( + decorated_columns, key=lambda t: get_column_key(t[0], metrics) + ) + indexes = [i for col, i in grouped_columns] + df = df[df.columns[indexes]] + elif rows: + # if metrics were not combined we sort the dataframe by the list + # of metrics defined by the user + df = df[metrics] + + # compute fractions, if needed + if aggfunc.endswith(" as Fraction of Total"): + total = df.sum().sum() + df = df.astype(total.dtypes) / total + elif aggfunc.endswith(" as Fraction of Columns"): + total = df.sum(axis=axis["rows"]) + df = df.astype(total.dtypes).div(total, axis=axis["columns"]) + elif aggfunc.endswith(" as Fraction of Rows"): + total = df.sum(axis=axis["columns"]) + df = df.astype(total.dtypes).div(total, axis=axis["rows"]) + + if show_rows_total: + # convert to a MultiIndex to simplify logic + if not isinstance(df.columns, pd.MultiIndex): + df.columns = pd.MultiIndex.from_tuples([(str(i),) for i in df.columns]) + + # add subtotal for each group and overall total; we start from the + # overall group, and iterate deeper into subgroups + groups = df.columns + for level in range(df.columns.nlevels): + subgroups = {group[:level] for group in groups} + for subgroup in subgroups: + slice_ = df.columns.get_loc(subgroup) + subtotal = pivot_v2_aggfunc_map[aggfunc](df.iloc[:, slice_], axis=1) + depth = df.columns.nlevels - len(subgroup) - 1 + total = metric_name if level == 0 else "Subtotal" + subtotal_name = tuple([*subgroup, total, *([""] * depth)]) + # insert column after subgroup + df.insert(int(slice_.stop), subtotal_name, subtotal) + + if rows and show_columns_total: + # convert to a MultiIndex to simplify logic + if not isinstance(df.index, pd.MultiIndex): + df.index = pd.MultiIndex.from_tuples([(str(i),) for i in df.index]) + + # add subtotal for each group and overall total; we start from the + # overall group, and iterate deeper into subgroups + groups = df.index + for level in range(df.index.nlevels): + subgroups = {group[:level] for group in groups} + for subgroup in subgroups: + slice_ = df.index.get_loc(subgroup) + subtotal = pivot_v2_aggfunc_map[aggfunc]( + df.iloc[slice_, :].apply(pd.to_numeric), axis=0 + ) + depth = df.index.nlevels - len(subgroup) - 1 + total = metric_name if level == 0 else "Subtotal" + subtotal.name = tuple([*subgroup, total, *([""] * depth)]) + # insert row after subgroup + df = pd.concat( + [df[: slice_.stop], subtotal.to_frame().T, df[slice_.stop :]] + ) + + # if we want to apply the metrics on the rows we need to pivot the + # dataframe back + if apply_metrics_on_rows: + df = df.T return df @@ -125,61 +216,49 @@ def pivot_table_v2( # pylint: disable=too-many-branches if form_data.get("granularity_sqla") == "all" and DTTM_ALIAS in df: del df[DTTM_ALIAS] - # TODO (betodealmeida): implement metricsLayout - metrics = [get_metric_name(m) for m in form_data["metrics"]] - aggregate_function = form_data.get("aggregateFunction", "Sum") - groupby = form_data.get("groupbyRows") or [] - columns = form_data.get("groupbyColumns") or [] - if form_data.get("transposePivot"): - groupby, columns = columns, groupby - - df = df.pivot_table( - index=groupby, - columns=columns, - values=metrics, - aggfunc=pivot_v2_aggfunc_map[aggregate_function], - margins=True, + return pivot_df( + df, + rows=form_data.get("groupbyRows") or [], + columns=form_data.get("groupbyColumns") or [], + metrics=[get_metric_name(m) for m in form_data["metrics"]], + aggfunc=form_data.get("aggregateFunction", "Sum"), + transpose_pivot=bool(form_data.get("transposePivot")), + combine_metrics=bool(form_data.get("combineMetric")), + show_rows_total=bool(form_data.get("rowTotals")), + show_columns_total=bool(form_data.get("colTotals")), + apply_metrics_on_rows=form_data.get("metricsLayout") == "ROWS", ) - # The pandas `pivot_table` method either brings both row/column - # totals, or none at all. We pass `margin=True` to get both, and - # remove any dimension that was not requests. - if columns and not form_data.get("rowTotals"): - df.drop(df.columns[len(df.columns) - 1], axis=1, inplace=True) - if groupby and not form_data.get("colTotals"): - df = df[:-1] - # Compute fractions, if needed. If `colTotals` or `rowTotals` are - # present we need to adjust for including them in the sum - if aggregate_function.endswith(" as Fraction of Total"): - total = df.sum().sum() - df = df.astype(total.dtypes) / total - if form_data.get("colTotals"): - df *= 2 - if form_data.get("rowTotals"): - df *= 2 - elif aggregate_function.endswith(" as Fraction of Columns"): - total = df.sum(axis=0) - df = df.astype(total.dtypes).div(total, axis=1) - if form_data.get("colTotals"): - df *= 2 - elif aggregate_function.endswith(" as Fraction of Rows"): - total = df.sum(axis=1) - df = df.astype(total.dtypes).div(total, axis=0) - if form_data.get("rowTotals"): - df *= 2 +def pivot_table(df: pd.DataFrame, form_data: Dict[str, Any]) -> pd.DataFrame: + """ + Pivot table (v1). + """ + if form_data.get("granularity") == "all" and DTTM_ALIAS in df: + del df[DTTM_ALIAS] - # Display metrics side by side with each column - if form_data.get("combineMetric"): - df = df.stack(0).unstack().reindex(level=-1, columns=metrics) + # v1 func names => v2 func names + func_map = { + "sum": "Sum", + "mean": "Average", + "min": "Minimum", + "max": "Maximum", + "std": "Sample Standard Deviation", + "var": "Sample Variance", + } - # flatten column names - df.columns = [ - " ".join(str(name) for name in column) if isinstance(column, tuple) else column - for column in df.columns - ] - - return df + return pivot_df( + df, + rows=form_data.get("groupby") or [], + columns=form_data.get("columns") or [], + metrics=[get_metric_name(m) for m in form_data["metrics"]], + aggfunc=func_map.get(form_data.get("pandas_aggfunc", "sum"), "Sum"), + transpose_pivot=bool(form_data.get("transpose_pivot")), + combine_metrics=bool(form_data.get("combine_metric")), + show_rows_total=bool(form_data.get("pivot_margins")), + show_columns_total=bool(form_data.get("pivot_margins")), + apply_metrics_on_rows=False, + ) post_processors = { @@ -203,6 +282,14 @@ def apply_post_process( df = pd.read_csv(StringIO(query["data"])) processed_df = post_processor(df, form_data) + # flatten column names + processed_df.columns = [ + " ".join(str(name) for name in column).strip() + if isinstance(column, tuple) + else column + for column in processed_df.columns + ] + buf = StringIO() processed_df.to_csv(buf) buf.seek(0) diff --git a/tests/unit_tests/charts/test_post_processing.py b/tests/unit_tests/charts/test_post_processing.py index dc2f9a1dd6..9463577996 100644 --- a/tests/unit_tests/charts/test_post_processing.py +++ b/tests/unit_tests/charts/test_post_processing.py @@ -18,7 +18,9 @@ import copy from typing import Any, Dict -from superset.charts.post_processing import apply_post_process +import pandas as pd + +from superset.charts.post_processing import apply_post_process, pivot_df from superset.utils.core import GenericDataType, QueryStatus RESULT: Dict[str, Any] = { @@ -149,7 +151,8 @@ LIMIT 50000; "Births PA", "Births TX", "Births other", - "Births All", + "Births Subtotal", + "Total (Sum)", ], "coltypes": [ GenericDataType.NUMERIC, @@ -164,11 +167,12 @@ LIMIT 50000; GenericDataType.NUMERIC, GenericDataType.NUMERIC, GenericDataType.NUMERIC, + GenericDataType.NUMERIC, ], - "data": """gender,Births CA,Births FL,Births IL,Births MA,Births MI,Births NJ,Births NY,Births OH,Births PA,Births TX,Births other,Births All -boy,5430796,1968060,2357411,1285126,1938321,1486126,3543961,2376385,2390275,3311985,22044909,48133355 -girl,3567754,1312593,1614427,842146,1326229,992702,2280733,1622814,1615383,2313186,15058341,32546308 -All,8998550,3280653,3971838,2127272,3264550,2478828,5824694,3999199,4005658,5625171,37103250,80679663 + "data": """,Births CA,Births FL,Births IL,Births MA,Births MI,Births NJ,Births NY,Births OH,Births PA,Births TX,Births other,Births Subtotal,Total (Sum) +boy,5430796,1968060,2357411,1285126,1938321,1486126,3543961,2376385,2390275,3311985,22044909,48133355,48133355 +girl,3567754,1312593,1614427,842146,1326229,992702,2280733,1622814,1615383,2313186,15058341,32546308,32546308 +Total (Sum),8998550,3280653,3971838,2127272,3264550,2478828,5824694,3999199,4005658,5625171,37103250,80679663,80679663 """, "applied_filters": [], "rejected_filters": [], @@ -199,7 +203,7 @@ def test_pivot_table_v2(): "optionName": "metric_11", } ], - "metricsLayout": "ROWS", + "metricsLayout": "COLUMNS", "rowOrder": "key_a_to_z", "rowTotals": True, "row_limit": 50000, @@ -237,28 +241,746 @@ LIMIT 50000; "status": QueryStatus.SUCCESS, "stacktrace": None, "rowcount": 12, - "colnames": ["All Births", "boy Births", "girl Births"], + "colnames": [ + "boy Births", + "boy Subtotal", + "girl Births", + "girl Subtotal", + "Total (Sum as Fraction of Rows)", + ], "coltypes": [ GenericDataType.NUMERIC, GenericDataType.NUMERIC, GenericDataType.NUMERIC, + GenericDataType.NUMERIC, + GenericDataType.NUMERIC, ], - "data": """state,All Births,boy Births,girl Births -All,1.0,0.5965983645717509,0.40340163542824914 -CA,1.0,0.6035190113962805,0.3964809886037195 -FL,1.0,0.5998988615985903,0.4001011384014097 -IL,1.0,0.5935315085862012,0.40646849141379887 -MA,1.0,0.6041192663655611,0.3958807336344389 -MI,1.0,0.5937482960898133,0.4062517039101867 -NJ,1.0,0.5995276800165239,0.40047231998347604 -NY,1.0,0.6084372844307357,0.39156271556926425 -OH,1.0,0.5942152416021308,0.40578475839786915 -PA,1.0,0.596724682935987,0.40327531706401293 -TX,1.0,0.5887794344385264,0.41122056556147357 -other,1.0,0.5941503507105172,0.40584964928948275 + "data": """,boy Births,boy Subtotal,girl Births,girl Subtotal,Total (Sum as Fraction of Rows) +CA,0.6035190113962805,0.6035190113962805,0.3964809886037195,0.3964809886037195,1.0 +FL,0.5998988615985903,0.5998988615985903,0.4001011384014097,0.4001011384014097,1.0 +IL,0.5935315085862012,0.5935315085862012,0.40646849141379887,0.40646849141379887,1.0 +MA,0.6041192663655611,0.6041192663655611,0.3958807336344389,0.3958807336344389,1.0 +MI,0.5937482960898133,0.5937482960898133,0.4062517039101867,0.4062517039101867,1.0 +NJ,0.5995276800165239,0.5995276800165239,0.40047231998347604,0.40047231998347604,1.0 +NY,0.6084372844307357,0.6084372844307357,0.39156271556926425,0.39156271556926425,1.0 +OH,0.5942152416021308,0.5942152416021308,0.40578475839786915,0.40578475839786915,1.0 +PA,0.596724682935987,0.596724682935987,0.40327531706401293,0.40327531706401293,1.0 +TX,0.5887794344385264,0.5887794344385264,0.41122056556147357,0.41122056556147357,1.0 +other,0.5941503507105172,0.5941503507105172,0.40584964928948275,0.40584964928948275,1.0 +Total (Sum as Fraction of Rows),6.576651618170867,6.576651618170867,4.423348381829133,4.423348381829133,11.0 """, "applied_filters": [], "rejected_filters": [], } ], } + + +def test_pivot_df_no_cols_no_rows_single_metric(): + """ + Pivot table when no cols/rows and 1 metric are selected. + """ + # when no cols/rows are selected there are no groupbys in the query, + # and the data has only the metric(s) + df = pd.DataFrame.from_dict({"SUM(num)": {0: 80679663}}) + assert ( + df.to_markdown() + == """ +| | SUM(num) | +|---:|------------:| +| 0 | 8.06797e+07 | + """.strip() + ) + + pivoted = pivot_df( + df, + rows=[], + columns=[], + metrics=["SUM(num)"], + aggfunc="Sum", + transpose_pivot=False, + combine_metrics=False, + show_rows_total=False, + show_columns_total=False, + apply_metrics_on_rows=False, + ) + assert ( + pivoted.to_markdown() + == """ +| metric | SUM(num) | +|:------------|------------:| +| Total (Sum) | 8.06797e+07 | + """.strip() + ) + + # tranpose_pivot and combine_metrics do nothing in this case + pivoted = pivot_df( + df, + rows=[], + columns=[], + metrics=["SUM(num)"], + aggfunc="Sum", + transpose_pivot=True, + combine_metrics=True, + show_rows_total=False, + show_columns_total=False, + apply_metrics_on_rows=False, + ) + assert ( + pivoted.to_markdown() + == """ +| metric | SUM(num) | +|:------------|------------:| +| Total (Sum) | 8.06797e+07 | + """.strip() + ) + + # apply_metrics_on_rows will pivot the table, moving the metrics + # to rows + pivoted = pivot_df( + df, + rows=[], + columns=[], + metrics=["SUM(num)"], + aggfunc="Sum", + transpose_pivot=True, + combine_metrics=True, + show_rows_total=False, + show_columns_total=False, + apply_metrics_on_rows=True, + ) + assert ( + pivoted.to_markdown() + == """ +| | Total (Sum) | +|:---------|--------------:| +| SUM(num) | 8.06797e+07 | + """.strip() + ) + + # showing totals + pivoted = pivot_df( + df, + rows=[], + columns=[], + metrics=["SUM(num)"], + aggfunc="Sum", + transpose_pivot=True, + combine_metrics=True, + show_rows_total=True, + show_columns_total=True, + apply_metrics_on_rows=False, + ) + assert ( + pivoted.to_markdown() + == """ +| metric | ('SUM(num)',) | ('Total (Sum)',) | +|:------------|----------------:|-------------------:| +| Total (Sum) | 8.06797e+07 | 8.06797e+07 | + """.strip() + ) + + +def test_pivot_df_no_cols_no_rows_two_metrics(): + """ + Pivot table when no cols/rows and 2 metrics are selected. + """ + # when no cols/rows are selected there are no groupbys in the query, + # and the data has only the metrics + df = pd.DataFrame.from_dict({"SUM(num)": {0: 80679663}, "MAX(num)": {0: 37296}}) + assert ( + df.to_markdown() + == """ +| | SUM(num) | MAX(num) | +|---:|------------:|-----------:| +| 0 | 8.06797e+07 | 37296 | + """.strip() + ) + + pivoted = pivot_df( + df, + rows=[], + columns=[], + metrics=["SUM(num)", "MAX(num)"], + aggfunc="Sum", + transpose_pivot=False, + combine_metrics=False, + show_rows_total=False, + show_columns_total=False, + apply_metrics_on_rows=False, + ) + assert ( + pivoted.to_markdown() + == """ +| metric | SUM(num) | MAX(num) | +|:------------|------------:|-----------:| +| Total (Sum) | 8.06797e+07 | 37296 | + """.strip() + ) + + # tranpose_pivot and combine_metrics do nothing in this case + pivoted = pivot_df( + df, + rows=[], + columns=[], + metrics=["SUM(num)", "MAX(num)"], + aggfunc="Sum", + transpose_pivot=True, + combine_metrics=True, + show_rows_total=False, + show_columns_total=False, + apply_metrics_on_rows=False, + ) + assert ( + pivoted.to_markdown() + == """ +| metric | SUM(num) | MAX(num) | +|:------------|------------:|-----------:| +| Total (Sum) | 8.06797e+07 | 37296 | + """.strip() + ) + + # apply_metrics_on_rows will pivot the table, moving the metrics + # to rows + pivoted = pivot_df( + df, + rows=[], + columns=[], + metrics=["SUM(num)", "MAX(num)"], + aggfunc="Sum", + transpose_pivot=True, + combine_metrics=True, + show_rows_total=False, + show_columns_total=False, + apply_metrics_on_rows=True, + ) + assert ( + pivoted.to_markdown() + == """ +| | Total (Sum) | +|:---------|----------------:| +| SUM(num) | 8.06797e+07 | +| MAX(num) | 37296 | + """.strip() + ) + + # when showing totals we only add a column, since adding a row + # would be redundant + pivoted = pivot_df( + df, + rows=[], + columns=[], + metrics=["SUM(num)", "MAX(num)"], + aggfunc="Sum", + transpose_pivot=True, + combine_metrics=True, + show_rows_total=True, + show_columns_total=True, + apply_metrics_on_rows=False, + ) + assert ( + pivoted.to_markdown() + == """ +| metric | ('SUM(num)',) | ('MAX(num)',) | ('Total (Sum)',) | +|:------------|----------------:|----------------:|-------------------:| +| Total (Sum) | 8.06797e+07 | 37296 | 8.0717e+07 | + """.strip() + ) + + +def test_pivot_df_single_row_two_metrics(): + """ + Pivot table when a single column and 2 metrics are selected. + """ + df = pd.DataFrame.from_dict( + { + "gender": {0: "girl", 1: "boy"}, + "SUM(num)": {0: 118065, 1: 47123}, + "MAX(num)": {0: 2588, 1: 1280}, + } + ) + assert ( + df.to_markdown() + == """ +| | gender | SUM(num) | MAX(num) | +|---:|:---------|-----------:|-----------:| +| 0 | girl | 118065 | 2588 | +| 1 | boy | 47123 | 1280 | + """.strip() + ) + + pivoted = pivot_df( + df, + rows=["gender"], + columns=[], + metrics=["SUM(num)", "MAX(num)"], + aggfunc="Sum", + transpose_pivot=False, + combine_metrics=False, + show_rows_total=False, + show_columns_total=False, + apply_metrics_on_rows=False, + ) + assert ( + pivoted.to_markdown() + == """ +| gender | SUM(num) | MAX(num) | +|:---------|-----------:|-----------:| +| boy | 47123 | 1280 | +| girl | 118065 | 2588 | + """.strip() + ) + + # transpose_pivot + pivoted = pivot_df( + df, + rows=["gender"], + columns=[], + metrics=["SUM(num)", "MAX(num)"], + aggfunc="Sum", + transpose_pivot=True, + combine_metrics=False, + show_rows_total=False, + show_columns_total=False, + apply_metrics_on_rows=False, + ) + assert ( + pivoted.to_markdown() + == """ +| metric | ('SUM(num)', 'boy') | ('SUM(num)', 'girl') | ('MAX(num)', 'boy') | ('MAX(num)', 'girl') | +|:------------|----------------------:|-----------------------:|----------------------:|-----------------------:| +| Total (Sum) | 47123 | 118065 | 1280 | 2588 | + """.strip() + ) + + # combine_metrics does nothing in this case + pivoted = pivot_df( + df, + rows=["gender"], + columns=[], + metrics=["SUM(num)", "MAX(num)"], + aggfunc="Sum", + transpose_pivot=False, + combine_metrics=True, + show_rows_total=False, + show_columns_total=False, + apply_metrics_on_rows=False, + ) + assert ( + pivoted.to_markdown() + == """ +| gender | SUM(num) | MAX(num) | +|:---------|-----------:|-----------:| +| boy | 47123 | 1280 | +| girl | 118065 | 2588 | + """.strip() + ) + + # show totals + pivoted = pivot_df( + df, + rows=["gender"], + columns=[], + metrics=["SUM(num)", "MAX(num)"], + aggfunc="Sum", + transpose_pivot=False, + combine_metrics=False, + show_rows_total=True, + show_columns_total=True, + apply_metrics_on_rows=False, + ) + assert ( + pivoted.to_markdown() + == """ +| | ('SUM(num)',) | ('MAX(num)',) | ('Total (Sum)',) | +|:-----------------|----------------:|----------------:|-------------------:| +| ('boy',) | 47123 | 1280 | 48403 | +| ('girl',) | 118065 | 2588 | 120653 | +| ('Total (Sum)',) | 165188 | 3868 | 169056 | + """.strip() + ) + + # apply_metrics_on_rows + pivoted = pivot_df( + df, + rows=["gender"], + columns=[], + metrics=["SUM(num)", "MAX(num)"], + aggfunc="Sum", + transpose_pivot=False, + combine_metrics=False, + show_rows_total=True, + show_columns_total=True, + apply_metrics_on_rows=True, + ) + assert ( + pivoted.to_markdown() + == """ +| | Total (Sum) | +|:-------------------------|--------------:| +| ('SUM(num)', 'boy') | 47123 | +| ('SUM(num)', 'girl') | 118065 | +| ('SUM(num)', 'Subtotal') | 165188 | +| ('MAX(num)', 'boy') | 1280 | +| ('MAX(num)', 'girl') | 2588 | +| ('MAX(num)', 'Subtotal') | 3868 | +| ('Total (Sum)', '') | 169056 | + """.strip() + ) + + # apply_metrics_on_rows with combine_metrics + pivoted = pivot_df( + df, + rows=["gender"], + columns=[], + metrics=["SUM(num)", "MAX(num)"], + aggfunc="Sum", + transpose_pivot=False, + combine_metrics=True, + show_rows_total=True, + show_columns_total=True, + apply_metrics_on_rows=True, + ) + assert ( + pivoted.to_markdown() + == """ +| | Total (Sum) | +|:---------------------|--------------:| +| ('boy', 'SUM(num)') | 47123 | +| ('boy', 'MAX(num)') | 1280 | +| ('boy', 'Subtotal') | 48403 | +| ('girl', 'SUM(num)') | 118065 | +| ('girl', 'MAX(num)') | 2588 | +| ('girl', 'Subtotal') | 120653 | +| ('Total (Sum)', '') | 169056 | + """.strip() + ) + + +def test_pivot_df_complex(): + """ + Pivot table when a column, rows and 2 metrics are selected. + """ + df = pd.DataFrame.from_dict( + { + "state": { + 0: "CA", + 1: "CA", + 2: "CA", + 3: "FL", + 4: "CA", + 5: "CA", + 6: "FL", + 7: "FL", + 8: "FL", + 9: "CA", + 10: "FL", + 11: "FL", + }, + "gender": { + 0: "girl", + 1: "boy", + 2: "girl", + 3: "girl", + 4: "girl", + 5: "girl", + 6: "boy", + 7: "girl", + 8: "girl", + 9: "boy", + 10: "boy", + 11: "girl", + }, + "name": { + 0: "Amy", + 1: "Edward", + 2: "Sophia", + 3: "Amy", + 4: "Cindy", + 5: "Dawn", + 6: "Edward", + 7: "Sophia", + 8: "Dawn", + 9: "Tony", + 10: "Tony", + 11: "Cindy", + }, + "SUM(num)": { + 0: 45426, + 1: 31290, + 2: 18859, + 3: 14740, + 4: 14149, + 5: 11403, + 6: 9395, + 7: 7181, + 8: 5089, + 9: 3765, + 10: 2673, + 11: 1218, + }, + "MAX(num)": { + 0: 2227, + 1: 1280, + 2: 2588, + 3: 854, + 4: 842, + 5: 1157, + 6: 389, + 7: 1187, + 8: 461, + 9: 598, + 10: 247, + 11: 217, + }, + } + ) + assert ( + df.to_markdown() + == """ +| | state | gender | name | SUM(num) | MAX(num) | +|---:|:--------|:---------|:-------|-----------:|-----------:| +| 0 | CA | girl | Amy | 45426 | 2227 | +| 1 | CA | boy | Edward | 31290 | 1280 | +| 2 | CA | girl | Sophia | 18859 | 2588 | +| 3 | FL | girl | Amy | 14740 | 854 | +| 4 | CA | girl | Cindy | 14149 | 842 | +| 5 | CA | girl | Dawn | 11403 | 1157 | +| 6 | FL | boy | Edward | 9395 | 389 | +| 7 | FL | girl | Sophia | 7181 | 1187 | +| 8 | FL | girl | Dawn | 5089 | 461 | +| 9 | CA | boy | Tony | 3765 | 598 | +| 10 | FL | boy | Tony | 2673 | 247 | +| 11 | FL | girl | Cindy | 1218 | 217 | + """.strip() + ) + + pivoted = pivot_df( + df, + rows=["gender", "name"], + columns=["state"], + metrics=["SUM(num)", "MAX(num)"], + aggfunc="Sum", + transpose_pivot=False, + combine_metrics=False, + show_rows_total=False, + show_columns_total=False, + apply_metrics_on_rows=False, + ) + assert ( + pivoted.to_markdown() + == """ +| | ('SUM(num)', 'CA') | ('SUM(num)', 'FL') | ('MAX(num)', 'CA') | ('MAX(num)', 'FL') | +|:-------------------|---------------------:|---------------------:|---------------------:|---------------------:| +| ('boy', 'Edward') | 31290 | 9395 | 1280 | 389 | +| ('boy', 'Tony') | 3765 | 2673 | 598 | 247 | +| ('girl', 'Amy') | 45426 | 14740 | 2227 | 854 | +| ('girl', 'Cindy') | 14149 | 1218 | 842 | 217 | +| ('girl', 'Dawn') | 11403 | 5089 | 1157 | 461 | +| ('girl', 'Sophia') | 18859 | 7181 | 2588 | 1187 | + """.strip() + ) + + # transpose_pivot + pivoted = pivot_df( + df, + rows=["gender", "name"], + columns=["state"], + metrics=["SUM(num)", "MAX(num)"], + aggfunc="Sum", + transpose_pivot=True, + combine_metrics=False, + show_rows_total=False, + show_columns_total=False, + apply_metrics_on_rows=False, + ) + assert ( + pivoted.to_markdown() + == """ +| state | ('SUM(num)', 'boy', 'Edward') | ('SUM(num)', 'boy', 'Tony') | ('SUM(num)', 'girl', 'Amy') | ('SUM(num)', 'girl', 'Cindy') | ('SUM(num)', 'girl', 'Dawn') | ('SUM(num)', 'girl', 'Sophia') | ('MAX(num)', 'boy', 'Edward') | ('MAX(num)', 'boy', 'Tony') | ('MAX(num)', 'girl', 'Amy') | ('MAX(num)', 'girl', 'Cindy') | ('MAX(num)', 'girl', 'Dawn') | ('MAX(num)', 'girl', 'Sophia') | +|:--------|--------------------------------:|------------------------------:|------------------------------:|--------------------------------:|-------------------------------:|---------------------------------:|--------------------------------:|------------------------------:|------------------------------:|--------------------------------:|-------------------------------:|---------------------------------:| +| CA | 31290 | 3765 | 45426 | 14149 | 11403 | 18859 | 1280 | 598 | 2227 | 842 | 1157 | 2588 | +| FL | 9395 | 2673 | 14740 | 1218 | 5089 | 7181 | 389 | 247 | 854 | 217 | 461 | 1187 | + """.strip() + ) + + # combine_metrics + pivoted = pivot_df( + df, + rows=["gender", "name"], + columns=["state"], + metrics=["SUM(num)", "MAX(num)"], + aggfunc="Sum", + transpose_pivot=False, + combine_metrics=True, + show_rows_total=False, + show_columns_total=False, + apply_metrics_on_rows=False, + ) + assert ( + pivoted.to_markdown() + == """ +| | ('CA', 'SUM(num)') | ('CA', 'MAX(num)') | ('FL', 'SUM(num)') | ('FL', 'MAX(num)') | +|:-------------------|---------------------:|---------------------:|---------------------:|---------------------:| +| ('boy', 'Edward') | 31290 | 1280 | 9395 | 389 | +| ('boy', 'Tony') | 3765 | 598 | 2673 | 247 | +| ('girl', 'Amy') | 45426 | 2227 | 14740 | 854 | +| ('girl', 'Cindy') | 14149 | 842 | 1218 | 217 | +| ('girl', 'Dawn') | 11403 | 1157 | 5089 | 461 | +| ('girl', 'Sophia') | 18859 | 2588 | 7181 | 1187 | + """.strip() + ) + + # show totals + pivoted = pivot_df( + df, + rows=["gender", "name"], + columns=["state"], + metrics=["SUM(num)", "MAX(num)"], + aggfunc="Sum", + transpose_pivot=False, + combine_metrics=False, + show_rows_total=True, + show_columns_total=True, + apply_metrics_on_rows=False, + ) + assert ( + pivoted.to_markdown() + == """ +| | ('SUM(num)', 'CA') | ('SUM(num)', 'FL') | ('SUM(num)', 'Subtotal') | ('MAX(num)', 'CA') | ('MAX(num)', 'FL') | ('MAX(num)', 'Subtotal') | ('Total (Sum)', '') | +|:---------------------|---------------------:|---------------------:|---------------------------:|---------------------:|---------------------:|---------------------------:|----------------------:| +| ('boy', 'Edward') | 31290 | 9395 | 40685 | 1280 | 389 | 1669 | 42354 | +| ('boy', 'Tony') | 3765 | 2673 | 6438 | 598 | 247 | 845 | 7283 | +| ('boy', 'Subtotal') | 35055 | 12068 | 47123 | 1878 | 636 | 2514 | 49637 | +| ('girl', 'Amy') | 45426 | 14740 | 60166 | 2227 | 854 | 3081 | 63247 | +| ('girl', 'Cindy') | 14149 | 1218 | 15367 | 842 | 217 | 1059 | 16426 | +| ('girl', 'Dawn') | 11403 | 5089 | 16492 | 1157 | 461 | 1618 | 18110 | +| ('girl', 'Sophia') | 18859 | 7181 | 26040 | 2588 | 1187 | 3775 | 29815 | +| ('girl', 'Subtotal') | 89837 | 28228 | 118065 | 6814 | 2719 | 9533 | 127598 | +| ('Total (Sum)', '') | 124892 | 40296 | 165188 | 8692 | 3355 | 12047 | 177235 | + """.strip() + ) + + # apply_metrics_on_rows + pivoted = pivot_df( + df, + rows=["gender", "name"], + columns=["state"], + metrics=["SUM(num)", "MAX(num)"], + aggfunc="Sum", + transpose_pivot=False, + combine_metrics=False, + show_rows_total=False, + show_columns_total=False, + apply_metrics_on_rows=True, + ) + assert ( + pivoted.to_markdown() + == """ +| | CA | FL | +|:-------------------------------|------:|------:| +| ('SUM(num)', 'boy', 'Edward') | 31290 | 9395 | +| ('SUM(num)', 'boy', 'Tony') | 3765 | 2673 | +| ('SUM(num)', 'girl', 'Amy') | 45426 | 14740 | +| ('SUM(num)', 'girl', 'Cindy') | 14149 | 1218 | +| ('SUM(num)', 'girl', 'Dawn') | 11403 | 5089 | +| ('SUM(num)', 'girl', 'Sophia') | 18859 | 7181 | +| ('MAX(num)', 'boy', 'Edward') | 1280 | 389 | +| ('MAX(num)', 'boy', 'Tony') | 598 | 247 | +| ('MAX(num)', 'girl', 'Amy') | 2227 | 854 | +| ('MAX(num)', 'girl', 'Cindy') | 842 | 217 | +| ('MAX(num)', 'girl', 'Dawn') | 1157 | 461 | +| ('MAX(num)', 'girl', 'Sophia') | 2588 | 1187 | + """.strip() + ) + + # apply_metrics_on_rows with combine_metrics + pivoted = pivot_df( + df, + rows=["gender", "name"], + columns=["state"], + metrics=["SUM(num)", "MAX(num)"], + aggfunc="Sum", + transpose_pivot=False, + combine_metrics=True, + show_rows_total=False, + show_columns_total=False, + apply_metrics_on_rows=True, + ) + assert ( + pivoted.to_markdown() + == """ +| | CA | FL | +|:-------------------------------|------:|------:| +| ('boy', 'Edward', 'SUM(num)') | 31290 | 9395 | +| ('boy', 'Edward', 'MAX(num)') | 1280 | 389 | +| ('boy', 'Tony', 'SUM(num)') | 3765 | 2673 | +| ('boy', 'Tony', 'MAX(num)') | 598 | 247 | +| ('girl', 'Amy', 'SUM(num)') | 45426 | 14740 | +| ('girl', 'Amy', 'MAX(num)') | 2227 | 854 | +| ('girl', 'Cindy', 'SUM(num)') | 14149 | 1218 | +| ('girl', 'Cindy', 'MAX(num)') | 842 | 217 | +| ('girl', 'Dawn', 'SUM(num)') | 11403 | 5089 | +| ('girl', 'Dawn', 'MAX(num)') | 1157 | 461 | +| ('girl', 'Sophia', 'SUM(num)') | 18859 | 7181 | +| ('girl', 'Sophia', 'MAX(num)') | 2588 | 1187 | + """.strip() + ) + + # everything + pivoted = pivot_df( + df, + rows=["gender", "name"], + columns=["state"], + metrics=["SUM(num)", "MAX(num)"], + aggfunc="Sum", + transpose_pivot=True, + combine_metrics=True, + show_rows_total=True, + show_columns_total=True, + apply_metrics_on_rows=True, + ) + assert ( + pivoted.to_markdown() + == """ +| | ('boy', 'Edward') | ('boy', 'Tony') | ('boy', 'Subtotal') | ('girl', 'Amy') | ('girl', 'Cindy') | ('girl', 'Dawn') | ('girl', 'Sophia') | ('girl', 'Subtotal') | ('Total (Sum)', '') | +|:--------------------|--------------------:|------------------:|----------------------:|------------------:|--------------------:|-------------------:|---------------------:|-----------------------:|----------------------:| +| ('CA', 'SUM(num)') | 31290 | 3765 | 35055 | 45426 | 14149 | 11403 | 18859 | 89837 | 124892 | +| ('CA', 'MAX(num)') | 1280 | 598 | 1878 | 2227 | 842 | 1157 | 2588 | 6814 | 8692 | +| ('CA', 'Subtotal') | 32570 | 4363 | 36933 | 47653 | 14991 | 12560 | 21447 | 96651 | 133584 | +| ('FL', 'SUM(num)') | 9395 | 2673 | 12068 | 14740 | 1218 | 5089 | 7181 | 28228 | 40296 | +| ('FL', 'MAX(num)') | 389 | 247 | 636 | 854 | 217 | 461 | 1187 | 2719 | 3355 | +| ('FL', 'Subtotal') | 9784 | 2920 | 12704 | 15594 | 1435 | 5550 | 8368 | 30947 | 43651 | +| ('Total (Sum)', '') | 42354 | 7283 | 49637 | 63247 | 16426 | 18110 | 29815 | 127598 | 177235 | + """.strip() + ) + + # fraction + pivoted = pivot_df( + df, + rows=["gender", "name"], + columns=["state"], + metrics=["SUM(num)", "MAX(num)"], + aggfunc="Sum as Fraction of Columns", + transpose_pivot=False, + combine_metrics=False, + show_rows_total=False, + show_columns_total=True, + apply_metrics_on_rows=False, + ) + assert ( + pivoted.to_markdown() + == """ +| | ('SUM(num)', 'CA') | ('SUM(num)', 'FL') | ('MAX(num)', 'CA') | ('MAX(num)', 'FL') | +|:-------------------------------------------|---------------------:|---------------------:|---------------------:|---------------------:| +| ('boy', 'Edward') | 0.250536 | 0.23315 | 0.147262 | 0.115946 | +| ('boy', 'Tony') | 0.030146 | 0.0663341 | 0.0687989 | 0.0736215 | +| ('boy', 'Subtotal') | 0.280683 | 0.299484 | 0.216061 | 0.189568 | +| ('girl', 'Amy') | 0.363722 | 0.365793 | 0.256213 | 0.254545 | +| ('girl', 'Cindy') | 0.11329 | 0.0302263 | 0.0968707 | 0.0646796 | +| ('girl', 'Dawn') | 0.0913029 | 0.12629 | 0.133111 | 0.137407 | +| ('girl', 'Sophia') | 0.151002 | 0.178206 | 0.297745 | 0.3538 | +| ('girl', 'Subtotal') | 0.719317 | 0.700516 | 0.783939 | 0.810432 | +| ('Total (Sum as Fraction of Columns)', '') | 1 | 1 | 1 | 1 | + """.strip() + )