mirror of https://github.com/apache/superset.git
feat: generate label map on the backend (#21124)
This commit is contained in:
parent
756ed0e36a
commit
11bf7b9125
|
@ -18,6 +18,7 @@ from __future__ import annotations
|
|||
|
||||
import copy
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import numpy as np
|
||||
|
@ -57,6 +58,7 @@ from superset.utils.core import (
|
|||
TIME_COMPARISON,
|
||||
)
|
||||
from superset.utils.date_parser import get_past_or_future, normalize_time_delta
|
||||
from superset.utils.pandas_postprocessing.utils import unescape_separator
|
||||
from superset.views.utils import get_viz
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -142,6 +144,17 @@ class QueryContextProcessor:
|
|||
cache.error_message = str(ex)
|
||||
cache.status = QueryStatus.FAILED
|
||||
|
||||
# the N-dimensional DataFrame has converteds into flat DataFrame
|
||||
# by `flatten operator`, "comma" in the column is escaped by `escape_separator`
|
||||
# the result DataFrame columns should be unescaped
|
||||
label_map = {
|
||||
unescape_separator(col): [
|
||||
unescape_separator(col) for col in re.split(r"(?<!\\),\s", col)
|
||||
]
|
||||
for col in cache.df.columns.values
|
||||
}
|
||||
cache.df.columns = [unescape_separator(col) for col in cache.df.columns.values]
|
||||
|
||||
return {
|
||||
"cache_key": cache_key,
|
||||
"cached_dttm": cache.cache_dttm,
|
||||
|
@ -157,6 +170,7 @@ class QueryContextProcessor:
|
|||
"rowcount": len(cache.df.index),
|
||||
"from_dttm": query_obj.from_dttm,
|
||||
"to_dttm": query_obj.to_dttm,
|
||||
"label_map": label_map,
|
||||
}
|
||||
|
||||
def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> Optional[str]:
|
||||
|
|
|
@ -33,6 +33,10 @@ from superset.utils.pandas_postprocessing.resample import resample
|
|||
from superset.utils.pandas_postprocessing.rolling import rolling
|
||||
from superset.utils.pandas_postprocessing.select import select
|
||||
from superset.utils.pandas_postprocessing.sort import sort
|
||||
from superset.utils.pandas_postprocessing.utils import (
|
||||
escape_separator,
|
||||
unescape_separator,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"aggregate",
|
||||
|
@ -52,4 +56,6 @@ __all__ = [
|
|||
"select",
|
||||
"sort",
|
||||
"flatten",
|
||||
"escape_separator",
|
||||
"unescape_separator",
|
||||
]
|
||||
|
|
|
@ -22,6 +22,7 @@ from numpy.distutils.misc_util import is_sequence
|
|||
|
||||
from superset.utils.pandas_postprocessing.utils import (
|
||||
_is_multi_index_on_columns,
|
||||
escape_separator,
|
||||
FLAT_COLUMN_SEPARATOR,
|
||||
)
|
||||
|
||||
|
@ -86,8 +87,8 @@ def flatten(
|
|||
_cells = []
|
||||
for cell in series if is_sequence(series) else [series]:
|
||||
if pd.notnull(cell):
|
||||
# every cell should be converted to string
|
||||
_cells.append(str(cell))
|
||||
# every cell should be converted to string and escape comma
|
||||
_cells.append(escape_separator(str(cell)))
|
||||
_columns.append(FLAT_COLUMN_SEPARATOR.join(_cells))
|
||||
|
||||
df.columns = _columns
|
||||
|
|
|
@ -198,3 +198,13 @@ def _append_columns(
|
|||
return _base_df
|
||||
append_df = append_df.rename(columns=columns)
|
||||
return pd.concat([base_df, append_df], axis="columns")
|
||||
|
||||
|
||||
def escape_separator(plain_str: str, sep: str = FLAT_COLUMN_SEPARATOR) -> str:
|
||||
char = sep.strip()
|
||||
return plain_str.replace(char, "\\" + char)
|
||||
|
||||
|
||||
def unescape_separator(escaped_str: str, sep: str = FLAT_COLUMN_SEPARATOR) -> str:
|
||||
char = sep.strip()
|
||||
return escaped_str.replace("\\" + char, char)
|
||||
|
|
|
@ -358,3 +358,30 @@ def physical_dataset():
|
|||
for ds in dataset:
|
||||
db.session.delete(ds)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def virtual_dataset_comma_in_column_value():
|
||||
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
|
||||
|
||||
dataset = SqlaTable(
|
||||
table_name="virtual_dataset",
|
||||
sql=(
|
||||
"SELECT 'col1,row1' as col1, 'col2, row1' as col2 "
|
||||
"UNION ALL "
|
||||
"SELECT 'col1,row2' as col1, 'col2, row2' as col2 "
|
||||
"UNION ALL "
|
||||
"SELECT 'col1,row3' as col1, 'col2, row3' as col2 "
|
||||
),
|
||||
database=get_example_database(),
|
||||
)
|
||||
TableColumn(column_name="col1", type="VARCHAR(255)", table=dataset)
|
||||
TableColumn(column_name="col2", type="VARCHAR(255)", table=dataset)
|
||||
|
||||
SqlMetric(metric_name="count", expression="count(*)", table=dataset)
|
||||
db.session.merge(dataset)
|
||||
|
||||
yield dataset
|
||||
|
||||
db.session.delete(dataset)
|
||||
db.session.commit()
|
||||
|
|
|
@ -25,6 +25,7 @@ from superset import db
|
|||
from superset.charts.schemas import ChartDataQueryContextSchema
|
||||
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
|
||||
from superset.common.query_context import QueryContext
|
||||
from superset.common.query_context_factory import QueryContextFactory
|
||||
from superset.common.query_object import QueryObject
|
||||
from superset.connectors.sqla.models import SqlMetric
|
||||
from superset.datasource.dao import DatasourceDAO
|
||||
|
@ -35,6 +36,7 @@ from superset.utils.core import (
|
|||
DatasourceType,
|
||||
QueryStatus,
|
||||
)
|
||||
from superset.utils.pandas_postprocessing.utils import FLAT_COLUMN_SEPARATOR
|
||||
from tests.integration_tests.base_tests import SupersetTestCase
|
||||
from tests.integration_tests.fixtures.birth_names_dashboard import (
|
||||
load_birth_names_dashboard_with_slices,
|
||||
|
@ -683,3 +685,46 @@ class TestQueryContext(SupersetTestCase):
|
|||
row["sum__num__3 years later"]
|
||||
== df_3_years_later.loc[index]["sum__num"]
|
||||
)
|
||||
|
||||
|
||||
def test_get_label_map(app_context, virtual_dataset_comma_in_column_value):
|
||||
qc = QueryContextFactory().create(
|
||||
datasource={
|
||||
"type": virtual_dataset_comma_in_column_value.type,
|
||||
"id": virtual_dataset_comma_in_column_value.id,
|
||||
},
|
||||
queries=[
|
||||
{
|
||||
"columns": ["col1", "col2"],
|
||||
"metrics": ["count"],
|
||||
"post_processing": [
|
||||
{
|
||||
"operation": "pivot",
|
||||
"options": {
|
||||
"aggregates": {"count": {"operator": "mean"}},
|
||||
"columns": ["col2"],
|
||||
"index": ["col1"],
|
||||
},
|
||||
},
|
||||
{"operation": "flatten"},
|
||||
],
|
||||
}
|
||||
],
|
||||
result_type=ChartDataResultType.FULL,
|
||||
force=True,
|
||||
)
|
||||
query_object = qc.queries[0]
|
||||
df = qc.get_df_payload(query_object)["df"]
|
||||
label_map = qc.get_df_payload(query_object)["label_map"]
|
||||
assert list(df.columns.values) == [
|
||||
"col1",
|
||||
"count" + FLAT_COLUMN_SEPARATOR + "col2, row1",
|
||||
"count" + FLAT_COLUMN_SEPARATOR + "col2, row2",
|
||||
"count" + FLAT_COLUMN_SEPARATOR + "col2, row3",
|
||||
]
|
||||
assert label_map == {
|
||||
"col1": ["col1"],
|
||||
"count, col2, row1": ["count", "col2, row1"],
|
||||
"count, col2, row2": ["count", "col2, row2"],
|
||||
"count, col2, row3": ["count", "col2, row3"],
|
||||
}
|
||||
|
|
|
@ -156,3 +156,22 @@ def test_flat_integer_column_name():
|
|||
}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_escape_column_name():
|
||||
index = pd.to_datetime(["2021-01-01", "2021-01-02", "2021-01-03"])
|
||||
index.name = "__timestamp"
|
||||
columns = pd.MultiIndex.from_arrays(
|
||||
[
|
||||
["level1,value1", "level1,value2", "level1,value3"],
|
||||
["level2, value1", "level2, value2", "level2, value3"],
|
||||
],
|
||||
names=["level1", "level2"],
|
||||
)
|
||||
df = pd.DataFrame(index=index, columns=columns, data=1)
|
||||
assert list(pp.flatten(df).columns.values) == [
|
||||
"__timestamp",
|
||||
"level1\\,value1" + FLAT_COLUMN_SEPARATOR + "level2\\, value1",
|
||||
"level1\\,value2" + FLAT_COLUMN_SEPARATOR + "level2\\, value2",
|
||||
"level1\\,value3" + FLAT_COLUMN_SEPARATOR + "level2\\, value3",
|
||||
]
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from superset.utils.pandas_postprocessing import escape_separator, unescape_separator
|
||||
|
||||
|
||||
def test_escape_separator():
|
||||
assert escape_separator(r" hell \world ") == r" hell \world "
|
||||
assert unescape_separator(r" hell \world ") == r" hell \world "
|
||||
|
||||
escape_string = escape_separator("hello, world")
|
||||
assert escape_string == r"hello\, world"
|
||||
assert unescape_separator(escape_string) == "hello, world"
|
||||
|
||||
escape_string = escape_separator("hello,world")
|
||||
assert escape_string == r"hello\,world"
|
||||
assert unescape_separator(escape_string) == "hello,world"
|
Loading…
Reference in New Issue