feat: generate label map on the backend (#21124)

This commit is contained in:
Yongjie Zhao 2022-08-22 21:00:02 +08:00 committed by GitHub
parent 756ed0e36a
commit 11bf7b9125
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 154 additions and 2 deletions

View File

@ -18,6 +18,7 @@ from __future__ import annotations
import copy
import logging
import re
from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Union
import numpy as np
@ -57,6 +58,7 @@ from superset.utils.core import (
TIME_COMPARISON,
)
from superset.utils.date_parser import get_past_or_future, normalize_time_delta
from superset.utils.pandas_postprocessing.utils import unescape_separator
from superset.views.utils import get_viz
if TYPE_CHECKING:
@ -142,6 +144,17 @@ class QueryContextProcessor:
cache.error_message = str(ex)
cache.status = QueryStatus.FAILED
# the N-dimensional DataFrame has converteds into flat DataFrame
# by `flatten operator`, "comma" in the column is escaped by `escape_separator`
# the result DataFrame columns should be unescaped
label_map = {
unescape_separator(col): [
unescape_separator(col) for col in re.split(r"(?<!\\),\s", col)
]
for col in cache.df.columns.values
}
cache.df.columns = [unescape_separator(col) for col in cache.df.columns.values]
return {
"cache_key": cache_key,
"cached_dttm": cache.cache_dttm,
@ -157,6 +170,7 @@ class QueryContextProcessor:
"rowcount": len(cache.df.index),
"from_dttm": query_obj.from_dttm,
"to_dttm": query_obj.to_dttm,
"label_map": label_map,
}
def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> Optional[str]:

View File

@ -33,6 +33,10 @@ from superset.utils.pandas_postprocessing.resample import resample
from superset.utils.pandas_postprocessing.rolling import rolling
from superset.utils.pandas_postprocessing.select import select
from superset.utils.pandas_postprocessing.sort import sort
from superset.utils.pandas_postprocessing.utils import (
escape_separator,
unescape_separator,
)
__all__ = [
"aggregate",
@ -52,4 +56,6 @@ __all__ = [
"select",
"sort",
"flatten",
"escape_separator",
"unescape_separator",
]

View File

@ -22,6 +22,7 @@ from numpy.distutils.misc_util import is_sequence
from superset.utils.pandas_postprocessing.utils import (
_is_multi_index_on_columns,
escape_separator,
FLAT_COLUMN_SEPARATOR,
)
@ -86,8 +87,8 @@ def flatten(
_cells = []
for cell in series if is_sequence(series) else [series]:
if pd.notnull(cell):
# every cell should be converted to string
_cells.append(str(cell))
# every cell should be converted to string and escape comma
_cells.append(escape_separator(str(cell)))
_columns.append(FLAT_COLUMN_SEPARATOR.join(_cells))
df.columns = _columns

View File

@ -198,3 +198,13 @@ def _append_columns(
return _base_df
append_df = append_df.rename(columns=columns)
return pd.concat([base_df, append_df], axis="columns")
def escape_separator(plain_str: str, sep: str = FLAT_COLUMN_SEPARATOR) -> str:
char = sep.strip()
return plain_str.replace(char, "\\" + char)
def unescape_separator(escaped_str: str, sep: str = FLAT_COLUMN_SEPARATOR) -> str:
char = sep.strip()
return escaped_str.replace("\\" + char, char)

View File

@ -358,3 +358,30 @@ def physical_dataset():
for ds in dataset:
db.session.delete(ds)
db.session.commit()
@pytest.fixture
def virtual_dataset_comma_in_column_value():
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
dataset = SqlaTable(
table_name="virtual_dataset",
sql=(
"SELECT 'col1,row1' as col1, 'col2, row1' as col2 "
"UNION ALL "
"SELECT 'col1,row2' as col1, 'col2, row2' as col2 "
"UNION ALL "
"SELECT 'col1,row3' as col1, 'col2, row3' as col2 "
),
database=get_example_database(),
)
TableColumn(column_name="col1", type="VARCHAR(255)", table=dataset)
TableColumn(column_name="col2", type="VARCHAR(255)", table=dataset)
SqlMetric(metric_name="count", expression="count(*)", table=dataset)
db.session.merge(dataset)
yield dataset
db.session.delete(dataset)
db.session.commit()

View File

@ -25,6 +25,7 @@ from superset import db
from superset.charts.schemas import ChartDataQueryContextSchema
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.common.query_context import QueryContext
from superset.common.query_context_factory import QueryContextFactory
from superset.common.query_object import QueryObject
from superset.connectors.sqla.models import SqlMetric
from superset.datasource.dao import DatasourceDAO
@ -35,6 +36,7 @@ from superset.utils.core import (
DatasourceType,
QueryStatus,
)
from superset.utils.pandas_postprocessing.utils import FLAT_COLUMN_SEPARATOR
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
@ -683,3 +685,46 @@ class TestQueryContext(SupersetTestCase):
row["sum__num__3 years later"]
== df_3_years_later.loc[index]["sum__num"]
)
def test_get_label_map(app_context, virtual_dataset_comma_in_column_value):
qc = QueryContextFactory().create(
datasource={
"type": virtual_dataset_comma_in_column_value.type,
"id": virtual_dataset_comma_in_column_value.id,
},
queries=[
{
"columns": ["col1", "col2"],
"metrics": ["count"],
"post_processing": [
{
"operation": "pivot",
"options": {
"aggregates": {"count": {"operator": "mean"}},
"columns": ["col2"],
"index": ["col1"],
},
},
{"operation": "flatten"},
],
}
],
result_type=ChartDataResultType.FULL,
force=True,
)
query_object = qc.queries[0]
df = qc.get_df_payload(query_object)["df"]
label_map = qc.get_df_payload(query_object)["label_map"]
assert list(df.columns.values) == [
"col1",
"count" + FLAT_COLUMN_SEPARATOR + "col2, row1",
"count" + FLAT_COLUMN_SEPARATOR + "col2, row2",
"count" + FLAT_COLUMN_SEPARATOR + "col2, row3",
]
assert label_map == {
"col1": ["col1"],
"count, col2, row1": ["count", "col2, row1"],
"count, col2, row2": ["count", "col2, row2"],
"count, col2, row3": ["count", "col2, row3"],
}

View File

@ -156,3 +156,22 @@ def test_flat_integer_column_name():
}
)
)
def test_escape_column_name():
index = pd.to_datetime(["2021-01-01", "2021-01-02", "2021-01-03"])
index.name = "__timestamp"
columns = pd.MultiIndex.from_arrays(
[
["level1,value1", "level1,value2", "level1,value3"],
["level2, value1", "level2, value2", "level2, value3"],
],
names=["level1", "level2"],
)
df = pd.DataFrame(index=index, columns=columns, data=1)
assert list(pp.flatten(df).columns.values) == [
"__timestamp",
"level1\\,value1" + FLAT_COLUMN_SEPARATOR + "level2\\, value1",
"level1\\,value2" + FLAT_COLUMN_SEPARATOR + "level2\\, value2",
"level1\\,value3" + FLAT_COLUMN_SEPARATOR + "level2\\, value3",
]

View File

@ -0,0 +1,30 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset.utils.pandas_postprocessing import escape_separator, unescape_separator
def test_escape_separator():
assert escape_separator(r" hell \world ") == r" hell \world "
assert unescape_separator(r" hell \world ") == r" hell \world "
escape_string = escape_separator("hello, world")
assert escape_string == r"hello\, world"
assert unescape_separator(escape_string) == "hello, world"
escape_string = escape_separator("hello,world")
assert escape_string == r"hello\,world"
assert unescape_separator(escape_string) == "hello,world"