From 0de61df72f58db957891abd10c9eb5c9110a9572 Mon Sep 17 00:00:00 2001 From: Karol Kostrzewa Date: Thu, 21 Jan 2021 10:34:48 +0100 Subject: [PATCH] test: sqlite db engine spec (#12616) --- tests/db_engine_specs/sqlite_tests.py | 51 +++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/tests/db_engine_specs/sqlite_tests.py b/tests/db_engine_specs/sqlite_tests.py index fd3001cdef..bb4a19f7c5 100644 --- a/tests/db_engine_specs/sqlite_tests.py +++ b/tests/db_engine_specs/sqlite_tests.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from unittest import mock + from superset.db_engine_specs.sqlite import SqliteEngineSpec from tests.db_engine_specs.base_tests import TestDbEngineSpec @@ -25,3 +27,52 @@ class TestSQliteDbEngineSpec(TestDbEngineSpec): self.assertEqual( SqliteEngineSpec.convert_dttm("TEXT", dttm), "'2019-01-02 03:04:05.678900'" ) + + def test_convert_dttm_lower(self): + dttm = self.get_dttm() + + self.assertEqual( + SqliteEngineSpec.convert_dttm("text", dttm), "'2019-01-02 03:04:05.678900'" + ) + + def test_convert_dttm_invalid_type(self): + dttm = self.get_dttm() + + self.assertEqual(SqliteEngineSpec.convert_dttm("other", dttm), None) + + def test_get_all_datasource_names_table(self): + database = mock.MagicMock() + database.get_all_schema_names.return_value = ["schema1"] + table_names = ["table1", "table2"] + get_tables = mock.MagicMock(return_value=table_names) + database.get_all_table_names_in_schema = get_tables + result = SqliteEngineSpec.get_all_datasource_names(database, "table") + assert result == table_names + get_tables.assert_called_once_with( + schema="schema1", + force=True, + cache=database.table_cache_enabled, + cache_timeout=database.table_cache_timeout, + ) + + def test_get_all_datasource_names_view(self): + database = mock.MagicMock() + database.get_all_schema_names.return_value = ["schema1"] + views_names = ["view1", "view2"] + get_views = mock.MagicMock(return_value=views_names) + database.get_all_view_names_in_schema = get_views + result = SqliteEngineSpec.get_all_datasource_names(database, "view") + assert result == views_names + get_views.assert_called_once_with( + schema="schema1", + force=True, + cache=database.table_cache_enabled, + cache_timeout=database.table_cache_timeout, + ) + + def test_get_all_datasource_names_invalid_type(self): + database = mock.MagicMock() + database.get_all_schema_names.return_value = ["schema1"] + invalid_type = "asdf" + with self.assertRaises(Exception): + SqliteEngineSpec.get_all_datasource_names(database, invalid_type)