feat: send post-processed data in reports (#15953)

* feat: send post-processed data in reports

* Fix tests and lint

* Use enums

* Limit Slack message to 4k chars
This commit is contained in:
Beto Dealmeida 2021-07-30 09:37:16 -07:00 committed by GitHub
parent cc704dd53a
commit 2d61f15153
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 229 additions and 322 deletions

View File

@ -52,7 +52,7 @@ from superset.charts.commands.importers.dispatcher import ImportChartsCommand
from superset.charts.commands.update import UpdateChartCommand
from superset.charts.dao import ChartDAO
from superset.charts.filters import ChartAllTextFilter, ChartFavoriteFilter, ChartFilter
from superset.charts.post_processing import post_processors
from superset.charts.post_processing import apply_post_process
from superset.charts.schemas import (
CHART_SCHEMAS,
ChartPostSchema,
@ -482,10 +482,7 @@ class ChartRestApi(BaseSupersetModelRestApi):
return self.response_422(message=str(ex))
def send_chart_response(
self,
result: Dict[Any, Any],
viz_type: Optional[str] = None,
form_data: Optional[Dict[str, Any]] = None,
self, result: Dict[Any, Any], form_data: Optional[Dict[str, Any]] = None,
) -> Response:
result_type = result["query_context"].result_type
result_format = result["query_context"].result_format
@ -495,10 +492,9 @@ class ChartRestApi(BaseSupersetModelRestApi):
# post-processing of data, eg, the pivot table.
if (
result_type == ChartDataResultType.POST_PROCESSED
and viz_type in post_processors
and result_format == ChartDataResultFormat.CSV
):
post_process = post_processors[viz_type]
result = post_process(result, form_data)
result = apply_post_process(result, form_data)
if result_format == ChartDataResultFormat.CSV:
# Verify user has permission to export CSV file
@ -525,7 +521,6 @@ class ChartRestApi(BaseSupersetModelRestApi):
self,
command: ChartDataCommand,
force_cached: bool = False,
viz_type: Optional[str] = None,
form_data: Optional[Dict[str, Any]] = None,
) -> Response:
try:
@ -535,7 +530,7 @@ class ChartRestApi(BaseSupersetModelRestApi):
except ChartDataQueryFailedError as exc:
return self.response_400(message=exc.message)
return self.send_chart_response(result, viz_type, form_data)
return self.send_chart_response(result, form_data)
@expose("/<int:pk>/data/", methods=["GET"])
@protect()
@ -637,9 +632,7 @@ class ChartRestApi(BaseSupersetModelRestApi):
except (TypeError, json.decoder.JSONDecodeError):
form_data = {}
return self.get_data_response(
command, viz_type=chart.viz_type, form_data=form_data
)
return self.get_data_response(command, form_data=form_data)
@expose("/data", methods=["POST"])
@protect()

View File

@ -26,6 +26,7 @@ In order to do that, we reproduce the post-processing in Python
for these chart types.
"""
from io import StringIO
from typing import Any, Callable, Dict, Optional, Union
import pandas as pd
@ -40,66 +41,48 @@ def sql_like_sum(series: pd.Series) -> pd.Series:
return series.sum(min_count=1)
def pivot_table(
result: Dict[Any, Any], form_data: Optional[Dict[str, Any]] = None
) -> Dict[Any, Any]:
def pivot_table(df: pd.DataFrame, form_data: Dict[str, Any]) -> pd.DataFrame:
"""
Pivot table.
"""
for query in result["queries"]:
data = query["data"]
df = pd.DataFrame(data)
form_data = form_data or {}
if form_data.get("granularity") == "all" and DTTM_ALIAS in df:
del df[DTTM_ALIAS]
if form_data.get("granularity") == "all" and DTTM_ALIAS in df:
del df[DTTM_ALIAS]
metrics = [get_metric_name(m) for m in form_data["metrics"]]
aggfuncs: Dict[str, Union[str, Callable[[Any], Any]]] = {}
for metric in metrics:
aggfunc = form_data.get("pandas_aggfunc") or "sum"
if pd.api.types.is_numeric_dtype(df[metric]):
if aggfunc == "sum":
aggfunc = sql_like_sum
elif aggfunc not in {"min", "max"}:
aggfunc = "max"
aggfuncs[metric] = aggfunc
metrics = [get_metric_name(m) for m in form_data["metrics"]]
aggfuncs: Dict[str, Union[str, Callable[[Any], Any]]] = {}
for metric in metrics:
aggfunc = form_data.get("pandas_aggfunc") or "sum"
if pd.api.types.is_numeric_dtype(df[metric]):
if aggfunc == "sum":
aggfunc = sql_like_sum
elif aggfunc not in {"min", "max"}:
aggfunc = "max"
aggfuncs[metric] = aggfunc
groupby = form_data.get("groupby") or []
columns = form_data.get("columns") or []
if form_data.get("transpose_pivot"):
groupby, columns = columns, groupby
groupby = form_data.get("groupby") or []
columns = form_data.get("columns") or []
if form_data.get("transpose_pivot"):
groupby, columns = columns, groupby
df = df.pivot_table(
index=groupby,
columns=columns,
values=metrics,
aggfunc=aggfuncs,
margins=form_data.get("pivot_margins"),
)
df = df.pivot_table(
index=groupby,
columns=columns,
values=metrics,
aggfunc=aggfuncs,
margins=form_data.get("pivot_margins"),
)
# Re-order the columns adhering to the metric ordering.
df = df[metrics]
# Re-order the columns adhering to the metric ordering.
df = df[metrics]
# Display metrics side by side with each column
if form_data.get("combine_metric"):
df = df.stack(0).unstack().reindex(level=-1, columns=metrics)
# Display metrics side by side with each column
if form_data.get("combine_metric"):
df = df.stack(0).unstack().reindex(level=-1, columns=metrics)
# flatten column names
df.columns = [" ".join(column) for column in df.columns]
# flatten column names
df.columns = [" ".join(column) for column in df.columns]
# re-arrange data into a list of dicts
data = []
for i in df.index:
row = {col: df[col][i] for col in df.columns}
row[df.index.name] = i
data.append(row)
query["data"] = data
query["colnames"] = list(df.columns)
query["coltypes"] = extract_dataframe_dtypes(df)
query["rowcount"] = len(df.index)
return result
return df
def list_unique_values(series: pd.Series) -> str:
@ -134,88 +117,99 @@ pivot_v2_aggfunc_map = {
def pivot_table_v2( # pylint: disable=too-many-branches
result: Dict[Any, Any], form_data: Optional[Dict[str, Any]] = None,
) -> Dict[Any, Any]:
df: pd.DataFrame, form_data: Dict[str, Any]
) -> pd.DataFrame:
"""
Pivot table v2.
"""
for query in result["queries"]:
data = query["data"]
df = pd.DataFrame(data)
form_data = form_data or {}
if form_data.get("granularity_sqla") == "all" and DTTM_ALIAS in df:
del df[DTTM_ALIAS]
if form_data.get("granularity_sqla") == "all" and DTTM_ALIAS in df:
del df[DTTM_ALIAS]
# TODO (betodealmeida): implement metricsLayout
metrics = [get_metric_name(m) for m in form_data["metrics"]]
aggregate_function = form_data.get("aggregateFunction", "Sum")
groupby = form_data.get("groupbyRows") or []
columns = form_data.get("groupbyColumns") or []
if form_data.get("transposePivot"):
groupby, columns = columns, groupby
# TODO (betodealmeida): implement metricsLayout
metrics = [get_metric_name(m) for m in form_data["metrics"]]
aggregate_function = form_data.get("aggregateFunction", "Sum")
groupby = form_data.get("groupbyRows") or []
columns = form_data.get("groupbyColumns") or []
if form_data.get("transposePivot"):
groupby, columns = columns, groupby
df = df.pivot_table(
index=groupby,
columns=columns,
values=metrics,
aggfunc=pivot_v2_aggfunc_map[aggregate_function],
margins=True,
)
df = df.pivot_table(
index=groupby,
columns=columns,
values=metrics,
aggfunc=pivot_v2_aggfunc_map[aggregate_function],
margins=True,
)
# The pandas `pivot_table` method either brings both row/column
# totals, or none at all. We pass `margin=True` to get both, and
# remove any dimension that was not requests.
if not form_data.get("rowTotals"):
df.drop(df.columns[len(df.columns) - 1], axis=1, inplace=True)
if not form_data.get("colTotals"):
df = df[:-1]
# The pandas `pivot_table` method either brings both row/column
# totals, or none at all. We pass `margin=True` to get both, and
# remove any dimension that was not requests.
if not form_data.get("rowTotals"):
df.drop(df.columns[len(df.columns) - 1], axis=1, inplace=True)
if not form_data.get("colTotals"):
df = df[:-1]
# Compute fractions, if needed. If `colTotals` or `rowTotals` are
# present we need to adjust for including them in the sum
if aggregate_function.endswith(" as Fraction of Total"):
total = df.sum().sum()
df = df.astype(total.dtypes) / total
if form_data.get("colTotals"):
df *= 2
if form_data.get("rowTotals"):
df *= 2
elif aggregate_function.endswith(" as Fraction of Columns"):
total = df.sum(axis=0)
df = df.astype(total.dtypes).div(total, axis=1)
if form_data.get("colTotals"):
df *= 2
elif aggregate_function.endswith(" as Fraction of Rows"):
total = df.sum(axis=1)
df = df.astype(total.dtypes).div(total, axis=0)
if form_data.get("rowTotals"):
df *= 2
# Compute fractions, if needed. If `colTotals` or `rowTotals` are
# present we need to adjust for including them in the sum
if aggregate_function.endswith(" as Fraction of Total"):
total = df.sum().sum()
df = df.astype(total.dtypes) / total
if form_data.get("colTotals"):
df *= 2
if form_data.get("rowTotals"):
df *= 2
elif aggregate_function.endswith(" as Fraction of Columns"):
total = df.sum(axis=0)
df = df.astype(total.dtypes).div(total, axis=1)
if form_data.get("colTotals"):
df *= 2
elif aggregate_function.endswith(" as Fraction of Rows"):
total = df.sum(axis=1)
df = df.astype(total.dtypes).div(total, axis=0)
if form_data.get("rowTotals"):
df *= 2
# Re-order the columns adhering to the metric ordering.
df = df[metrics]
# Re-order the columns adhering to the metric ordering.
df = df[metrics]
# Display metrics side by side with each column
if form_data.get("combineMetric"):
df = df.stack(0).unstack().reindex(level=-1, columns=metrics)
# Display metrics side by side with each column
if form_data.get("combineMetric"):
df = df.stack(0).unstack().reindex(level=-1, columns=metrics)
# flatten column names
df.columns = [" ".join(column) for column in df.columns]
# flatten column names
df.columns = [" ".join(column) for column in df.columns]
# re-arrange data into a list of dicts
data = []
for i in df.index:
row = {col: df[col][i] for col in df.columns}
row[df.index.name] = i
data.append(row)
query["data"] = data
query["colnames"] = list(df.columns)
query["coltypes"] = extract_dataframe_dtypes(df)
query["rowcount"] = len(df.index)
return result
return df
post_processors = {
"pivot_table": pivot_table,
"pivot_table_v2": pivot_table_v2,
}
def apply_post_process(
result: Dict[Any, Any], form_data: Optional[Dict[str, Any]] = None,
) -> Dict[Any, Any]:
form_data = form_data or {}
viz_type = form_data.get("viz_type")
if viz_type not in post_processors:
return result
post_processor = post_processors[viz_type]
for query in result["queries"]:
df = pd.read_csv(StringIO(query["data"]))
processed_df = post_processor(df, form_data)
buf = StringIO()
processed_df.to_csv(buf)
buf.seek(0)
query["data"] = buf.getvalue()
query["colnames"] = list(processed_df.columns)
query["coltypes"] = extract_dataframe_dtypes(processed_df)
query["rowcount"] = len(processed_df.index)
return result

View File

@ -55,7 +55,7 @@ logger = logging.getLogger(__name__)
class Slice(
Model, AuditMixinNullable, ImportExportMixin
): # pylint: disable=too-many-public-methods
): # pylint: disable=too-many-public-methods, too-many-instance-attributes
"""A slice is essentially a report or a view on data"""

View File

@ -64,6 +64,7 @@ from superset.reports.notifications import create_notification
from superset.reports.notifications.base import NotificationContent
from superset.reports.notifications.exceptions import NotificationError
from superset.utils.celery import session_scope
from superset.utils.core import ChartDataResultFormat, ChartDataResultType
from superset.utils.csv import get_chart_csv_data
from superset.utils.screenshots import (
BaseScreenshot,
@ -146,7 +147,8 @@ class BaseReportState:
return get_url_path(
"ChartRestApi.get_data",
pk=self._report_schedule.chart_id,
format="csv",
format=ChartDataResultFormat.CSV.value,
type=ChartDataResultType.POST_PROCESSED.value,
)
return get_url_path(
"Superset.slice",

View File

@ -17,7 +17,6 @@
# under the License.
import json
import logging
import textwrap
from io import IOBase
from typing import Optional, Union
@ -34,8 +33,8 @@ from superset.reports.notifications.exceptions import NotificationError
logger = logging.getLogger(__name__)
# Slack only shows ~25 lines in the code block section
MAXIMUM_ROWS_IN_CODE_SECTION = 21
# Slack only allows Markdown messages up to 4k chars
MAXIMUM_MESSAGE_SIZE = 4000
class SlackNotification(BaseNotification): # pylint: disable=too-few-public-methods
@ -48,43 +47,7 @@ class SlackNotification(BaseNotification): # pylint: disable=too-few-public-met
def _get_channel(self) -> str:
return json.loads(self._recipient.recipient_config_json)["target"]
@staticmethod
def _error_template(name: str, description: str, text: str) -> str:
return textwrap.dedent(
__(
"""
*%(name)s*\n
%(description)s\n
Error: %(text)s
""",
name=name,
description=description,
text=text,
)
)
def _get_body(self) -> str:
if self._content.text:
return self._error_template(
self._content.name, self._content.description or "", self._content.text
)
# Convert Pandas dataframe into a nice ASCII table
if self._content.embedded_data is not None:
df = self._content.embedded_data
truncated = len(df) > MAXIMUM_ROWS_IN_CODE_SECTION
message = "(table was truncated)" if truncated else ""
if truncated:
df = df[:MAXIMUM_ROWS_IN_CODE_SECTION].fillna("")
# add a last row with '...' for values
df = df.append({k: "..." for k in df.columns}, ignore_index=True)
tabulated = tabulate(df, headers="keys", showindex=False)
table = f"```\n{tabulated}\n```\n\n{message}"
else:
table = ""
def _message_template(self, table: str = "") -> str:
return __(
"""*%(name)s*
@ -93,13 +56,70 @@ class SlackNotification(BaseNotification): # pylint: disable=too-few-public-met
<%(url)s|Explore in Superset>
%(table)s
""",
""",
name=self._content.name,
description=self._content.description or "",
url=self._content.url,
table=table,
)
@staticmethod
def _error_template(name: str, description: str, text: str) -> str:
return __(
"""*%(name)s*
%(description)s
Error: %(text)s
""",
name=name,
description=description,
text=text,
)
def _get_body(self) -> str:
if self._content.text:
return self._error_template(
self._content.name, self._content.description or "", self._content.text
)
if self._content.embedded_data is None:
return self._message_template()
# Embed data in the message
df = self._content.embedded_data
# Slack Markdown only works on messages shorter than 4k chars, so we might
# need to truncate the data
for i in range(len(df) - 1):
truncated_df = df[: i + 1].fillna("")
truncated_df = truncated_df.append(
{k: "..." for k in df.columns}, ignore_index=True
)
tabulated = tabulate(truncated_df, headers="keys", showindex=False)
table = f"```\n{tabulated}\n```\n\n(table was truncated)"
message = self._message_template(table)
if len(message) > MAXIMUM_MESSAGE_SIZE:
# Decrement i and build a message that is under the limit
truncated_df = df[:i].fillna("")
truncated_df = truncated_df.append(
{k: "..." for k in df.columns}, ignore_index=True
)
tabulated = tabulate(truncated_df, headers="keys", showindex=False)
table = (
f"```\n{tabulated}\n```\n\n(table was truncated)"
if len(truncated_df) > 0
else ""
)
break
# Send full data
else:
tabulated = tabulate(df, headers="keys", showindex=False)
table = f"```\n{tabulated}\n```"
return self._message_template(table)
def _get_inline_file(self) -> Optional[Union[str, IOBase, bytes]]:
if self._content.csv:
return self._content.csv

View File

@ -18,7 +18,7 @@
import copy
from typing import Any, Dict
from superset.charts.post_processing import pivot_table, pivot_table_v2
from superset.charts.post_processing import apply_post_process
from superset.utils.core import GenericDataType, QueryStatus
RESULT: Dict[str, Any] = {
@ -51,30 +51,30 @@ LIMIT 50000;
GenericDataType.STRING,
GenericDataType.NUMERIC,
],
"data": [
{"state": "OH", "gender": "boy", "Births": int("2376385")},
{"state": "TX", "gender": "girl", "Births": int("2313186")},
{"state": "MA", "gender": "boy", "Births": int("1285126")},
{"state": "MA", "gender": "girl", "Births": int("842146")},
{"state": "PA", "gender": "boy", "Births": int("2390275")},
{"state": "NY", "gender": "boy", "Births": int("3543961")},
{"state": "FL", "gender": "boy", "Births": int("1968060")},
{"state": "TX", "gender": "boy", "Births": int("3311985")},
{"state": "NJ", "gender": "boy", "Births": int("1486126")},
{"state": "CA", "gender": "girl", "Births": int("3567754")},
{"state": "CA", "gender": "boy", "Births": int("5430796")},
{"state": "IL", "gender": "girl", "Births": int("1614427")},
{"state": "FL", "gender": "girl", "Births": int("1312593")},
{"state": "NY", "gender": "girl", "Births": int("2280733")},
{"state": "NJ", "gender": "girl", "Births": int("992702")},
{"state": "MI", "gender": "girl", "Births": int("1326229")},
{"state": "other", "gender": "girl", "Births": int("15058341")},
{"state": "other", "gender": "boy", "Births": int("22044909")},
{"state": "MI", "gender": "boy", "Births": int("1938321")},
{"state": "IL", "gender": "boy", "Births": int("2357411")},
{"state": "PA", "gender": "girl", "Births": int("1615383")},
{"state": "OH", "gender": "girl", "Births": int("1622814")},
],
"data": """state,gender,Births
OH,boy,2376385
TX,girl,2313186
MA,boy,1285126
MA,girl,842146
PA,boy,2390275
NY,boy,3543961
FL,boy,1968060
TX,boy,3311985
NJ,boy,1486126
CA,girl,3567754
CA,boy,5430796
IL,girl,1614427
FL,girl,1312593
NY,girl,2280733
NJ,girl,992702
MI,girl,1326229
other,girl,15058341
other,boy,22044909
MI,boy,1938321
IL,boy,2357411
PA,girl,1615383
OH,girl,1622814
""",
"applied_filters": [],
"rejected_filters": [],
}
@ -113,7 +113,7 @@ def test_pivot_table():
"viz_type": "pivot_table",
}
result = copy.deepcopy(RESULT)
assert pivot_table(result, form_data) == {
assert apply_post_process(result, form_data) == {
"query_context": None,
"queries": [
{
@ -165,53 +165,11 @@ LIMIT 50000;
GenericDataType.NUMERIC,
GenericDataType.NUMERIC,
],
"data": [
{
"Births CA": 5430796,
"Births FL": 1968060,
"Births IL": 2357411,
"Births MA": 1285126,
"Births MI": 1938321,
"Births NJ": 1486126,
"Births NY": 3543961,
"Births OH": 2376385,
"Births PA": 2390275,
"Births TX": 3311985,
"Births other": 22044909,
"Births All": 48133355,
"gender": "boy",
},
{
"Births CA": 3567754,
"Births FL": 1312593,
"Births IL": 1614427,
"Births MA": 842146,
"Births MI": 1326229,
"Births NJ": 992702,
"Births NY": 2280733,
"Births OH": 1622814,
"Births PA": 1615383,
"Births TX": 2313186,
"Births other": 15058341,
"Births All": 32546308,
"gender": "girl",
},
{
"Births CA": 8998550,
"Births FL": 3280653,
"Births IL": 3971838,
"Births MA": 2127272,
"Births MI": 3264550,
"Births NJ": 2478828,
"Births NY": 5824694,
"Births OH": 3999199,
"Births PA": 4005658,
"Births TX": 5625171,
"Births other": 37103250,
"Births All": 80679663,
"gender": "All",
},
],
"data": """gender,Births CA,Births FL,Births IL,Births MA,Births MI,Births NJ,Births NY,Births OH,Births PA,Births TX,Births other,Births All
boy,5430796,1968060,2357411,1285126,1938321,1486126,3543961,2376385,2390275,3311985,22044909,48133355
girl,3567754,1312593,1614427,842146,1326229,992702,2280733,1622814,1615383,2313186,15058341,32546308
All,8998550,3280653,3971838,2127272,3264550,2478828,5824694,3999199,4005658,5625171,37103250,80679663
""",
"applied_filters": [],
"rejected_filters": [],
}
@ -255,7 +213,7 @@ def test_pivot_table_v2():
"viz_type": "pivot_table_v2",
}
result = copy.deepcopy(RESULT)
assert pivot_table_v2(result, form_data) == {
assert apply_post_process(result, form_data) == {
"query_context": None,
"queries": [
{
@ -285,80 +243,20 @@ LIMIT 50000;
GenericDataType.NUMERIC,
GenericDataType.NUMERIC,
],
"data": [
{
"All Births": 1.0,
"boy Births": 0.5965983645717509,
"girl Births": 0.40340163542824914,
"state": "All",
},
{
"All Births": 1.0,
"boy Births": 0.6035190113962805,
"girl Births": 0.3964809886037195,
"state": "CA",
},
{
"All Births": 1.0,
"boy Births": 0.5998988615985903,
"girl Births": 0.4001011384014097,
"state": "FL",
},
{
"All Births": 1.0,
"boy Births": 0.5935315085862012,
"girl Births": 0.40646849141379887,
"state": "IL",
},
{
"All Births": 1.0,
"boy Births": 0.6041192663655611,
"girl Births": 0.3958807336344389,
"state": "MA",
},
{
"All Births": 1.0,
"boy Births": 0.5937482960898133,
"girl Births": 0.4062517039101867,
"state": "MI",
},
{
"All Births": 1.0,
"boy Births": 0.5995276800165239,
"girl Births": 0.40047231998347604,
"state": "NJ",
},
{
"All Births": 1.0,
"boy Births": 0.6084372844307357,
"girl Births": 0.39156271556926425,
"state": "NY",
},
{
"All Births": 1.0,
"boy Births": 0.5942152416021308,
"girl Births": 0.40578475839786915,
"state": "OH",
},
{
"All Births": 1.0,
"boy Births": 0.596724682935987,
"girl Births": 0.40327531706401293,
"state": "PA",
},
{
"All Births": 1.0,
"boy Births": 0.5887794344385264,
"girl Births": 0.41122056556147357,
"state": "TX",
},
{
"All Births": 1.0,
"boy Births": 0.5941503507105172,
"girl Births": 0.40584964928948275,
"state": "other",
},
],
"data": """state,All Births,boy Births,girl Births
All,1.0,0.5965983645717509,0.40340163542824914
CA,1.0,0.6035190113962805,0.3964809886037195
FL,1.0,0.5998988615985903,0.4001011384014097
IL,1.0,0.5935315085862012,0.40646849141379887
MA,1.0,0.6041192663655611,0.3958807336344389
MI,1.0,0.5937482960898133,0.4062517039101867
NJ,1.0,0.5995276800165239,0.40047231998347604
NY,1.0,0.6084372844307357,0.39156271556926425
OH,1.0,0.5942152416021308,0.40578475839786915
PA,1.0,0.596724682935987,0.40327531706401293
TX,1.0,0.5887794344385264,0.41122056556147357
other,1.0,0.5941503507105172,0.40584964928948275
""",
"applied_filters": [],
"rejected_filters": [],
}