From 11bf7b9125eefd93796a46d964c3f027fbc9ce4d Mon Sep 17 00:00:00 2001 From: Yongjie Zhao Date: Mon, 22 Aug 2022 21:00:02 +0800 Subject: [PATCH] feat: generate label map on the backend (#21124) --- superset/common/query_context_processor.py | 14 ++++++ .../utils/pandas_postprocessing/__init__.py | 6 +++ .../utils/pandas_postprocessing/flatten.py | 5 ++- superset/utils/pandas_postprocessing/utils.py | 10 +++++ tests/integration_tests/conftest.py | 27 +++++++++++ .../integration_tests/query_context_tests.py | 45 +++++++++++++++++++ .../pandas_postprocessing/test_flatten.py | 19 ++++++++ .../pandas_postprocessing/test_utils.py | 30 +++++++++++++ 8 files changed, 154 insertions(+), 2 deletions(-) create mode 100644 tests/unit_tests/pandas_postprocessing/test_utils.py diff --git a/superset/common/query_context_processor.py b/superset/common/query_context_processor.py index 2978eeace4..b253caa6b9 100644 --- a/superset/common/query_context_processor.py +++ b/superset/common/query_context_processor.py @@ -18,6 +18,7 @@ from __future__ import annotations import copy import logging +import re from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Union import numpy as np @@ -57,6 +58,7 @@ from superset.utils.core import ( TIME_COMPARISON, ) from superset.utils.date_parser import get_past_or_future, normalize_time_delta +from superset.utils.pandas_postprocessing.utils import unescape_separator from superset.views.utils import get_viz if TYPE_CHECKING: @@ -142,6 +144,17 @@ class QueryContextProcessor: cache.error_message = str(ex) cache.status = QueryStatus.FAILED + # the N-dimensional DataFrame has converteds into flat DataFrame + # by `flatten operator`, "comma" in the column is escaped by `escape_separator` + # the result DataFrame columns should be unescaped + label_map = { + unescape_separator(col): [ + unescape_separator(col) for col in re.split(r"(? Optional[str]: diff --git a/superset/utils/pandas_postprocessing/__init__.py b/superset/utils/pandas_postprocessing/__init__.py index 7902cb3232..e66a52f655 100644 --- a/superset/utils/pandas_postprocessing/__init__.py +++ b/superset/utils/pandas_postprocessing/__init__.py @@ -33,6 +33,10 @@ from superset.utils.pandas_postprocessing.resample import resample from superset.utils.pandas_postprocessing.rolling import rolling from superset.utils.pandas_postprocessing.select import select from superset.utils.pandas_postprocessing.sort import sort +from superset.utils.pandas_postprocessing.utils import ( + escape_separator, + unescape_separator, +) __all__ = [ "aggregate", @@ -52,4 +56,6 @@ __all__ = [ "select", "sort", "flatten", + "escape_separator", + "unescape_separator", ] diff --git a/superset/utils/pandas_postprocessing/flatten.py b/superset/utils/pandas_postprocessing/flatten.py index 2874ac5797..db783c4bed 100644 --- a/superset/utils/pandas_postprocessing/flatten.py +++ b/superset/utils/pandas_postprocessing/flatten.py @@ -22,6 +22,7 @@ from numpy.distutils.misc_util import is_sequence from superset.utils.pandas_postprocessing.utils import ( _is_multi_index_on_columns, + escape_separator, FLAT_COLUMN_SEPARATOR, ) @@ -86,8 +87,8 @@ def flatten( _cells = [] for cell in series if is_sequence(series) else [series]: if pd.notnull(cell): - # every cell should be converted to string - _cells.append(str(cell)) + # every cell should be converted to string and escape comma + _cells.append(escape_separator(str(cell))) _columns.append(FLAT_COLUMN_SEPARATOR.join(_cells)) df.columns = _columns diff --git a/superset/utils/pandas_postprocessing/utils.py b/superset/utils/pandas_postprocessing/utils.py index 3d14f643c5..bff62dcb64 100644 --- a/superset/utils/pandas_postprocessing/utils.py +++ b/superset/utils/pandas_postprocessing/utils.py @@ -198,3 +198,13 @@ def _append_columns( return _base_df append_df = append_df.rename(columns=columns) return pd.concat([base_df, append_df], axis="columns") + + +def escape_separator(plain_str: str, sep: str = FLAT_COLUMN_SEPARATOR) -> str: + char = sep.strip() + return plain_str.replace(char, "\\" + char) + + +def unescape_separator(escaped_str: str, sep: str = FLAT_COLUMN_SEPARATOR) -> str: + char = sep.strip() + return escaped_str.replace("\\" + char, char) diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py index 549a987db1..aaa156b5b4 100644 --- a/tests/integration_tests/conftest.py +++ b/tests/integration_tests/conftest.py @@ -358,3 +358,30 @@ def physical_dataset(): for ds in dataset: db.session.delete(ds) db.session.commit() + + +@pytest.fixture +def virtual_dataset_comma_in_column_value(): + from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn + + dataset = SqlaTable( + table_name="virtual_dataset", + sql=( + "SELECT 'col1,row1' as col1, 'col2, row1' as col2 " + "UNION ALL " + "SELECT 'col1,row2' as col1, 'col2, row2' as col2 " + "UNION ALL " + "SELECT 'col1,row3' as col1, 'col2, row3' as col2 " + ), + database=get_example_database(), + ) + TableColumn(column_name="col1", type="VARCHAR(255)", table=dataset) + TableColumn(column_name="col2", type="VARCHAR(255)", table=dataset) + + SqlMetric(metric_name="count", expression="count(*)", table=dataset) + db.session.merge(dataset) + + yield dataset + + db.session.delete(dataset) + db.session.commit() diff --git a/tests/integration_tests/query_context_tests.py b/tests/integration_tests/query_context_tests.py index abd5d2be8b..b17072f6bc 100644 --- a/tests/integration_tests/query_context_tests.py +++ b/tests/integration_tests/query_context_tests.py @@ -25,6 +25,7 @@ from superset import db from superset.charts.schemas import ChartDataQueryContextSchema from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType from superset.common.query_context import QueryContext +from superset.common.query_context_factory import QueryContextFactory from superset.common.query_object import QueryObject from superset.connectors.sqla.models import SqlMetric from superset.datasource.dao import DatasourceDAO @@ -35,6 +36,7 @@ from superset.utils.core import ( DatasourceType, QueryStatus, ) +from superset.utils.pandas_postprocessing.utils import FLAT_COLUMN_SEPARATOR from tests.integration_tests.base_tests import SupersetTestCase from tests.integration_tests.fixtures.birth_names_dashboard import ( load_birth_names_dashboard_with_slices, @@ -683,3 +685,46 @@ class TestQueryContext(SupersetTestCase): row["sum__num__3 years later"] == df_3_years_later.loc[index]["sum__num"] ) + + +def test_get_label_map(app_context, virtual_dataset_comma_in_column_value): + qc = QueryContextFactory().create( + datasource={ + "type": virtual_dataset_comma_in_column_value.type, + "id": virtual_dataset_comma_in_column_value.id, + }, + queries=[ + { + "columns": ["col1", "col2"], + "metrics": ["count"], + "post_processing": [ + { + "operation": "pivot", + "options": { + "aggregates": {"count": {"operator": "mean"}}, + "columns": ["col2"], + "index": ["col1"], + }, + }, + {"operation": "flatten"}, + ], + } + ], + result_type=ChartDataResultType.FULL, + force=True, + ) + query_object = qc.queries[0] + df = qc.get_df_payload(query_object)["df"] + label_map = qc.get_df_payload(query_object)["label_map"] + assert list(df.columns.values) == [ + "col1", + "count" + FLAT_COLUMN_SEPARATOR + "col2, row1", + "count" + FLAT_COLUMN_SEPARATOR + "col2, row2", + "count" + FLAT_COLUMN_SEPARATOR + "col2, row3", + ] + assert label_map == { + "col1": ["col1"], + "count, col2, row1": ["count", "col2, row1"], + "count, col2, row2": ["count", "col2, row2"], + "count, col2, row3": ["count", "col2, row3"], + } diff --git a/tests/unit_tests/pandas_postprocessing/test_flatten.py b/tests/unit_tests/pandas_postprocessing/test_flatten.py index 78a2e3eea4..fea84f7b9f 100644 --- a/tests/unit_tests/pandas_postprocessing/test_flatten.py +++ b/tests/unit_tests/pandas_postprocessing/test_flatten.py @@ -156,3 +156,22 @@ def test_flat_integer_column_name(): } ) ) + + +def test_escape_column_name(): + index = pd.to_datetime(["2021-01-01", "2021-01-02", "2021-01-03"]) + index.name = "__timestamp" + columns = pd.MultiIndex.from_arrays( + [ + ["level1,value1", "level1,value2", "level1,value3"], + ["level2, value1", "level2, value2", "level2, value3"], + ], + names=["level1", "level2"], + ) + df = pd.DataFrame(index=index, columns=columns, data=1) + assert list(pp.flatten(df).columns.values) == [ + "__timestamp", + "level1\\,value1" + FLAT_COLUMN_SEPARATOR + "level2\\, value1", + "level1\\,value2" + FLAT_COLUMN_SEPARATOR + "level2\\, value2", + "level1\\,value3" + FLAT_COLUMN_SEPARATOR + "level2\\, value3", + ] diff --git a/tests/unit_tests/pandas_postprocessing/test_utils.py b/tests/unit_tests/pandas_postprocessing/test_utils.py new file mode 100644 index 0000000000..058cefcd6c --- /dev/null +++ b/tests/unit_tests/pandas_postprocessing/test_utils.py @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from superset.utils.pandas_postprocessing import escape_separator, unescape_separator + + +def test_escape_separator(): + assert escape_separator(r" hell \world ") == r" hell \world " + assert unescape_separator(r" hell \world ") == r" hell \world " + + escape_string = escape_separator("hello, world") + assert escape_string == r"hello\, world" + assert unescape_separator(escape_string) == "hello, world" + + escape_string = escape_separator("hello,world") + assert escape_string == r"hello\,world" + assert unescape_separator(escape_string) == "hello,world"