mirror of https://github.com/apache/superset.git
feat: trino support server-cert (#16346)
Signed-off-by: Đặng Minh Dũng <dungdm93@live.com>
This commit is contained in:
parent
ff68502d31
commit
ebb34196f2
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING
|
||||
from urllib import parse
|
||||
|
||||
import simplejson as json
|
||||
|
@ -24,6 +24,9 @@ from sqlalchemy.engine.url import make_url, URL
|
|||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
from superset.utils import core as utils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.models.core import Database
|
||||
|
||||
|
||||
class TrinoEngineSpec(BaseEngineSpec):
|
||||
engine = "trino"
|
||||
|
@ -81,7 +84,6 @@ class TrinoEngineSpec(BaseEngineSpec):
|
|||
that can set the correct properties for impersonating users
|
||||
:param connect_args: config to be updated
|
||||
:param uri: URI string
|
||||
:param impersonate_user: Flag indicating if impersonation is enabled
|
||||
:param username: Effective username
|
||||
:return: None
|
||||
"""
|
||||
|
@ -116,9 +118,7 @@ class TrinoEngineSpec(BaseEngineSpec):
|
|||
Run a SQL query that estimates the cost of a given statement.
|
||||
|
||||
:param statement: A single SQL statement
|
||||
:param database: Database instance
|
||||
:param cursor: Cursor instance
|
||||
:param username: Effective username
|
||||
:return: JSON response from Trino
|
||||
"""
|
||||
sql = f"EXPLAIN (TYPE IO, FORMAT JSON) {statement}"
|
||||
|
@ -183,3 +183,22 @@ class TrinoEngineSpec(BaseEngineSpec):
|
|||
cost.append(statement_cost)
|
||||
|
||||
return cost
|
||||
|
||||
@staticmethod
|
||||
def get_extra_params(database: "Database") -> Dict[str, Any]:
|
||||
"""
|
||||
Some databases require adding elements to connection parameters,
|
||||
like passing certificates to `extra`. This can be done here.
|
||||
|
||||
:param database: database instance from which to extract extras
|
||||
:raises CertificateException: If certificate is not valid/unparseable
|
||||
"""
|
||||
extra: Dict[str, Any] = BaseEngineSpec.get_extra_params(database)
|
||||
engine_params: Dict[str, Any] = extra.setdefault("engine_params", {})
|
||||
connect_args: Dict[str, Any] = engine_params.setdefault("connect_args", {})
|
||||
|
||||
if database.server_cert:
|
||||
connect_args["http_scheme"] = "https"
|
||||
connect_args["verify"] = utils.create_ssl_cert_file(database.server_cert)
|
||||
|
||||
return extra
|
||||
|
|
|
@ -14,6 +14,9 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import json
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from sqlalchemy.engine.url import URL
|
||||
|
||||
from superset.db_engine_specs.trino import TrinoEngineSpec
|
||||
|
@ -52,3 +55,35 @@ class TestTrinoDbEngineSpec(TestDbEngineSpec):
|
|||
url.database = "hive/default"
|
||||
TrinoEngineSpec.adjust_database_uri(url, selected_schema=None)
|
||||
self.assertEqual(url.database, "hive/default")
|
||||
|
||||
def test_get_extra_params(self):
|
||||
database = Mock()
|
||||
|
||||
database.extra = json.dumps({})
|
||||
database.server_cert = None
|
||||
extra = TrinoEngineSpec.get_extra_params(database)
|
||||
expected = {"engine_params": {"connect_args": {}}}
|
||||
self.assertEqual(extra, expected)
|
||||
|
||||
expected = {
|
||||
"first": 1,
|
||||
"engine_params": {"second": "two", "connect_args": {"third": "three"}},
|
||||
}
|
||||
database.extra = json.dumps(expected)
|
||||
database.server_cert = None
|
||||
extra = TrinoEngineSpec.get_extra_params(database)
|
||||
self.assertEqual(extra, expected)
|
||||
|
||||
@patch("superset.utils.core.create_ssl_cert_file")
|
||||
def test_get_extra_params_with_server_cert(self, create_ssl_cert_file_func: Mock):
|
||||
database = Mock()
|
||||
|
||||
database.extra = json.dumps({})
|
||||
database.server_cert = "TEST_CERT"
|
||||
create_ssl_cert_file_func.return_value = "/path/to/tls.crt"
|
||||
extra = TrinoEngineSpec.get_extra_params(database)
|
||||
|
||||
connect_args = extra.get("engine_params", {}).get("connect_args", {})
|
||||
self.assertEqual(connect_args.get("http_scheme"), "https")
|
||||
self.assertEqual(connect_args.get("verify"), "/path/to/tls.crt")
|
||||
create_ssl_cert_file_func.assert_called_once_with(database.server_cert)
|
||||
|
|
Loading…
Reference in New Issue