From a60891702a157fa24c4e26eced914469d9ac773f Mon Sep 17 00:00:00 2001 From: John Bodley <4567245+john-bodley@users.noreply.github.com> Date: Wed, 5 Jun 2024 10:27:26 -0700 Subject: [PATCH] fix: Remove BASE_AXIS from pre-query (#29084) (cherry picked from commit 17d7e7e5e192d003f9655e1ad7498f0f1966f659) --- null_byte.csv | Bin 0 -> 6 bytes superset/common/query_context_processor.py | 10 ++++----- superset/common/query_object_factory.py | 6 ++--- superset/common/utils/time_range_utils.py | 6 ++--- superset/models/helpers.py | 3 ++- superset/utils/core.py | 21 ++++++++++++------ .../unit_tests/db_engine_specs/test_presto.py | 1 + .../unit_tests/db_engine_specs/test_trino.py | 2 ++ 8 files changed, 30 insertions(+), 19 deletions(-) create mode 100644 null_byte.csv diff --git a/null_byte.csv b/null_byte.csv new file mode 100644 index 0000000000000000000000000000000000000000..55132aaa6398b76cf42aa1473f9959dd09b08b03 GIT binary patch literal 6 NcmZ?dQesfz0ssP(0Ga>* literal 0 HcmV?d00001 diff --git a/superset/common/query_context_processor.py b/superset/common/query_context_processor.py index c47e295e96..65cadba84f 100644 --- a/superset/common/query_context_processor.py +++ b/superset/common/query_context_processor.py @@ -59,7 +59,7 @@ from superset.utils.core import ( get_column_names_from_columns, get_column_names_from_metrics, get_metric_names, - get_xaxis_label, + get_x_axis_label, normalize_dttm_col, TIME_COMPARISON, ) @@ -403,7 +403,7 @@ class QueryContextProcessor: for offset in query_object.time_offsets: try: # pylint: disable=line-too-long - # Since the xaxis is also a column name for the time filter, xaxis_label will be set as granularity + # Since the x-axis is also a column name for the time filter, x_axis_label will be set as granularity # these query object are equivalent: # 1) { granularity: 'dttm_col', time_range: '2020 : 2021', time_offsets: ['1 year ago']} # 2) { columns: [ @@ -418,9 +418,9 @@ class QueryContextProcessor: ) query_object_clone.to_dttm = get_past_or_future(offset, outer_to_dttm) - xaxis_label = get_xaxis_label(query_object.columns) + x_axis_label = get_x_axis_label(query_object.columns) query_object_clone.granularity = ( - query_object_clone.granularity or xaxis_label + query_object_clone.granularity or x_axis_label ) except ValueError as ex: raise QueryObjectValidationError(str(ex)) from ex @@ -432,7 +432,7 @@ class QueryContextProcessor: query_object_clone.filter = [ flt for flt in query_object_clone.filter - if flt.get("col") != xaxis_label + if flt.get("col") != x_axis_label ] # `offset` is added to the hash function diff --git a/superset/common/query_object_factory.py b/superset/common/query_object_factory.py index fe4cca3f48..63d1ef0966 100644 --- a/superset/common/query_object_factory.py +++ b/superset/common/query_object_factory.py @@ -28,7 +28,7 @@ from superset.utils.core import ( DatasourceDict, DatasourceType, FilterOperator, - get_xaxis_label, + get_x_axis_label, QueryObjectFilterClause, ) @@ -122,9 +122,9 @@ class QueryObjectFactory: # pylint: disable=too-few-public-methods # Use the temporal filter as the time range. # if the temporal filters uses x-axis as the temporal filter # then use it or use the first temporal filter - xaxis_label = get_xaxis_label(columns or []) + x_axis_label = get_x_axis_label(columns) match_flt = [ - flt for flt in temporal_flt if flt.get("col") == xaxis_label + flt for flt in temporal_flt if flt.get("col") == x_axis_label ] if match_flt: time_range = cast(str, match_flt[0].get("val")) diff --git a/superset/common/utils/time_range_utils.py b/superset/common/utils/time_range_utils.py index 5f9139c047..417ff52627 100644 --- a/superset/common/utils/time_range_utils.py +++ b/superset/common/utils/time_range_utils.py @@ -21,7 +21,7 @@ from typing import Any, cast from superset import app from superset.common.query_object import QueryObject -from superset.utils.core import FilterOperator, get_xaxis_label +from superset.utils.core import FilterOperator, get_x_axis_label from superset.utils.date_parser import get_since_until @@ -49,7 +49,7 @@ def get_since_until_from_query_object( """ this function will return since and until by tuple if 1) the time_range is in the query object. - 2) the xaxis column is in the columns field + 2) the x-axis column is in the columns field and its corresponding `temporal_range` filter is in the adhoc filters. :param query_object: a valid query object :return: since and until by tuple @@ -65,7 +65,7 @@ def get_since_until_from_query_object( for flt in query_object.filter: if ( flt.get("op") == FilterOperator.TEMPORAL_RANGE.value - and flt.get("col") == get_xaxis_label(query_object.columns) + and flt.get("col") == get_x_axis_label(query_object.columns) and isinstance(flt.get("val"), str) ): time_range = cast(str, flt.get("val")) diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 51b771ff61..0a7df8a195 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -87,6 +87,7 @@ from superset.utils import core as utils from superset.utils.core import ( GenericDataType, get_column_name, + get_non_base_axis_columns, get_user_id, is_adhoc_column, MediumText, @@ -2070,7 +2071,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods "filter": filter, "orderby": orderby, "extras": extras, - "columns": columns, + "columns": get_non_base_axis_columns(columns), "order_desc": True, } diff --git a/superset/utils/core.py b/superset/utils/core.py index 6649f34717..e2c1e50f17 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -1177,16 +1177,23 @@ def is_adhoc_column(column: Column) -> TypeGuard[AdhocColumn]: ) +def is_base_axis(column: Column) -> bool: + return is_adhoc_column(column) and column.get("columnType") == "BASE_AXIS" + + +def get_base_axis_columns(columns: list[Column] | None) -> list[Column]: + return [column for column in columns or [] if is_base_axis(column)] + + +def get_non_base_axis_columns(columns: list[Column] | None) -> list[Column]: + return [column for column in columns or [] if not is_base_axis(column)] + + def get_base_axis_labels(columns: list[Column] | None) -> tuple[str, ...]: - axis_cols = [ - col - for col in columns or [] - if is_adhoc_column(col) and col.get("columnType") == "BASE_AXIS" - ] - return tuple(get_column_name(col) for col in axis_cols) + return tuple(get_column_name(column) for column in get_base_axis_columns(columns)) -def get_xaxis_label(columns: list[Column] | None) -> str | None: +def get_x_axis_label(columns: list[Column] | None) -> str | None: labels = get_base_axis_labels(columns) return labels[0] if labels else None diff --git a/tests/unit_tests/db_engine_specs/test_presto.py b/tests/unit_tests/db_engine_specs/test_presto.py index 5c91bc87fc..544ce7cdc5 100644 --- a/tests/unit_tests/db_engine_specs/test_presto.py +++ b/tests/unit_tests/db_engine_specs/test_presto.py @@ -135,6 +135,7 @@ def test_where_latest_partition( PrestoEngineSpec.where_latest_partition( # type: ignore database=mock.MagicMock(), table_name="table", + schema="schema", query=sql.select(text("* FROM table")), columns=[ { diff --git a/tests/unit_tests/db_engine_specs/test_trino.py b/tests/unit_tests/db_engine_specs/test_trino.py index df86b4701f..22c5f10649 100644 --- a/tests/unit_tests/db_engine_specs/test_trino.py +++ b/tests/unit_tests/db_engine_specs/test_trino.py @@ -44,6 +44,7 @@ from tests.unit_tests.db_engine_specs.utils import ( assert_column_spec, assert_convert_dttm, ) +from tests.unit_tests.fixtures.common import dttm def _assert_columns_equal(actual_cols, expected_cols) -> None: @@ -575,6 +576,7 @@ def test_where_latest_partition( TrinoEngineSpec.where_latest_partition( # type: ignore database=MagicMock(), table_name="table", + schema="schema", query=sql.select(text("* FROM table")), columns=[ {