From ebaad10d6ce72fa9d939833720b44880d5139bb9 Mon Sep 17 00:00:00 2001 From: "Hugh A. Miles II" Date: Tue, 3 Jan 2023 17:22:42 -0500 Subject: [PATCH] feat(ssh-tunnelling): Setup SSH Tunneling Commands for Database Connections (#21912) Co-authored-by: Antonio Rivero Martinez <38889534+Antonio-RiveroMartnez@users.noreply.github.com> Co-authored-by: Elizabeth Thompson --- requirements/base.txt | 13 +- setup.py | 1 + superset/config.py | 24 +- superset/constants.py | 1 + superset/databases/api.py | 111 +++++++ superset/databases/commands/create.py | 30 +- .../databases/commands/test_connection.py | 9 +- superset/databases/commands/update.py | 35 +- superset/databases/dao.py | 11 + superset/databases/schemas.py | 17 + superset/databases/ssh_tunnel/__init__.py | 16 + .../databases/ssh_tunnel/commands/__init__.py | 16 + .../databases/ssh_tunnel/commands/create.py | 92 ++++++ .../databases/ssh_tunnel/commands/delete.py | 51 +++ .../ssh_tunnel/commands/exceptions.py | 54 +++ .../databases/ssh_tunnel/commands/update.py | 62 ++++ superset/databases/ssh_tunnel/dao.py | 26 ++ superset/databases/ssh_tunnel/models.py | 76 +++++ superset/db_engine_specs/base.py | 1 + superset/db_engine_specs/postgres.py | 1 + superset/extensions/__init__.py | 2 + superset/extensions/ssh.py | 88 +++++ superset/initialization/__init__.py | 5 + ...c8595_create_ssh_tunnel_credentials_tbl.py | 89 +++++ superset/models/core.py | 132 +++++--- superset/utils/ssh_tunnel.py | 30 ++ tests/conftest.py | 8 +- .../integration_tests/databases/api_tests.py | 310 ++++++++++++++++++ .../databases/ssh_tunnel/__init__.py | 16 + .../databases/ssh_tunnel/commands/__init__.py | 16 + .../ssh_tunnel/commands/commands_tests.py | 76 +++++ tests/unit_tests/databases/api_test.py | 144 ++++++++ tests/unit_tests/databases/dao/__init__.py | 16 + tests/unit_tests/databases/dao/dao_tests.py | 69 ++++ .../databases/ssh_tunnel/__init__.py | 16 + .../databases/ssh_tunnel/commands/__init__.py | 16 + .../ssh_tunnel/commands/create_test.py | 68 ++++ .../ssh_tunnel/commands/delete_test.py | 68 ++++ .../ssh_tunnel/commands/update_test.py | 93 ++++++ .../databases/ssh_tunnel/dao_tests.py | 43 +++ 40 files changed, 1905 insertions(+), 47 deletions(-) create mode 100644 superset/databases/ssh_tunnel/__init__.py create mode 100644 superset/databases/ssh_tunnel/commands/__init__.py create mode 100644 superset/databases/ssh_tunnel/commands/create.py create mode 100644 superset/databases/ssh_tunnel/commands/delete.py create mode 100644 superset/databases/ssh_tunnel/commands/exceptions.py create mode 100644 superset/databases/ssh_tunnel/commands/update.py create mode 100644 superset/databases/ssh_tunnel/dao.py create mode 100644 superset/databases/ssh_tunnel/models.py create mode 100644 superset/extensions/ssh.py create mode 100644 superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py create mode 100644 superset/utils/ssh_tunnel.py create mode 100644 tests/integration_tests/databases/ssh_tunnel/__init__.py create mode 100644 tests/integration_tests/databases/ssh_tunnel/commands/__init__.py create mode 100644 tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py create mode 100644 tests/unit_tests/databases/dao/__init__.py create mode 100644 tests/unit_tests/databases/dao/dao_tests.py create mode 100644 tests/unit_tests/databases/ssh_tunnel/__init__.py create mode 100644 tests/unit_tests/databases/ssh_tunnel/commands/__init__.py create mode 100644 tests/unit_tests/databases/ssh_tunnel/commands/create_test.py create mode 100644 tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py create mode 100644 tests/unit_tests/databases/ssh_tunnel/commands/update_test.py create mode 100644 tests/unit_tests/databases/ssh_tunnel/dao_tests.py diff --git a/requirements/base.txt b/requirements/base.txt index 3f494c6b9e..4b7363ca18 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -19,6 +19,8 @@ babel==2.9.1 # via flask-babel backoff==1.11.1 # via apache-superset +bcrypt==4.0.1 + # via paramiko billiard==3.6.4.0 # via celery bleach==3.3.1 @@ -57,7 +59,9 @@ cron-descriptor==1.2.24 croniter==1.0.15 # via apache-superset cryptography==3.4.7 - # via apache-superset + # via + # apache-superset + # paramiko deprecation==2.1.0 # via apache-superset dnspython==2.1.0 @@ -167,6 +171,8 @@ packaging==21.3 # deprecation pandas==1.5.2 # via apache-superset +paramiko==2.11.0 + # via sshtunnel parsedatetime==2.6 # via apache-superset pgsanity==0.2.9 @@ -188,6 +194,8 @@ pyjwt==2.4.0 # flask-jwt-extended pymeeus==0.5.11 # via convertdate +pynacl==1.5.0 + # via paramiko pyparsing==3.0.6 # via # apache-superset @@ -231,6 +239,7 @@ six==1.16.0 # flask-talisman # isodate # jsonschema + # paramiko # polyline # prison # pyrsistent @@ -252,6 +261,8 @@ sqlalchemy-utils==0.38.3 # flask-appbuilder sqlparse==0.4.3 # via apache-superset +sshtunnel==0.4.0 + # via apache-superset tabulate==0.8.9 # via apache-superset typing-extensions==4.4.0 diff --git a/setup.py b/setup.py index 95178a879a..3e017fe263 100644 --- a/setup.py +++ b/setup.py @@ -113,6 +113,7 @@ setup( "PyJWT>=2.4.0, <3.0", "redis", "selenium>=3.141.0", + "sshtunnel>=0.4.0, <0.5", "simplejson>=3.15.0", "slack_sdk>=3.1.1, <4", "sqlalchemy>=1.4, <2", diff --git a/superset/config.py b/superset/config.py index 948e234e4c..6760a0af72 100644 --- a/superset/config.py +++ b/superset/config.py @@ -476,8 +476,30 @@ DEFAULT_FEATURE_FLAGS: Dict[str, bool] = { "DRILL_TO_DETAIL": False, "DATAPANEL_CLOSED_BY_DEFAULT": False, "HORIZONTAL_FILTER_BAR": False, + # Allow users to enable ssh tunneling when creating a DB. + # Users must check whether the DB engine supports SSH Tunnels + # otherwise enabling this flag won't have any effect on the DB. + "SSH_TUNNELING": False, } +# ------------------------------ +# SSH Tunnel +# ------------------------------ +# Allow users to set the host used when connecting to the SSH Tunnel +# as localhost and any other alias (0.0.0.0) +# ---------------------------------------------------------------------- +# | +# -------------+ | +----------+ +# LOCAL | | | REMOTE | :22 SSH +# CLIENT | <== SSH ========> | SERVER | :8080 web service +# -------------+ | +----------+ +# | +# FIREWALL (only port 22 is open) + +# ---------------------------------------------------------------------- +SSH_TUNNEL_MANAGER_CLASS = "superset.extensions.ssh.SSHManager" +SSH_TUNNEL_LOCAL_BIND_ADDRESS = "127.0.0.1" + # Feature flags may also be set via 'SUPERSET_FEATURE_' prefixed environment vars. DEFAULT_FEATURE_FLAGS.update( { @@ -1506,7 +1528,7 @@ elif importlib.util.find_spec("superset_config") and not is_test(): try: # pylint: disable=import-error,wildcard-import,unused-wildcard-import import superset_config - from superset_config import * # type:ignore + from superset_config import * # type: ignore print(f"Loaded your LOCAL configuration at [{superset_config.__file__}]") except Exception: diff --git a/superset/constants.py b/superset/constants.py index 5091d65a43..ea7920ff2f 100644 --- a/superset/constants.py +++ b/superset/constants.py @@ -139,6 +139,7 @@ MODEL_API_RW_METHOD_PERMISSION_MAP = { "validate_sql": "read", "get_data": "read", "samples": "read", + "delete_ssh_tunnel": "write", } EXTRA_FORM_DATA_APPEND_KEYS = { diff --git a/superset/databases/api.py b/superset/databases/api.py index 3f737ec4da..1c75204f79 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -72,6 +72,11 @@ from superset.databases.schemas import ( ValidateSQLRequest, ValidateSQLResponse, ) +from superset.databases.ssh_tunnel.commands.delete import DeleteSSHTunnelCommand +from superset.databases.ssh_tunnel.commands.exceptions import ( + SSHTunnelDeleteFailedError, + SSHTunnelNotFoundError, +) from superset.databases.utils import get_table_metadata from superset.db_engine_specs import get_available_engine_specs from superset.errors import ErrorLevel, SupersetError, SupersetErrorType @@ -80,6 +85,7 @@ from superset.extensions import security_manager from superset.models.core import Database from superset.superset_typing import FlaskResponse from superset.utils.core import error_msg_from_exception, parse_js_uri_path_item +from superset.utils.ssh_tunnel import mask_password_info from superset.views.base import json_errors_response from superset.views.base_api import ( BaseSupersetModelRestApi, @@ -107,6 +113,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi): "available", "validate_parameters", "validate_sql", + "delete_ssh_tunnel", } resource_name = "database" class_permission_name = "Database" @@ -219,6 +226,47 @@ class DatabaseRestApi(BaseSupersetModelRestApi): ValidateSQLResponse, ) + @expose("/", methods=["GET"]) + @protect() + @safe + def get(self, pk: int, **kwargs: Any) -> Response: + """Get a database + --- + get: + description: >- + Get a database + parameters: + - in: path + schema: + type: integer + description: The database id + name: pk + responses: + 200: + description: Database + content: + application/json: + schema: + type: object + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 422: + $ref: '#/components/responses/422' + 500: + $ref: '#/components/responses/500' + """ + data = self.get_headless(pk, **kwargs) + try: + if ssh_tunnel := DatabaseDAO.get_ssh_tunnel(pk): + payload = data.json + payload["result"]["ssh_tunnel"] = ssh_tunnel.data + return payload + return data + except SupersetException as ex: + return self.response(ex.status, message=ex.message) + @expose("/", methods=["POST"]) @protect() @safe @@ -280,6 +328,12 @@ class DatabaseRestApi(BaseSupersetModelRestApi): if new_model.driver: item["driver"] = new_model.driver + # Return SSH Tunnel and hide passwords if any + if item.get("ssh_tunnel"): + item["ssh_tunnel"] = mask_password_info( + new_model.ssh_tunnel # pylint: disable=no-member + ) + return self.response(201, id=new_model.id, result=item) except DatabaseInvalidError as ex: return self.response_422(message=ex.normalized_messages()) @@ -361,6 +415,9 @@ class DatabaseRestApi(BaseSupersetModelRestApi): item["sqlalchemy_uri"] = changed_model.sqlalchemy_uri if changed_model.parameters: item["parameters"] = changed_model.parameters + # Return SSH Tunnel and hide passwords if any + if item.get("ssh_tunnel"): + item["ssh_tunnel"] = mask_password_info(changed_model.ssh_tunnel) return self.response(200, id=changed_model.id, result=item) except DatabaseNotFoundError: return self.response_404() @@ -1206,3 +1263,57 @@ class DatabaseRestApi(BaseSupersetModelRestApi): command = ValidateDatabaseParametersCommand(payload) command.run() return self.response(200, message="OK") + + @expose("//ssh_tunnel/", methods=["DELETE"]) + @protect() + @statsd_metrics + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}" + f".delete_ssh_tunnel", + log_to_statsd=False, + ) + def delete_ssh_tunnel(self, pk: int) -> Response: + """Deletes a SSH Tunnel + --- + delete: + description: >- + Deletes a SSH Tunnel. + parameters: + - in: path + schema: + type: integer + name: pk + responses: + 200: + description: SSH Tunnel deleted + content: + application/json: + schema: + type: object + properties: + message: + type: string + 401: + $ref: '#/components/responses/401' + 403: + $ref: '#/components/responses/403' + 404: + $ref: '#/components/responses/404' + 422: + $ref: '#/components/responses/422' + 500: + $ref: '#/components/responses/500' + """ + try: + DeleteSSHTunnelCommand(pk).run() + return self.response(200, message="OK") + except SSHTunnelNotFoundError: + return self.response_404() + except SSHTunnelDeleteFailedError as ex: + logger.error( + "Error deleting SSH Tunnel %s: %s", + self.__class__.__name__, + str(ex), + exc_info=True, + ) + return self.response_422(message=str(ex)) diff --git a/superset/databases/commands/create.py b/superset/databases/commands/create.py index 4dc8e8eda4..c826d82835 100644 --- a/superset/databases/commands/create.py +++ b/superset/databases/commands/create.py @@ -31,6 +31,11 @@ from superset.databases.commands.exceptions import ( ) from superset.databases.commands.test_connection import TestConnectionDatabaseCommand from superset.databases.dao import DatabaseDAO +from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand +from superset.databases.ssh_tunnel.commands.exceptions import ( + SSHTunnelCreateFailedError, + SSHTunnelInvalidError, +) from superset.exceptions import SupersetErrorsException from superset.extensions import db, event_logger, security_manager @@ -71,12 +76,35 @@ class CreateDatabaseCommand(BaseCommand): database = DatabaseDAO.create(self._properties, commit=False) database.set_sqlalchemy_uri(database.sqlalchemy_uri) + ssh_tunnel = None + if ssh_tunnel_properties := self._properties.get("ssh_tunnel"): + try: + # So database.id is not None + db.session.flush() + ssh_tunnel = CreateSSHTunnelCommand( + database.id, ssh_tunnel_properties + ).run() + except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex: + event_logger.log_with_context( + action=f"db_creation_failed.{ex.__class__.__name__}", + engine=self._properties.get("sqlalchemy_uri", "").split(":")[0], + ) + # So we can show the original message + raise ex + except Exception as ex: + event_logger.log_with_context( + action=f"db_creation_failed.{ex.__class__.__name__}", + engine=self._properties.get("sqlalchemy_uri", "").split(":")[0], + ) + raise DatabaseCreateFailedError() from ex + # adding a new database we always want to force refresh schema list - schemas = database.get_all_schema_names(cache=False) + schemas = database.get_all_schema_names(cache=False, ssh_tunnel=ssh_tunnel) for schema in schemas: security_manager.add_permission_view_menu( "schema_access", security_manager.get_schema_perm(database, schema) ) + db.session.commit() except DAOCreateFailedError as ex: db.session.rollback() diff --git a/superset/databases/commands/test_connection.py b/superset/databases/commands/test_connection.py index 3393be67b7..8027efcb49 100644 --- a/superset/databases/commands/test_connection.py +++ b/superset/databases/commands/test_connection.py @@ -32,6 +32,7 @@ from superset.databases.commands.exceptions import ( DatabaseTestConnectionUnexpectedError, ) from superset.databases.dao import DatabaseDAO +from superset.databases.ssh_tunnel.models import SSHTunnel from superset.databases.utils import make_url_safe from superset.errors import ErrorLevel, SupersetErrorType from superset.exceptions import ( @@ -90,6 +91,10 @@ class TestConnectionDatabaseCommand(BaseCommand): database.set_sqlalchemy_uri(uri) database.db_engine_spec.mutate_db_for_connection_test(database) + # Generate tunnel if present in the properties + if ssh_tunnel := self._properties.get("ssh_tunnel"): + ssh_tunnel = SSHTunnel(**ssh_tunnel) + event_logger.log_with_context( action="test_connection_attempt", engine=database.db_engine_spec.__name__, @@ -99,7 +104,9 @@ class TestConnectionDatabaseCommand(BaseCommand): with closing(engine.raw_connection()) as conn: return engine.dialect.do_ping(conn) - with database.get_sqla_engine_with_context() as engine: + with database.get_sqla_engine_with_context( + override_ssh_tunnel=ssh_tunnel + ) as engine: try: alive = func_timeout( app.config["TEST_DATABASE_CONNECTION_TIMEOUT"].total_seconds(), diff --git a/superset/databases/commands/update.py b/superset/databases/commands/update.py index 80e3a9b54e..2e5931788e 100644 --- a/superset/databases/commands/update.py +++ b/superset/databases/commands/update.py @@ -21,7 +21,7 @@ from flask_appbuilder.models.sqla import Model from marshmallow import ValidationError from superset.commands.base import BaseCommand -from superset.dao.exceptions import DAOUpdateFailedError +from superset.dao.exceptions import DAOCreateFailedError, DAOUpdateFailedError from superset.databases.commands.exceptions import ( DatabaseConnectionFailedError, DatabaseExistsValidationError, @@ -30,6 +30,12 @@ from superset.databases.commands.exceptions import ( DatabaseUpdateFailedError, ) from superset.databases.dao import DatabaseDAO +from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand +from superset.databases.ssh_tunnel.commands.exceptions import ( + SSHTunnelCreateFailedError, + SSHTunnelInvalidError, +) +from superset.databases.ssh_tunnel.commands.update import UpdateSSHTunnelCommand from superset.extensions import db, security_manager from superset.models.core import Database from superset.utils.core import DatasourceType @@ -94,10 +100,33 @@ class UpdateDatabaseCommand(BaseCommand): security_manager.add_permission_view_menu( "schema_access", security_manager.get_schema_perm(database, schema) ) + + if ssh_tunnel_properties := self._properties.get("ssh_tunnel"): + existing_ssh_tunnel_model = DatabaseDAO.get_ssh_tunnel(database.id) + if existing_ssh_tunnel_model is None: + # We couldn't found an existing tunnel so we need to create one + try: + CreateSSHTunnelCommand(database.id, ssh_tunnel_properties).run() + except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex: + # So we can show the original message + raise ex + except Exception as ex: + raise DatabaseUpdateFailedError() from ex + else: + # We found an existing tunnel so we need to update it + try: + UpdateSSHTunnelCommand( + existing_ssh_tunnel_model.id, ssh_tunnel_properties + ).run() + except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex: + # So we can show the original message + raise ex + except Exception as ex: + raise DatabaseUpdateFailedError() from ex + db.session.commit() - except DAOUpdateFailedError as ex: - logger.exception(ex.exception) + except (DAOUpdateFailedError, DAOCreateFailedError) as ex: raise DatabaseUpdateFailedError() from ex return database diff --git a/superset/databases/dao.py b/superset/databases/dao.py index 568755dd32..c82f0db574 100644 --- a/superset/databases/dao.py +++ b/superset/databases/dao.py @@ -19,6 +19,7 @@ from typing import Any, Dict, Optional from superset.dao.base import BaseDAO from superset.databases.filters import DatabaseFilter +from superset.databases.ssh_tunnel.models import SSHTunnel from superset.extensions import db from superset.models.core import Database from superset.models.dashboard import Dashboard @@ -124,3 +125,13 @@ class DatabaseDAO(BaseDAO): return dict( charts=charts, dashboards=dashboards, sqllab_tab_states=sqllab_tab_states ) + + @classmethod + def get_ssh_tunnel(cls, database_id: int) -> Optional[SSHTunnel]: + ssh_tunnel = ( + db.session.query(SSHTunnel) + .filter(SSHTunnel.database_id == database_id) + .one_or_none() + ) + + return ssh_tunnel diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py index ef22374ef8..1732b01ecc 100644 --- a/superset/databases/schemas.py +++ b/superset/databases/schemas.py @@ -365,6 +365,19 @@ class DatabaseValidateParametersSchema(Schema): ) +class DatabaseSSHTunnel(Schema): + server_address = fields.String() + server_port = fields.Integer() + username = fields.String() + + # Basic Authentication + password = fields.String(required=False) + + # password protected private key authentication + private_key = fields.String(required=False) + private_key_password = fields.String(required=False) + + class DatabasePostSchema(Schema, DatabaseParametersSchemaMixin): class Meta: # pylint: disable=too-few-public-methods unknown = EXCLUDE @@ -409,6 +422,7 @@ class DatabasePostSchema(Schema, DatabaseParametersSchemaMixin): is_managed_externally = fields.Boolean(allow_none=True, default=False) external_url = fields.String(allow_none=True) uuid = fields.String(required=False) + ssh_tunnel = fields.Nested(DatabaseSSHTunnel, allow_none=True) class DatabasePutSchema(Schema, DatabaseParametersSchemaMixin): @@ -454,6 +468,7 @@ class DatabasePutSchema(Schema, DatabaseParametersSchemaMixin): ) is_managed_externally = fields.Boolean(allow_none=True, default=False) external_url = fields.String(allow_none=True) + ssh_tunnel = fields.Nested(DatabaseSSHTunnel, allow_none=True) class DatabaseTestConnectionSchema(Schema, DatabaseParametersSchemaMixin): @@ -482,6 +497,8 @@ class DatabaseTestConnectionSchema(Schema, DatabaseParametersSchemaMixin): validate=[Length(1, 1024), sqlalchemy_uri_validator], ) + ssh_tunnel = fields.Nested(DatabaseSSHTunnel, allow_none=True) + class TableMetadataOptionsResponseSchema(Schema): deferrable = fields.Bool() diff --git a/superset/databases/ssh_tunnel/__init__.py b/superset/databases/ssh_tunnel/__init__.py new file mode 100644 index 0000000000..13a83393a9 --- /dev/null +++ b/superset/databases/ssh_tunnel/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/superset/databases/ssh_tunnel/commands/__init__.py b/superset/databases/ssh_tunnel/commands/__init__.py new file mode 100644 index 0000000000..13a83393a9 --- /dev/null +++ b/superset/databases/ssh_tunnel/commands/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/superset/databases/ssh_tunnel/commands/create.py b/superset/databases/ssh_tunnel/commands/create.py new file mode 100644 index 0000000000..9c17149ba3 --- /dev/null +++ b/superset/databases/ssh_tunnel/commands/create.py @@ -0,0 +1,92 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +from typing import Any, Dict, List, Optional + +from flask_appbuilder.models.sqla import Model +from marshmallow import ValidationError + +from superset.commands.base import BaseCommand +from superset.dao.exceptions import DAOCreateFailedError +from superset.databases.ssh_tunnel.commands.exceptions import ( + SSHTunnelCreateFailedError, + SSHTunnelInvalidError, + SSHTunnelRequiredFieldValidationError, +) +from superset.databases.ssh_tunnel.dao import SSHTunnelDAO +from superset.extensions import db, event_logger + +logger = logging.getLogger(__name__) + + +class CreateSSHTunnelCommand(BaseCommand): + def __init__(self, database_id: int, data: Dict[str, Any]): + self._properties = data.copy() + self._properties["database_id"] = database_id + + def run(self) -> Model: + try: + # Start nested transaction since we are always creating the tunnel + # through a DB command (Create or Update). Without this, we cannot + # safely rollback changes to databases if any, i.e, things like + # test_do_not_create_database_if_ssh_tunnel_creation_fails test will fail + db.session.begin_nested() + self.validate() + tunnel = SSHTunnelDAO.create(self._properties, commit=False) + except DAOCreateFailedError as ex: + # Rollback nested transaction + db.session.rollback() + raise SSHTunnelCreateFailedError() from ex + except SSHTunnelInvalidError as ex: + # Rollback nested transaction + db.session.rollback() + raise ex + + return tunnel + + def validate(self) -> None: + # TODO(hughhh): check to make sure the server port is not localhost + # using the config.SSH_TUNNEL_MANAGER + exceptions: List[ValidationError] = [] + database_id: Optional[int] = self._properties.get("database_id") + server_address: Optional[str] = self._properties.get("server_address") + server_port: Optional[int] = self._properties.get("server_port") + username: Optional[str] = self._properties.get("username") + private_key: Optional[str] = self._properties.get("private_key") + private_key_password: Optional[str] = self._properties.get( + "private_key_password" + ) + if not database_id: + exceptions.append(SSHTunnelRequiredFieldValidationError("database_id")) + if not server_address: + exceptions.append(SSHTunnelRequiredFieldValidationError("server_address")) + if not server_port: + exceptions.append(SSHTunnelRequiredFieldValidationError("server_port")) + if not username: + exceptions.append(SSHTunnelRequiredFieldValidationError("username")) + if private_key_password and private_key is None: + exceptions.append(SSHTunnelRequiredFieldValidationError("private_key")) + if exceptions: + exception = SSHTunnelInvalidError() + exception.add_list(exceptions) + event_logger.log_with_context( + action="ssh_tunnel_creation_failed.{}.{}".format( + exception.__class__.__name__, + ".".join(exception.get_list_classnames()), + ) + ) + raise exception diff --git a/superset/databases/ssh_tunnel/commands/delete.py b/superset/databases/ssh_tunnel/commands/delete.py new file mode 100644 index 0000000000..3ad2fc2a15 --- /dev/null +++ b/superset/databases/ssh_tunnel/commands/delete.py @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +from typing import Optional + +from flask_appbuilder.models.sqla import Model + +from superset.commands.base import BaseCommand +from superset.dao.exceptions import DAODeleteFailedError +from superset.databases.ssh_tunnel.commands.exceptions import ( + SSHTunnelDeleteFailedError, + SSHTunnelNotFoundError, +) +from superset.databases.ssh_tunnel.dao import SSHTunnelDAO +from superset.databases.ssh_tunnel.models import SSHTunnel + +logger = logging.getLogger(__name__) + + +class DeleteSSHTunnelCommand(BaseCommand): + def __init__(self, model_id: int): + self._model_id = model_id + self._model: Optional[SSHTunnel] = None + + def run(self) -> Model: + self.validate() + try: + ssh_tunnel = SSHTunnelDAO.delete(self._model) + except DAODeleteFailedError as ex: + raise SSHTunnelDeleteFailedError() from ex + return ssh_tunnel + + def validate(self) -> None: + # Validate/populate model exists + self._model = SSHTunnelDAO.find_by_id(self._model_id) + if not self._model: + raise SSHTunnelNotFoundError() diff --git a/superset/databases/ssh_tunnel/commands/exceptions.py b/superset/databases/ssh_tunnel/commands/exceptions.py new file mode 100644 index 0000000000..db2d3173de --- /dev/null +++ b/superset/databases/ssh_tunnel/commands/exceptions.py @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from flask_babel import lazy_gettext as _ +from marshmallow import ValidationError + +from superset.commands.exceptions import ( + CommandException, + CommandInvalidError, + DeleteFailedError, + UpdateFailedError, +) + + +class SSHTunnelDeleteFailedError(DeleteFailedError): + message = _("SSH Tunnel could not be deleted.") + + +class SSHTunnelNotFoundError(CommandException): + status = 404 + message = _("SSH Tunnel not found.") + + +class SSHTunnelInvalidError(CommandInvalidError): + message = _("SSH Tunnel parameters are invalid.") + + +class SSHTunnelUpdateFailedError(UpdateFailedError): + message = _("SSH Tunnel could not be updated.") + + +class SSHTunnelCreateFailedError(CommandException): + message = _("Creating SSH Tunnel failed for an unknown reason") + + +class SSHTunnelRequiredFieldValidationError(ValidationError): + def __init__(self, field_name: str) -> None: + super().__init__( + [_("Field is required")], + field_name=field_name, + ) diff --git a/superset/databases/ssh_tunnel/commands/update.py b/superset/databases/ssh_tunnel/commands/update.py new file mode 100644 index 0000000000..8d2feaf1b0 --- /dev/null +++ b/superset/databases/ssh_tunnel/commands/update.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +from typing import Any, Dict, Optional + +from flask_appbuilder.models.sqla import Model + +from superset.commands.base import BaseCommand +from superset.dao.exceptions import DAOUpdateFailedError +from superset.databases.ssh_tunnel.commands.exceptions import ( + SSHTunnelInvalidError, + SSHTunnelNotFoundError, + SSHTunnelRequiredFieldValidationError, + SSHTunnelUpdateFailedError, +) +from superset.databases.ssh_tunnel.dao import SSHTunnelDAO +from superset.databases.ssh_tunnel.models import SSHTunnel + +logger = logging.getLogger(__name__) + + +class UpdateSSHTunnelCommand(BaseCommand): + def __init__(self, model_id: int, data: Dict[str, Any]): + self._properties = data.copy() + self._model_id = model_id + self._model: Optional[SSHTunnel] = None + + def run(self) -> Model: + self.validate() + try: + tunnel = SSHTunnelDAO.update(self._model, self._properties) + except DAOUpdateFailedError as ex: + raise SSHTunnelUpdateFailedError() from ex + return tunnel + + def validate(self) -> None: + # Validate/populate model exists + self._model = SSHTunnelDAO.find_by_id(self._model_id) + if not self._model: + raise SSHTunnelNotFoundError() + private_key: Optional[str] = self._properties.get("private_key") + private_key_password: Optional[str] = self._properties.get( + "private_key_password" + ) + if private_key_password and private_key is None: + exception = SSHTunnelInvalidError() + exception.add(SSHTunnelRequiredFieldValidationError("private_key")) + raise exception diff --git a/superset/databases/ssh_tunnel/dao.py b/superset/databases/ssh_tunnel/dao.py new file mode 100644 index 0000000000..9241481644 --- /dev/null +++ b/superset/databases/ssh_tunnel/dao.py @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging + +from superset.dao.base import BaseDAO +from superset.databases.ssh_tunnel.models import SSHTunnel + +logger = logging.getLogger(__name__) + + +class SSHTunnelDAO(BaseDAO): + model_cls = SSHTunnel diff --git a/superset/databases/ssh_tunnel/models.py b/superset/databases/ssh_tunnel/models.py new file mode 100644 index 0000000000..79e8b918d9 --- /dev/null +++ b/superset/databases/ssh_tunnel/models.py @@ -0,0 +1,76 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict + +import sqlalchemy as sa +from flask import current_app +from flask_appbuilder import Model +from sqlalchemy.orm import backref, relationship +from sqlalchemy_utils import EncryptedType + +from superset.models.core import Database +from superset.models.helpers import ( + AuditMixinNullable, + ExtraJSONMixin, + ImportExportMixin, +) + +app_config = current_app.config + + +class SSHTunnel(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin): + """ + A ssh tunnel configuration in a database. + """ + + __tablename__ = "ssh_tunnels" + + id = sa.Column(sa.Integer, primary_key=True) + database_id = sa.Column( + sa.Integer, sa.ForeignKey("dbs.id"), nullable=False, unique=True + ) + database: Database = relationship( + "Database", + backref=backref("ssh_tunnels", uselist=False, cascade="all, delete-orphan"), + foreign_keys=[database_id], + ) + + server_address = sa.Column(sa.Text) + server_port = sa.Column(sa.Integer) + username = sa.Column(EncryptedType(sa.String, app_config["SECRET_KEY"])) + + # basic authentication + password = sa.Column( + EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True + ) + + # password protected pkey authentication + private_key = sa.Column( + EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True + ) + private_key_password = sa.Column( + EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True + ) + + @property + def data(self) -> Dict[str, Any]: + return { + "server_address": self.server_address, + "server_port": self.server_port, + "username": self.username, + } diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 921b82bc30..a2a98ee6c0 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -193,6 +193,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods engine_aliases: Set[str] = set() drivers: Dict[str, str] = {} default_driver: Optional[str] = None + allow_ssh_tunneling = False _date_trunc_functions: Dict[str, str] = {} _time_grain_expressions: Dict[Optional[str], str] = {} diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index 286b6e80a1..3a6a2e17d8 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -166,6 +166,7 @@ class PostgresBaseEngineSpec(BaseEngineSpec): class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin): engine = "postgresql" engine_aliases = {"postgres"} + allow_ssh_tunneling = True default_driver = "psycopg2" sqlalchemy_uri_placeholder = ( diff --git a/superset/extensions/__init__.py b/superset/extensions/__init__.py index 1f5882f749..cccf3a526f 100644 --- a/superset/extensions/__init__.py +++ b/superset/extensions/__init__.py @@ -28,6 +28,7 @@ from flask_talisman import Talisman from flask_wtf.csrf import CSRFProtect from werkzeug.local import LocalProxy +from superset.extensions.ssh import SSHManagerFactory from superset.utils.async_query_manager import AsyncQueryManager from superset.utils.cache_manager import CacheManager from superset.utils.encrypt import EncryptedFieldFactory @@ -127,3 +128,4 @@ profiling = ProfilingExtension() results_backend_manager = ResultsBackendManager() security_manager = LocalProxy(lambda: appbuilder.sm) talisman = Talisman() +ssh_manager_factory = SSHManagerFactory() diff --git a/superset/extensions/ssh.py b/superset/extensions/ssh.py new file mode 100644 index 0000000000..4ae8d508fc --- /dev/null +++ b/superset/extensions/ssh.py @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import importlib +from typing import TYPE_CHECKING + +from flask import Flask +from sshtunnel import open_tunnel, SSHTunnelForwarder + +from superset.databases.utils import make_url_safe + +if TYPE_CHECKING: + from superset.databases.ssh_tunnel.models import SSHTunnel + + +class SSHManager: + def __init__(self, app: Flask) -> None: + super().__init__() + self.local_bind_address = app.config["SSH_TUNNEL_LOCAL_BIND_ADDRESS"] + + def build_sqla_url( # pylint: disable=no-self-use + self, sqlalchemy_url: str, server: SSHTunnelForwarder + ) -> str: + # override any ssh tunnel configuration object + url = make_url_safe(sqlalchemy_url) + return url.set( + host=server.local_bind_address[0], + port=server.local_bind_port, + ) + + def create_tunnel( + self, + ssh_tunnel: "SSHTunnel", + sqlalchemy_database_uri: str, + ) -> SSHTunnelForwarder: + url = make_url_safe(sqlalchemy_database_uri) + params = { + "ssh_address_or_host": ssh_tunnel.server_address, + "ssh_port": ssh_tunnel.server_port, + "ssh_username": ssh_tunnel.username, + "remote_bind_address": (url.host, url.port), # bind_port, bind_host + "local_bind_address": (self.local_bind_address,), + } + + if ssh_tunnel.password: + params["ssh_password"] = ssh_tunnel.password + elif ssh_tunnel.private_key: + params["private_key"] = ssh_tunnel.private_key + params["private_key_password"] = ssh_tunnel.private_key_password + + return open_tunnel(**params) + + +class SSHManagerFactory: + def __init__(self) -> None: + self._ssh_manager = None + + def init_app(self, app: Flask) -> None: + ssh_manager_fqclass = app.config["SSH_TUNNEL_MANAGER_CLASS"] + ssh_manager_classname = ssh_manager_fqclass[ + ssh_manager_fqclass.rfind(".") + 1 : + ] + ssh_manager_module_name = ssh_manager_fqclass[ + 0 : ssh_manager_fqclass.rfind(".") + ] + ssh_manager_class = getattr( + importlib.import_module(ssh_manager_module_name), ssh_manager_classname + ) + + self._ssh_manager = ssh_manager_class(app) + + @property + def instance(self) -> SSHManager: + return self._ssh_manager # type: ignore diff --git a/superset/initialization/__init__.py b/superset/initialization/__init__.py index 8c53c4c8e7..2b02d5106e 100644 --- a/superset/initialization/__init__.py +++ b/superset/initialization/__init__.py @@ -45,6 +45,7 @@ from superset.extensions import ( migrate, profiling, results_backend_manager, + ssh_manager_factory, talisman, ) from superset.security import SupersetSecurityManager @@ -417,6 +418,7 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods self.configure_data_sources() self.configure_auth_provider() self.configure_async_queries() + self.configure_ssh_manager() # Hook that provides administrators a handle on the Flask APP # after initialization @@ -474,6 +476,9 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods def configure_auth_provider(self) -> None: machine_auth_provider_factory.init_app(self.superset_app) + def configure_ssh_manager(self) -> None: + ssh_manager_factory.init_app(self.superset_app) + def setup_event_logger(self) -> None: _event_logger["event_logger"] = get_event_logger_from_cfg_value( self.superset_app.config.get("EVENT_LOGGER", DBEventLogger()) diff --git a/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py new file mode 100644 index 0000000000..b373020cb1 --- /dev/null +++ b/superset/migrations/versions/2022-10-20_10-48_f3c2d8ec8595_create_ssh_tunnel_credentials_tbl.py @@ -0,0 +1,89 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""create_ssh_tunnel_credentials_tbl + +Revision ID: f3c2d8ec8595 +Revises: 4ce1d9b25135 +Create Date: 2022-10-20 10:48:08.722861 + +""" + +# revision identifiers, used by Alembic. +revision = "f3c2d8ec8595" +down_revision = "4ce1d9b25135" + +from uuid import uuid4 + +import sqlalchemy as sa +from alembic import op +from sqlalchemy_utils import UUIDType + +from superset import app +from superset.extensions import encrypted_field_factory + +app_config = app.config + + +def upgrade(): + op.create_table( + "ssh_tunnels", + # AuditMixinNullable + sa.Column("created_on", sa.DateTime(), nullable=True), + sa.Column("changed_on", sa.DateTime(), nullable=True), + sa.Column("created_by_fk", sa.Integer(), nullable=True), + sa.Column("changed_by_fk", sa.Integer(), nullable=True), + # ExtraJSONMixin + sa.Column("extra_json", sa.Text(), nullable=True), + # ImportExportMixin + sa.Column( + "uuid", + UUIDType(binary=True), + primary_key=False, + default=uuid4, + unique=True, + index=True, + ), + # SSHTunnelCredentials + sa.Column("id", sa.Integer(), primary_key=True), + sa.Column( + "database_id", + sa.INTEGER(), + sa.ForeignKey("dbs.id"), + unique=True, + index=True, + ), + sa.Column("server_address", sa.String(256)), + sa.Column("server_port", sa.INTEGER()), + sa.Column("username", encrypted_field_factory.create(sa.String(256))), + sa.Column( + "password", encrypted_field_factory.create(sa.String(256)), nullable=True + ), + sa.Column( + "private_key", + encrypted_field_factory.create(sa.String(1024)), + nullable=True, + ), + sa.Column( + "private_key_password", + encrypted_field_factory.create(sa.String(256)), + nullable=True, + ), + ) + + +def downgrade(): + op.drop_table("ssh_tunnels") diff --git a/superset/models/core.py b/superset/models/core.py index 12ce9ef95e..173bd5b590 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -21,10 +21,10 @@ import json import logging import textwrap from ast import literal_eval -from contextlib import closing, contextmanager +from contextlib import closing, contextmanager, nullcontext from copy import deepcopy from datetime import datetime -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TYPE_CHECKING import numpy import pandas as pd @@ -57,7 +57,12 @@ from superset import app, db_engine_specs from superset.constants import PASSWORD_MASK from superset.databases.utils import make_url_safe from superset.db_engine_specs.base import MetricType, TimeGrain -from superset.extensions import cache_manager, encrypted_field_factory, security_manager +from superset.extensions import ( + cache_manager, + encrypted_field_factory, + security_manager, + ssh_manager_factory, +) from superset.models.helpers import AuditMixinNullable, ImportExportMixin from superset.result_set import SupersetResultSet from superset.utils import cache as cache_util, core as utils @@ -71,6 +76,9 @@ log_query = config["QUERY_LOGGER"] metadata = Model.metadata # pylint: disable=no-member logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from superset.databases.ssh_tunnel.models import SSHTunnel + DB_CONNECTION_MUTATOR = config["DB_CONNECTION_MUTATOR"] @@ -373,17 +381,48 @@ class Database( schema: Optional[str] = None, nullpool: bool = True, source: Optional[utils.QuerySource] = None, + override_ssh_tunnel: Optional["SSHTunnel"] = None, ) -> Engine: - yield self._get_sqla_engine(schema=schema, nullpool=nullpool, source=source) + from superset.databases.dao import ( # pylint: disable=import-outside-toplevel + DatabaseDAO, + ) + + sqlalchemy_uri = self.sqlalchemy_uri_decrypted + engine_context = nullcontext() + ssh_tunnel = override_ssh_tunnel or DatabaseDAO.get_ssh_tunnel( + database_id=self.id + ) + + if ssh_tunnel: + # if ssh_tunnel is available build engine with information + engine_context = ssh_manager_factory.instance.create_tunnel( + ssh_tunnel=ssh_tunnel, + sqlalchemy_database_uri=self.sqlalchemy_uri_decrypted, + ) + + with engine_context as server_context: + if ssh_tunnel: + sqlalchemy_uri = ssh_manager_factory.instance.build_sqla_url( + sqlalchemy_uri, server_context + ) + yield self._get_sqla_engine( + schema=schema, + nullpool=nullpool, + source=source, + sqlalchemy_uri=sqlalchemy_uri, + ) def _get_sqla_engine( self, schema: Optional[str] = None, nullpool: bool = True, source: Optional[utils.QuerySource] = None, + sqlalchemy_uri: Optional[str] = None, ) -> Engine: extra = self.get_extra() - sqlalchemy_url = make_url_safe(self.sqlalchemy_uri_decrypted) + sqlalchemy_url = make_url_safe( + sqlalchemy_uri if sqlalchemy_uri else self.sqlalchemy_uri_decrypted + ) sqlalchemy_url = self.db_engine_spec.adjust_database_uri(sqlalchemy_url, schema) effective_username = self.get_effective_user(sqlalchemy_url) # If using MySQL or Presto for example, will set url.username @@ -423,7 +462,6 @@ class Database( sqlalchemy_url, params = DB_CONNECTION_MUTATOR( sqlalchemy_url, params, effective_username, security_manager, source ) - try: return create_engine(sqlalchemy_url, **params) except Exception as ex: @@ -477,7 +515,7 @@ class Database( security_manager, ) - with closing(engine.raw_connection()) as conn: + with self.get_raw_connection(schema=schema) as conn: cursor = conn.cursor() for sql_ in sqls[:-1]: _log_query(sql_) @@ -574,14 +612,16 @@ class Database( :return: The table/schema pairs """ try: - return { - (table, schema) - for table in self.db_engine_spec.get_table_names( - database=self, - inspector=self.inspector, - schema=schema, - ) - } + with self.get_inspector_with_context() as inspector: + tables = { + (table, schema) + for table in self.db_engine_spec.get_table_names( + database=self, + inspector=inspector, + schema=schema, + ) + } + return tables except Exception as ex: raise self.db_engine_spec.get_dbapi_mapped_exception(ex) @@ -608,17 +648,27 @@ class Database( :return: set of views """ try: - return { - (view, schema) - for view in self.db_engine_spec.get_view_names( - database=self, - inspector=self.inspector, - schema=schema, - ) - } + with self.get_inspector_with_context() as inspector: + return { + (view, schema) + for view in self.db_engine_spec.get_view_names( + database=self, + inspector=inspector, + schema=schema, + ) + } except Exception as ex: raise self.db_engine_spec.get_dbapi_mapped_exception(ex) + @contextmanager + def get_inspector_with_context( + self, ssh_tunnel: Optional["SSHTunnel"] = None + ) -> Inspector: + with self.get_sqla_engine_with_context( + override_ssh_tunnel=ssh_tunnel + ) as engine: + yield sqla.inspect(engine) + @cache_util.memoized_func( key="db:{self.id}:schema_list", cache=cache_manager.cache, @@ -628,6 +678,7 @@ class Database( cache: bool = False, cache_timeout: Optional[int] = None, force: bool = False, + ssh_tunnel: Optional["SSHTunnel"] = None, ) -> List[str]: """Parameters need to be passed as keyword arguments. @@ -640,7 +691,8 @@ class Database( :return: schema list """ try: - return self.db_engine_spec.get_schema_names(self.inspector) + with self.get_inspector_with_context(ssh_tunnel=ssh_tunnel) as inspector: + return self.db_engine_spec.get_schema_names(inspector) except Exception as ex: raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex @@ -703,43 +755,49 @@ class Database( def get_table_comment( self, table_name: str, schema: Optional[str] = None ) -> Optional[str]: - return self.db_engine_spec.get_table_comment(self.inspector, table_name, schema) + with self.get_inspector_with_context() as inspector: + return self.db_engine_spec.get_table_comment(inspector, table_name, schema) def get_columns( self, table_name: str, schema: Optional[str] = None ) -> List[Dict[str, Any]]: - return self.db_engine_spec.get_columns(self.inspector, table_name, schema) + with self.get_inspector_with_context() as inspector: + return self.db_engine_spec.get_columns(inspector, table_name, schema) def get_metrics( self, table_name: str, schema: Optional[str] = None, ) -> List[MetricType]: - return self.db_engine_spec.get_metrics(self, self.inspector, table_name, schema) + with self.get_inspector_with_context() as inspector: + return self.db_engine_spec.get_metrics(self, inspector, table_name, schema) def get_indexes( self, table_name: str, schema: Optional[str] = None ) -> List[Dict[str, Any]]: - indexes = self.inspector.get_indexes(table_name, schema) - return self.db_engine_spec.normalize_indexes(indexes) + with self.get_inspector_with_context() as inspector: + indexes = inspector.get_indexes(table_name, schema) + return self.db_engine_spec.normalize_indexes(indexes) def get_pk_constraint( self, table_name: str, schema: Optional[str] = None ) -> Dict[str, Any]: - pk_constraint = self.inspector.get_pk_constraint(table_name, schema) or {} + with self.get_inspector_with_context() as inspector: + pk_constraint = inspector.get_pk_constraint(table_name, schema) or {} - def _convert(value: Any) -> Any: - try: - return utils.base_json_conv(value) - except TypeError: - return None + def _convert(value: Any) -> Any: + try: + return utils.base_json_conv(value) + except TypeError: + return None - return {key: _convert(value) for key, value in pk_constraint.items()} + return {key: _convert(value) for key, value in pk_constraint.items()} def get_foreign_keys( self, table_name: str, schema: Optional[str] = None ) -> List[Dict[str, Any]]: - return self.inspector.get_foreign_keys(table_name, schema) + with self.get_inspector_with_context() as inspector: + return inspector.get_foreign_keys(table_name, schema) def get_schema_access_for_file_upload( # pylint: disable=invalid-name self, diff --git a/superset/utils/ssh_tunnel.py b/superset/utils/ssh_tunnel.py new file mode 100644 index 0000000000..6562a8bbb5 --- /dev/null +++ b/superset/utils/ssh_tunnel.py @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict + +from superset.constants import PASSWORD_MASK + + +def mask_password_info(ssh_tunnel: Dict[str, Any]) -> Dict[str, Any]: + if ssh_tunnel.pop("password", None) is not None: + ssh_tunnel["password"] = PASSWORD_MASK + if ssh_tunnel.pop("private_key", None) is not None: + ssh_tunnel["private_key"] = PASSWORD_MASK + if ssh_tunnel.pop("private_key_password", None) is not None: + ssh_tunnel["private_key_password"] = PASSWORD_MASK + return ssh_tunnel diff --git a/tests/conftest.py b/tests/conftest.py index a5945f2f5c..9d13e58170 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,7 +28,7 @@ from __future__ import annotations from typing import Callable, TYPE_CHECKING from unittest.mock import MagicMock, Mock, PropertyMock -from flask import Flask +from flask import current_app, Flask from flask.ctx import AppContext from pytest import fixture @@ -40,6 +40,7 @@ from tests.example_data.data_loading.pandas.pands_data_loading_conf import ( from tests.example_data.data_loading.pandas.table_df_convertor import ( TableToDfConvertorImpl, ) +from tests.integration_tests.test_app import app SUPPORT_DATETIME_TYPE = "support_datetime_type" @@ -70,8 +71,9 @@ def example_db_provider() -> Callable[[], Database]: @fixture(scope="session") def example_db_engine(example_db_provider: Callable[[], Database]) -> Engine: - with example_db_provider().get_sqla_engine_with_context() as engine: - return engine + with app.app_context(): + with example_db_provider().get_sqla_engine_with_context() as engine: + return engine @fixture(scope="session") diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index 8a96184b81..aeb74ec91e 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -35,6 +35,8 @@ from sqlalchemy.sql import func from superset import db, security_manager from superset.connectors.sqla.models import SqlaTable +from superset.databases.ssh_tunnel.models import SSHTunnel +from superset.databases.utils import make_url_safe from superset.db_engine_specs.mysql import MySQLEngineSpec from superset.db_engine_specs.postgres import PostgresEngineSpec from superset.db_engine_specs.redshift import RedshiftEngineSpec @@ -280,6 +282,314 @@ class TestDatabaseApi(SupersetTestCase): db.session.delete(model) db.session.commit() + @mock.patch( + "superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run", + ) + @mock.patch( + "superset.models.core.Database.get_all_schema_names", + ) + def test_create_database_with_ssh_tunnel( + self, mock_test_connection_database_command_run, mock_get_all_schema_names + ): + """ + Database API: Test create with SSH Tunnel + """ + self.login(username="admin") + example_db = get_example_database() + if example_db.backend == "sqlite": + return + ssh_tunnel_properties = { + "server_address": "123.132.123.1", + "server_port": 8080, + "username": "foo", + "password": "bar", + } + database_data = { + "database_name": "test-db-with-ssh-tunnel", + "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, + "ssh_tunnel": ssh_tunnel_properties, + } + uri = "api/v1/database/" + rv = self.client.post(uri, json=database_data) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 201) + model_ssh_tunnel = ( + db.session.query(SSHTunnel) + .filter(SSHTunnel.database_id == response.get("id")) + .one() + ) + self.assertEqual(model_ssh_tunnel.database_id, response.get("id")) + # Cleanup + model = db.session.query(Database).get(response.get("id")) + db.session.delete(model) + db.session.commit() + + @mock.patch( + "superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run", + ) + @mock.patch( + "superset.models.core.Database.get_all_schema_names", + ) + def test_update_database_with_ssh_tunnel( + self, mock_test_connection_database_command_run, mock_get_all_schema_names + ): + """ + Database API: Test update with SSH Tunnel + """ + self.login(username="admin") + example_db = get_example_database() + if example_db.backend == "sqlite": + return + ssh_tunnel_properties = { + "server_address": "123.132.123.1", + "server_port": 8080, + "username": "foo", + "password": "bar", + } + database_data = { + "database_name": "test-db-with-ssh-tunnel", + "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, + } + database_data_with_ssh_tunnel = { + "database_name": "test-db-with-ssh-tunnel", + "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, + "ssh_tunnel": ssh_tunnel_properties, + } + + uri = "api/v1/database/" + rv = self.client.post(uri, json=database_data) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 201) + + uri = "api/v1/database/{}".format(response.get("id")) + rv = self.client.put(uri, json=database_data_with_ssh_tunnel) + response_update = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 200) + + model_ssh_tunnel = ( + db.session.query(SSHTunnel) + .filter(SSHTunnel.database_id == response_update.get("id")) + .one() + ) + self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id")) + # Cleanup + model = db.session.query(Database).get(response.get("id")) + db.session.delete(model) + db.session.commit() + + @mock.patch( + "superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run", + ) + @mock.patch( + "superset.models.core.Database.get_all_schema_names", + ) + def test_update_ssh_tunnel_via_database_api( + self, mock_test_connection_database_command_run, mock_get_all_schema_names + ): + """ + Database API: Test update with SSH Tunnel + """ + self.login(username="admin") + example_db = get_example_database() + + if example_db.backend == "sqlite": + return + initial_ssh_tunnel_properties = { + "server_address": "123.132.123.1", + "server_port": 8080, + "username": "foo", + "password": "bar", + } + updated_ssh_tunnel_properties = { + "server_address": "123.132.123.1", + "server_port": 8080, + "username": "Test", + } + database_data_with_ssh_tunnel = { + "database_name": "test-db-with-ssh-tunnel", + "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, + "ssh_tunnel": initial_ssh_tunnel_properties, + } + database_data_with_ssh_tunnel_update = { + "database_name": "test-db-with-ssh-tunnel", + "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, + "ssh_tunnel": updated_ssh_tunnel_properties, + } + + uri = "api/v1/database/" + rv = self.client.post(uri, json=database_data_with_ssh_tunnel) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 201) + model_ssh_tunnel = ( + db.session.query(SSHTunnel) + .filter(SSHTunnel.database_id == response.get("id")) + .one() + ) + self.assertEqual(model_ssh_tunnel.database_id, response.get("id")) + self.assertEqual(model_ssh_tunnel.username, "foo") + uri = "api/v1/database/{}".format(response.get("id")) + rv = self.client.put(uri, json=database_data_with_ssh_tunnel_update) + response_update = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 200) + model_ssh_tunnel = ( + db.session.query(SSHTunnel) + .filter(SSHTunnel.database_id == response_update.get("id")) + .one() + ) + self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id")) + self.assertEqual(model_ssh_tunnel.username, "Test") + self.assertEqual(model_ssh_tunnel.server_address, "123.132.123.1") + self.assertEqual(model_ssh_tunnel.server_port, 8080) + # Cleanup + model = db.session.query(Database).get(response.get("id")) + db.session.delete(model) + db.session.commit() + + @mock.patch( + "superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run", + ) + @mock.patch( + "superset.models.core.Database.get_all_schema_names", + ) + def test_cascade_delete_ssh_tunnel( + self, mock_test_connection_database_command_run, mock_get_all_schema_names + ): + """ + Database API: Test create with SSH Tunnel + """ + self.login(username="admin") + example_db = get_example_database() + if example_db.backend == "sqlite": + return + ssh_tunnel_properties = { + "server_address": "123.132.123.1", + "server_port": 8080, + "username": "foo", + "password": "bar", + } + database_data = { + "database_name": "test-db-with-ssh-tunnel", + "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, + "ssh_tunnel": ssh_tunnel_properties, + } + + uri = "api/v1/database/" + rv = self.client.post(uri, json=database_data) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 201) + model_ssh_tunnel = ( + db.session.query(SSHTunnel) + .filter(SSHTunnel.database_id == response.get("id")) + .one() + ) + self.assertEqual(model_ssh_tunnel.database_id, response.get("id")) + # Cleanup + model = db.session.query(Database).get(response.get("id")) + db.session.delete(model) + db.session.commit() + model_ssh_tunnel = ( + db.session.query(SSHTunnel) + .filter(SSHTunnel.database_id == response.get("id")) + .one_or_none() + ) + assert model_ssh_tunnel is None + + @mock.patch( + "superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run", + ) + @mock.patch( + "superset.models.core.Database.get_all_schema_names", + ) + def test_do_not_create_database_if_ssh_tunnel_creation_fails( + self, mock_test_connection_database_command_run, mock_get_all_schema_names + ): + """ + Database API: Test create with SSH Tunnel + """ + self.login(username="admin") + example_db = get_example_database() + if example_db.backend == "sqlite": + return + ssh_tunnel_properties = { + "server_address": "123.132.123.1", + } + database_data = { + "database_name": "test-db-failure-ssh-tunnel", + "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, + "ssh_tunnel": ssh_tunnel_properties, + } + fail_message = {"message": "SSH Tunnel parameters are invalid."} + + uri = "api/v1/database/" + rv = self.client.post(uri, json=database_data) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 422) + model_ssh_tunnel = ( + db.session.query(SSHTunnel) + .filter(SSHTunnel.database_id == response.get("id")) + .one_or_none() + ) + assert model_ssh_tunnel is None + self.assertEqual(response, fail_message) + # Cleanup + model = ( + db.session.query(Database) + .filter(Database.database_name == "test-db-failure-ssh-tunnel") + .one_or_none() + ) + # the DB should not be created + assert model is None + + @mock.patch( + "superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run", + ) + @mock.patch( + "superset.models.core.Database.get_all_schema_names", + ) + def test_get_database_returns_related_ssh_tunnel( + self, mock_test_connection_database_command_run, mock_get_all_schema_names + ): + """ + Database API: Test GET Database returns its related SSH Tunnel + """ + self.login(username="admin") + example_db = get_example_database() + if example_db.backend == "sqlite": + return + ssh_tunnel_properties = { + "server_address": "123.132.123.1", + "server_port": 8080, + "username": "foo", + "password": "bar", + } + database_data = { + "database_name": "test-db-with-ssh-tunnel", + "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, + "ssh_tunnel": ssh_tunnel_properties, + } + response_ssh_tunnel = { + "server_address": "123.132.123.1", + "server_port": 8080, + "username": "foo", + "password": "XXXXXXXXXX", + } + + uri = "api/v1/database/" + rv = self.client.post(uri, json=database_data) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 201) + model_ssh_tunnel = ( + db.session.query(SSHTunnel) + .filter(SSHTunnel.database_id == response.get("id")) + .one() + ) + self.assertEqual(model_ssh_tunnel.database_id, response.get("id")) + self.assertEqual(response.get("result")["ssh_tunnel"], response_ssh_tunnel) + # Cleanup + model = db.session.query(Database).get(response.get("id")) + db.session.delete(model) + db.session.commit() + def test_create_database_invalid_configuration_method(self): """ Database API: Test create with an invalid configuration method. diff --git a/tests/integration_tests/databases/ssh_tunnel/__init__.py b/tests/integration_tests/databases/ssh_tunnel/__init__.py new file mode 100644 index 0000000000..13a83393a9 --- /dev/null +++ b/tests/integration_tests/databases/ssh_tunnel/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/integration_tests/databases/ssh_tunnel/commands/__init__.py b/tests/integration_tests/databases/ssh_tunnel/commands/__init__.py new file mode 100644 index 0000000000..13a83393a9 --- /dev/null +++ b/tests/integration_tests/databases/ssh_tunnel/commands/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py b/tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py new file mode 100644 index 0000000000..75e5a55e86 --- /dev/null +++ b/tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py @@ -0,0 +1,76 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from unittest import mock, skip +from unittest.mock import patch + +import pytest + +from superset import security_manager +from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand +from superset.databases.ssh_tunnel.commands.delete import DeleteSSHTunnelCommand +from superset.databases.ssh_tunnel.commands.exceptions import ( + SSHTunnelInvalidError, + SSHTunnelNotFoundError, +) +from superset.databases.ssh_tunnel.commands.update import UpdateSSHTunnelCommand +from tests.integration_tests.base_tests import SupersetTestCase + + +class TestCreateSSHTunnelCommand(SupersetTestCase): + @mock.patch("superset.utils.core.g") + def test_create_invalid_database_id(self, mock_g): + mock_g.user = security_manager.find_user("admin") + command = CreateSSHTunnelCommand( + None, + { + "server_address": "127.0.0.1", + "server_port": 5432, + "username": "test_user", + }, + ) + with pytest.raises(SSHTunnelInvalidError) as excinfo: + command.run() + assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.") + + +class TestUpdateSSHTunnelCommand(SupersetTestCase): + @mock.patch("superset.utils.core.g") + def test_update_ssh_tunnel_not_found(self, mock_g): + mock_g.user = security_manager.find_user("admin") + # We have not created a SSH Tunnel yet so id = 1 is invalid + command = UpdateSSHTunnelCommand( + 1, + { + "server_address": "127.0.0.1", + "server_port": 5432, + "username": "test_user", + }, + ) + with pytest.raises(SSHTunnelNotFoundError) as excinfo: + command.run() + assert str(excinfo.value) == ("SSH Tunnel not found.") + + +class TestDeleteSSHTunnelCommand(SupersetTestCase): + @mock.patch("superset.utils.core.g") + def test_delete_ssh_tunnel_not_found(self, mock_g): + mock_g.user = security_manager.find_user("admin") + # We have not created a SSH Tunnel yet so id = 1 is invalid + command = DeleteSSHTunnelCommand(1) + with pytest.raises(SSHTunnelNotFoundError) as excinfo: + command.run() + assert str(excinfo.value) == ("SSH Tunnel not found.") diff --git a/tests/unit_tests/databases/api_test.py b/tests/unit_tests/databases/api_test.py index d6f8897c4a..fe4211289c 100644 --- a/tests/unit_tests/databases/api_test.py +++ b/tests/unit_tests/databases/api_test.py @@ -191,3 +191,147 @@ def test_non_zip_import(client: Any, full_api_access: None) -> None: } ] } + + +def test_delete_ssh_tunnel( + mocker: MockFixture, + app: Any, + session: Session, + client: Any, + full_api_access: None, +) -> None: + """ + Test that we can delete SSH Tunnel + """ + with app.app_context(): + from superset.databases.api import DatabaseRestApi + from superset.databases.dao import DatabaseDAO + from superset.databases.ssh_tunnel.models import SSHTunnel + from superset.models.core import Database + + DatabaseRestApi.datamodel.session = session + + # create table for databases + Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member + + # Create our Database + database = Database( + database_name="my_database", + sqlalchemy_uri="gsheets://", + encrypted_extra=json.dumps( + { + "service_account_info": { + "type": "service_account", + "project_id": "black-sanctum-314419", + "private_key_id": "259b0d419a8f840056158763ff54d8b08f7b8173", + "private_key": "SECRET", + "client_email": "google-spreadsheets-demo-servi@black-sanctum-314419.iam.gserviceaccount.com", + "client_id": "SSH_TUNNEL_CREDENTIALS_CLIENT", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://oauth2.googleapis.com/token", + "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", + "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/google-spreadsheets-demo-servi%40black-sanctum-314419.iam.gserviceaccount.com", + }, + } + ), + ) + session.add(database) + session.commit() + + # mock the lookup so that we don't need to include the driver + mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets") + mocker.patch("superset.utils.log.DBEventLogger.log") + + # Create our SSHTunnel + tunnel = SSHTunnel( + database_id=1, + database=database, + ) + + session.add(tunnel) + session.commit() + + # Get our recently created SSHTunnel + response_tunnel = DatabaseDAO.get_ssh_tunnel(1) + assert response_tunnel + assert isinstance(response_tunnel, SSHTunnel) + assert 1 == response_tunnel.database_id + + # Delete the recently created SSHTunnel + response_delete_tunnel = client.delete("/api/v1/database/1/ssh_tunnel/") + assert response_delete_tunnel.json["message"] == "OK" + + response_tunnel = DatabaseDAO.get_ssh_tunnel(1) + assert response_tunnel is None + + +def test_delete_ssh_tunnel_not_found( + mocker: MockFixture, + app: Any, + session: Session, + client: Any, + full_api_access: None, +) -> None: + """ + Test that we cannot delete a tunnel that does not exist + """ + with app.app_context(): + from superset.databases.api import DatabaseRestApi + from superset.databases.dao import DatabaseDAO + from superset.databases.ssh_tunnel.models import SSHTunnel + from superset.models.core import Database + + DatabaseRestApi.datamodel.session = session + + # create table for databases + Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member + + # Create our Database + database = Database( + database_name="my_database", + sqlalchemy_uri="gsheets://", + encrypted_extra=json.dumps( + { + "service_account_info": { + "type": "service_account", + "project_id": "black-sanctum-314419", + "private_key_id": "259b0d419a8f840056158763ff54d8b08f7b8173", + "private_key": "SECRET", + "client_email": "google-spreadsheets-demo-servi@black-sanctum-314419.iam.gserviceaccount.com", + "client_id": "SSH_TUNNEL_CREDENTIALS_CLIENT", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://oauth2.googleapis.com/token", + "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", + "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/google-spreadsheets-demo-servi%40black-sanctum-314419.iam.gserviceaccount.com", + }, + } + ), + ) + session.add(database) + session.commit() + + # mock the lookup so that we don't need to include the driver + mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets") + mocker.patch("superset.utils.log.DBEventLogger.log") + + # Create our SSHTunnel + tunnel = SSHTunnel( + database_id=1, + database=database, + ) + + session.add(tunnel) + session.commit() + + # Delete the recently created SSHTunnel + response_delete_tunnel = client.delete("/api/v1/database/2/ssh_tunnel/") + assert response_delete_tunnel.json["message"] == "Not found" + + # Get our recently created SSHTunnel + response_tunnel = DatabaseDAO.get_ssh_tunnel(1) + assert response_tunnel + assert isinstance(response_tunnel, SSHTunnel) + assert 1 == response_tunnel.database_id + + response_tunnel = DatabaseDAO.get_ssh_tunnel(2) + assert response_tunnel is None diff --git a/tests/unit_tests/databases/dao/__init__.py b/tests/unit_tests/databases/dao/__init__.py new file mode 100644 index 0000000000..13a83393a9 --- /dev/null +++ b/tests/unit_tests/databases/dao/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/databases/dao/dao_tests.py b/tests/unit_tests/databases/dao/dao_tests.py new file mode 100644 index 0000000000..47db402670 --- /dev/null +++ b/tests/unit_tests/databases/dao/dao_tests.py @@ -0,0 +1,69 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Iterator + +import pytest +from sqlalchemy.orm.session import Session + + +@pytest.fixture +def session_with_data(session: Session) -> Iterator[Session]: + from superset.connectors.sqla.models import SqlaTable + from superset.databases.ssh_tunnel.models import SSHTunnel + from superset.models.core import Database + + engine = session.get_bind() + SqlaTable.metadata.create_all(engine) # pylint: disable=no-member + + db = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + sqla_table = SqlaTable( + table_name="my_sqla_table", + columns=[], + metrics=[], + database=db, + ) + ssh_tunnel = SSHTunnel( + database_id=db.id, + database=db, + ) + + session.add(db) + session.add(sqla_table) + session.add(ssh_tunnel) + session.flush() + yield session + session.rollback() + + +def test_database_get_ssh_tunnel(session_with_data: Session) -> None: + from superset.databases.dao import DatabaseDAO + from superset.databases.ssh_tunnel.models import SSHTunnel + + result = DatabaseDAO.get_ssh_tunnel(1) + + assert result + assert isinstance(result, SSHTunnel) + assert 1 == result.database_id + + +def test_database_get_ssh_tunnel_not_found(session_with_data: Session) -> None: + from superset.databases.dao import DatabaseDAO + + result = DatabaseDAO.get_ssh_tunnel(2) + + assert result is None diff --git a/tests/unit_tests/databases/ssh_tunnel/__init__.py b/tests/unit_tests/databases/ssh_tunnel/__init__.py new file mode 100644 index 0000000000..13a83393a9 --- /dev/null +++ b/tests/unit_tests/databases/ssh_tunnel/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/__init__.py b/tests/unit_tests/databases/ssh_tunnel/commands/__init__.py new file mode 100644 index 0000000000..13a83393a9 --- /dev/null +++ b/tests/unit_tests/databases/ssh_tunnel/commands/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py new file mode 100644 index 0000000000..2a5738ebd3 --- /dev/null +++ b/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Iterator + +import pytest +from sqlalchemy.orm.session import Session + +from superset.databases.ssh_tunnel.commands.exceptions import SSHTunnelInvalidError + + +def test_create_ssh_tunnel_command() -> None: + from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand + from superset.databases.ssh_tunnel.models import SSHTunnel + from superset.models.core import Database + + db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://") + + properties = { + "database_id": db.id, + "server_address": "123.132.123.1", + "server_port": "3005", + "username": "foo", + "password": "bar", + } + + result = CreateSSHTunnelCommand(db.id, properties).run() + + assert result is not None + assert isinstance(result, SSHTunnel) + + +def test_create_ssh_tunnel_command_invalid_params() -> None: + from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand + from superset.databases.ssh_tunnel.models import SSHTunnel + from superset.models.core import Database + + db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://") + + # If we are trying to create a tunnel with a private_key_password + # then a private_key is mandatory + properties = { + "database_id": db.id, + "server_address": "123.132.123.1", + "server_port": "3005", + "username": "foo", + "private_key_password": "bar", + } + + command = CreateSSHTunnelCommand(db.id, properties) + + with pytest.raises(SSHTunnelInvalidError) as excinfo: + command.run() + assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.") diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py new file mode 100644 index 0000000000..17afebfa0f --- /dev/null +++ b/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Iterator + +import pytest +from sqlalchemy.orm.session import Session + + +@pytest.fixture +def session_with_data(session: Session) -> Iterator[Session]: + from superset.connectors.sqla.models import SqlaTable + from superset.databases.ssh_tunnel.models import SSHTunnel + from superset.models.core import Database + + engine = session.get_bind() + SqlaTable.metadata.create_all(engine) # pylint: disable=no-member + + db = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + sqla_table = SqlaTable( + table_name="my_sqla_table", + columns=[], + metrics=[], + database=db, + ) + ssh_tunnel = SSHTunnel( + database_id=db.id, + database=db, + ) + + session.add(db) + session.add(sqla_table) + session.add(ssh_tunnel) + session.flush() + yield session + session.rollback() + + +def test_delete_ssh_tunnel_command(session_with_data: Session) -> None: + from superset.databases.dao import DatabaseDAO + from superset.databases.ssh_tunnel.commands.delete import DeleteSSHTunnelCommand + from superset.databases.ssh_tunnel.models import SSHTunnel + + result = DatabaseDAO.get_ssh_tunnel(1) + + assert result + assert isinstance(result, SSHTunnel) + assert 1 == result.database_id + + DeleteSSHTunnelCommand(1).run() + + result = DatabaseDAO.get_ssh_tunnel(1) + + assert result is None diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py new file mode 100644 index 0000000000..58f90054cc --- /dev/null +++ b/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Iterator + +import pytest +from sqlalchemy.orm.session import Session + +from superset.databases.ssh_tunnel.commands.exceptions import SSHTunnelInvalidError + + +@pytest.fixture +def session_with_data(session: Session) -> Iterator[Session]: + from superset.connectors.sqla.models import SqlaTable + from superset.databases.ssh_tunnel.models import SSHTunnel + from superset.models.core import Database + + engine = session.get_bind() + SqlaTable.metadata.create_all(engine) # pylint: disable=no-member + + db = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + sqla_table = SqlaTable( + table_name="my_sqla_table", + columns=[], + metrics=[], + database=db, + ) + ssh_tunnel = SSHTunnel(database_id=db.id, database=db, server_address="Test") + + session.add(db) + session.add(sqla_table) + session.add(ssh_tunnel) + session.flush() + yield session + session.rollback() + + +def test_update_shh_tunnel_command(session_with_data: Session) -> None: + from superset.databases.dao import DatabaseDAO + from superset.databases.ssh_tunnel.commands.update import UpdateSSHTunnelCommand + from superset.databases.ssh_tunnel.models import SSHTunnel + + result = DatabaseDAO.get_ssh_tunnel(1) + + assert result + assert isinstance(result, SSHTunnel) + assert 1 == result.database_id + assert "Test" == result.server_address + + update_payload = {"server_address": "Test2"} + UpdateSSHTunnelCommand(1, update_payload).run() + + result = DatabaseDAO.get_ssh_tunnel(1) + + assert result + assert isinstance(result, SSHTunnel) + assert "Test2" == result.server_address + + +def test_update_shh_tunnel_invalid_params(session_with_data: Session) -> None: + from superset.databases.dao import DatabaseDAO + from superset.databases.ssh_tunnel.commands.update import UpdateSSHTunnelCommand + from superset.databases.ssh_tunnel.models import SSHTunnel + + result = DatabaseDAO.get_ssh_tunnel(1) + + assert result + assert isinstance(result, SSHTunnel) + assert 1 == result.database_id + assert "Test" == result.server_address + + # If we are trying to update a tunnel with a private_key_password + # then a private_key is mandatory + update_payload = {"private_key_password": "pass"} + command = UpdateSSHTunnelCommand(1, update_payload) + + with pytest.raises(SSHTunnelInvalidError) as excinfo: + command.run() + assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.") diff --git a/tests/unit_tests/databases/ssh_tunnel/dao_tests.py b/tests/unit_tests/databases/ssh_tunnel/dao_tests.py new file mode 100644 index 0000000000..ae5b6e9bd3 --- /dev/null +++ b/tests/unit_tests/databases/ssh_tunnel/dao_tests.py @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Iterator + +import pytest +from sqlalchemy.orm.session import Session + + +def test_create_ssh_tunnel(): + from superset.databases.dao import DatabaseDAO + from superset.databases.ssh_tunnel.dao import SSHTunnelDAO + from superset.databases.ssh_tunnel.models import SSHTunnel + from superset.models.core import Database + + db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://") + + properties = { + "database_id": db.id, + "server_address": "123.132.123.1", + "server_port": "3005", + "username": "foo", + "password": "bar", + } + + result = SSHTunnelDAO.create(properties) + + assert result is not None + assert isinstance(result, SSHTunnel)