feat: TrinoEngineSpec.adjust_database_uri (#14122)

* feat: TrinoEngine implement adjust_database_uri

Signed-off-by: Đặng Minh Dũng <dungdm93@live.com>

* test: TrinoEngine implement adjust_database_uri

Signed-off-by: Đặng Minh Dũng <dungdm93@live.com>
This commit is contained in:
Đặng Minh Dũng 2021-04-19 17:10:12 +07:00 committed by GitHub
parent ca359402bd
commit 11e0f4cb2d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 35 additions and 2 deletions

View File

@ -16,8 +16,11 @@
# under the License.
from datetime import datetime
from typing import Optional
from urllib import parse
from superset.db_engine_specs import BaseEngineSpec
from sqlalchemy.engine.url import URL
from superset.db_engine_specs.base import BaseEngineSpec
from superset.utils import core as utils
@ -56,3 +59,13 @@ class TrinoEngineSpec(BaseEngineSpec):
@classmethod
def epoch_to_dttm(cls) -> str:
return "from_unixtime({col})"
@classmethod
def adjust_database_uri(
cls, uri: URL, selected_schema: Optional[str] = None
) -> None:
database = uri.database
if selected_schema and database:
selected_schema = parse.quote(selected_schema, safe="")
database = database.split("/")[0] + "/" + selected_schema
uri.database = database

View File

@ -14,12 +14,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from sqlalchemy.engine.url import URL
from superset.db_engine_specs.trino import TrinoEngineSpec
from tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestPrestoDbEngineSpec(TestDbEngineSpec):
class TestTrinoDbEngineSpec(TestDbEngineSpec):
def test_convert_dttm(self):
dttm = self.get_dttm()
@ -32,3 +33,22 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
TrinoEngineSpec.convert_dttm("TIMESTAMP", dttm),
"from_iso8601_timestamp('2019-01-02T03:04:05.678900')",
)
def test_adjust_database_uri(self):
url = URL(drivername="trino", database="hive")
TrinoEngineSpec.adjust_database_uri(url, selected_schema="foobar")
self.assertEqual(url.database, "hive/foobar")
def test_adjust_database_uri_when_database_contain_schema(self):
url = URL(drivername="trino", database="hive/default")
TrinoEngineSpec.adjust_database_uri(url, selected_schema="foobar")
self.assertEqual(url.database, "hive/foobar")
def test_adjust_database_uri_when_selected_schema_is_none(self):
url = URL(drivername="trino", database="hive")
TrinoEngineSpec.adjust_database_uri(url, selected_schema=None)
self.assertEqual(url.database, "hive")
url.database = "hive/default"
TrinoEngineSpec.adjust_database_uri(url, selected_schema=None)
self.assertEqual(url.database, "hive/default")