# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # isort:skip_file import unittest from typing import Any, Dict from tests.base_tests import SupersetTestCase from tests.test_app import app from superset import db from superset.connectors.sqla.models import SqlaTable from superset.utils.core import get_or_create_db FULL_DTTM_DEFAULTS_EXAMPLE = { "main_dttm_col": "id", "dttm_columns": { "dttm": { "python_date_format": "epoch_s", "expression": "CAST(dttm as INTEGER)", }, "id": {"python_date_format": "epoch_ms"}, "month": { "python_date_format": "%Y-%m-%d", "expression": "CASE WHEN length(month) = 7 THEN month || '-01' ELSE month END", }, }, } def apply_dttm_defaults(table: SqlaTable, dttm_defaults: Dict[str, Any]): """Applies dttm defaults to the table, mutates in place.""" for dbcol in table.columns: # Set is_dttm is column is listed in dttm_columns. if dbcol.column_name in dttm_defaults.get("dttm_columns", {}): dbcol.is_dttm = True # Skip non dttm columns. if dbcol.column_name not in dttm_defaults.get("dttm_columns", {}): continue # Set table main_dttm_col. if dbcol.column_name == dttm_defaults.get("main_dttm_col"): table.main_dttm_col = dbcol.column_name # Apply defaults if empty. dttm_column_defaults = dttm_defaults.get("dttm_columns", {}).get( dbcol.column_name, {} ) dbcol.is_dttm = True if ( not dbcol.python_date_format and "python_date_format" in dttm_column_defaults ): dbcol.python_date_format = dttm_column_defaults["python_date_format"] if not dbcol.expression and "expression" in dttm_column_defaults: dbcol.expression = dttm_column_defaults["expression"] class TestConfig(SupersetTestCase): def setUp(self) -> None: self.login(username="admin") self._test_db_id = get_or_create_db( "column_test_db", app.config["SQLALCHEMY_DATABASE_URI"] ).id self._old_sqla_table_mutator = app.config["SQLA_TABLE_MUTATOR"] def createTable(self, dttm_defaults): app.config["SQLA_TABLE_MUTATOR"] = lambda t: apply_dttm_defaults( t, dttm_defaults ) resp = self.client.post( "/tablemodelview/add", data=dict(database=self._test_db_id, table_name="logs"), follow_redirects=True, ) self.assertEqual(resp.status_code, 200) self._logs_table = ( db.session.query(SqlaTable).filter_by(table_name="logs").one() ) def tearDown(self): app.config["SQLA_TABLE_MUTATOR"] = self._old_sqla_table_mutator if hasattr(self, "_logs_table"): db.session.delete(self._logs_table) db.session.delete(self._logs_table.database) db.session.commit() def test_main_dttm_col(self): # Make sure that dttm column is set properly. self.createTable({"main_dttm_col": "id", "dttm_columns": {"id": {}}}) self.assertEqual(self._logs_table.main_dttm_col, "id") def test_main_dttm_col_nonexistent(self): self.createTable({"main_dttm_col": "nonexistent"}) # Column doesn't exist, falls back to dttm. self.assertEqual(self._logs_table.main_dttm_col, "dttm") def test_main_dttm_col_nondttm(self): self.createTable({"main_dttm_col": "duration_ms"}) # duration_ms is not dttm column, falls back to dttm. self.assertEqual(self._logs_table.main_dttm_col, "dttm") def test_python_date_format_by_column_name(self): table_defaults = { "dttm_columns": { "id": {"python_date_format": "epoch_ms"}, "dttm": {"python_date_format": "epoch_s"}, "duration_ms": {"python_date_format": "invalid"}, } } self.createTable(table_defaults) id_col = [c for c in self._logs_table.columns if c.column_name == "id"][0] self.assertTrue(id_col.is_dttm) self.assertEqual(id_col.python_date_format, "epoch_ms") dttm_col = [c for c in self._logs_table.columns if c.column_name == "dttm"][0] self.assertTrue(dttm_col.is_dttm) self.assertEqual(dttm_col.python_date_format, "epoch_s") dms_col = [ c for c in self._logs_table.columns if c.column_name == "duration_ms" ][0] self.assertTrue(dms_col.is_dttm) self.assertEqual(dms_col.python_date_format, "invalid") def test_expression_by_column_name(self): table_defaults = { "dttm_columns": { "dttm": {"expression": "CAST(dttm as INTEGER)"}, "duration_ms": {"expression": "CAST(duration_ms as DOUBLE)"}, } } self.createTable(table_defaults) dttm_col = [c for c in self._logs_table.columns if c.column_name == "dttm"][0] self.assertTrue(dttm_col.is_dttm) self.assertEqual(dttm_col.expression, "CAST(dttm as INTEGER)") dms_col = [ c for c in self._logs_table.columns if c.column_name == "duration_ms" ][0] self.assertEqual(dms_col.expression, "CAST(duration_ms as DOUBLE)") self.assertTrue(dms_col.is_dttm) def test_full_setting(self): self.createTable(FULL_DTTM_DEFAULTS_EXAMPLE) self.assertEqual(self._logs_table.main_dttm_col, "id") id_col = [c for c in self._logs_table.columns if c.column_name == "id"][0] self.assertTrue(id_col.is_dttm) self.assertEqual(id_col.python_date_format, "epoch_ms") self.assertIsNone(id_col.expression) dttm_col = [c for c in self._logs_table.columns if c.column_name == "dttm"][0] self.assertTrue(dttm_col.is_dttm) self.assertEqual(dttm_col.python_date_format, "epoch_s") self.assertEqual(dttm_col.expression, "CAST(dttm as INTEGER)") if __name__ == "__main__": unittest.main()