diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index 5913be0b39..791d248ce3 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -16,8 +16,11 @@ # under the License. from datetime import datetime from typing import Optional +from urllib import parse -from superset.db_engine_specs import BaseEngineSpec +from sqlalchemy.engine.url import URL + +from superset.db_engine_specs.base import BaseEngineSpec from superset.utils import core as utils @@ -56,3 +59,13 @@ class TrinoEngineSpec(BaseEngineSpec): @classmethod def epoch_to_dttm(cls) -> str: return "from_unixtime({col})" + + @classmethod + def adjust_database_uri( + cls, uri: URL, selected_schema: Optional[str] = None + ) -> None: + database = uri.database + if selected_schema and database: + selected_schema = parse.quote(selected_schema, safe="") + database = database.split("/")[0] + "/" + selected_schema + uri.database = database diff --git a/tests/db_engine_specs/trino_tests.py b/tests/db_engine_specs/trino_tests.py index 2f827721fa..557d3bda26 100644 --- a/tests/db_engine_specs/trino_tests.py +++ b/tests/db_engine_specs/trino_tests.py @@ -14,12 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from sqlalchemy.engine.url import URL from superset.db_engine_specs.trino import TrinoEngineSpec from tests.db_engine_specs.base_tests import TestDbEngineSpec -class TestPrestoDbEngineSpec(TestDbEngineSpec): +class TestTrinoDbEngineSpec(TestDbEngineSpec): def test_convert_dttm(self): dttm = self.get_dttm() @@ -32,3 +33,22 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec): TrinoEngineSpec.convert_dttm("TIMESTAMP", dttm), "from_iso8601_timestamp('2019-01-02T03:04:05.678900')", ) + + def test_adjust_database_uri(self): + url = URL(drivername="trino", database="hive") + TrinoEngineSpec.adjust_database_uri(url, selected_schema="foobar") + self.assertEqual(url.database, "hive/foobar") + + def test_adjust_database_uri_when_database_contain_schema(self): + url = URL(drivername="trino", database="hive/default") + TrinoEngineSpec.adjust_database_uri(url, selected_schema="foobar") + self.assertEqual(url.database, "hive/foobar") + + def test_adjust_database_uri_when_selected_schema_is_none(self): + url = URL(drivername="trino", database="hive") + TrinoEngineSpec.adjust_database_uri(url, selected_schema=None) + self.assertEqual(url.database, "hive") + + url.database = "hive/default" + TrinoEngineSpec.adjust_database_uri(url, selected_schema=None) + self.assertEqual(url.database, "hive/default")