mirror of https://github.com/apache/superset.git
feat(ssh_tunnel): Add feature flag to SSH Tunnel API (#22805)
This commit is contained in:
parent
0045816772
commit
d6a4a5da79
|
@ -75,6 +75,7 @@ from superset.databases.schemas import (
|
|||
from superset.databases.ssh_tunnel.commands.delete import DeleteSSHTunnelCommand
|
||||
from superset.databases.ssh_tunnel.commands.exceptions import (
|
||||
SSHTunnelDeleteFailedError,
|
||||
SSHTunnelingNotEnabledError,
|
||||
SSHTunnelNotFoundError,
|
||||
)
|
||||
from superset.databases.utils import get_table_metadata
|
||||
|
@ -349,6 +350,8 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
|||
exc_info=True,
|
||||
)
|
||||
return self.response_422(message=str(ex))
|
||||
except SSHTunnelingNotEnabledError as ex:
|
||||
return self.response_400(message=str(ex))
|
||||
except SupersetException as ex:
|
||||
return self.response(ex.status, message=ex.message)
|
||||
|
||||
|
@ -433,6 +436,8 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
|||
exc_info=True,
|
||||
)
|
||||
return self.response_422(message=str(ex))
|
||||
except SSHTunnelingNotEnabledError as ex:
|
||||
return self.response_400(message=str(ex))
|
||||
|
||||
@expose("/<int:pk>", methods=["DELETE"])
|
||||
@protect()
|
||||
|
@ -782,8 +787,11 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
|||
# This validates custom Schema with custom validations
|
||||
except ValidationError as error:
|
||||
return self.response_400(message=error.messages)
|
||||
TestConnectionDatabaseCommand(item).run()
|
||||
return self.response(200, message="OK")
|
||||
try:
|
||||
TestConnectionDatabaseCommand(item).run()
|
||||
return self.response(200, message="OK")
|
||||
except SSHTunnelingNotEnabledError as ex:
|
||||
return self.response_400(message=str(ex))
|
||||
|
||||
@expose("/<int:pk>/related_objects/", methods=["GET"])
|
||||
@protect()
|
||||
|
@ -1320,3 +1328,11 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
|||
exc_info=True,
|
||||
)
|
||||
return self.response_422(message=str(ex))
|
||||
except SSHTunnelingNotEnabledError as ex:
|
||||
logger.error(
|
||||
"Error deleting SSH Tunnel %s: %s",
|
||||
self.__class__.__name__,
|
||||
str(ex),
|
||||
exc_info=True,
|
||||
)
|
||||
return self.response_400(message=str(ex))
|
||||
|
|
|
@ -20,6 +20,7 @@ from typing import Any, Dict, List, Optional
|
|||
from flask_appbuilder.models.sqla import Model
|
||||
from marshmallow import ValidationError
|
||||
|
||||
from superset import is_feature_enabled
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.dao.exceptions import DAOCreateFailedError
|
||||
from superset.databases.commands.exceptions import (
|
||||
|
@ -34,6 +35,7 @@ from superset.databases.dao import DatabaseDAO
|
|||
from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand
|
||||
from superset.databases.ssh_tunnel.commands.exceptions import (
|
||||
SSHTunnelCreateFailedError,
|
||||
SSHTunnelingNotEnabledError,
|
||||
SSHTunnelInvalidError,
|
||||
)
|
||||
from superset.exceptions import SupersetErrorsException
|
||||
|
@ -52,7 +54,7 @@ class CreateDatabaseCommand(BaseCommand):
|
|||
try:
|
||||
# Test connection before starting create transaction
|
||||
TestConnectionDatabaseCommand(self._properties).run()
|
||||
except SupersetErrorsException as ex:
|
||||
except (SupersetErrorsException, SSHTunnelingNotEnabledError) as ex:
|
||||
event_logger.log_with_context(
|
||||
action=f"db_creation_failed.{ex.__class__.__name__}",
|
||||
engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
|
||||
|
@ -78,6 +80,9 @@ class CreateDatabaseCommand(BaseCommand):
|
|||
|
||||
ssh_tunnel = None
|
||||
if ssh_tunnel_properties := self._properties.get("ssh_tunnel"):
|
||||
if not is_feature_enabled("SSH_TUNNELING"):
|
||||
db.session.rollback()
|
||||
raise SSHTunnelingNotEnabledError()
|
||||
try:
|
||||
# So database.id is not None
|
||||
db.session.flush()
|
||||
|
|
|
@ -25,6 +25,7 @@ from func_timeout import func_timeout, FunctionTimedOut
|
|||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.exc import DBAPIError, NoSuchModuleError
|
||||
|
||||
from superset import is_feature_enabled
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.databases.commands.exceptions import (
|
||||
DatabaseSecurityUnsafeError,
|
||||
|
@ -32,6 +33,9 @@ from superset.databases.commands.exceptions import (
|
|||
DatabaseTestConnectionUnexpectedError,
|
||||
)
|
||||
from superset.databases.dao import DatabaseDAO
|
||||
from superset.databases.ssh_tunnel.commands.exceptions import (
|
||||
SSHTunnelingNotEnabledError,
|
||||
)
|
||||
from superset.databases.ssh_tunnel.dao import SSHTunnelDAO
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.databases.utils import make_url_safe
|
||||
|
@ -64,7 +68,7 @@ class TestConnectionDatabaseCommand(BaseCommand):
|
|||
self._properties = data.copy()
|
||||
self._model: Optional[Database] = None
|
||||
|
||||
def run(self) -> None: # pylint: disable=too-many-statements
|
||||
def run(self) -> None: # pylint: disable=too-many-statements, too-many-branches
|
||||
self.validate()
|
||||
ex_str = ""
|
||||
uri = self._properties.get("sqlalchemy_uri", "")
|
||||
|
@ -107,6 +111,8 @@ class TestConnectionDatabaseCommand(BaseCommand):
|
|||
|
||||
# Generate tunnel if present in the properties
|
||||
if ssh_tunnel:
|
||||
if not is_feature_enabled("SSH_TUNNELING"):
|
||||
raise SSHTunnelingNotEnabledError()
|
||||
# If there's an existing tunnel for that DB we need to use the stored
|
||||
# password, private_key and private_key_password instead
|
||||
if ssh_tunnel_id := ssh_tunnel.pop("id", None):
|
||||
|
@ -203,6 +209,15 @@ class TestConnectionDatabaseCommand(BaseCommand):
|
|||
)
|
||||
# bubble up the exception to return a 408
|
||||
raise ex
|
||||
except SSHTunnelingNotEnabledError as ex:
|
||||
event_logger.log_with_context(
|
||||
action=get_log_connection_action(
|
||||
"test_connection_error", ssh_tunnel, ex
|
||||
),
|
||||
engine=database.db_engine_spec.__name__,
|
||||
)
|
||||
# bubble up the exception to return a 400
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
event_logger.log_with_context(
|
||||
action=get_log_connection_action(
|
||||
|
|
|
@ -20,6 +20,7 @@ from typing import Any, Dict, List, Optional
|
|||
from flask_appbuilder.models.sqla import Model
|
||||
from marshmallow import ValidationError
|
||||
|
||||
from superset import is_feature_enabled
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.dao.exceptions import DAOCreateFailedError, DAOUpdateFailedError
|
||||
from superset.databases.commands.exceptions import (
|
||||
|
@ -33,7 +34,9 @@ from superset.databases.dao import DatabaseDAO
|
|||
from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand
|
||||
from superset.databases.ssh_tunnel.commands.exceptions import (
|
||||
SSHTunnelCreateFailedError,
|
||||
SSHTunnelingNotEnabledError,
|
||||
SSHTunnelInvalidError,
|
||||
SSHTunnelUpdateFailedError,
|
||||
)
|
||||
from superset.databases.ssh_tunnel.commands.update import UpdateSSHTunnelCommand
|
||||
from superset.extensions import db, security_manager
|
||||
|
@ -102,6 +105,9 @@ class UpdateDatabaseCommand(BaseCommand):
|
|||
)
|
||||
|
||||
if ssh_tunnel_properties := self._properties.get("ssh_tunnel"):
|
||||
if not is_feature_enabled("SSH_TUNNELING"):
|
||||
db.session.rollback()
|
||||
raise SSHTunnelingNotEnabledError()
|
||||
existing_ssh_tunnel_model = DatabaseDAO.get_ssh_tunnel(database.id)
|
||||
if existing_ssh_tunnel_model is None:
|
||||
# We couldn't found an existing tunnel so we need to create one
|
||||
|
@ -118,7 +124,7 @@ class UpdateDatabaseCommand(BaseCommand):
|
|||
UpdateSSHTunnelCommand(
|
||||
existing_ssh_tunnel_model.id, ssh_tunnel_properties
|
||||
).run()
|
||||
except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex:
|
||||
except (SSHTunnelInvalidError, SSHTunnelUpdateFailedError) as ex:
|
||||
# So we can show the original message
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
|
|
|
@ -19,10 +19,12 @@ from typing import Optional
|
|||
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
|
||||
from superset import is_feature_enabled
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.dao.exceptions import DAODeleteFailedError
|
||||
from superset.databases.ssh_tunnel.commands.exceptions import (
|
||||
SSHTunnelDeleteFailedError,
|
||||
SSHTunnelingNotEnabledError,
|
||||
SSHTunnelNotFoundError,
|
||||
)
|
||||
from superset.databases.ssh_tunnel.dao import SSHTunnelDAO
|
||||
|
@ -37,6 +39,8 @@ class DeleteSSHTunnelCommand(BaseCommand):
|
|||
self._model: Optional[SSHTunnel] = None
|
||||
|
||||
def run(self) -> Model:
|
||||
if not is_feature_enabled("SSH_TUNNELING"):
|
||||
raise SSHTunnelingNotEnabledError()
|
||||
self.validate()
|
||||
try:
|
||||
ssh_tunnel = SSHTunnelDAO.delete(self._model)
|
||||
|
|
|
@ -46,6 +46,11 @@ class SSHTunnelCreateFailedError(CommandException):
|
|||
message = _("Creating SSH Tunnel failed for an unknown reason")
|
||||
|
||||
|
||||
class SSHTunnelingNotEnabledError(CommandException):
|
||||
status = 400
|
||||
message = _("SSH Tunneling is not enabled")
|
||||
|
||||
|
||||
class SSHTunnelRequiredFieldValidationError(ValidationError):
|
||||
def __init__(self, field_name: str) -> None:
|
||||
super().__init__(
|
||||
|
|
|
@ -285,15 +285,20 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
@mock.patch(
|
||||
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
|
||||
)
|
||||
@mock.patch("superset.databases.commands.create.is_feature_enabled")
|
||||
@mock.patch(
|
||||
"superset.models.core.Database.get_all_schema_names",
|
||||
)
|
||||
def test_create_database_with_ssh_tunnel(
|
||||
self, mock_test_connection_database_command_run, mock_get_all_schema_names
|
||||
self,
|
||||
mock_test_connection_database_command_run,
|
||||
mock_create_is_feature_enabled,
|
||||
mock_get_all_schema_names,
|
||||
):
|
||||
"""
|
||||
Database API: Test create with SSH Tunnel
|
||||
"""
|
||||
mock_create_is_feature_enabled.return_value = True
|
||||
self.login(username="admin")
|
||||
example_db = get_example_database()
|
||||
if example_db.backend == "sqlite":
|
||||
|
@ -328,15 +333,23 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
@mock.patch(
|
||||
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
|
||||
)
|
||||
@mock.patch("superset.databases.commands.create.is_feature_enabled")
|
||||
@mock.patch("superset.databases.commands.update.is_feature_enabled")
|
||||
@mock.patch(
|
||||
"superset.models.core.Database.get_all_schema_names",
|
||||
)
|
||||
def test_update_database_with_ssh_tunnel(
|
||||
self, mock_test_connection_database_command_run, mock_get_all_schema_names
|
||||
self,
|
||||
mock_test_connection_database_command_run,
|
||||
mock_create_is_feature_enabled,
|
||||
mock_update_is_feature_enabled,
|
||||
mock_get_all_schema_names,
|
||||
):
|
||||
"""
|
||||
Database API: Test update with SSH Tunnel
|
||||
Database API: Test update Database with SSH Tunnel
|
||||
"""
|
||||
mock_create_is_feature_enabled.return_value = True
|
||||
mock_update_is_feature_enabled.return_value = True
|
||||
self.login(username="admin")
|
||||
example_db = get_example_database()
|
||||
if example_db.backend == "sqlite":
|
||||
|
@ -381,15 +394,23 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
@mock.patch(
|
||||
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
|
||||
)
|
||||
@mock.patch("superset.databases.commands.create.is_feature_enabled")
|
||||
@mock.patch("superset.databases.commands.update.is_feature_enabled")
|
||||
@mock.patch(
|
||||
"superset.models.core.Database.get_all_schema_names",
|
||||
)
|
||||
def test_update_ssh_tunnel_via_database_api(
|
||||
self, mock_test_connection_database_command_run, mock_get_all_schema_names
|
||||
self,
|
||||
mock_test_connection_database_command_run,
|
||||
mock_create_is_feature_enabled,
|
||||
mock_update_is_feature_enabled,
|
||||
mock_get_all_schema_names,
|
||||
):
|
||||
"""
|
||||
Database API: Test update with SSH Tunnel
|
||||
Database API: Test update SSH Tunnel via Database API
|
||||
"""
|
||||
mock_create_is_feature_enabled.return_value = True
|
||||
mock_update_is_feature_enabled.return_value = True
|
||||
self.login(username="admin")
|
||||
example_db = get_example_database()
|
||||
|
||||
|
@ -456,12 +477,17 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
@mock.patch(
|
||||
"superset.models.core.Database.get_all_schema_names",
|
||||
)
|
||||
@mock.patch("superset.databases.commands.create.is_feature_enabled")
|
||||
def test_cascade_delete_ssh_tunnel(
|
||||
self, mock_test_connection_database_command_run, mock_get_all_schema_names
|
||||
self,
|
||||
mock_test_connection_database_command_run,
|
||||
mock_get_all_schema_names,
|
||||
mock_create_is_feature_enabled,
|
||||
):
|
||||
"""
|
||||
Database API: Test create with SSH Tunnel
|
||||
Database API: SSH Tunnel gets deleted if Database gets deleted
|
||||
"""
|
||||
mock_create_is_feature_enabled.return_value = True
|
||||
self.login(username="admin")
|
||||
example_db = get_example_database()
|
||||
if example_db.backend == "sqlite":
|
||||
|
@ -502,15 +528,20 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
@mock.patch(
|
||||
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
|
||||
)
|
||||
@mock.patch("superset.databases.commands.create.is_feature_enabled")
|
||||
@mock.patch(
|
||||
"superset.models.core.Database.get_all_schema_names",
|
||||
)
|
||||
def test_do_not_create_database_if_ssh_tunnel_creation_fails(
|
||||
self, mock_test_connection_database_command_run, mock_get_all_schema_names
|
||||
self,
|
||||
mock_test_connection_database_command_run,
|
||||
mock_create_is_feature_enabled,
|
||||
mock_get_all_schema_names,
|
||||
):
|
||||
"""
|
||||
Database API: Test create with SSH Tunnel
|
||||
Database API: Test Database is not created if SSH Tunnel creation fails
|
||||
"""
|
||||
mock_create_is_feature_enabled.return_value = True
|
||||
self.login(username="admin")
|
||||
example_db = get_example_database()
|
||||
if example_db.backend == "sqlite":
|
||||
|
@ -548,15 +579,20 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
@mock.patch(
|
||||
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
|
||||
)
|
||||
@mock.patch("superset.databases.commands.create.is_feature_enabled")
|
||||
@mock.patch(
|
||||
"superset.models.core.Database.get_all_schema_names",
|
||||
)
|
||||
def test_get_database_returns_related_ssh_tunnel(
|
||||
self, mock_test_connection_database_command_run, mock_get_all_schema_names
|
||||
self,
|
||||
mock_test_connection_database_command_run,
|
||||
mock_create_is_feature_enabled,
|
||||
mock_get_all_schema_names,
|
||||
):
|
||||
"""
|
||||
Database API: Test GET Database returns its related SSH Tunnel
|
||||
"""
|
||||
mock_create_is_feature_enabled.return_value = True
|
||||
self.login(username="admin")
|
||||
example_db = get_example_database()
|
||||
if example_db.backend == "sqlite":
|
||||
|
@ -595,6 +631,56 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
db.session.delete(model)
|
||||
db.session.commit()
|
||||
|
||||
@mock.patch(
|
||||
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
|
||||
)
|
||||
@mock.patch(
|
||||
"superset.models.core.Database.get_all_schema_names",
|
||||
)
|
||||
def test_if_ssh_tunneling_flag_is_not_active_it_raises_new_exception(
|
||||
self,
|
||||
mock_test_connection_database_command_run,
|
||||
mock_get_all_schema_names,
|
||||
):
|
||||
"""
|
||||
Database API: Test raises SSHTunneling feature flag not enabled
|
||||
"""
|
||||
self.login(username="admin")
|
||||
example_db = get_example_database()
|
||||
if example_db.backend == "sqlite":
|
||||
return
|
||||
ssh_tunnel_properties = {
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": 8080,
|
||||
"username": "foo",
|
||||
"password": "bar",
|
||||
}
|
||||
database_data = {
|
||||
"database_name": "test-db-with-ssh-tunnel-7",
|
||||
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||
"ssh_tunnel": ssh_tunnel_properties,
|
||||
}
|
||||
|
||||
uri = "api/v1/database/"
|
||||
rv = self.client.post(uri, json=database_data)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
self.assertEqual(response, {"message": "SSH Tunneling is not enabled"})
|
||||
model_ssh_tunnel = (
|
||||
db.session.query(SSHTunnel)
|
||||
.filter(SSHTunnel.database_id == response.get("id"))
|
||||
.one_or_none()
|
||||
)
|
||||
assert model_ssh_tunnel is None
|
||||
# Cleanup
|
||||
model = (
|
||||
db.session.query(Database)
|
||||
.filter(Database.database_name == "test-db-with-ssh-tunnel-7")
|
||||
.one_or_none()
|
||||
)
|
||||
# the DB should not be created
|
||||
assert model is None
|
||||
|
||||
def test_create_database_invalid_configuration_method(self):
|
||||
"""
|
||||
Database API: Test create with an invalid configuration method.
|
||||
|
|
|
@ -67,8 +67,10 @@ class TestUpdateSSHTunnelCommand(SupersetTestCase):
|
|||
|
||||
class TestDeleteSSHTunnelCommand(SupersetTestCase):
|
||||
@mock.patch("superset.utils.core.g")
|
||||
def test_delete_ssh_tunnel_not_found(self, mock_g):
|
||||
@mock.patch("superset.databases.ssh_tunnel.commands.delete.is_feature_enabled")
|
||||
def test_delete_ssh_tunnel_not_found(self, mock_g, mock_delete_is_feature_enabled):
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
mock_delete_is_feature_enabled.return_value = True
|
||||
# We have not created a SSH Tunnel yet so id = 1 is invalid
|
||||
command = DeleteSSHTunnelCommand(1)
|
||||
with pytest.raises(SSHTunnelNotFoundError) as excinfo:
|
||||
|
|
|
@ -241,6 +241,10 @@ def test_delete_ssh_tunnel(
|
|||
# mock the lookup so that we don't need to include the driver
|
||||
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
|
||||
mocker.patch("superset.utils.log.DBEventLogger.log")
|
||||
mocker.patch(
|
||||
"superset.databases.ssh_tunnel.commands.delete.is_feature_enabled",
|
||||
return_value=True,
|
||||
)
|
||||
|
||||
# Create our SSHTunnel
|
||||
tunnel = SSHTunnel(
|
||||
|
@ -313,6 +317,10 @@ def test_delete_ssh_tunnel_not_found(
|
|||
# mock the lookup so that we don't need to include the driver
|
||||
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
|
||||
mocker.patch("superset.utils.log.DBEventLogger.log")
|
||||
mocker.patch(
|
||||
"superset.databases.ssh_tunnel.commands.delete.is_feature_enabled",
|
||||
return_value=True,
|
||||
)
|
||||
|
||||
# Create our SSHTunnel
|
||||
tunnel = SSHTunnel(
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
from typing import Iterator
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockFixture
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
|
||||
|
@ -50,7 +51,9 @@ def session_with_data(session: Session) -> Iterator[Session]:
|
|||
session.rollback()
|
||||
|
||||
|
||||
def test_delete_ssh_tunnel_command(session_with_data: Session) -> None:
|
||||
def test_delete_ssh_tunnel_command(
|
||||
mocker: MockFixture, session_with_data: Session
|
||||
) -> None:
|
||||
from superset.databases.dao import DatabaseDAO
|
||||
from superset.databases.ssh_tunnel.commands.delete import DeleteSSHTunnelCommand
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
|
@ -60,9 +63,11 @@ def test_delete_ssh_tunnel_command(session_with_data: Session) -> None:
|
|||
assert result
|
||||
assert isinstance(result, SSHTunnel)
|
||||
assert 1 == result.database_id
|
||||
|
||||
mocker.patch(
|
||||
"superset.databases.ssh_tunnel.commands.delete.is_feature_enabled",
|
||||
return_value=True,
|
||||
)
|
||||
DeleteSSHTunnelCommand(1).run()
|
||||
|
||||
result = DatabaseDAO.get_ssh_tunnel(1)
|
||||
|
||||
assert result is None
|
||||
|
|
Loading…
Reference in New Issue