2022-08-12 00:48:29 -04:00
|
|
|
# Licensed to the Apache Software Foundation (ASF) under one
|
|
|
|
# or more contributor license agreements. See the NOTICE file
|
|
|
|
# distributed with this work for additional information
|
|
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
|
|
# to you under the Apache License, Version 2.0 (the
|
|
|
|
# "License"); you may not use this file except in compliance
|
|
|
|
# with the License. You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing,
|
|
|
|
# software distributed under the License is distributed on an
|
|
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
|
|
# KIND, either express or implied. See the License for the
|
|
|
|
# specific language governing permissions and limitations
|
|
|
|
# under the License.
|
|
|
|
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
|
2023-11-20 12:59:10 -05:00
|
|
|
import copy
|
2022-12-23 23:31:46 -05:00
|
|
|
import json
|
2023-01-31 08:54:07 -05:00
|
|
|
from datetime import datetime
|
2023-06-01 15:01:10 -04:00
|
|
|
from typing import Any, Optional
|
2023-01-31 08:54:07 -05:00
|
|
|
from unittest.mock import Mock, patch
|
2022-08-12 00:48:29 -04:00
|
|
|
|
2023-01-31 08:54:07 -05:00
|
|
|
import pandas as pd
|
2022-12-23 23:31:46 -05:00
|
|
|
import pytest
|
|
|
|
from pytest_mock import MockerFixture
|
2024-01-17 12:48:50 -05:00
|
|
|
from requests.exceptions import ConnectionError as RequestsConnectionError
|
2023-01-31 08:54:07 -05:00
|
|
|
from sqlalchemy import types
|
2024-01-17 12:48:50 -05:00
|
|
|
from trino.exceptions import TrinoExternalError, TrinoInternalError, TrinoUserError
|
2023-11-20 12:59:10 -05:00
|
|
|
from trino.sqlalchemy import datatype
|
2022-12-23 23:31:46 -05:00
|
|
|
|
2023-01-31 08:54:07 -05:00
|
|
|
import superset.config
|
|
|
|
from superset.constants import QUERY_CANCEL_KEY, QUERY_EARLY_CANCEL_KEY, USER_AGENT
|
2024-01-17 12:48:50 -05:00
|
|
|
from superset.db_engine_specs.exceptions import (
|
|
|
|
SupersetDBAPIConnectionError,
|
|
|
|
SupersetDBAPIDatabaseError,
|
|
|
|
SupersetDBAPIOperationalError,
|
|
|
|
SupersetDBAPIProgrammingError,
|
|
|
|
)
|
2024-04-18 10:42:53 -04:00
|
|
|
from superset.sql_parse import Table
|
2023-11-20 12:59:10 -05:00
|
|
|
from superset.superset_typing import ResultSetColumnType, SQLAColumnType
|
2023-01-31 08:54:07 -05:00
|
|
|
from superset.utils.core import GenericDataType
|
|
|
|
from tests.unit_tests.db_engine_specs.utils import (
|
|
|
|
assert_column_spec,
|
|
|
|
assert_convert_dttm,
|
|
|
|
)
|
|
|
|
from tests.unit_tests.fixtures.common import dttm
|
|
|
|
|
|
|
|
|
2023-11-20 12:59:10 -05:00
|
|
|
def _assert_columns_equal(actual_cols, expected_cols) -> None:
|
|
|
|
"""
|
|
|
|
Assert equality of the given cols, bearing in mind sqlalchemy type
|
|
|
|
instances can't be compared for equality, so will have to be converted to
|
|
|
|
strings first.
|
|
|
|
"""
|
|
|
|
actual = copy.deepcopy(actual_cols)
|
|
|
|
expected = copy.deepcopy(expected_cols)
|
|
|
|
|
|
|
|
for col in actual:
|
|
|
|
col["type"] = str(col["type"])
|
|
|
|
|
|
|
|
for col in expected:
|
|
|
|
col["type"] = str(col["type"])
|
|
|
|
|
|
|
|
assert actual == expected
|
|
|
|
|
|
|
|
|
2023-01-31 08:54:07 -05:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"extra,expected",
|
|
|
|
[
|
|
|
|
({}, {"engine_params": {"connect_args": {"source": USER_AGENT}}}),
|
|
|
|
(
|
|
|
|
{
|
|
|
|
"first": 1,
|
|
|
|
"engine_params": {
|
|
|
|
"second": "two",
|
|
|
|
"connect_args": {"source": "foobar", "third": "three"},
|
|
|
|
},
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"first": 1,
|
|
|
|
"engine_params": {
|
|
|
|
"second": "two",
|
|
|
|
"connect_args": {"source": "foobar", "third": "three"},
|
|
|
|
},
|
|
|
|
},
|
|
|
|
),
|
|
|
|
],
|
|
|
|
)
|
2023-06-01 15:01:10 -04:00
|
|
|
def test_get_extra_params(extra: dict[str, Any], expected: dict[str, Any]) -> None:
|
2023-01-31 08:54:07 -05:00
|
|
|
from superset.db_engine_specs.trino import TrinoEngineSpec
|
|
|
|
|
|
|
|
database = Mock()
|
|
|
|
|
|
|
|
database.extra = json.dumps(extra)
|
|
|
|
database.server_cert = None
|
|
|
|
assert TrinoEngineSpec.get_extra_params(database) == expected
|
|
|
|
|
|
|
|
|
|
|
|
@patch("superset.utils.core.create_ssl_cert_file")
|
|
|
|
def test_get_extra_params_with_server_cert(mock_create_ssl_cert_file: Mock) -> None:
|
|
|
|
from superset.db_engine_specs.trino import TrinoEngineSpec
|
|
|
|
|
|
|
|
database = Mock()
|
|
|
|
|
|
|
|
database.extra = json.dumps({})
|
|
|
|
database.server_cert = "TEST_CERT"
|
|
|
|
mock_create_ssl_cert_file.return_value = "/path/to/tls.crt"
|
|
|
|
extra = TrinoEngineSpec.get_extra_params(database)
|
|
|
|
|
|
|
|
connect_args = extra.get("engine_params", {}).get("connect_args", {})
|
|
|
|
assert connect_args.get("http_scheme") == "https"
|
|
|
|
assert connect_args.get("verify") == "/path/to/tls.crt"
|
|
|
|
mock_create_ssl_cert_file.assert_called_once_with(database.server_cert)
|
|
|
|
|
|
|
|
|
|
|
|
@patch("trino.auth.BasicAuthentication")
|
|
|
|
def test_auth_basic(mock_auth: Mock) -> None:
|
|
|
|
from superset.db_engine_specs.trino import TrinoEngineSpec
|
|
|
|
|
|
|
|
database = Mock()
|
|
|
|
|
|
|
|
auth_params = {"username": "username", "password": "password"}
|
|
|
|
database.encrypted_extra = json.dumps(
|
|
|
|
{"auth_method": "basic", "auth_params": auth_params}
|
|
|
|
)
|
|
|
|
|
2023-06-01 15:01:10 -04:00
|
|
|
params: dict[str, Any] = {}
|
2023-01-31 08:54:07 -05:00
|
|
|
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
|
|
|
|
connect_args = params.setdefault("connect_args", {})
|
|
|
|
assert connect_args.get("http_scheme") == "https"
|
|
|
|
mock_auth.assert_called_once_with(**auth_params)
|
|
|
|
|
|
|
|
|
|
|
|
@patch("trino.auth.KerberosAuthentication")
|
|
|
|
def test_auth_kerberos(mock_auth: Mock) -> None:
|
|
|
|
from superset.db_engine_specs.trino import TrinoEngineSpec
|
|
|
|
|
|
|
|
database = Mock()
|
|
|
|
|
|
|
|
auth_params = {
|
|
|
|
"service_name": "superset",
|
|
|
|
"mutual_authentication": False,
|
|
|
|
"delegate": True,
|
|
|
|
}
|
|
|
|
database.encrypted_extra = json.dumps(
|
|
|
|
{"auth_method": "kerberos", "auth_params": auth_params}
|
|
|
|
)
|
|
|
|
|
2023-06-01 15:01:10 -04:00
|
|
|
params: dict[str, Any] = {}
|
2023-01-31 08:54:07 -05:00
|
|
|
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
|
|
|
|
connect_args = params.setdefault("connect_args", {})
|
|
|
|
assert connect_args.get("http_scheme") == "https"
|
|
|
|
mock_auth.assert_called_once_with(**auth_params)
|
|
|
|
|
|
|
|
|
|
|
|
@patch("trino.auth.CertificateAuthentication")
|
|
|
|
def test_auth_certificate(mock_auth: Mock) -> None:
|
|
|
|
from superset.db_engine_specs.trino import TrinoEngineSpec
|
|
|
|
|
|
|
|
database = Mock()
|
|
|
|
auth_params = {"cert": "/path/to/cert.pem", "key": "/path/to/key.pem"}
|
|
|
|
database.encrypted_extra = json.dumps(
|
|
|
|
{"auth_method": "certificate", "auth_params": auth_params}
|
|
|
|
)
|
|
|
|
|
2023-06-01 15:01:10 -04:00
|
|
|
params: dict[str, Any] = {}
|
2023-01-31 08:54:07 -05:00
|
|
|
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
|
|
|
|
connect_args = params.setdefault("connect_args", {})
|
|
|
|
assert connect_args.get("http_scheme") == "https"
|
|
|
|
mock_auth.assert_called_once_with(**auth_params)
|
|
|
|
|
|
|
|
|
|
|
|
@patch("trino.auth.JWTAuthentication")
|
|
|
|
def test_auth_jwt(mock_auth: Mock) -> None:
|
|
|
|
from superset.db_engine_specs.trino import TrinoEngineSpec
|
|
|
|
|
|
|
|
database = Mock()
|
|
|
|
|
|
|
|
auth_params = {"token": "jwt-token-string"}
|
|
|
|
database.encrypted_extra = json.dumps(
|
|
|
|
{"auth_method": "jwt", "auth_params": auth_params}
|
|
|
|
)
|
|
|
|
|
2023-06-01 15:01:10 -04:00
|
|
|
params: dict[str, Any] = {}
|
2023-01-31 08:54:07 -05:00
|
|
|
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
|
|
|
|
connect_args = params.setdefault("connect_args", {})
|
|
|
|
assert connect_args.get("http_scheme") == "https"
|
|
|
|
mock_auth.assert_called_once_with(**auth_params)
|
|
|
|
|
|
|
|
|
|
|
|
def test_auth_custom_auth() -> None:
|
|
|
|
from superset.db_engine_specs.trino import TrinoEngineSpec
|
|
|
|
|
|
|
|
database = Mock()
|
|
|
|
auth_class = Mock()
|
|
|
|
|
|
|
|
auth_method = "custom_auth"
|
|
|
|
auth_params = {"params1": "params1", "params2": "params2"}
|
|
|
|
database.encrypted_extra = json.dumps(
|
|
|
|
{"auth_method": auth_method, "auth_params": auth_params}
|
|
|
|
)
|
|
|
|
|
|
|
|
with patch.dict(
|
|
|
|
"superset.config.ALLOWED_EXTRA_AUTHENTICATIONS",
|
|
|
|
{"trino": {"custom_auth": auth_class}},
|
|
|
|
clear=True,
|
|
|
|
):
|
2023-06-01 15:01:10 -04:00
|
|
|
params: dict[str, Any] = {}
|
2023-01-31 08:54:07 -05:00
|
|
|
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
|
|
|
|
|
|
|
|
connect_args = params.setdefault("connect_args", {})
|
|
|
|
assert connect_args.get("http_scheme") == "https"
|
|
|
|
|
|
|
|
auth_class.assert_called_once_with(**auth_params)
|
|
|
|
|
|
|
|
|
|
|
|
def test_auth_custom_auth_denied() -> None:
|
|
|
|
from superset.db_engine_specs.trino import TrinoEngineSpec
|
|
|
|
|
|
|
|
database = Mock()
|
|
|
|
auth_method = "my.module:TrinoAuthClass"
|
|
|
|
auth_params = {"params1": "params1", "params2": "params2"}
|
|
|
|
database.encrypted_extra = json.dumps(
|
|
|
|
{"auth_method": auth_method, "auth_params": auth_params}
|
|
|
|
)
|
|
|
|
|
|
|
|
superset.config.ALLOWED_EXTRA_AUTHENTICATIONS = {}
|
|
|
|
|
|
|
|
with pytest.raises(ValueError) as excinfo:
|
|
|
|
TrinoEngineSpec.update_params_from_encrypted_extra(database, {})
|
|
|
|
|
|
|
|
assert str(excinfo.value) == (
|
|
|
|
f"For security reason, custom authentication '{auth_method}' "
|
|
|
|
f"must be listed in 'ALLOWED_EXTRA_AUTHENTICATIONS' config"
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"native_type,sqla_type,attrs,generic_type,is_dttm",
|
|
|
|
[
|
|
|
|
("BOOLEAN", types.Boolean, None, GenericDataType.BOOLEAN, False),
|
|
|
|
("TINYINT", types.Integer, None, GenericDataType.NUMERIC, False),
|
|
|
|
("SMALLINT", types.SmallInteger, None, GenericDataType.NUMERIC, False),
|
|
|
|
("INTEGER", types.Integer, None, GenericDataType.NUMERIC, False),
|
|
|
|
("BIGINT", types.BigInteger, None, GenericDataType.NUMERIC, False),
|
|
|
|
("REAL", types.FLOAT, None, GenericDataType.NUMERIC, False),
|
|
|
|
("DOUBLE", types.FLOAT, None, GenericDataType.NUMERIC, False),
|
|
|
|
("DECIMAL", types.DECIMAL, None, GenericDataType.NUMERIC, False),
|
|
|
|
("VARCHAR", types.String, None, GenericDataType.STRING, False),
|
|
|
|
("VARCHAR(20)", types.VARCHAR, {"length": 20}, GenericDataType.STRING, False),
|
|
|
|
("CHAR", types.String, None, GenericDataType.STRING, False),
|
|
|
|
("CHAR(2)", types.CHAR, {"length": 2}, GenericDataType.STRING, False),
|
|
|
|
("JSON", types.JSON, None, GenericDataType.STRING, False),
|
|
|
|
("TIMESTAMP", types.TIMESTAMP, None, GenericDataType.TEMPORAL, True),
|
|
|
|
("TIMESTAMP(3)", types.TIMESTAMP, None, GenericDataType.TEMPORAL, True),
|
|
|
|
(
|
|
|
|
"TIMESTAMP WITH TIME ZONE",
|
|
|
|
types.TIMESTAMP,
|
|
|
|
None,
|
|
|
|
GenericDataType.TEMPORAL,
|
|
|
|
True,
|
|
|
|
),
|
|
|
|
(
|
|
|
|
"TIMESTAMP(3) WITH TIME ZONE",
|
|
|
|
types.TIMESTAMP,
|
|
|
|
None,
|
|
|
|
GenericDataType.TEMPORAL,
|
|
|
|
True,
|
|
|
|
),
|
|
|
|
("DATE", types.Date, None, GenericDataType.TEMPORAL, True),
|
|
|
|
],
|
|
|
|
)
|
|
|
|
def test_get_column_spec(
|
|
|
|
native_type: str,
|
2023-06-01 15:01:10 -04:00
|
|
|
sqla_type: type[types.TypeEngine],
|
|
|
|
attrs: Optional[dict[str, Any]],
|
2023-01-31 08:54:07 -05:00
|
|
|
generic_type: GenericDataType,
|
|
|
|
is_dttm: bool,
|
|
|
|
) -> None:
|
|
|
|
from superset.db_engine_specs.trino import TrinoEngineSpec as spec
|
|
|
|
|
|
|
|
assert_column_spec(
|
|
|
|
spec,
|
|
|
|
native_type,
|
|
|
|
sqla_type,
|
|
|
|
attrs,
|
|
|
|
generic_type,
|
|
|
|
is_dttm,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"target_type,expected_result",
|
|
|
|
[
|
|
|
|
("TimeStamp", "TIMESTAMP '2019-01-02 03:04:05.678900'"),
|
|
|
|
("TimeStamp(3)", "TIMESTAMP '2019-01-02 03:04:05.678900'"),
|
|
|
|
("TimeStamp With Time Zone", "TIMESTAMP '2019-01-02 03:04:05.678900'"),
|
|
|
|
("TimeStamp(3) With Time Zone", "TIMESTAMP '2019-01-02 03:04:05.678900'"),
|
|
|
|
("Date", "DATE '2019-01-02'"),
|
|
|
|
("Other", None),
|
|
|
|
],
|
|
|
|
)
|
|
|
|
def test_convert_dttm(
|
|
|
|
target_type: str,
|
|
|
|
expected_result: Optional[str],
|
|
|
|
dttm: datetime,
|
|
|
|
) -> None:
|
|
|
|
from superset.db_engine_specs.trino import TrinoEngineSpec
|
|
|
|
|
|
|
|
assert_convert_dttm(TrinoEngineSpec, target_type, expected_result, dttm)
|
|
|
|
|
|
|
|
|
2024-04-18 10:42:53 -04:00
|
|
|
def test_get_extra_table_metadata() -> None:
|
2023-01-31 08:54:07 -05:00
|
|
|
from superset.db_engine_specs.trino import TrinoEngineSpec
|
|
|
|
|
|
|
|
db_mock = Mock()
|
|
|
|
db_mock.get_indexes = Mock(
|
|
|
|
return_value=[{"column_names": ["ds", "hour"], "name": "partition"}]
|
|
|
|
)
|
|
|
|
db_mock.get_extra = Mock(return_value={})
|
|
|
|
db_mock.has_view_by_name = Mock(return_value=None)
|
|
|
|
db_mock.get_df = Mock(return_value=pd.DataFrame({"ds": ["01-01-19"], "hour": [1]}))
|
2024-04-18 10:42:53 -04:00
|
|
|
result = TrinoEngineSpec.get_extra_table_metadata(
|
|
|
|
db_mock,
|
|
|
|
Table("test_table", "test_schema"),
|
|
|
|
)
|
2023-01-31 08:54:07 -05:00
|
|
|
assert result["partitions"]["cols"] == ["ds", "hour"]
|
|
|
|
assert result["partitions"]["latest"] == {"ds": "01-01-19", "hour": 1}
|
2022-12-23 23:31:46 -05:00
|
|
|
|
2022-08-12 00:48:29 -04:00
|
|
|
|
2023-01-31 08:54:07 -05:00
|
|
|
@patch("sqlalchemy.engine.Engine.connect")
|
|
|
|
def test_cancel_query_success(engine_mock: Mock) -> None:
|
2022-08-12 00:48:29 -04:00
|
|
|
from superset.db_engine_specs.trino import TrinoEngineSpec
|
|
|
|
from superset.models.sql_lab import Query
|
|
|
|
|
|
|
|
query = Query()
|
|
|
|
cursor_mock = engine_mock.return_value.__enter__.return_value
|
|
|
|
assert TrinoEngineSpec.cancel_query(cursor_mock, query, "123") is True
|
|
|
|
|
|
|
|
|
2023-01-31 08:54:07 -05:00
|
|
|
@patch("sqlalchemy.engine.Engine.connect")
|
|
|
|
def test_cancel_query_failed(engine_mock: Mock) -> None:
|
2022-08-12 00:48:29 -04:00
|
|
|
from superset.db_engine_specs.trino import TrinoEngineSpec
|
|
|
|
from superset.models.sql_lab import Query
|
|
|
|
|
|
|
|
query = Query()
|
|
|
|
cursor_mock = engine_mock.raiseError.side_effect = Exception()
|
|
|
|
assert TrinoEngineSpec.cancel_query(cursor_mock, query, "123") is False
|
2022-12-23 23:31:46 -05:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"initial_extra,final_extra",
|
|
|
|
[
|
|
|
|
({}, {QUERY_EARLY_CANCEL_KEY: True}),
|
|
|
|
({QUERY_CANCEL_KEY: "my_key"}, {QUERY_CANCEL_KEY: "my_key"}),
|
|
|
|
],
|
|
|
|
)
|
|
|
|
def test_prepare_cancel_query(
|
2023-06-01 15:01:10 -04:00
|
|
|
initial_extra: dict[str, Any],
|
|
|
|
final_extra: dict[str, Any],
|
2022-12-23 23:31:46 -05:00
|
|
|
mocker: MockerFixture,
|
|
|
|
) -> None:
|
|
|
|
from superset.db_engine_specs.trino import TrinoEngineSpec
|
|
|
|
from superset.models.sql_lab import Query
|
|
|
|
|
|
|
|
query = Query(extra_json=json.dumps(initial_extra))
|
2024-01-16 23:06:22 -05:00
|
|
|
TrinoEngineSpec.prepare_cancel_query(query=query)
|
2022-12-23 23:31:46 -05:00
|
|
|
assert query.extra == final_extra
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("cancel_early", [True, False])
|
2023-01-31 08:54:07 -05:00
|
|
|
@patch("superset.db_engine_specs.trino.TrinoEngineSpec.cancel_query")
|
|
|
|
@patch("sqlalchemy.engine.Engine.connect")
|
2022-12-23 23:31:46 -05:00
|
|
|
def test_handle_cursor_early_cancel(
|
2023-01-31 08:54:07 -05:00
|
|
|
engine_mock: Mock,
|
|
|
|
cancel_query_mock: Mock,
|
2022-12-23 23:31:46 -05:00
|
|
|
cancel_early: bool,
|
|
|
|
mocker: MockerFixture,
|
|
|
|
) -> None:
|
|
|
|
from superset.db_engine_specs.trino import TrinoEngineSpec
|
|
|
|
from superset.models.sql_lab import Query
|
|
|
|
|
|
|
|
query_id = "myQueryId"
|
|
|
|
|
|
|
|
cursor_mock = engine_mock.return_value.__enter__.return_value
|
2023-10-19 09:38:13 -04:00
|
|
|
cursor_mock.query_id = query_id
|
2022-12-23 23:31:46 -05:00
|
|
|
|
|
|
|
query = Query()
|
|
|
|
|
|
|
|
if cancel_early:
|
2024-01-16 23:06:22 -05:00
|
|
|
TrinoEngineSpec.prepare_cancel_query(query=query)
|
2022-12-23 23:31:46 -05:00
|
|
|
|
2024-01-16 23:06:22 -05:00
|
|
|
TrinoEngineSpec.handle_cursor(cursor=cursor_mock, query=query)
|
2022-12-23 23:31:46 -05:00
|
|
|
|
|
|
|
if cancel_early:
|
|
|
|
assert cancel_query_mock.call_args[1]["cancel_query_id"] == query_id
|
|
|
|
else:
|
|
|
|
assert cancel_query_mock.call_args is None
|
2023-10-19 09:38:13 -04:00
|
|
|
|
|
|
|
|
|
|
|
def test_execute_with_cursor_in_parallel(mocker: MockerFixture):
|
|
|
|
"""Test that `execute_with_cursor` fetches query ID from the cursor"""
|
|
|
|
from superset.db_engine_specs.trino import TrinoEngineSpec
|
|
|
|
|
|
|
|
query_id = "myQueryId"
|
|
|
|
|
|
|
|
mock_cursor = mocker.MagicMock()
|
|
|
|
mock_cursor.query_id = None
|
|
|
|
|
|
|
|
mock_query = mocker.MagicMock()
|
|
|
|
|
|
|
|
def _mock_execute(*args, **kwargs):
|
|
|
|
mock_cursor.query_id = query_id
|
|
|
|
|
|
|
|
mock_cursor.execute.side_effect = _mock_execute
|
|
|
|
|
|
|
|
TrinoEngineSpec.execute_with_cursor(
|
|
|
|
cursor=mock_cursor,
|
|
|
|
sql="SELECT 1 FROM foo",
|
|
|
|
query=mock_query,
|
|
|
|
)
|
|
|
|
|
|
|
|
mock_query.set_extra_json_key.assert_called_once_with(
|
|
|
|
key=QUERY_CANCEL_KEY, value=query_id
|
|
|
|
)
|
2023-11-20 12:59:10 -05:00
|
|
|
|
|
|
|
|
|
|
|
def test_get_columns(mocker: MockerFixture):
|
|
|
|
"""Test that ROW columns are not expanded without expand_rows"""
|
|
|
|
from superset.db_engine_specs.trino import TrinoEngineSpec
|
|
|
|
|
|
|
|
field1_type = datatype.parse_sqltype("row(a varchar, b date)")
|
|
|
|
field2_type = datatype.parse_sqltype("row(r1 row(a varchar, b varchar))")
|
|
|
|
field3_type = datatype.parse_sqltype("int")
|
|
|
|
|
|
|
|
sqla_columns = [
|
|
|
|
SQLAColumnType(name="field1", type=field1_type, is_dttm=False),
|
|
|
|
SQLAColumnType(name="field2", type=field2_type, is_dttm=False),
|
|
|
|
SQLAColumnType(name="field3", type=field3_type, is_dttm=False),
|
|
|
|
]
|
|
|
|
mock_inspector = mocker.MagicMock()
|
|
|
|
mock_inspector.get_columns.return_value = sqla_columns
|
|
|
|
|
|
|
|
actual = TrinoEngineSpec.get_columns(mock_inspector, "table", "schema")
|
|
|
|
expected = [
|
|
|
|
ResultSetColumnType(
|
|
|
|
name="field1", column_name="field1", type=field1_type, is_dttm=False
|
|
|
|
),
|
|
|
|
ResultSetColumnType(
|
|
|
|
name="field2", column_name="field2", type=field2_type, is_dttm=False
|
|
|
|
),
|
|
|
|
ResultSetColumnType(
|
|
|
|
name="field3", column_name="field3", type=field3_type, is_dttm=False
|
|
|
|
),
|
|
|
|
]
|
|
|
|
|
|
|
|
_assert_columns_equal(actual, expected)
|
|
|
|
|
|
|
|
|
|
|
|
def test_get_columns_expand_rows(mocker: MockerFixture):
|
|
|
|
"""Test that ROW columns are correctly expanded with expand_rows"""
|
|
|
|
from superset.db_engine_specs.trino import TrinoEngineSpec
|
|
|
|
|
|
|
|
field1_type = datatype.parse_sqltype("row(a varchar, b date)")
|
|
|
|
field2_type = datatype.parse_sqltype("row(r1 row(a varchar, b varchar))")
|
|
|
|
field3_type = datatype.parse_sqltype("int")
|
|
|
|
|
|
|
|
sqla_columns = [
|
|
|
|
SQLAColumnType(name="field1", type=field1_type, is_dttm=False),
|
|
|
|
SQLAColumnType(name="field2", type=field2_type, is_dttm=False),
|
|
|
|
SQLAColumnType(name="field3", type=field3_type, is_dttm=False),
|
|
|
|
]
|
|
|
|
mock_inspector = mocker.MagicMock()
|
|
|
|
mock_inspector.get_columns.return_value = sqla_columns
|
|
|
|
|
|
|
|
actual = TrinoEngineSpec.get_columns(
|
|
|
|
mock_inspector, "table", "schema", {"expand_rows": True}
|
|
|
|
)
|
|
|
|
expected = [
|
|
|
|
ResultSetColumnType(
|
|
|
|
name="field1", column_name="field1", type=field1_type, is_dttm=False
|
|
|
|
),
|
|
|
|
ResultSetColumnType(
|
|
|
|
name="field1.a",
|
|
|
|
column_name="field1.a",
|
|
|
|
type=types.VARCHAR(),
|
|
|
|
is_dttm=False,
|
|
|
|
query_as='"field1"."a" AS "field1.a"',
|
|
|
|
),
|
|
|
|
ResultSetColumnType(
|
|
|
|
name="field1.b",
|
|
|
|
column_name="field1.b",
|
|
|
|
type=types.DATE(),
|
|
|
|
is_dttm=True,
|
|
|
|
query_as='"field1"."b" AS "field1.b"',
|
|
|
|
),
|
|
|
|
ResultSetColumnType(
|
|
|
|
name="field2", column_name="field2", type=field2_type, is_dttm=False
|
|
|
|
),
|
|
|
|
ResultSetColumnType(
|
|
|
|
name="field2.r1",
|
|
|
|
column_name="field2.r1",
|
|
|
|
type=datatype.parse_sqltype("row(a varchar, b varchar)"),
|
|
|
|
is_dttm=False,
|
|
|
|
query_as='"field2"."r1" AS "field2.r1"',
|
|
|
|
),
|
|
|
|
ResultSetColumnType(
|
|
|
|
name="field2.r1.a",
|
|
|
|
column_name="field2.r1.a",
|
|
|
|
type=types.VARCHAR(),
|
|
|
|
is_dttm=False,
|
|
|
|
query_as='"field2"."r1"."a" AS "field2.r1.a"',
|
|
|
|
),
|
|
|
|
ResultSetColumnType(
|
|
|
|
name="field2.r1.b",
|
|
|
|
column_name="field2.r1.b",
|
|
|
|
type=types.VARCHAR(),
|
|
|
|
is_dttm=False,
|
|
|
|
query_as='"field2"."r1"."b" AS "field2.r1.b"',
|
|
|
|
),
|
|
|
|
ResultSetColumnType(
|
|
|
|
name="field3", column_name="field3", type=field3_type, is_dttm=False
|
|
|
|
),
|
|
|
|
]
|
|
|
|
|
|
|
|
_assert_columns_equal(actual, expected)
|
2024-01-10 19:37:18 -05:00
|
|
|
|
|
|
|
|
|
|
|
def test_get_indexes_no_table():
|
|
|
|
from sqlalchemy.exc import NoSuchTableError
|
|
|
|
|
|
|
|
from superset.db_engine_specs.trino import TrinoEngineSpec
|
|
|
|
|
|
|
|
db_mock = Mock()
|
|
|
|
inspector_mock = Mock()
|
|
|
|
inspector_mock.get_indexes = Mock(
|
|
|
|
side_effect=NoSuchTableError("The specified table does not exist.")
|
|
|
|
)
|
|
|
|
result = TrinoEngineSpec.get_indexes(
|
|
|
|
db_mock, inspector_mock, "test_table", "test_schema"
|
|
|
|
)
|
|
|
|
assert result == []
|
2024-01-17 12:48:50 -05:00
|
|
|
|
|
|
|
|
|
|
|
def test_get_dbapi_exception_mapping():
|
|
|
|
from superset.db_engine_specs.trino import TrinoEngineSpec
|
|
|
|
|
|
|
|
mapping = TrinoEngineSpec.get_dbapi_exception_mapping()
|
|
|
|
assert mapping.get(TrinoUserError) == SupersetDBAPIProgrammingError
|
|
|
|
assert mapping.get(TrinoInternalError) == SupersetDBAPIDatabaseError
|
|
|
|
assert mapping.get(TrinoExternalError) == SupersetDBAPIOperationalError
|
|
|
|
assert mapping.get(RequestsConnectionError) == SupersetDBAPIConnectionError
|
|
|
|
assert mapping.get(Exception) is None
|