fix(dashboard): 500 error caused by data_for_slices API (#16053)

This commit is contained in:
Jesse Yang 2021-08-03 19:01:39 -07:00 committed by GitHub
parent 69c5cd7922
commit 490890de23
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 16 additions and 21 deletions

View File

@ -282,10 +282,9 @@ class BaseDatasource(
column_names = set()
for slc in slices:
form_data = slc.form_data
# pull out all required metrics from the form_data
for param in METRIC_FORM_DATA_PARAMS:
for metric in utils.get_iterable(form_data.get(param) or []):
for metric_param in METRIC_FORM_DATA_PARAMS:
for metric in utils.get_iterable(form_data.get(metric_param) or []):
metric_names.add(utils.get_metric_name(metric))
if utils.is_adhoc_metric(metric):
column_names.add(
@ -308,8 +307,8 @@ class BaseDatasource(
column_names.update(
column
for column in utils.get_iterable(form_data.get(param) or [])
for param in COLUMN_FORM_DATA_PARAMS
for column_param in COLUMN_FORM_DATA_PARAMS
for column in utils.get_iterable(form_data.get(column_param) or [])
)
filtered_metrics = [

View File

@ -1217,10 +1217,10 @@ def get_metric_name(metric: Metric) -> str:
def get_metric_names(metrics: Sequence[Metric]) -> List[str]:
return [get_metric_name(metric) for metric in metrics]
return [metric for metric in map(get_metric_name, metrics) if metric]
def get_main_metric_name(metrics: Sequence[Metric]) -> Optional[str]:
def get_first_metric_name(metrics: Sequence[Metric]) -> Optional[str]:
metric_labels = get_metric_names(metrics)
return metric_labels[0] if metric_labels else None
@ -1427,7 +1427,6 @@ def get_iterable(x: Any) -> List[Any]:
:param x: The object
:returns: An iterable representation
"""
return x if isinstance(x, list) else [x]
@ -1464,12 +1463,7 @@ def get_column_names_from_metrics(metrics: List[Metric]) -> List[str]:
:param metrics: Ad-hoc metric
:return: column name if simple metric, otherwise None
"""
columns: List[str] = []
for metric in metrics:
column_name = get_column_name_from_metric(metric)
if column_name:
columns.append(column_name)
return columns
return [col for col in map(get_column_name_from_metric, metrics) if col]
def extract_dataframe_dtypes(df: pd.DataFrame) -> List[GenericDataType]:

View File

@ -1230,7 +1230,7 @@ class NVD3TimeSeriesViz(NVD3Viz):
d = super().query_obj()
sort_by = self.form_data.get(
"timeseries_limit_metric"
) or utils.get_main_metric_name(d.get("metrics") or [])
) or utils.get_first_metric_name(d.get("metrics") or [])
is_asc = not self.form_data.get("order_desc")
if sort_by:
sort_by_label = utils.get_metric_name(sort_by)

View File

@ -497,13 +497,15 @@ class TestSqlaTableModel(SupersetTestCase):
slc = (
metadata_db.session.query(Slice)
.filter_by(
datasource_id=tbl.id,
datasource_type=tbl.type,
slice_name="Participants",
datasource_id=tbl.id, datasource_type=tbl.type, slice_name="Genders",
)
.first()
)
data_for_slices = tbl.data_for_slices([slc])
self.assertEqual(len(data_for_slices["columns"]), 0)
self.assertEqual(len(data_for_slices["metrics"]), 1)
self.assertEqual(len(data_for_slices["verbose_map"].keys()), 2)
assert len(data_for_slices["metrics"]) == 1
assert len(data_for_slices["columns"]) == 1
assert data_for_slices["metrics"][0]["metric_name"] == "sum__num"
assert data_for_slices["columns"][0]["column_name"] == "gender"
assert set(data_for_slices["verbose_map"].keys()) == set(
["__timestamp", "sum__num", "gender",]
)