fix: Remove BASE_AXIS from pre-query (#29084)

This commit is contained in:
John Bodley 2024-06-05 10:27:26 -07:00 committed by GitHub
parent df0b1cb8ed
commit 17d7e7e5e1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 25 additions and 17 deletions

BIN
null_byte.csv Normal file

Binary file not shown.
1 A
2

View File

@ -60,7 +60,7 @@ from superset.utils.core import (
get_column_names_from_columns,
get_column_names_from_metrics,
get_metric_names,
get_xaxis_label,
get_x_axis_label,
normalize_dttm_col,
TIME_COMPARISON,
)
@ -399,7 +399,7 @@ class QueryContextProcessor:
for offset in query_object.time_offsets:
try:
# pylint: disable=line-too-long
# Since the xaxis is also a column name for the time filter, xaxis_label will be set as granularity
# Since the x-axis is also a column name for the time filter, x_axis_label will be set as granularity
# these query object are equivalent:
# 1) { granularity: 'dttm_col', time_range: '2020 : 2021', time_offsets: ['1 year ago']}
# 2) { columns: [
@ -414,9 +414,9 @@ class QueryContextProcessor:
)
query_object_clone.to_dttm = get_past_or_future(offset, outer_to_dttm)
xaxis_label = get_xaxis_label(query_object.columns)
x_axis_label = get_x_axis_label(query_object.columns)
query_object_clone.granularity = (
query_object_clone.granularity or xaxis_label
query_object_clone.granularity or x_axis_label
)
except ValueError as ex:
raise QueryObjectValidationError(str(ex)) from ex
@ -450,7 +450,7 @@ class QueryContextProcessor:
query_object_clone.filter = [
flt
for flt in query_object_clone.filter
if flt.get("col") != xaxis_label
if flt.get("col") != x_axis_label
]
# `offset` is added to the hash function

View File

@ -28,7 +28,7 @@ from superset.utils.core import (
DatasourceDict,
DatasourceType,
FilterOperator,
get_xaxis_label,
get_x_axis_label,
QueryObjectFilterClause,
)
@ -122,9 +122,9 @@ class QueryObjectFactory: # pylint: disable=too-few-public-methods
# Use the temporal filter as the time range.
# if the temporal filters uses x-axis as the temporal filter
# then use it or use the first temporal filter
xaxis_label = get_xaxis_label(columns or [])
x_axis_label = get_x_axis_label(columns)
match_flt = [
flt for flt in temporal_flt if flt.get("col") == xaxis_label
flt for flt in temporal_flt if flt.get("col") == x_axis_label
]
if match_flt:
time_range = cast(str, match_flt[0].get("val"))

View File

@ -52,7 +52,7 @@ def get_since_until_from_query_object(
"""
this function will return since and until by tuple if
1) the time_range is in the query object.
2) the xaxis column is in the columns field
2) the x-axis column is in the columns field
and its corresponding `temporal_range` filter is in the adhoc filters.
:param query_object: a valid query object
:return: since and until by tuple

View File

@ -90,6 +90,7 @@ from superset.utils import core as utils, json
from superset.utils.core import (
GenericDataType,
get_column_name,
get_non_base_axis_columns,
get_user_id,
is_adhoc_column,
MediumText,
@ -2083,7 +2084,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
"filter": filter,
"orderby": orderby,
"extras": extras,
"columns": columns,
"columns": get_non_base_axis_columns(columns),
"order_desc": True,
}

View File

@ -1056,16 +1056,23 @@ def is_adhoc_column(column: Column) -> TypeGuard[AdhocColumn]:
)
def is_base_axis(column: Column) -> bool:
return is_adhoc_column(column) and column.get("columnType") == "BASE_AXIS"
def get_base_axis_columns(columns: list[Column] | None) -> list[Column]:
return [column for column in columns or [] if is_base_axis(column)]
def get_non_base_axis_columns(columns: list[Column] | None) -> list[Column]:
return [column for column in columns or [] if not is_base_axis(column)]
def get_base_axis_labels(columns: list[Column] | None) -> tuple[str, ...]:
axis_cols = [
col
for col in columns or []
if is_adhoc_column(col) and col.get("columnType") == "BASE_AXIS"
]
return tuple(get_column_name(col) for col in axis_cols)
return tuple(get_column_name(column) for column in get_base_axis_columns(columns))
def get_xaxis_label(columns: list[Column] | None) -> str | None:
def get_x_axis_label(columns: list[Column] | None) -> str | None:
labels = get_base_axis_labels(columns)
return labels[0] if labels else None