# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from typing import Any, Optional from unittest.mock import Mock, patch from pytest import fixture, mark from superset.common.query_object_factory import QueryObjectFactory from tests.common.query_context_generator import QueryContextGenerator def create_app_config() -> dict[str, Any]: return { "ROW_LIMIT": 5000, "DEFAULT_RELATIVE_START_TIME": "today", "DEFAULT_RELATIVE_END_TIME": "today", "SAMPLES_ROW_LIMIT": 1000, "SQL_MAX_ROW": 100000, } @fixture def app_config() -> dict[str, Any]: return create_app_config().copy() @fixture def session_factory() -> Mock: return Mock() class SimpleDatasetColumn: def __init__(self, col_params: dict[str, Any]): self.__dict__.update(col_params) TEMPORAL_COLUMN_NAMES = ["temporal_column", "temporal_column_with_python_date_format"] TEMPORAL_COLUMNS = { TEMPORAL_COLUMN_NAMES[0]: SimpleDatasetColumn( { "column_name": TEMPORAL_COLUMN_NAMES[0], "is_dttm": True, "python_date_format": None, "type": "string", "num_types": ["BIGINT"], } ), TEMPORAL_COLUMN_NAMES[1]: SimpleDatasetColumn( { "column_name": TEMPORAL_COLUMN_NAMES[1], "type": "BIGINT", "is_dttm": True, "python_date_format": "%Y", "num_types": ["BIGINT"], } ), } @fixture def connector_registry() -> Mock: datasource_dao_mock = Mock(spec=["get_datasource"]) datasource_dao_mock.get_datasource.return_value = Mock() datasource_dao_mock.get_datasource().get_column = Mock( side_effect=lambda col_name: TEMPORAL_COLUMNS[col_name] if col_name in TEMPORAL_COLUMN_NAMES else Mock() ) datasource_dao_mock.get_datasource().db_extra = None return datasource_dao_mock def apply_max_row_limit(limit: int, max_limit: Optional[int] = None) -> int: if max_limit is None: max_limit = create_app_config()["SQL_MAX_ROW"] if limit != 0: return min(max_limit, limit) return max_limit @fixture def query_object_factory( app_config: dict[str, Any], connector_registry: Mock, session_factory: Mock ) -> QueryObjectFactory: import superset.common.query_object_factory as mod mod.apply_max_row_limit = apply_max_row_limit return QueryObjectFactory(app_config, connector_registry, session_factory) @fixture def raw_query_context() -> dict[str, Any]: return QueryContextGenerator().generate("birth_names") class TestQueryObjectFactory: def test_query_context_limit_and_offset_defaults( self, query_object_factory: QueryObjectFactory, raw_query_context: dict[str, Any], ): raw_query_object = raw_query_context["queries"][0] raw_query_object.pop("row_limit", None) raw_query_object.pop("row_offset", None) query_object = query_object_factory.create( raw_query_context["result_type"], **raw_query_object ) assert query_object.row_limit == 5000 assert query_object.row_offset == 0 def test_query_context_limit( self, query_object_factory: QueryObjectFactory, raw_query_context: dict[str, Any], ): raw_query_object = raw_query_context["queries"][0] raw_query_object["row_limit"] = 100 raw_query_object["row_offset"] = 200 query_object = query_object_factory.create( raw_query_context["result_type"], **raw_query_object ) assert query_object.row_limit == 100 assert query_object.row_offset == 200 def test_query_context_null_post_processing_op( self, query_object_factory: QueryObjectFactory, raw_query_context: dict[str, Any], ): raw_query_object = raw_query_context["queries"][0] raw_query_object["post_processing"] = [None] query_object = query_object_factory.create( raw_query_context["result_type"], **raw_query_object ) assert query_object.post_processing == [] def test_query_context_no_python_date_format_filters( self, query_object_factory: QueryObjectFactory, raw_query_context: dict[str, Any], ): raw_query_object = raw_query_context["queries"][0] raw_query_object["filters"].append( {"col": TEMPORAL_COLUMN_NAMES[0], "op": "==", "val": 315532800000} ) query_object = query_object_factory.create( raw_query_context["result_type"], raw_query_context["datasource"], **raw_query_object ) assert query_object.filter[3]["val"] == 315532800000 def test_query_context_python_date_format_filters( self, query_object_factory: QueryObjectFactory, raw_query_context: dict[str, Any], ): raw_query_object = raw_query_context["queries"][0] raw_query_object["filters"].append( {"col": TEMPORAL_COLUMN_NAMES[1], "op": "==", "val": 315532800000} ) query_object = query_object_factory.create( raw_query_context["result_type"], raw_query_context["datasource"], **raw_query_object ) assert query_object.filter[3]["val"] == 1980 def test_query_context_python_date_format_filters_list_of_values( self, query_object_factory: QueryObjectFactory, raw_query_context: dict[str, Any], ): raw_query_object = raw_query_context["queries"][0] raw_query_object["filters"].append( { "col": TEMPORAL_COLUMN_NAMES[1], "op": "==", "val": [315532800000, 631152000000], } ) query_object = query_object_factory.create( raw_query_context["result_type"], raw_query_context["datasource"], **raw_query_object ) assert query_object.filter[3]["val"] == [1980, 1990]