fix: send CSV pivoted in reports (#16347)

This commit is contained in:
Beto Dealmeida 2021-08-18 19:36:48 -07:00 committed by GitHub
parent afb8bd5fe6
commit ec8d3b03e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 28 additions and 9 deletions

View File

@ -499,6 +499,12 @@ class ChartRestApi(BaseSupersetModelRestApi):
result_type = result["query_context"].result_type result_type = result["query_context"].result_type
result_format = result["query_context"].result_format result_format = result["query_context"].result_format
# Post-process the data so it matches the data presented in the chart.
# This is needed for sending reports based on text charts that do the
# post-processing of data, eg, the pivot table.
if result_type == ChartDataResultType.POST_PROCESSED:
result = apply_post_process(result, form_data)
if result_format == ChartDataResultFormat.CSV: if result_format == ChartDataResultFormat.CSV:
# Verify user has permission to export CSV file # Verify user has permission to export CSV file
if not security_manager.can_access("can_csv", "Superset"): if not security_manager.can_access("can_csv", "Superset"):
@ -509,12 +515,6 @@ class ChartRestApi(BaseSupersetModelRestApi):
return CsvResponse(data, headers=generate_download_headers("csv")) return CsvResponse(data, headers=generate_download_headers("csv"))
if result_format == ChartDataResultFormat.JSON: if result_format == ChartDataResultFormat.JSON:
# Post-process the data so it matches the data presented in the chart.
# This is needed for sending reports based on text charts that do the
# post-processing of data, eg, the pivot table.
if result_type == ChartDataResultType.POST_PROCESSED:
result = apply_post_process(result, form_data)
response_data = simplejson.dumps( response_data = simplejson.dumps(
{"result": result["queries"]}, {"result": result["queries"]},
default=json_int_dttm_ser, default=json_int_dttm_ser,

View File

@ -26,11 +26,17 @@ In order to do that, we reproduce the post-processing in Python
for these chart types. for these chart types.
""" """
from io import StringIO
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import pandas as pd import pandas as pd
from superset.utils.core import DTTM_ALIAS, extract_dataframe_dtypes, get_metric_name from superset.utils.core import (
ChartDataResultFormat,
DTTM_ALIAS,
extract_dataframe_dtypes,
get_metric_name,
)
def get_column_key(label: Tuple[str, ...], metrics: List[str]) -> Tuple[Any, ...]: def get_column_key(label: Tuple[str, ...], metrics: List[str]) -> Tuple[Any, ...]:
@ -276,7 +282,13 @@ def apply_post_process(
post_processor = post_processors[viz_type] post_processor = post_processors[viz_type]
for query in result["queries"]: for query in result["queries"]:
df = pd.DataFrame.from_dict(query["data"]) if query["result_format"] == ChartDataResultFormat.JSON:
df = pd.DataFrame.from_dict(query["data"])
elif query["result_format"] == ChartDataResultFormat.CSV:
df = pd.read_csv(StringIO(query["data"]))
else:
raise Exception(f"Result format {query['result_format']} not supported")
processed_df = post_processor(df, form_data) processed_df = post_processor(df, form_data)
query["colnames"] = list(processed_df.columns) query["colnames"] = list(processed_df.columns)
@ -298,6 +310,12 @@ def apply_post_process(
for index in processed_df.index for index in processed_df.index
] ]
query["data"] = processed_df.to_dict() if query["result_format"] == ChartDataResultFormat.JSON:
query["data"] = processed_df.to_dict()
elif query["result_format"] == ChartDataResultFormat.CSV:
buf = StringIO()
processed_df.to_csv(buf)
buf.seek(0)
query["data"] = buf.getvalue()
return result return result

View File

@ -104,6 +104,7 @@ def _get_full(
payload["indexnames"] = list(df.index) payload["indexnames"] = list(df.index)
payload["coltypes"] = extract_dataframe_dtypes(df) payload["coltypes"] = extract_dataframe_dtypes(df)
payload["data"] = query_context.get_data(df) payload["data"] = query_context.get_data(df)
payload["result_format"] = query_context.result_format
del payload["df"] del payload["df"]
filters = query_obj.filter filters = query_obj.filter