feat(ssh-tunnelling): Setup SSH Tunneling Commands for Database Connections (#21912)

Co-authored-by: Antonio Rivero Martinez <38889534+Antonio-RiveroMartnez@users.noreply.github.com>
Co-authored-by: Elizabeth Thompson <eschutho@gmail.com>
This commit is contained in:
Hugh A. Miles II 2023-01-03 17:22:42 -05:00 committed by GitHub
parent a7a4561550
commit ebaad10d6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
40 changed files with 1905 additions and 47 deletions

View File

@ -19,6 +19,8 @@ babel==2.9.1
# via flask-babel
backoff==1.11.1
# via apache-superset
bcrypt==4.0.1
# via paramiko
billiard==3.6.4.0
# via celery
bleach==3.3.1
@ -57,7 +59,9 @@ cron-descriptor==1.2.24
croniter==1.0.15
# via apache-superset
cryptography==3.4.7
# via apache-superset
# via
# apache-superset
# paramiko
deprecation==2.1.0
# via apache-superset
dnspython==2.1.0
@ -167,6 +171,8 @@ packaging==21.3
# deprecation
pandas==1.5.2
# via apache-superset
paramiko==2.11.0
# via sshtunnel
parsedatetime==2.6
# via apache-superset
pgsanity==0.2.9
@ -188,6 +194,8 @@ pyjwt==2.4.0
# flask-jwt-extended
pymeeus==0.5.11
# via convertdate
pynacl==1.5.0
# via paramiko
pyparsing==3.0.6
# via
# apache-superset
@ -231,6 +239,7 @@ six==1.16.0
# flask-talisman
# isodate
# jsonschema
# paramiko
# polyline
# prison
# pyrsistent
@ -252,6 +261,8 @@ sqlalchemy-utils==0.38.3
# flask-appbuilder
sqlparse==0.4.3
# via apache-superset
sshtunnel==0.4.0
# via apache-superset
tabulate==0.8.9
# via apache-superset
typing-extensions==4.4.0

View File

@ -113,6 +113,7 @@ setup(
"PyJWT>=2.4.0, <3.0",
"redis",
"selenium>=3.141.0",
"sshtunnel>=0.4.0, <0.5",
"simplejson>=3.15.0",
"slack_sdk>=3.1.1, <4",
"sqlalchemy>=1.4, <2",

View File

@ -476,8 +476,30 @@ DEFAULT_FEATURE_FLAGS: Dict[str, bool] = {
"DRILL_TO_DETAIL": False,
"DATAPANEL_CLOSED_BY_DEFAULT": False,
"HORIZONTAL_FILTER_BAR": False,
# Allow users to enable ssh tunneling when creating a DB.
# Users must check whether the DB engine supports SSH Tunnels
# otherwise enabling this flag won't have any effect on the DB.
"SSH_TUNNELING": False,
}
# ------------------------------
# SSH Tunnel
# ------------------------------
# Allow users to set the host used when connecting to the SSH Tunnel
# as localhost and any other alias (0.0.0.0)
# ----------------------------------------------------------------------
# |
# -------------+ | +----------+
# LOCAL | | | REMOTE | :22 SSH
# CLIENT | <== SSH ========> | SERVER | :8080 web service
# -------------+ | +----------+
# |
# FIREWALL (only port 22 is open)
# ----------------------------------------------------------------------
SSH_TUNNEL_MANAGER_CLASS = "superset.extensions.ssh.SSHManager"
SSH_TUNNEL_LOCAL_BIND_ADDRESS = "127.0.0.1"
# Feature flags may also be set via 'SUPERSET_FEATURE_' prefixed environment vars.
DEFAULT_FEATURE_FLAGS.update(
{
@ -1506,7 +1528,7 @@ elif importlib.util.find_spec("superset_config") and not is_test():
try:
# pylint: disable=import-error,wildcard-import,unused-wildcard-import
import superset_config
from superset_config import * # type:ignore
from superset_config import * # type: ignore
print(f"Loaded your LOCAL configuration at [{superset_config.__file__}]")
except Exception:

View File

@ -139,6 +139,7 @@ MODEL_API_RW_METHOD_PERMISSION_MAP = {
"validate_sql": "read",
"get_data": "read",
"samples": "read",
"delete_ssh_tunnel": "write",
}
EXTRA_FORM_DATA_APPEND_KEYS = {

View File

@ -72,6 +72,11 @@ from superset.databases.schemas import (
ValidateSQLRequest,
ValidateSQLResponse,
)
from superset.databases.ssh_tunnel.commands.delete import DeleteSSHTunnelCommand
from superset.databases.ssh_tunnel.commands.exceptions import (
SSHTunnelDeleteFailedError,
SSHTunnelNotFoundError,
)
from superset.databases.utils import get_table_metadata
from superset.db_engine_specs import get_available_engine_specs
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
@ -80,6 +85,7 @@ from superset.extensions import security_manager
from superset.models.core import Database
from superset.superset_typing import FlaskResponse
from superset.utils.core import error_msg_from_exception, parse_js_uri_path_item
from superset.utils.ssh_tunnel import mask_password_info
from superset.views.base import json_errors_response
from superset.views.base_api import (
BaseSupersetModelRestApi,
@ -107,6 +113,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
"available",
"validate_parameters",
"validate_sql",
"delete_ssh_tunnel",
}
resource_name = "database"
class_permission_name = "Database"
@ -219,6 +226,47 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
ValidateSQLResponse,
)
@expose("/<int:pk>", methods=["GET"])
@protect()
@safe
def get(self, pk: int, **kwargs: Any) -> Response:
"""Get a database
---
get:
description: >-
Get a database
parameters:
- in: path
schema:
type: integer
description: The database id
name: pk
responses:
200:
description: Database
content:
application/json:
schema:
type: object
400:
$ref: '#/components/responses/400'
401:
$ref: '#/components/responses/401'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
data = self.get_headless(pk, **kwargs)
try:
if ssh_tunnel := DatabaseDAO.get_ssh_tunnel(pk):
payload = data.json
payload["result"]["ssh_tunnel"] = ssh_tunnel.data
return payload
return data
except SupersetException as ex:
return self.response(ex.status, message=ex.message)
@expose("/", methods=["POST"])
@protect()
@safe
@ -280,6 +328,12 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
if new_model.driver:
item["driver"] = new_model.driver
# Return SSH Tunnel and hide passwords if any
if item.get("ssh_tunnel"):
item["ssh_tunnel"] = mask_password_info(
new_model.ssh_tunnel # pylint: disable=no-member
)
return self.response(201, id=new_model.id, result=item)
except DatabaseInvalidError as ex:
return self.response_422(message=ex.normalized_messages())
@ -361,6 +415,9 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
item["sqlalchemy_uri"] = changed_model.sqlalchemy_uri
if changed_model.parameters:
item["parameters"] = changed_model.parameters
# Return SSH Tunnel and hide passwords if any
if item.get("ssh_tunnel"):
item["ssh_tunnel"] = mask_password_info(changed_model.ssh_tunnel)
return self.response(200, id=changed_model.id, result=item)
except DatabaseNotFoundError:
return self.response_404()
@ -1206,3 +1263,57 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
command = ValidateDatabaseParametersCommand(payload)
command.run()
return self.response(200, message="OK")
@expose("/<int:pk>/ssh_tunnel/", methods=["DELETE"])
@protect()
@statsd_metrics
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}"
f".delete_ssh_tunnel",
log_to_statsd=False,
)
def delete_ssh_tunnel(self, pk: int) -> Response:
"""Deletes a SSH Tunnel
---
delete:
description: >-
Deletes a SSH Tunnel.
parameters:
- in: path
schema:
type: integer
name: pk
responses:
200:
description: SSH Tunnel deleted
content:
application/json:
schema:
type: object
properties:
message:
type: string
401:
$ref: '#/components/responses/401'
403:
$ref: '#/components/responses/403'
404:
$ref: '#/components/responses/404'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
try:
DeleteSSHTunnelCommand(pk).run()
return self.response(200, message="OK")
except SSHTunnelNotFoundError:
return self.response_404()
except SSHTunnelDeleteFailedError as ex:
logger.error(
"Error deleting SSH Tunnel %s: %s",
self.__class__.__name__,
str(ex),
exc_info=True,
)
return self.response_422(message=str(ex))

View File

@ -31,6 +31,11 @@ from superset.databases.commands.exceptions import (
)
from superset.databases.commands.test_connection import TestConnectionDatabaseCommand
from superset.databases.dao import DatabaseDAO
from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand
from superset.databases.ssh_tunnel.commands.exceptions import (
SSHTunnelCreateFailedError,
SSHTunnelInvalidError,
)
from superset.exceptions import SupersetErrorsException
from superset.extensions import db, event_logger, security_manager
@ -71,12 +76,35 @@ class CreateDatabaseCommand(BaseCommand):
database = DatabaseDAO.create(self._properties, commit=False)
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
ssh_tunnel = None
if ssh_tunnel_properties := self._properties.get("ssh_tunnel"):
try:
# So database.id is not None
db.session.flush()
ssh_tunnel = CreateSSHTunnelCommand(
database.id, ssh_tunnel_properties
).run()
except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex:
event_logger.log_with_context(
action=f"db_creation_failed.{ex.__class__.__name__}",
engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
)
# So we can show the original message
raise ex
except Exception as ex:
event_logger.log_with_context(
action=f"db_creation_failed.{ex.__class__.__name__}",
engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
)
raise DatabaseCreateFailedError() from ex
# adding a new database we always want to force refresh schema list
schemas = database.get_all_schema_names(cache=False)
schemas = database.get_all_schema_names(cache=False, ssh_tunnel=ssh_tunnel)
for schema in schemas:
security_manager.add_permission_view_menu(
"schema_access", security_manager.get_schema_perm(database, schema)
)
db.session.commit()
except DAOCreateFailedError as ex:
db.session.rollback()

View File

@ -32,6 +32,7 @@ from superset.databases.commands.exceptions import (
DatabaseTestConnectionUnexpectedError,
)
from superset.databases.dao import DatabaseDAO
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.databases.utils import make_url_safe
from superset.errors import ErrorLevel, SupersetErrorType
from superset.exceptions import (
@ -90,6 +91,10 @@ class TestConnectionDatabaseCommand(BaseCommand):
database.set_sqlalchemy_uri(uri)
database.db_engine_spec.mutate_db_for_connection_test(database)
# Generate tunnel if present in the properties
if ssh_tunnel := self._properties.get("ssh_tunnel"):
ssh_tunnel = SSHTunnel(**ssh_tunnel)
event_logger.log_with_context(
action="test_connection_attempt",
engine=database.db_engine_spec.__name__,
@ -99,7 +104,9 @@ class TestConnectionDatabaseCommand(BaseCommand):
with closing(engine.raw_connection()) as conn:
return engine.dialect.do_ping(conn)
with database.get_sqla_engine_with_context() as engine:
with database.get_sqla_engine_with_context(
override_ssh_tunnel=ssh_tunnel
) as engine:
try:
alive = func_timeout(
app.config["TEST_DATABASE_CONNECTION_TIMEOUT"].total_seconds(),

View File

@ -21,7 +21,7 @@ from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError
from superset.commands.base import BaseCommand
from superset.dao.exceptions import DAOUpdateFailedError
from superset.dao.exceptions import DAOCreateFailedError, DAOUpdateFailedError
from superset.databases.commands.exceptions import (
DatabaseConnectionFailedError,
DatabaseExistsValidationError,
@ -30,6 +30,12 @@ from superset.databases.commands.exceptions import (
DatabaseUpdateFailedError,
)
from superset.databases.dao import DatabaseDAO
from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand
from superset.databases.ssh_tunnel.commands.exceptions import (
SSHTunnelCreateFailedError,
SSHTunnelInvalidError,
)
from superset.databases.ssh_tunnel.commands.update import UpdateSSHTunnelCommand
from superset.extensions import db, security_manager
from superset.models.core import Database
from superset.utils.core import DatasourceType
@ -94,10 +100,33 @@ class UpdateDatabaseCommand(BaseCommand):
security_manager.add_permission_view_menu(
"schema_access", security_manager.get_schema_perm(database, schema)
)
if ssh_tunnel_properties := self._properties.get("ssh_tunnel"):
existing_ssh_tunnel_model = DatabaseDAO.get_ssh_tunnel(database.id)
if existing_ssh_tunnel_model is None:
# We couldn't found an existing tunnel so we need to create one
try:
CreateSSHTunnelCommand(database.id, ssh_tunnel_properties).run()
except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex:
# So we can show the original message
raise ex
except Exception as ex:
raise DatabaseUpdateFailedError() from ex
else:
# We found an existing tunnel so we need to update it
try:
UpdateSSHTunnelCommand(
existing_ssh_tunnel_model.id, ssh_tunnel_properties
).run()
except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex:
# So we can show the original message
raise ex
except Exception as ex:
raise DatabaseUpdateFailedError() from ex
db.session.commit()
except DAOUpdateFailedError as ex:
logger.exception(ex.exception)
except (DAOUpdateFailedError, DAOCreateFailedError) as ex:
raise DatabaseUpdateFailedError() from ex
return database

View File

@ -19,6 +19,7 @@ from typing import Any, Dict, Optional
from superset.dao.base import BaseDAO
from superset.databases.filters import DatabaseFilter
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.extensions import db
from superset.models.core import Database
from superset.models.dashboard import Dashboard
@ -124,3 +125,13 @@ class DatabaseDAO(BaseDAO):
return dict(
charts=charts, dashboards=dashboards, sqllab_tab_states=sqllab_tab_states
)
@classmethod
def get_ssh_tunnel(cls, database_id: int) -> Optional[SSHTunnel]:
ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == database_id)
.one_or_none()
)
return ssh_tunnel

View File

@ -365,6 +365,19 @@ class DatabaseValidateParametersSchema(Schema):
)
class DatabaseSSHTunnel(Schema):
server_address = fields.String()
server_port = fields.Integer()
username = fields.String()
# Basic Authentication
password = fields.String(required=False)
# password protected private key authentication
private_key = fields.String(required=False)
private_key_password = fields.String(required=False)
class DatabasePostSchema(Schema, DatabaseParametersSchemaMixin):
class Meta: # pylint: disable=too-few-public-methods
unknown = EXCLUDE
@ -409,6 +422,7 @@ class DatabasePostSchema(Schema, DatabaseParametersSchemaMixin):
is_managed_externally = fields.Boolean(allow_none=True, default=False)
external_url = fields.String(allow_none=True)
uuid = fields.String(required=False)
ssh_tunnel = fields.Nested(DatabaseSSHTunnel, allow_none=True)
class DatabasePutSchema(Schema, DatabaseParametersSchemaMixin):
@ -454,6 +468,7 @@ class DatabasePutSchema(Schema, DatabaseParametersSchemaMixin):
)
is_managed_externally = fields.Boolean(allow_none=True, default=False)
external_url = fields.String(allow_none=True)
ssh_tunnel = fields.Nested(DatabaseSSHTunnel, allow_none=True)
class DatabaseTestConnectionSchema(Schema, DatabaseParametersSchemaMixin):
@ -482,6 +497,8 @@ class DatabaseTestConnectionSchema(Schema, DatabaseParametersSchemaMixin):
validate=[Length(1, 1024), sqlalchemy_uri_validator],
)
ssh_tunnel = fields.Nested(DatabaseSSHTunnel, allow_none=True)
class TableMetadataOptionsResponseSchema(Schema):
deferrable = fields.Bool()

View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@ -0,0 +1,92 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Dict, List, Optional
from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError
from superset.commands.base import BaseCommand
from superset.dao.exceptions import DAOCreateFailedError
from superset.databases.ssh_tunnel.commands.exceptions import (
SSHTunnelCreateFailedError,
SSHTunnelInvalidError,
SSHTunnelRequiredFieldValidationError,
)
from superset.databases.ssh_tunnel.dao import SSHTunnelDAO
from superset.extensions import db, event_logger
logger = logging.getLogger(__name__)
class CreateSSHTunnelCommand(BaseCommand):
def __init__(self, database_id: int, data: Dict[str, Any]):
self._properties = data.copy()
self._properties["database_id"] = database_id
def run(self) -> Model:
try:
# Start nested transaction since we are always creating the tunnel
# through a DB command (Create or Update). Without this, we cannot
# safely rollback changes to databases if any, i.e, things like
# test_do_not_create_database_if_ssh_tunnel_creation_fails test will fail
db.session.begin_nested()
self.validate()
tunnel = SSHTunnelDAO.create(self._properties, commit=False)
except DAOCreateFailedError as ex:
# Rollback nested transaction
db.session.rollback()
raise SSHTunnelCreateFailedError() from ex
except SSHTunnelInvalidError as ex:
# Rollback nested transaction
db.session.rollback()
raise ex
return tunnel
def validate(self) -> None:
# TODO(hughhh): check to make sure the server port is not localhost
# using the config.SSH_TUNNEL_MANAGER
exceptions: List[ValidationError] = []
database_id: Optional[int] = self._properties.get("database_id")
server_address: Optional[str] = self._properties.get("server_address")
server_port: Optional[int] = self._properties.get("server_port")
username: Optional[str] = self._properties.get("username")
private_key: Optional[str] = self._properties.get("private_key")
private_key_password: Optional[str] = self._properties.get(
"private_key_password"
)
if not database_id:
exceptions.append(SSHTunnelRequiredFieldValidationError("database_id"))
if not server_address:
exceptions.append(SSHTunnelRequiredFieldValidationError("server_address"))
if not server_port:
exceptions.append(SSHTunnelRequiredFieldValidationError("server_port"))
if not username:
exceptions.append(SSHTunnelRequiredFieldValidationError("username"))
if private_key_password and private_key is None:
exceptions.append(SSHTunnelRequiredFieldValidationError("private_key"))
if exceptions:
exception = SSHTunnelInvalidError()
exception.add_list(exceptions)
event_logger.log_with_context(
action="ssh_tunnel_creation_failed.{}.{}".format(
exception.__class__.__name__,
".".join(exception.get_list_classnames()),
)
)
raise exception

View File

@ -0,0 +1,51 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Optional
from flask_appbuilder.models.sqla import Model
from superset.commands.base import BaseCommand
from superset.dao.exceptions import DAODeleteFailedError
from superset.databases.ssh_tunnel.commands.exceptions import (
SSHTunnelDeleteFailedError,
SSHTunnelNotFoundError,
)
from superset.databases.ssh_tunnel.dao import SSHTunnelDAO
from superset.databases.ssh_tunnel.models import SSHTunnel
logger = logging.getLogger(__name__)
class DeleteSSHTunnelCommand(BaseCommand):
def __init__(self, model_id: int):
self._model_id = model_id
self._model: Optional[SSHTunnel] = None
def run(self) -> Model:
self.validate()
try:
ssh_tunnel = SSHTunnelDAO.delete(self._model)
except DAODeleteFailedError as ex:
raise SSHTunnelDeleteFailedError() from ex
return ssh_tunnel
def validate(self) -> None:
# Validate/populate model exists
self._model = SSHTunnelDAO.find_by_id(self._model_id)
if not self._model:
raise SSHTunnelNotFoundError()

View File

@ -0,0 +1,54 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from flask_babel import lazy_gettext as _
from marshmallow import ValidationError
from superset.commands.exceptions import (
CommandException,
CommandInvalidError,
DeleteFailedError,
UpdateFailedError,
)
class SSHTunnelDeleteFailedError(DeleteFailedError):
message = _("SSH Tunnel could not be deleted.")
class SSHTunnelNotFoundError(CommandException):
status = 404
message = _("SSH Tunnel not found.")
class SSHTunnelInvalidError(CommandInvalidError):
message = _("SSH Tunnel parameters are invalid.")
class SSHTunnelUpdateFailedError(UpdateFailedError):
message = _("SSH Tunnel could not be updated.")
class SSHTunnelCreateFailedError(CommandException):
message = _("Creating SSH Tunnel failed for an unknown reason")
class SSHTunnelRequiredFieldValidationError(ValidationError):
def __init__(self, field_name: str) -> None:
super().__init__(
[_("Field is required")],
field_name=field_name,
)

View File

@ -0,0 +1,62 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Dict, Optional
from flask_appbuilder.models.sqla import Model
from superset.commands.base import BaseCommand
from superset.dao.exceptions import DAOUpdateFailedError
from superset.databases.ssh_tunnel.commands.exceptions import (
SSHTunnelInvalidError,
SSHTunnelNotFoundError,
SSHTunnelRequiredFieldValidationError,
SSHTunnelUpdateFailedError,
)
from superset.databases.ssh_tunnel.dao import SSHTunnelDAO
from superset.databases.ssh_tunnel.models import SSHTunnel
logger = logging.getLogger(__name__)
class UpdateSSHTunnelCommand(BaseCommand):
def __init__(self, model_id: int, data: Dict[str, Any]):
self._properties = data.copy()
self._model_id = model_id
self._model: Optional[SSHTunnel] = None
def run(self) -> Model:
self.validate()
try:
tunnel = SSHTunnelDAO.update(self._model, self._properties)
except DAOUpdateFailedError as ex:
raise SSHTunnelUpdateFailedError() from ex
return tunnel
def validate(self) -> None:
# Validate/populate model exists
self._model = SSHTunnelDAO.find_by_id(self._model_id)
if not self._model:
raise SSHTunnelNotFoundError()
private_key: Optional[str] = self._properties.get("private_key")
private_key_password: Optional[str] = self._properties.get(
"private_key_password"
)
if private_key_password and private_key is None:
exception = SSHTunnelInvalidError()
exception.add(SSHTunnelRequiredFieldValidationError("private_key"))
raise exception

View File

@ -0,0 +1,26 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from superset.dao.base import BaseDAO
from superset.databases.ssh_tunnel.models import SSHTunnel
logger = logging.getLogger(__name__)
class SSHTunnelDAO(BaseDAO):
model_cls = SSHTunnel

View File

@ -0,0 +1,76 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict
import sqlalchemy as sa
from flask import current_app
from flask_appbuilder import Model
from sqlalchemy.orm import backref, relationship
from sqlalchemy_utils import EncryptedType
from superset.models.core import Database
from superset.models.helpers import (
AuditMixinNullable,
ExtraJSONMixin,
ImportExportMixin,
)
app_config = current_app.config
class SSHTunnel(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
"""
A ssh tunnel configuration in a database.
"""
__tablename__ = "ssh_tunnels"
id = sa.Column(sa.Integer, primary_key=True)
database_id = sa.Column(
sa.Integer, sa.ForeignKey("dbs.id"), nullable=False, unique=True
)
database: Database = relationship(
"Database",
backref=backref("ssh_tunnels", uselist=False, cascade="all, delete-orphan"),
foreign_keys=[database_id],
)
server_address = sa.Column(sa.Text)
server_port = sa.Column(sa.Integer)
username = sa.Column(EncryptedType(sa.String, app_config["SECRET_KEY"]))
# basic authentication
password = sa.Column(
EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True
)
# password protected pkey authentication
private_key = sa.Column(
EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True
)
private_key_password = sa.Column(
EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True
)
@property
def data(self) -> Dict[str, Any]:
return {
"server_address": self.server_address,
"server_port": self.server_port,
"username": self.username,
}

View File

@ -193,6 +193,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
engine_aliases: Set[str] = set()
drivers: Dict[str, str] = {}
default_driver: Optional[str] = None
allow_ssh_tunneling = False
_date_trunc_functions: Dict[str, str] = {}
_time_grain_expressions: Dict[Optional[str], str] = {}

View File

@ -166,6 +166,7 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
engine = "postgresql"
engine_aliases = {"postgres"}
allow_ssh_tunneling = True
default_driver = "psycopg2"
sqlalchemy_uri_placeholder = (

View File

@ -28,6 +28,7 @@ from flask_talisman import Talisman
from flask_wtf.csrf import CSRFProtect
from werkzeug.local import LocalProxy
from superset.extensions.ssh import SSHManagerFactory
from superset.utils.async_query_manager import AsyncQueryManager
from superset.utils.cache_manager import CacheManager
from superset.utils.encrypt import EncryptedFieldFactory
@ -127,3 +128,4 @@ profiling = ProfilingExtension()
results_backend_manager = ResultsBackendManager()
security_manager = LocalProxy(lambda: appbuilder.sm)
talisman = Talisman()
ssh_manager_factory = SSHManagerFactory()

View File

@ -0,0 +1,88 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import importlib
from typing import TYPE_CHECKING
from flask import Flask
from sshtunnel import open_tunnel, SSHTunnelForwarder
from superset.databases.utils import make_url_safe
if TYPE_CHECKING:
from superset.databases.ssh_tunnel.models import SSHTunnel
class SSHManager:
def __init__(self, app: Flask) -> None:
super().__init__()
self.local_bind_address = app.config["SSH_TUNNEL_LOCAL_BIND_ADDRESS"]
def build_sqla_url( # pylint: disable=no-self-use
self, sqlalchemy_url: str, server: SSHTunnelForwarder
) -> str:
# override any ssh tunnel configuration object
url = make_url_safe(sqlalchemy_url)
return url.set(
host=server.local_bind_address[0],
port=server.local_bind_port,
)
def create_tunnel(
self,
ssh_tunnel: "SSHTunnel",
sqlalchemy_database_uri: str,
) -> SSHTunnelForwarder:
url = make_url_safe(sqlalchemy_database_uri)
params = {
"ssh_address_or_host": ssh_tunnel.server_address,
"ssh_port": ssh_tunnel.server_port,
"ssh_username": ssh_tunnel.username,
"remote_bind_address": (url.host, url.port), # bind_port, bind_host
"local_bind_address": (self.local_bind_address,),
}
if ssh_tunnel.password:
params["ssh_password"] = ssh_tunnel.password
elif ssh_tunnel.private_key:
params["private_key"] = ssh_tunnel.private_key
params["private_key_password"] = ssh_tunnel.private_key_password
return open_tunnel(**params)
class SSHManagerFactory:
def __init__(self) -> None:
self._ssh_manager = None
def init_app(self, app: Flask) -> None:
ssh_manager_fqclass = app.config["SSH_TUNNEL_MANAGER_CLASS"]
ssh_manager_classname = ssh_manager_fqclass[
ssh_manager_fqclass.rfind(".") + 1 :
]
ssh_manager_module_name = ssh_manager_fqclass[
0 : ssh_manager_fqclass.rfind(".")
]
ssh_manager_class = getattr(
importlib.import_module(ssh_manager_module_name), ssh_manager_classname
)
self._ssh_manager = ssh_manager_class(app)
@property
def instance(self) -> SSHManager:
return self._ssh_manager # type: ignore

View File

@ -45,6 +45,7 @@ from superset.extensions import (
migrate,
profiling,
results_backend_manager,
ssh_manager_factory,
talisman,
)
from superset.security import SupersetSecurityManager
@ -417,6 +418,7 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
self.configure_data_sources()
self.configure_auth_provider()
self.configure_async_queries()
self.configure_ssh_manager()
# Hook that provides administrators a handle on the Flask APP
# after initialization
@ -474,6 +476,9 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
def configure_auth_provider(self) -> None:
machine_auth_provider_factory.init_app(self.superset_app)
def configure_ssh_manager(self) -> None:
ssh_manager_factory.init_app(self.superset_app)
def setup_event_logger(self) -> None:
_event_logger["event_logger"] = get_event_logger_from_cfg_value(
self.superset_app.config.get("EVENT_LOGGER", DBEventLogger())

View File

@ -0,0 +1,89 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""create_ssh_tunnel_credentials_tbl
Revision ID: f3c2d8ec8595
Revises: 4ce1d9b25135
Create Date: 2022-10-20 10:48:08.722861
"""
# revision identifiers, used by Alembic.
revision = "f3c2d8ec8595"
down_revision = "4ce1d9b25135"
from uuid import uuid4
import sqlalchemy as sa
from alembic import op
from sqlalchemy_utils import UUIDType
from superset import app
from superset.extensions import encrypted_field_factory
app_config = app.config
def upgrade():
op.create_table(
"ssh_tunnels",
# AuditMixinNullable
sa.Column("created_on", sa.DateTime(), nullable=True),
sa.Column("changed_on", sa.DateTime(), nullable=True),
sa.Column("created_by_fk", sa.Integer(), nullable=True),
sa.Column("changed_by_fk", sa.Integer(), nullable=True),
# ExtraJSONMixin
sa.Column("extra_json", sa.Text(), nullable=True),
# ImportExportMixin
sa.Column(
"uuid",
UUIDType(binary=True),
primary_key=False,
default=uuid4,
unique=True,
index=True,
),
# SSHTunnelCredentials
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column(
"database_id",
sa.INTEGER(),
sa.ForeignKey("dbs.id"),
unique=True,
index=True,
),
sa.Column("server_address", sa.String(256)),
sa.Column("server_port", sa.INTEGER()),
sa.Column("username", encrypted_field_factory.create(sa.String(256))),
sa.Column(
"password", encrypted_field_factory.create(sa.String(256)), nullable=True
),
sa.Column(
"private_key",
encrypted_field_factory.create(sa.String(1024)),
nullable=True,
),
sa.Column(
"private_key_password",
encrypted_field_factory.create(sa.String(256)),
nullable=True,
),
)
def downgrade():
op.drop_table("ssh_tunnels")

View File

@ -21,10 +21,10 @@ import json
import logging
import textwrap
from ast import literal_eval
from contextlib import closing, contextmanager
from contextlib import closing, contextmanager, nullcontext
from copy import deepcopy
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TYPE_CHECKING
import numpy
import pandas as pd
@ -57,7 +57,12 @@ from superset import app, db_engine_specs
from superset.constants import PASSWORD_MASK
from superset.databases.utils import make_url_safe
from superset.db_engine_specs.base import MetricType, TimeGrain
from superset.extensions import cache_manager, encrypted_field_factory, security_manager
from superset.extensions import (
cache_manager,
encrypted_field_factory,
security_manager,
ssh_manager_factory,
)
from superset.models.helpers import AuditMixinNullable, ImportExportMixin
from superset.result_set import SupersetResultSet
from superset.utils import cache as cache_util, core as utils
@ -71,6 +76,9 @@ log_query = config["QUERY_LOGGER"]
metadata = Model.metadata # pylint: disable=no-member
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from superset.databases.ssh_tunnel.models import SSHTunnel
DB_CONNECTION_MUTATOR = config["DB_CONNECTION_MUTATOR"]
@ -373,17 +381,48 @@ class Database(
schema: Optional[str] = None,
nullpool: bool = True,
source: Optional[utils.QuerySource] = None,
override_ssh_tunnel: Optional["SSHTunnel"] = None,
) -> Engine:
yield self._get_sqla_engine(schema=schema, nullpool=nullpool, source=source)
from superset.databases.dao import ( # pylint: disable=import-outside-toplevel
DatabaseDAO,
)
sqlalchemy_uri = self.sqlalchemy_uri_decrypted
engine_context = nullcontext()
ssh_tunnel = override_ssh_tunnel or DatabaseDAO.get_ssh_tunnel(
database_id=self.id
)
if ssh_tunnel:
# if ssh_tunnel is available build engine with information
engine_context = ssh_manager_factory.instance.create_tunnel(
ssh_tunnel=ssh_tunnel,
sqlalchemy_database_uri=self.sqlalchemy_uri_decrypted,
)
with engine_context as server_context:
if ssh_tunnel:
sqlalchemy_uri = ssh_manager_factory.instance.build_sqla_url(
sqlalchemy_uri, server_context
)
yield self._get_sqla_engine(
schema=schema,
nullpool=nullpool,
source=source,
sqlalchemy_uri=sqlalchemy_uri,
)
def _get_sqla_engine(
self,
schema: Optional[str] = None,
nullpool: bool = True,
source: Optional[utils.QuerySource] = None,
sqlalchemy_uri: Optional[str] = None,
) -> Engine:
extra = self.get_extra()
sqlalchemy_url = make_url_safe(self.sqlalchemy_uri_decrypted)
sqlalchemy_url = make_url_safe(
sqlalchemy_uri if sqlalchemy_uri else self.sqlalchemy_uri_decrypted
)
sqlalchemy_url = self.db_engine_spec.adjust_database_uri(sqlalchemy_url, schema)
effective_username = self.get_effective_user(sqlalchemy_url)
# If using MySQL or Presto for example, will set url.username
@ -423,7 +462,6 @@ class Database(
sqlalchemy_url, params = DB_CONNECTION_MUTATOR(
sqlalchemy_url, params, effective_username, security_manager, source
)
try:
return create_engine(sqlalchemy_url, **params)
except Exception as ex:
@ -477,7 +515,7 @@ class Database(
security_manager,
)
with closing(engine.raw_connection()) as conn:
with self.get_raw_connection(schema=schema) as conn:
cursor = conn.cursor()
for sql_ in sqls[:-1]:
_log_query(sql_)
@ -574,14 +612,16 @@ class Database(
:return: The table/schema pairs
"""
try:
return {
(table, schema)
for table in self.db_engine_spec.get_table_names(
database=self,
inspector=self.inspector,
schema=schema,
)
}
with self.get_inspector_with_context() as inspector:
tables = {
(table, schema)
for table in self.db_engine_spec.get_table_names(
database=self,
inspector=inspector,
schema=schema,
)
}
return tables
except Exception as ex:
raise self.db_engine_spec.get_dbapi_mapped_exception(ex)
@ -608,17 +648,27 @@ class Database(
:return: set of views
"""
try:
return {
(view, schema)
for view in self.db_engine_spec.get_view_names(
database=self,
inspector=self.inspector,
schema=schema,
)
}
with self.get_inspector_with_context() as inspector:
return {
(view, schema)
for view in self.db_engine_spec.get_view_names(
database=self,
inspector=inspector,
schema=schema,
)
}
except Exception as ex:
raise self.db_engine_spec.get_dbapi_mapped_exception(ex)
@contextmanager
def get_inspector_with_context(
self, ssh_tunnel: Optional["SSHTunnel"] = None
) -> Inspector:
with self.get_sqla_engine_with_context(
override_ssh_tunnel=ssh_tunnel
) as engine:
yield sqla.inspect(engine)
@cache_util.memoized_func(
key="db:{self.id}:schema_list",
cache=cache_manager.cache,
@ -628,6 +678,7 @@ class Database(
cache: bool = False,
cache_timeout: Optional[int] = None,
force: bool = False,
ssh_tunnel: Optional["SSHTunnel"] = None,
) -> List[str]:
"""Parameters need to be passed as keyword arguments.
@ -640,7 +691,8 @@ class Database(
:return: schema list
"""
try:
return self.db_engine_spec.get_schema_names(self.inspector)
with self.get_inspector_with_context(ssh_tunnel=ssh_tunnel) as inspector:
return self.db_engine_spec.get_schema_names(inspector)
except Exception as ex:
raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
@ -703,43 +755,49 @@ class Database(
def get_table_comment(
self, table_name: str, schema: Optional[str] = None
) -> Optional[str]:
return self.db_engine_spec.get_table_comment(self.inspector, table_name, schema)
with self.get_inspector_with_context() as inspector:
return self.db_engine_spec.get_table_comment(inspector, table_name, schema)
def get_columns(
self, table_name: str, schema: Optional[str] = None
) -> List[Dict[str, Any]]:
return self.db_engine_spec.get_columns(self.inspector, table_name, schema)
with self.get_inspector_with_context() as inspector:
return self.db_engine_spec.get_columns(inspector, table_name, schema)
def get_metrics(
self,
table_name: str,
schema: Optional[str] = None,
) -> List[MetricType]:
return self.db_engine_spec.get_metrics(self, self.inspector, table_name, schema)
with self.get_inspector_with_context() as inspector:
return self.db_engine_spec.get_metrics(self, inspector, table_name, schema)
def get_indexes(
self, table_name: str, schema: Optional[str] = None
) -> List[Dict[str, Any]]:
indexes = self.inspector.get_indexes(table_name, schema)
return self.db_engine_spec.normalize_indexes(indexes)
with self.get_inspector_with_context() as inspector:
indexes = inspector.get_indexes(table_name, schema)
return self.db_engine_spec.normalize_indexes(indexes)
def get_pk_constraint(
self, table_name: str, schema: Optional[str] = None
) -> Dict[str, Any]:
pk_constraint = self.inspector.get_pk_constraint(table_name, schema) or {}
with self.get_inspector_with_context() as inspector:
pk_constraint = inspector.get_pk_constraint(table_name, schema) or {}
def _convert(value: Any) -> Any:
try:
return utils.base_json_conv(value)
except TypeError:
return None
def _convert(value: Any) -> Any:
try:
return utils.base_json_conv(value)
except TypeError:
return None
return {key: _convert(value) for key, value in pk_constraint.items()}
return {key: _convert(value) for key, value in pk_constraint.items()}
def get_foreign_keys(
self, table_name: str, schema: Optional[str] = None
) -> List[Dict[str, Any]]:
return self.inspector.get_foreign_keys(table_name, schema)
with self.get_inspector_with_context() as inspector:
return inspector.get_foreign_keys(table_name, schema)
def get_schema_access_for_file_upload( # pylint: disable=invalid-name
self,

View File

@ -0,0 +1,30 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict
from superset.constants import PASSWORD_MASK
def mask_password_info(ssh_tunnel: Dict[str, Any]) -> Dict[str, Any]:
if ssh_tunnel.pop("password", None) is not None:
ssh_tunnel["password"] = PASSWORD_MASK
if ssh_tunnel.pop("private_key", None) is not None:
ssh_tunnel["private_key"] = PASSWORD_MASK
if ssh_tunnel.pop("private_key_password", None) is not None:
ssh_tunnel["private_key_password"] = PASSWORD_MASK
return ssh_tunnel

View File

@ -28,7 +28,7 @@ from __future__ import annotations
from typing import Callable, TYPE_CHECKING
from unittest.mock import MagicMock, Mock, PropertyMock
from flask import Flask
from flask import current_app, Flask
from flask.ctx import AppContext
from pytest import fixture
@ -40,6 +40,7 @@ from tests.example_data.data_loading.pandas.pands_data_loading_conf import (
from tests.example_data.data_loading.pandas.table_df_convertor import (
TableToDfConvertorImpl,
)
from tests.integration_tests.test_app import app
SUPPORT_DATETIME_TYPE = "support_datetime_type"
@ -70,8 +71,9 @@ def example_db_provider() -> Callable[[], Database]:
@fixture(scope="session")
def example_db_engine(example_db_provider: Callable[[], Database]) -> Engine:
with example_db_provider().get_sqla_engine_with_context() as engine:
return engine
with app.app_context():
with example_db_provider().get_sqla_engine_with_context() as engine:
return engine
@fixture(scope="session")

View File

@ -35,6 +35,8 @@ from sqlalchemy.sql import func
from superset import db, security_manager
from superset.connectors.sqla.models import SqlaTable
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.databases.utils import make_url_safe
from superset.db_engine_specs.mysql import MySQLEngineSpec
from superset.db_engine_specs.postgres import PostgresEngineSpec
from superset.db_engine_specs.redshift import RedshiftEngineSpec
@ -280,6 +282,314 @@ class TestDatabaseApi(SupersetTestCase):
db.session.delete(model)
db.session.commit()
@mock.patch(
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
def test_create_database_with_ssh_tunnel(
self, mock_test_connection_database_command_run, mock_get_all_schema_names
):
"""
Database API: Test create with SSH Tunnel
"""
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
database_data = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": ssh_tunnel_properties,
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one()
)
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
db.session.commit()
@mock.patch(
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
def test_update_database_with_ssh_tunnel(
self, mock_test_connection_database_command_run, mock_get_all_schema_names
):
"""
Database API: Test update with SSH Tunnel
"""
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
database_data = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
}
database_data_with_ssh_tunnel = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": ssh_tunnel_properties,
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
uri = "api/v1/database/{}".format(response.get("id"))
rv = self.client.put(uri, json=database_data_with_ssh_tunnel)
response_update = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response_update.get("id"))
.one()
)
self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id"))
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
db.session.commit()
@mock.patch(
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
def test_update_ssh_tunnel_via_database_api(
self, mock_test_connection_database_command_run, mock_get_all_schema_names
):
"""
Database API: Test update with SSH Tunnel
"""
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
initial_ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
updated_ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "Test",
}
database_data_with_ssh_tunnel = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": initial_ssh_tunnel_properties,
}
database_data_with_ssh_tunnel_update = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": updated_ssh_tunnel_properties,
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data_with_ssh_tunnel)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one()
)
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
self.assertEqual(model_ssh_tunnel.username, "foo")
uri = "api/v1/database/{}".format(response.get("id"))
rv = self.client.put(uri, json=database_data_with_ssh_tunnel_update)
response_update = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response_update.get("id"))
.one()
)
self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id"))
self.assertEqual(model_ssh_tunnel.username, "Test")
self.assertEqual(model_ssh_tunnel.server_address, "123.132.123.1")
self.assertEqual(model_ssh_tunnel.server_port, 8080)
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
db.session.commit()
@mock.patch(
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
def test_cascade_delete_ssh_tunnel(
self, mock_test_connection_database_command_run, mock_get_all_schema_names
):
"""
Database API: Test create with SSH Tunnel
"""
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
database_data = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": ssh_tunnel_properties,
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one()
)
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
db.session.commit()
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one_or_none()
)
assert model_ssh_tunnel is None
@mock.patch(
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
def test_do_not_create_database_if_ssh_tunnel_creation_fails(
self, mock_test_connection_database_command_run, mock_get_all_schema_names
):
"""
Database API: Test create with SSH Tunnel
"""
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
ssh_tunnel_properties = {
"server_address": "123.132.123.1",
}
database_data = {
"database_name": "test-db-failure-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": ssh_tunnel_properties,
}
fail_message = {"message": "SSH Tunnel parameters are invalid."}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 422)
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one_or_none()
)
assert model_ssh_tunnel is None
self.assertEqual(response, fail_message)
# Cleanup
model = (
db.session.query(Database)
.filter(Database.database_name == "test-db-failure-ssh-tunnel")
.one_or_none()
)
# the DB should not be created
assert model is None
@mock.patch(
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
def test_get_database_returns_related_ssh_tunnel(
self, mock_test_connection_database_command_run, mock_get_all_schema_names
):
"""
Database API: Test GET Database returns its related SSH Tunnel
"""
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
database_data = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": ssh_tunnel_properties,
}
response_ssh_tunnel = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "XXXXXXXXXX",
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one()
)
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
self.assertEqual(response.get("result")["ssh_tunnel"], response_ssh_tunnel)
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
db.session.commit()
def test_create_database_invalid_configuration_method(self):
"""
Database API: Test create with an invalid configuration method.

View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@ -0,0 +1,76 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest import mock, skip
from unittest.mock import patch
import pytest
from superset import security_manager
from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand
from superset.databases.ssh_tunnel.commands.delete import DeleteSSHTunnelCommand
from superset.databases.ssh_tunnel.commands.exceptions import (
SSHTunnelInvalidError,
SSHTunnelNotFoundError,
)
from superset.databases.ssh_tunnel.commands.update import UpdateSSHTunnelCommand
from tests.integration_tests.base_tests import SupersetTestCase
class TestCreateSSHTunnelCommand(SupersetTestCase):
@mock.patch("superset.utils.core.g")
def test_create_invalid_database_id(self, mock_g):
mock_g.user = security_manager.find_user("admin")
command = CreateSSHTunnelCommand(
None,
{
"server_address": "127.0.0.1",
"server_port": 5432,
"username": "test_user",
},
)
with pytest.raises(SSHTunnelInvalidError) as excinfo:
command.run()
assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.")
class TestUpdateSSHTunnelCommand(SupersetTestCase):
@mock.patch("superset.utils.core.g")
def test_update_ssh_tunnel_not_found(self, mock_g):
mock_g.user = security_manager.find_user("admin")
# We have not created a SSH Tunnel yet so id = 1 is invalid
command = UpdateSSHTunnelCommand(
1,
{
"server_address": "127.0.0.1",
"server_port": 5432,
"username": "test_user",
},
)
with pytest.raises(SSHTunnelNotFoundError) as excinfo:
command.run()
assert str(excinfo.value) == ("SSH Tunnel not found.")
class TestDeleteSSHTunnelCommand(SupersetTestCase):
@mock.patch("superset.utils.core.g")
def test_delete_ssh_tunnel_not_found(self, mock_g):
mock_g.user = security_manager.find_user("admin")
# We have not created a SSH Tunnel yet so id = 1 is invalid
command = DeleteSSHTunnelCommand(1)
with pytest.raises(SSHTunnelNotFoundError) as excinfo:
command.run()
assert str(excinfo.value) == ("SSH Tunnel not found.")

View File

@ -191,3 +191,147 @@ def test_non_zip_import(client: Any, full_api_access: None) -> None:
}
]
}
def test_delete_ssh_tunnel(
mocker: MockFixture,
app: Any,
session: Session,
client: Any,
full_api_access: None,
) -> None:
"""
Test that we can delete SSH Tunnel
"""
with app.app_context():
from superset.databases.api import DatabaseRestApi
from superset.databases.dao import DatabaseDAO
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
DatabaseRestApi.datamodel.session = session
# create table for databases
Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member
# Create our Database
database = Database(
database_name="my_database",
sqlalchemy_uri="gsheets://",
encrypted_extra=json.dumps(
{
"service_account_info": {
"type": "service_account",
"project_id": "black-sanctum-314419",
"private_key_id": "259b0d419a8f840056158763ff54d8b08f7b8173",
"private_key": "SECRET",
"client_email": "google-spreadsheets-demo-servi@black-sanctum-314419.iam.gserviceaccount.com",
"client_id": "SSH_TUNNEL_CREDENTIALS_CLIENT",
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token",
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/google-spreadsheets-demo-servi%40black-sanctum-314419.iam.gserviceaccount.com",
},
}
),
)
session.add(database)
session.commit()
# mock the lookup so that we don't need to include the driver
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
mocker.patch("superset.utils.log.DBEventLogger.log")
# Create our SSHTunnel
tunnel = SSHTunnel(
database_id=1,
database=database,
)
session.add(tunnel)
session.commit()
# Get our recently created SSHTunnel
response_tunnel = DatabaseDAO.get_ssh_tunnel(1)
assert response_tunnel
assert isinstance(response_tunnel, SSHTunnel)
assert 1 == response_tunnel.database_id
# Delete the recently created SSHTunnel
response_delete_tunnel = client.delete("/api/v1/database/1/ssh_tunnel/")
assert response_delete_tunnel.json["message"] == "OK"
response_tunnel = DatabaseDAO.get_ssh_tunnel(1)
assert response_tunnel is None
def test_delete_ssh_tunnel_not_found(
mocker: MockFixture,
app: Any,
session: Session,
client: Any,
full_api_access: None,
) -> None:
"""
Test that we cannot delete a tunnel that does not exist
"""
with app.app_context():
from superset.databases.api import DatabaseRestApi
from superset.databases.dao import DatabaseDAO
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
DatabaseRestApi.datamodel.session = session
# create table for databases
Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member
# Create our Database
database = Database(
database_name="my_database",
sqlalchemy_uri="gsheets://",
encrypted_extra=json.dumps(
{
"service_account_info": {
"type": "service_account",
"project_id": "black-sanctum-314419",
"private_key_id": "259b0d419a8f840056158763ff54d8b08f7b8173",
"private_key": "SECRET",
"client_email": "google-spreadsheets-demo-servi@black-sanctum-314419.iam.gserviceaccount.com",
"client_id": "SSH_TUNNEL_CREDENTIALS_CLIENT",
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token",
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/google-spreadsheets-demo-servi%40black-sanctum-314419.iam.gserviceaccount.com",
},
}
),
)
session.add(database)
session.commit()
# mock the lookup so that we don't need to include the driver
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
mocker.patch("superset.utils.log.DBEventLogger.log")
# Create our SSHTunnel
tunnel = SSHTunnel(
database_id=1,
database=database,
)
session.add(tunnel)
session.commit()
# Delete the recently created SSHTunnel
response_delete_tunnel = client.delete("/api/v1/database/2/ssh_tunnel/")
assert response_delete_tunnel.json["message"] == "Not found"
# Get our recently created SSHTunnel
response_tunnel = DatabaseDAO.get_ssh_tunnel(1)
assert response_tunnel
assert isinstance(response_tunnel, SSHTunnel)
assert 1 == response_tunnel.database_id
response_tunnel = DatabaseDAO.get_ssh_tunnel(2)
assert response_tunnel is None

View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@ -0,0 +1,69 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Iterator
import pytest
from sqlalchemy.orm.session import Session
@pytest.fixture
def session_with_data(session: Session) -> Iterator[Session]:
from superset.connectors.sqla.models import SqlaTable
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
engine = session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
sqla_table = SqlaTable(
table_name="my_sqla_table",
columns=[],
metrics=[],
database=db,
)
ssh_tunnel = SSHTunnel(
database_id=db.id,
database=db,
)
session.add(db)
session.add(sqla_table)
session.add(ssh_tunnel)
session.flush()
yield session
session.rollback()
def test_database_get_ssh_tunnel(session_with_data: Session) -> None:
from superset.databases.dao import DatabaseDAO
from superset.databases.ssh_tunnel.models import SSHTunnel
result = DatabaseDAO.get_ssh_tunnel(1)
assert result
assert isinstance(result, SSHTunnel)
assert 1 == result.database_id
def test_database_get_ssh_tunnel_not_found(session_with_data: Session) -> None:
from superset.databases.dao import DatabaseDAO
result = DatabaseDAO.get_ssh_tunnel(2)
assert result is None

View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@ -0,0 +1,68 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Iterator
import pytest
from sqlalchemy.orm.session import Session
from superset.databases.ssh_tunnel.commands.exceptions import SSHTunnelInvalidError
def test_create_ssh_tunnel_command() -> None:
from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
properties = {
"database_id": db.id,
"server_address": "123.132.123.1",
"server_port": "3005",
"username": "foo",
"password": "bar",
}
result = CreateSSHTunnelCommand(db.id, properties).run()
assert result is not None
assert isinstance(result, SSHTunnel)
def test_create_ssh_tunnel_command_invalid_params() -> None:
from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
# If we are trying to create a tunnel with a private_key_password
# then a private_key is mandatory
properties = {
"database_id": db.id,
"server_address": "123.132.123.1",
"server_port": "3005",
"username": "foo",
"private_key_password": "bar",
}
command = CreateSSHTunnelCommand(db.id, properties)
with pytest.raises(SSHTunnelInvalidError) as excinfo:
command.run()
assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.")

View File

@ -0,0 +1,68 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Iterator
import pytest
from sqlalchemy.orm.session import Session
@pytest.fixture
def session_with_data(session: Session) -> Iterator[Session]:
from superset.connectors.sqla.models import SqlaTable
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
engine = session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
sqla_table = SqlaTable(
table_name="my_sqla_table",
columns=[],
metrics=[],
database=db,
)
ssh_tunnel = SSHTunnel(
database_id=db.id,
database=db,
)
session.add(db)
session.add(sqla_table)
session.add(ssh_tunnel)
session.flush()
yield session
session.rollback()
def test_delete_ssh_tunnel_command(session_with_data: Session) -> None:
from superset.databases.dao import DatabaseDAO
from superset.databases.ssh_tunnel.commands.delete import DeleteSSHTunnelCommand
from superset.databases.ssh_tunnel.models import SSHTunnel
result = DatabaseDAO.get_ssh_tunnel(1)
assert result
assert isinstance(result, SSHTunnel)
assert 1 == result.database_id
DeleteSSHTunnelCommand(1).run()
result = DatabaseDAO.get_ssh_tunnel(1)
assert result is None

View File

@ -0,0 +1,93 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Iterator
import pytest
from sqlalchemy.orm.session import Session
from superset.databases.ssh_tunnel.commands.exceptions import SSHTunnelInvalidError
@pytest.fixture
def session_with_data(session: Session) -> Iterator[Session]:
from superset.connectors.sqla.models import SqlaTable
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
engine = session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
sqla_table = SqlaTable(
table_name="my_sqla_table",
columns=[],
metrics=[],
database=db,
)
ssh_tunnel = SSHTunnel(database_id=db.id, database=db, server_address="Test")
session.add(db)
session.add(sqla_table)
session.add(ssh_tunnel)
session.flush()
yield session
session.rollback()
def test_update_shh_tunnel_command(session_with_data: Session) -> None:
from superset.databases.dao import DatabaseDAO
from superset.databases.ssh_tunnel.commands.update import UpdateSSHTunnelCommand
from superset.databases.ssh_tunnel.models import SSHTunnel
result = DatabaseDAO.get_ssh_tunnel(1)
assert result
assert isinstance(result, SSHTunnel)
assert 1 == result.database_id
assert "Test" == result.server_address
update_payload = {"server_address": "Test2"}
UpdateSSHTunnelCommand(1, update_payload).run()
result = DatabaseDAO.get_ssh_tunnel(1)
assert result
assert isinstance(result, SSHTunnel)
assert "Test2" == result.server_address
def test_update_shh_tunnel_invalid_params(session_with_data: Session) -> None:
from superset.databases.dao import DatabaseDAO
from superset.databases.ssh_tunnel.commands.update import UpdateSSHTunnelCommand
from superset.databases.ssh_tunnel.models import SSHTunnel
result = DatabaseDAO.get_ssh_tunnel(1)
assert result
assert isinstance(result, SSHTunnel)
assert 1 == result.database_id
assert "Test" == result.server_address
# If we are trying to update a tunnel with a private_key_password
# then a private_key is mandatory
update_payload = {"private_key_password": "pass"}
command = UpdateSSHTunnelCommand(1, update_payload)
with pytest.raises(SSHTunnelInvalidError) as excinfo:
command.run()
assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.")

View File

@ -0,0 +1,43 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Iterator
import pytest
from sqlalchemy.orm.session import Session
def test_create_ssh_tunnel():
from superset.databases.dao import DatabaseDAO
from superset.databases.ssh_tunnel.dao import SSHTunnelDAO
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
properties = {
"database_id": db.id,
"server_address": "123.132.123.1",
"server_port": "3005",
"username": "foo",
"password": "bar",
}
result = SSHTunnelDAO.create(properties)
assert result is not None
assert isinstance(result, SSHTunnel)