From 11e0f4cb2d1779dd709de4500a52d428efb62382 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C4=90=E1=BA=B7ng=20Minh=20D=C5=A9ng?= Date: Mon, 19 Apr 2021 17:10:12 +0700 Subject: [PATCH] feat: TrinoEngineSpec.adjust_database_uri (#14122) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: TrinoEngine implement adjust_database_uri Signed-off-by: Đặng Minh Dũng * test: TrinoEngine implement adjust_database_uri Signed-off-by: Đặng Minh Dũng --- superset/db_engine_specs/trino.py | 15 ++++++++++++++- tests/db_engine_specs/trino_tests.py | 22 +++++++++++++++++++++- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index 5913be0b39..791d248ce3 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -16,8 +16,11 @@ # under the License. from datetime import datetime from typing import Optional +from urllib import parse -from superset.db_engine_specs import BaseEngineSpec +from sqlalchemy.engine.url import URL + +from superset.db_engine_specs.base import BaseEngineSpec from superset.utils import core as utils @@ -56,3 +59,13 @@ class TrinoEngineSpec(BaseEngineSpec): @classmethod def epoch_to_dttm(cls) -> str: return "from_unixtime({col})" + + @classmethod + def adjust_database_uri( + cls, uri: URL, selected_schema: Optional[str] = None + ) -> None: + database = uri.database + if selected_schema and database: + selected_schema = parse.quote(selected_schema, safe="") + database = database.split("/")[0] + "/" + selected_schema + uri.database = database diff --git a/tests/db_engine_specs/trino_tests.py b/tests/db_engine_specs/trino_tests.py index 2f827721fa..557d3bda26 100644 --- a/tests/db_engine_specs/trino_tests.py +++ b/tests/db_engine_specs/trino_tests.py @@ -14,12 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from sqlalchemy.engine.url import URL from superset.db_engine_specs.trino import TrinoEngineSpec from tests.db_engine_specs.base_tests import TestDbEngineSpec -class TestPrestoDbEngineSpec(TestDbEngineSpec): +class TestTrinoDbEngineSpec(TestDbEngineSpec): def test_convert_dttm(self): dttm = self.get_dttm() @@ -32,3 +33,22 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec): TrinoEngineSpec.convert_dttm("TIMESTAMP", dttm), "from_iso8601_timestamp('2019-01-02T03:04:05.678900')", ) + + def test_adjust_database_uri(self): + url = URL(drivername="trino", database="hive") + TrinoEngineSpec.adjust_database_uri(url, selected_schema="foobar") + self.assertEqual(url.database, "hive/foobar") + + def test_adjust_database_uri_when_database_contain_schema(self): + url = URL(drivername="trino", database="hive/default") + TrinoEngineSpec.adjust_database_uri(url, selected_schema="foobar") + self.assertEqual(url.database, "hive/foobar") + + def test_adjust_database_uri_when_selected_schema_is_none(self): + url = URL(drivername="trino", database="hive") + TrinoEngineSpec.adjust_database_uri(url, selected_schema=None) + self.assertEqual(url.database, "hive") + + url.database = "hive/default" + TrinoEngineSpec.adjust_database_uri(url, selected_schema=None) + self.assertEqual(url.database, "hive/default")