feat: send post-processed data in reports (#15953)

* feat: send post-processed data in reports

* Fix tests and lint

* Use enums

* Limit Slack message to 4k chars
This commit is contained in:
Beto Dealmeida 2021-07-30 09:37:16 -07:00 committed by GitHub
parent cc704dd53a
commit 2d61f15153
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 229 additions and 322 deletions

View File

@ -52,7 +52,7 @@ from superset.charts.commands.importers.dispatcher import ImportChartsCommand
from superset.charts.commands.update import UpdateChartCommand from superset.charts.commands.update import UpdateChartCommand
from superset.charts.dao import ChartDAO from superset.charts.dao import ChartDAO
from superset.charts.filters import ChartAllTextFilter, ChartFavoriteFilter, ChartFilter from superset.charts.filters import ChartAllTextFilter, ChartFavoriteFilter, ChartFilter
from superset.charts.post_processing import post_processors from superset.charts.post_processing import apply_post_process
from superset.charts.schemas import ( from superset.charts.schemas import (
CHART_SCHEMAS, CHART_SCHEMAS,
ChartPostSchema, ChartPostSchema,
@ -482,10 +482,7 @@ class ChartRestApi(BaseSupersetModelRestApi):
return self.response_422(message=str(ex)) return self.response_422(message=str(ex))
def send_chart_response( def send_chart_response(
self, self, result: Dict[Any, Any], form_data: Optional[Dict[str, Any]] = None,
result: Dict[Any, Any],
viz_type: Optional[str] = None,
form_data: Optional[Dict[str, Any]] = None,
) -> Response: ) -> Response:
result_type = result["query_context"].result_type result_type = result["query_context"].result_type
result_format = result["query_context"].result_format result_format = result["query_context"].result_format
@ -495,10 +492,9 @@ class ChartRestApi(BaseSupersetModelRestApi):
# post-processing of data, eg, the pivot table. # post-processing of data, eg, the pivot table.
if ( if (
result_type == ChartDataResultType.POST_PROCESSED result_type == ChartDataResultType.POST_PROCESSED
and viz_type in post_processors and result_format == ChartDataResultFormat.CSV
): ):
post_process = post_processors[viz_type] result = apply_post_process(result, form_data)
result = post_process(result, form_data)
if result_format == ChartDataResultFormat.CSV: if result_format == ChartDataResultFormat.CSV:
# Verify user has permission to export CSV file # Verify user has permission to export CSV file
@ -525,7 +521,6 @@ class ChartRestApi(BaseSupersetModelRestApi):
self, self,
command: ChartDataCommand, command: ChartDataCommand,
force_cached: bool = False, force_cached: bool = False,
viz_type: Optional[str] = None,
form_data: Optional[Dict[str, Any]] = None, form_data: Optional[Dict[str, Any]] = None,
) -> Response: ) -> Response:
try: try:
@ -535,7 +530,7 @@ class ChartRestApi(BaseSupersetModelRestApi):
except ChartDataQueryFailedError as exc: except ChartDataQueryFailedError as exc:
return self.response_400(message=exc.message) return self.response_400(message=exc.message)
return self.send_chart_response(result, viz_type, form_data) return self.send_chart_response(result, form_data)
@expose("/<int:pk>/data/", methods=["GET"]) @expose("/<int:pk>/data/", methods=["GET"])
@protect() @protect()
@ -637,9 +632,7 @@ class ChartRestApi(BaseSupersetModelRestApi):
except (TypeError, json.decoder.JSONDecodeError): except (TypeError, json.decoder.JSONDecodeError):
form_data = {} form_data = {}
return self.get_data_response( return self.get_data_response(command, form_data=form_data)
command, viz_type=chart.viz_type, form_data=form_data
)
@expose("/data", methods=["POST"]) @expose("/data", methods=["POST"])
@protect() @protect()

View File

@ -26,6 +26,7 @@ In order to do that, we reproduce the post-processing in Python
for these chart types. for these chart types.
""" """
from io import StringIO
from typing import Any, Callable, Dict, Optional, Union from typing import Any, Callable, Dict, Optional, Union
import pandas as pd import pandas as pd
@ -40,17 +41,10 @@ def sql_like_sum(series: pd.Series) -> pd.Series:
return series.sum(min_count=1) return series.sum(min_count=1)
def pivot_table( def pivot_table(df: pd.DataFrame, form_data: Dict[str, Any]) -> pd.DataFrame:
result: Dict[Any, Any], form_data: Optional[Dict[str, Any]] = None
) -> Dict[Any, Any]:
""" """
Pivot table. Pivot table.
""" """
for query in result["queries"]:
data = query["data"]
df = pd.DataFrame(data)
form_data = form_data or {}
if form_data.get("granularity") == "all" and DTTM_ALIAS in df: if form_data.get("granularity") == "all" and DTTM_ALIAS in df:
del df[DTTM_ALIAS] del df[DTTM_ALIAS]
@ -88,18 +82,7 @@ def pivot_table(
# flatten column names # flatten column names
df.columns = [" ".join(column) for column in df.columns] df.columns = [" ".join(column) for column in df.columns]
# re-arrange data into a list of dicts return df
data = []
for i in df.index:
row = {col: df[col][i] for col in df.columns}
row[df.index.name] = i
data.append(row)
query["data"] = data
query["colnames"] = list(df.columns)
query["coltypes"] = extract_dataframe_dtypes(df)
query["rowcount"] = len(df.index)
return result
def list_unique_values(series: pd.Series) -> str: def list_unique_values(series: pd.Series) -> str:
@ -134,16 +117,11 @@ pivot_v2_aggfunc_map = {
def pivot_table_v2( # pylint: disable=too-many-branches def pivot_table_v2( # pylint: disable=too-many-branches
result: Dict[Any, Any], form_data: Optional[Dict[str, Any]] = None, df: pd.DataFrame, form_data: Dict[str, Any]
) -> Dict[Any, Any]: ) -> pd.DataFrame:
""" """
Pivot table v2. Pivot table v2.
""" """
for query in result["queries"]:
data = query["data"]
df = pd.DataFrame(data)
form_data = form_data or {}
if form_data.get("granularity_sqla") == "all" and DTTM_ALIAS in df: if form_data.get("granularity_sqla") == "all" and DTTM_ALIAS in df:
del df[DTTM_ALIAS] del df[DTTM_ALIAS]
@ -201,21 +179,37 @@ def pivot_table_v2( # pylint: disable=too-many-branches
# flatten column names # flatten column names
df.columns = [" ".join(column) for column in df.columns] df.columns = [" ".join(column) for column in df.columns]
# re-arrange data into a list of dicts return df
data = []
for i in df.index:
row = {col: df[col][i] for col in df.columns}
row[df.index.name] = i
data.append(row)
query["data"] = data
query["colnames"] = list(df.columns)
query["coltypes"] = extract_dataframe_dtypes(df)
query["rowcount"] = len(df.index)
return result
post_processors = { post_processors = {
"pivot_table": pivot_table, "pivot_table": pivot_table,
"pivot_table_v2": pivot_table_v2, "pivot_table_v2": pivot_table_v2,
} }
def apply_post_process(
result: Dict[Any, Any], form_data: Optional[Dict[str, Any]] = None,
) -> Dict[Any, Any]:
form_data = form_data or {}
viz_type = form_data.get("viz_type")
if viz_type not in post_processors:
return result
post_processor = post_processors[viz_type]
for query in result["queries"]:
df = pd.read_csv(StringIO(query["data"]))
processed_df = post_processor(df, form_data)
buf = StringIO()
processed_df.to_csv(buf)
buf.seek(0)
query["data"] = buf.getvalue()
query["colnames"] = list(processed_df.columns)
query["coltypes"] = extract_dataframe_dtypes(processed_df)
query["rowcount"] = len(processed_df.index)
return result

View File

@ -55,7 +55,7 @@ logger = logging.getLogger(__name__)
class Slice( class Slice(
Model, AuditMixinNullable, ImportExportMixin Model, AuditMixinNullable, ImportExportMixin
): # pylint: disable=too-many-public-methods ): # pylint: disable=too-many-public-methods, too-many-instance-attributes
"""A slice is essentially a report or a view on data""" """A slice is essentially a report or a view on data"""

View File

@ -64,6 +64,7 @@ from superset.reports.notifications import create_notification
from superset.reports.notifications.base import NotificationContent from superset.reports.notifications.base import NotificationContent
from superset.reports.notifications.exceptions import NotificationError from superset.reports.notifications.exceptions import NotificationError
from superset.utils.celery import session_scope from superset.utils.celery import session_scope
from superset.utils.core import ChartDataResultFormat, ChartDataResultType
from superset.utils.csv import get_chart_csv_data from superset.utils.csv import get_chart_csv_data
from superset.utils.screenshots import ( from superset.utils.screenshots import (
BaseScreenshot, BaseScreenshot,
@ -146,7 +147,8 @@ class BaseReportState:
return get_url_path( return get_url_path(
"ChartRestApi.get_data", "ChartRestApi.get_data",
pk=self._report_schedule.chart_id, pk=self._report_schedule.chart_id,
format="csv", format=ChartDataResultFormat.CSV.value,
type=ChartDataResultType.POST_PROCESSED.value,
) )
return get_url_path( return get_url_path(
"Superset.slice", "Superset.slice",

View File

@ -17,7 +17,6 @@
# under the License. # under the License.
import json import json
import logging import logging
import textwrap
from io import IOBase from io import IOBase
from typing import Optional, Union from typing import Optional, Union
@ -34,8 +33,8 @@ from superset.reports.notifications.exceptions import NotificationError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Slack only shows ~25 lines in the code block section # Slack only allows Markdown messages up to 4k chars
MAXIMUM_ROWS_IN_CODE_SECTION = 21 MAXIMUM_MESSAGE_SIZE = 4000
class SlackNotification(BaseNotification): # pylint: disable=too-few-public-methods class SlackNotification(BaseNotification): # pylint: disable=too-few-public-methods
@ -48,43 +47,7 @@ class SlackNotification(BaseNotification): # pylint: disable=too-few-public-met
def _get_channel(self) -> str: def _get_channel(self) -> str:
return json.loads(self._recipient.recipient_config_json)["target"] return json.loads(self._recipient.recipient_config_json)["target"]
@staticmethod def _message_template(self, table: str = "") -> str:
def _error_template(name: str, description: str, text: str) -> str:
return textwrap.dedent(
__(
"""
*%(name)s*\n
%(description)s\n
Error: %(text)s
""",
name=name,
description=description,
text=text,
)
)
def _get_body(self) -> str:
if self._content.text:
return self._error_template(
self._content.name, self._content.description or "", self._content.text
)
# Convert Pandas dataframe into a nice ASCII table
if self._content.embedded_data is not None:
df = self._content.embedded_data
truncated = len(df) > MAXIMUM_ROWS_IN_CODE_SECTION
message = "(table was truncated)" if truncated else ""
if truncated:
df = df[:MAXIMUM_ROWS_IN_CODE_SECTION].fillna("")
# add a last row with '...' for values
df = df.append({k: "..." for k in df.columns}, ignore_index=True)
tabulated = tabulate(df, headers="keys", showindex=False)
table = f"```\n{tabulated}\n```\n\n{message}"
else:
table = ""
return __( return __(
"""*%(name)s* """*%(name)s*
@ -100,6 +63,63 @@ class SlackNotification(BaseNotification): # pylint: disable=too-few-public-met
table=table, table=table,
) )
@staticmethod
def _error_template(name: str, description: str, text: str) -> str:
return __(
"""*%(name)s*
%(description)s
Error: %(text)s
""",
name=name,
description=description,
text=text,
)
def _get_body(self) -> str:
if self._content.text:
return self._error_template(
self._content.name, self._content.description or "", self._content.text
)
if self._content.embedded_data is None:
return self._message_template()
# Embed data in the message
df = self._content.embedded_data
# Slack Markdown only works on messages shorter than 4k chars, so we might
# need to truncate the data
for i in range(len(df) - 1):
truncated_df = df[: i + 1].fillna("")
truncated_df = truncated_df.append(
{k: "..." for k in df.columns}, ignore_index=True
)
tabulated = tabulate(truncated_df, headers="keys", showindex=False)
table = f"```\n{tabulated}\n```\n\n(table was truncated)"
message = self._message_template(table)
if len(message) > MAXIMUM_MESSAGE_SIZE:
# Decrement i and build a message that is under the limit
truncated_df = df[:i].fillna("")
truncated_df = truncated_df.append(
{k: "..." for k in df.columns}, ignore_index=True
)
tabulated = tabulate(truncated_df, headers="keys", showindex=False)
table = (
f"```\n{tabulated}\n```\n\n(table was truncated)"
if len(truncated_df) > 0
else ""
)
break
# Send full data
else:
tabulated = tabulate(df, headers="keys", showindex=False)
table = f"```\n{tabulated}\n```"
return self._message_template(table)
def _get_inline_file(self) -> Optional[Union[str, IOBase, bytes]]: def _get_inline_file(self) -> Optional[Union[str, IOBase, bytes]]:
if self._content.csv: if self._content.csv:
return self._content.csv return self._content.csv

View File

@ -18,7 +18,7 @@
import copy import copy
from typing import Any, Dict from typing import Any, Dict
from superset.charts.post_processing import pivot_table, pivot_table_v2 from superset.charts.post_processing import apply_post_process
from superset.utils.core import GenericDataType, QueryStatus from superset.utils.core import GenericDataType, QueryStatus
RESULT: Dict[str, Any] = { RESULT: Dict[str, Any] = {
@ -51,30 +51,30 @@ LIMIT 50000;
GenericDataType.STRING, GenericDataType.STRING,
GenericDataType.NUMERIC, GenericDataType.NUMERIC,
], ],
"data": [ "data": """state,gender,Births
{"state": "OH", "gender": "boy", "Births": int("2376385")}, OH,boy,2376385
{"state": "TX", "gender": "girl", "Births": int("2313186")}, TX,girl,2313186
{"state": "MA", "gender": "boy", "Births": int("1285126")}, MA,boy,1285126
{"state": "MA", "gender": "girl", "Births": int("842146")}, MA,girl,842146
{"state": "PA", "gender": "boy", "Births": int("2390275")}, PA,boy,2390275
{"state": "NY", "gender": "boy", "Births": int("3543961")}, NY,boy,3543961
{"state": "FL", "gender": "boy", "Births": int("1968060")}, FL,boy,1968060
{"state": "TX", "gender": "boy", "Births": int("3311985")}, TX,boy,3311985
{"state": "NJ", "gender": "boy", "Births": int("1486126")}, NJ,boy,1486126
{"state": "CA", "gender": "girl", "Births": int("3567754")}, CA,girl,3567754
{"state": "CA", "gender": "boy", "Births": int("5430796")}, CA,boy,5430796
{"state": "IL", "gender": "girl", "Births": int("1614427")}, IL,girl,1614427
{"state": "FL", "gender": "girl", "Births": int("1312593")}, FL,girl,1312593
{"state": "NY", "gender": "girl", "Births": int("2280733")}, NY,girl,2280733
{"state": "NJ", "gender": "girl", "Births": int("992702")}, NJ,girl,992702
{"state": "MI", "gender": "girl", "Births": int("1326229")}, MI,girl,1326229
{"state": "other", "gender": "girl", "Births": int("15058341")}, other,girl,15058341
{"state": "other", "gender": "boy", "Births": int("22044909")}, other,boy,22044909
{"state": "MI", "gender": "boy", "Births": int("1938321")}, MI,boy,1938321
{"state": "IL", "gender": "boy", "Births": int("2357411")}, IL,boy,2357411
{"state": "PA", "gender": "girl", "Births": int("1615383")}, PA,girl,1615383
{"state": "OH", "gender": "girl", "Births": int("1622814")}, OH,girl,1622814
], """,
"applied_filters": [], "applied_filters": [],
"rejected_filters": [], "rejected_filters": [],
} }
@ -113,7 +113,7 @@ def test_pivot_table():
"viz_type": "pivot_table", "viz_type": "pivot_table",
} }
result = copy.deepcopy(RESULT) result = copy.deepcopy(RESULT)
assert pivot_table(result, form_data) == { assert apply_post_process(result, form_data) == {
"query_context": None, "query_context": None,
"queries": [ "queries": [
{ {
@ -165,53 +165,11 @@ LIMIT 50000;
GenericDataType.NUMERIC, GenericDataType.NUMERIC,
GenericDataType.NUMERIC, GenericDataType.NUMERIC,
], ],
"data": [ "data": """gender,Births CA,Births FL,Births IL,Births MA,Births MI,Births NJ,Births NY,Births OH,Births PA,Births TX,Births other,Births All
{ boy,5430796,1968060,2357411,1285126,1938321,1486126,3543961,2376385,2390275,3311985,22044909,48133355
"Births CA": 5430796, girl,3567754,1312593,1614427,842146,1326229,992702,2280733,1622814,1615383,2313186,15058341,32546308
"Births FL": 1968060, All,8998550,3280653,3971838,2127272,3264550,2478828,5824694,3999199,4005658,5625171,37103250,80679663
"Births IL": 2357411, """,
"Births MA": 1285126,
"Births MI": 1938321,
"Births NJ": 1486126,
"Births NY": 3543961,
"Births OH": 2376385,
"Births PA": 2390275,
"Births TX": 3311985,
"Births other": 22044909,
"Births All": 48133355,
"gender": "boy",
},
{
"Births CA": 3567754,
"Births FL": 1312593,
"Births IL": 1614427,
"Births MA": 842146,
"Births MI": 1326229,
"Births NJ": 992702,
"Births NY": 2280733,
"Births OH": 1622814,
"Births PA": 1615383,
"Births TX": 2313186,
"Births other": 15058341,
"Births All": 32546308,
"gender": "girl",
},
{
"Births CA": 8998550,
"Births FL": 3280653,
"Births IL": 3971838,
"Births MA": 2127272,
"Births MI": 3264550,
"Births NJ": 2478828,
"Births NY": 5824694,
"Births OH": 3999199,
"Births PA": 4005658,
"Births TX": 5625171,
"Births other": 37103250,
"Births All": 80679663,
"gender": "All",
},
],
"applied_filters": [], "applied_filters": [],
"rejected_filters": [], "rejected_filters": [],
} }
@ -255,7 +213,7 @@ def test_pivot_table_v2():
"viz_type": "pivot_table_v2", "viz_type": "pivot_table_v2",
} }
result = copy.deepcopy(RESULT) result = copy.deepcopy(RESULT)
assert pivot_table_v2(result, form_data) == { assert apply_post_process(result, form_data) == {
"query_context": None, "query_context": None,
"queries": [ "queries": [
{ {
@ -285,80 +243,20 @@ LIMIT 50000;
GenericDataType.NUMERIC, GenericDataType.NUMERIC,
GenericDataType.NUMERIC, GenericDataType.NUMERIC,
], ],
"data": [ "data": """state,All Births,boy Births,girl Births
{ All,1.0,0.5965983645717509,0.40340163542824914
"All Births": 1.0, CA,1.0,0.6035190113962805,0.3964809886037195
"boy Births": 0.5965983645717509, FL,1.0,0.5998988615985903,0.4001011384014097
"girl Births": 0.40340163542824914, IL,1.0,0.5935315085862012,0.40646849141379887
"state": "All", MA,1.0,0.6041192663655611,0.3958807336344389
}, MI,1.0,0.5937482960898133,0.4062517039101867
{ NJ,1.0,0.5995276800165239,0.40047231998347604
"All Births": 1.0, NY,1.0,0.6084372844307357,0.39156271556926425
"boy Births": 0.6035190113962805, OH,1.0,0.5942152416021308,0.40578475839786915
"girl Births": 0.3964809886037195, PA,1.0,0.596724682935987,0.40327531706401293
"state": "CA", TX,1.0,0.5887794344385264,0.41122056556147357
}, other,1.0,0.5941503507105172,0.40584964928948275
{ """,
"All Births": 1.0,
"boy Births": 0.5998988615985903,
"girl Births": 0.4001011384014097,
"state": "FL",
},
{
"All Births": 1.0,
"boy Births": 0.5935315085862012,
"girl Births": 0.40646849141379887,
"state": "IL",
},
{
"All Births": 1.0,
"boy Births": 0.6041192663655611,
"girl Births": 0.3958807336344389,
"state": "MA",
},
{
"All Births": 1.0,
"boy Births": 0.5937482960898133,
"girl Births": 0.4062517039101867,
"state": "MI",
},
{
"All Births": 1.0,
"boy Births": 0.5995276800165239,
"girl Births": 0.40047231998347604,
"state": "NJ",
},
{
"All Births": 1.0,
"boy Births": 0.6084372844307357,
"girl Births": 0.39156271556926425,
"state": "NY",
},
{
"All Births": 1.0,
"boy Births": 0.5942152416021308,
"girl Births": 0.40578475839786915,
"state": "OH",
},
{
"All Births": 1.0,
"boy Births": 0.596724682935987,
"girl Births": 0.40327531706401293,
"state": "PA",
},
{
"All Births": 1.0,
"boy Births": 0.5887794344385264,
"girl Births": 0.41122056556147357,
"state": "TX",
},
{
"All Births": 1.0,
"boy Births": 0.5941503507105172,
"girl Births": 0.40584964928948275,
"state": "other",
},
],
"applied_filters": [], "applied_filters": [],
"rejected_filters": [], "rejected_filters": [],
} }