mirror of https://github.com/apache/superset.git
feat(ssh-tunnelling): Setup SSH Tunneling Commands for Database Connections (#21912)
Co-authored-by: Antonio Rivero Martinez <38889534+Antonio-RiveroMartnez@users.noreply.github.com> Co-authored-by: Elizabeth Thompson <eschutho@gmail.com>
This commit is contained in:
parent
a7a4561550
commit
ebaad10d6c
|
@ -19,6 +19,8 @@ babel==2.9.1
|
||||||
# via flask-babel
|
# via flask-babel
|
||||||
backoff==1.11.1
|
backoff==1.11.1
|
||||||
# via apache-superset
|
# via apache-superset
|
||||||
|
bcrypt==4.0.1
|
||||||
|
# via paramiko
|
||||||
billiard==3.6.4.0
|
billiard==3.6.4.0
|
||||||
# via celery
|
# via celery
|
||||||
bleach==3.3.1
|
bleach==3.3.1
|
||||||
|
@ -57,7 +59,9 @@ cron-descriptor==1.2.24
|
||||||
croniter==1.0.15
|
croniter==1.0.15
|
||||||
# via apache-superset
|
# via apache-superset
|
||||||
cryptography==3.4.7
|
cryptography==3.4.7
|
||||||
# via apache-superset
|
# via
|
||||||
|
# apache-superset
|
||||||
|
# paramiko
|
||||||
deprecation==2.1.0
|
deprecation==2.1.0
|
||||||
# via apache-superset
|
# via apache-superset
|
||||||
dnspython==2.1.0
|
dnspython==2.1.0
|
||||||
|
@ -167,6 +171,8 @@ packaging==21.3
|
||||||
# deprecation
|
# deprecation
|
||||||
pandas==1.5.2
|
pandas==1.5.2
|
||||||
# via apache-superset
|
# via apache-superset
|
||||||
|
paramiko==2.11.0
|
||||||
|
# via sshtunnel
|
||||||
parsedatetime==2.6
|
parsedatetime==2.6
|
||||||
# via apache-superset
|
# via apache-superset
|
||||||
pgsanity==0.2.9
|
pgsanity==0.2.9
|
||||||
|
@ -188,6 +194,8 @@ pyjwt==2.4.0
|
||||||
# flask-jwt-extended
|
# flask-jwt-extended
|
||||||
pymeeus==0.5.11
|
pymeeus==0.5.11
|
||||||
# via convertdate
|
# via convertdate
|
||||||
|
pynacl==1.5.0
|
||||||
|
# via paramiko
|
||||||
pyparsing==3.0.6
|
pyparsing==3.0.6
|
||||||
# via
|
# via
|
||||||
# apache-superset
|
# apache-superset
|
||||||
|
@ -231,6 +239,7 @@ six==1.16.0
|
||||||
# flask-talisman
|
# flask-talisman
|
||||||
# isodate
|
# isodate
|
||||||
# jsonschema
|
# jsonschema
|
||||||
|
# paramiko
|
||||||
# polyline
|
# polyline
|
||||||
# prison
|
# prison
|
||||||
# pyrsistent
|
# pyrsistent
|
||||||
|
@ -252,6 +261,8 @@ sqlalchemy-utils==0.38.3
|
||||||
# flask-appbuilder
|
# flask-appbuilder
|
||||||
sqlparse==0.4.3
|
sqlparse==0.4.3
|
||||||
# via apache-superset
|
# via apache-superset
|
||||||
|
sshtunnel==0.4.0
|
||||||
|
# via apache-superset
|
||||||
tabulate==0.8.9
|
tabulate==0.8.9
|
||||||
# via apache-superset
|
# via apache-superset
|
||||||
typing-extensions==4.4.0
|
typing-extensions==4.4.0
|
||||||
|
|
1
setup.py
1
setup.py
|
@ -113,6 +113,7 @@ setup(
|
||||||
"PyJWT>=2.4.0, <3.0",
|
"PyJWT>=2.4.0, <3.0",
|
||||||
"redis",
|
"redis",
|
||||||
"selenium>=3.141.0",
|
"selenium>=3.141.0",
|
||||||
|
"sshtunnel>=0.4.0, <0.5",
|
||||||
"simplejson>=3.15.0",
|
"simplejson>=3.15.0",
|
||||||
"slack_sdk>=3.1.1, <4",
|
"slack_sdk>=3.1.1, <4",
|
||||||
"sqlalchemy>=1.4, <2",
|
"sqlalchemy>=1.4, <2",
|
||||||
|
|
|
@ -476,8 +476,30 @@ DEFAULT_FEATURE_FLAGS: Dict[str, bool] = {
|
||||||
"DRILL_TO_DETAIL": False,
|
"DRILL_TO_DETAIL": False,
|
||||||
"DATAPANEL_CLOSED_BY_DEFAULT": False,
|
"DATAPANEL_CLOSED_BY_DEFAULT": False,
|
||||||
"HORIZONTAL_FILTER_BAR": False,
|
"HORIZONTAL_FILTER_BAR": False,
|
||||||
|
# Allow users to enable ssh tunneling when creating a DB.
|
||||||
|
# Users must check whether the DB engine supports SSH Tunnels
|
||||||
|
# otherwise enabling this flag won't have any effect on the DB.
|
||||||
|
"SSH_TUNNELING": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# SSH Tunnel
|
||||||
|
# ------------------------------
|
||||||
|
# Allow users to set the host used when connecting to the SSH Tunnel
|
||||||
|
# as localhost and any other alias (0.0.0.0)
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
# |
|
||||||
|
# -------------+ | +----------+
|
||||||
|
# LOCAL | | | REMOTE | :22 SSH
|
||||||
|
# CLIENT | <== SSH ========> | SERVER | :8080 web service
|
||||||
|
# -------------+ | +----------+
|
||||||
|
# |
|
||||||
|
# FIREWALL (only port 22 is open)
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------
|
||||||
|
SSH_TUNNEL_MANAGER_CLASS = "superset.extensions.ssh.SSHManager"
|
||||||
|
SSH_TUNNEL_LOCAL_BIND_ADDRESS = "127.0.0.1"
|
||||||
|
|
||||||
# Feature flags may also be set via 'SUPERSET_FEATURE_' prefixed environment vars.
|
# Feature flags may also be set via 'SUPERSET_FEATURE_' prefixed environment vars.
|
||||||
DEFAULT_FEATURE_FLAGS.update(
|
DEFAULT_FEATURE_FLAGS.update(
|
||||||
{
|
{
|
||||||
|
|
|
@ -139,6 +139,7 @@ MODEL_API_RW_METHOD_PERMISSION_MAP = {
|
||||||
"validate_sql": "read",
|
"validate_sql": "read",
|
||||||
"get_data": "read",
|
"get_data": "read",
|
||||||
"samples": "read",
|
"samples": "read",
|
||||||
|
"delete_ssh_tunnel": "write",
|
||||||
}
|
}
|
||||||
|
|
||||||
EXTRA_FORM_DATA_APPEND_KEYS = {
|
EXTRA_FORM_DATA_APPEND_KEYS = {
|
||||||
|
|
|
@ -72,6 +72,11 @@ from superset.databases.schemas import (
|
||||||
ValidateSQLRequest,
|
ValidateSQLRequest,
|
||||||
ValidateSQLResponse,
|
ValidateSQLResponse,
|
||||||
)
|
)
|
||||||
|
from superset.databases.ssh_tunnel.commands.delete import DeleteSSHTunnelCommand
|
||||||
|
from superset.databases.ssh_tunnel.commands.exceptions import (
|
||||||
|
SSHTunnelDeleteFailedError,
|
||||||
|
SSHTunnelNotFoundError,
|
||||||
|
)
|
||||||
from superset.databases.utils import get_table_metadata
|
from superset.databases.utils import get_table_metadata
|
||||||
from superset.db_engine_specs import get_available_engine_specs
|
from superset.db_engine_specs import get_available_engine_specs
|
||||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||||
|
@ -80,6 +85,7 @@ from superset.extensions import security_manager
|
||||||
from superset.models.core import Database
|
from superset.models.core import Database
|
||||||
from superset.superset_typing import FlaskResponse
|
from superset.superset_typing import FlaskResponse
|
||||||
from superset.utils.core import error_msg_from_exception, parse_js_uri_path_item
|
from superset.utils.core import error_msg_from_exception, parse_js_uri_path_item
|
||||||
|
from superset.utils.ssh_tunnel import mask_password_info
|
||||||
from superset.views.base import json_errors_response
|
from superset.views.base import json_errors_response
|
||||||
from superset.views.base_api import (
|
from superset.views.base_api import (
|
||||||
BaseSupersetModelRestApi,
|
BaseSupersetModelRestApi,
|
||||||
|
@ -107,6 +113,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
||||||
"available",
|
"available",
|
||||||
"validate_parameters",
|
"validate_parameters",
|
||||||
"validate_sql",
|
"validate_sql",
|
||||||
|
"delete_ssh_tunnel",
|
||||||
}
|
}
|
||||||
resource_name = "database"
|
resource_name = "database"
|
||||||
class_permission_name = "Database"
|
class_permission_name = "Database"
|
||||||
|
@ -219,6 +226,47 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
||||||
ValidateSQLResponse,
|
ValidateSQLResponse,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@expose("/<int:pk>", methods=["GET"])
|
||||||
|
@protect()
|
||||||
|
@safe
|
||||||
|
def get(self, pk: int, **kwargs: Any) -> Response:
|
||||||
|
"""Get a database
|
||||||
|
---
|
||||||
|
get:
|
||||||
|
description: >-
|
||||||
|
Get a database
|
||||||
|
parameters:
|
||||||
|
- in: path
|
||||||
|
schema:
|
||||||
|
type: integer
|
||||||
|
description: The database id
|
||||||
|
name: pk
|
||||||
|
responses:
|
||||||
|
200:
|
||||||
|
description: Database
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
type: object
|
||||||
|
400:
|
||||||
|
$ref: '#/components/responses/400'
|
||||||
|
401:
|
||||||
|
$ref: '#/components/responses/401'
|
||||||
|
422:
|
||||||
|
$ref: '#/components/responses/422'
|
||||||
|
500:
|
||||||
|
$ref: '#/components/responses/500'
|
||||||
|
"""
|
||||||
|
data = self.get_headless(pk, **kwargs)
|
||||||
|
try:
|
||||||
|
if ssh_tunnel := DatabaseDAO.get_ssh_tunnel(pk):
|
||||||
|
payload = data.json
|
||||||
|
payload["result"]["ssh_tunnel"] = ssh_tunnel.data
|
||||||
|
return payload
|
||||||
|
return data
|
||||||
|
except SupersetException as ex:
|
||||||
|
return self.response(ex.status, message=ex.message)
|
||||||
|
|
||||||
@expose("/", methods=["POST"])
|
@expose("/", methods=["POST"])
|
||||||
@protect()
|
@protect()
|
||||||
@safe
|
@safe
|
||||||
|
@ -280,6 +328,12 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
||||||
if new_model.driver:
|
if new_model.driver:
|
||||||
item["driver"] = new_model.driver
|
item["driver"] = new_model.driver
|
||||||
|
|
||||||
|
# Return SSH Tunnel and hide passwords if any
|
||||||
|
if item.get("ssh_tunnel"):
|
||||||
|
item["ssh_tunnel"] = mask_password_info(
|
||||||
|
new_model.ssh_tunnel # pylint: disable=no-member
|
||||||
|
)
|
||||||
|
|
||||||
return self.response(201, id=new_model.id, result=item)
|
return self.response(201, id=new_model.id, result=item)
|
||||||
except DatabaseInvalidError as ex:
|
except DatabaseInvalidError as ex:
|
||||||
return self.response_422(message=ex.normalized_messages())
|
return self.response_422(message=ex.normalized_messages())
|
||||||
|
@ -361,6 +415,9 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
||||||
item["sqlalchemy_uri"] = changed_model.sqlalchemy_uri
|
item["sqlalchemy_uri"] = changed_model.sqlalchemy_uri
|
||||||
if changed_model.parameters:
|
if changed_model.parameters:
|
||||||
item["parameters"] = changed_model.parameters
|
item["parameters"] = changed_model.parameters
|
||||||
|
# Return SSH Tunnel and hide passwords if any
|
||||||
|
if item.get("ssh_tunnel"):
|
||||||
|
item["ssh_tunnel"] = mask_password_info(changed_model.ssh_tunnel)
|
||||||
return self.response(200, id=changed_model.id, result=item)
|
return self.response(200, id=changed_model.id, result=item)
|
||||||
except DatabaseNotFoundError:
|
except DatabaseNotFoundError:
|
||||||
return self.response_404()
|
return self.response_404()
|
||||||
|
@ -1206,3 +1263,57 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
||||||
command = ValidateDatabaseParametersCommand(payload)
|
command = ValidateDatabaseParametersCommand(payload)
|
||||||
command.run()
|
command.run()
|
||||||
return self.response(200, message="OK")
|
return self.response(200, message="OK")
|
||||||
|
|
||||||
|
@expose("/<int:pk>/ssh_tunnel/", methods=["DELETE"])
|
||||||
|
@protect()
|
||||||
|
@statsd_metrics
|
||||||
|
@event_logger.log_this_with_context(
|
||||||
|
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}"
|
||||||
|
f".delete_ssh_tunnel",
|
||||||
|
log_to_statsd=False,
|
||||||
|
)
|
||||||
|
def delete_ssh_tunnel(self, pk: int) -> Response:
|
||||||
|
"""Deletes a SSH Tunnel
|
||||||
|
---
|
||||||
|
delete:
|
||||||
|
description: >-
|
||||||
|
Deletes a SSH Tunnel.
|
||||||
|
parameters:
|
||||||
|
- in: path
|
||||||
|
schema:
|
||||||
|
type: integer
|
||||||
|
name: pk
|
||||||
|
responses:
|
||||||
|
200:
|
||||||
|
description: SSH Tunnel deleted
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
message:
|
||||||
|
type: string
|
||||||
|
401:
|
||||||
|
$ref: '#/components/responses/401'
|
||||||
|
403:
|
||||||
|
$ref: '#/components/responses/403'
|
||||||
|
404:
|
||||||
|
$ref: '#/components/responses/404'
|
||||||
|
422:
|
||||||
|
$ref: '#/components/responses/422'
|
||||||
|
500:
|
||||||
|
$ref: '#/components/responses/500'
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
DeleteSSHTunnelCommand(pk).run()
|
||||||
|
return self.response(200, message="OK")
|
||||||
|
except SSHTunnelNotFoundError:
|
||||||
|
return self.response_404()
|
||||||
|
except SSHTunnelDeleteFailedError as ex:
|
||||||
|
logger.error(
|
||||||
|
"Error deleting SSH Tunnel %s: %s",
|
||||||
|
self.__class__.__name__,
|
||||||
|
str(ex),
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return self.response_422(message=str(ex))
|
||||||
|
|
|
@ -31,6 +31,11 @@ from superset.databases.commands.exceptions import (
|
||||||
)
|
)
|
||||||
from superset.databases.commands.test_connection import TestConnectionDatabaseCommand
|
from superset.databases.commands.test_connection import TestConnectionDatabaseCommand
|
||||||
from superset.databases.dao import DatabaseDAO
|
from superset.databases.dao import DatabaseDAO
|
||||||
|
from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand
|
||||||
|
from superset.databases.ssh_tunnel.commands.exceptions import (
|
||||||
|
SSHTunnelCreateFailedError,
|
||||||
|
SSHTunnelInvalidError,
|
||||||
|
)
|
||||||
from superset.exceptions import SupersetErrorsException
|
from superset.exceptions import SupersetErrorsException
|
||||||
from superset.extensions import db, event_logger, security_manager
|
from superset.extensions import db, event_logger, security_manager
|
||||||
|
|
||||||
|
@ -71,12 +76,35 @@ class CreateDatabaseCommand(BaseCommand):
|
||||||
database = DatabaseDAO.create(self._properties, commit=False)
|
database = DatabaseDAO.create(self._properties, commit=False)
|
||||||
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
|
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
|
||||||
|
|
||||||
|
ssh_tunnel = None
|
||||||
|
if ssh_tunnel_properties := self._properties.get("ssh_tunnel"):
|
||||||
|
try:
|
||||||
|
# So database.id is not None
|
||||||
|
db.session.flush()
|
||||||
|
ssh_tunnel = CreateSSHTunnelCommand(
|
||||||
|
database.id, ssh_tunnel_properties
|
||||||
|
).run()
|
||||||
|
except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex:
|
||||||
|
event_logger.log_with_context(
|
||||||
|
action=f"db_creation_failed.{ex.__class__.__name__}",
|
||||||
|
engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
|
||||||
|
)
|
||||||
|
# So we can show the original message
|
||||||
|
raise ex
|
||||||
|
except Exception as ex:
|
||||||
|
event_logger.log_with_context(
|
||||||
|
action=f"db_creation_failed.{ex.__class__.__name__}",
|
||||||
|
engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
|
||||||
|
)
|
||||||
|
raise DatabaseCreateFailedError() from ex
|
||||||
|
|
||||||
# adding a new database we always want to force refresh schema list
|
# adding a new database we always want to force refresh schema list
|
||||||
schemas = database.get_all_schema_names(cache=False)
|
schemas = database.get_all_schema_names(cache=False, ssh_tunnel=ssh_tunnel)
|
||||||
for schema in schemas:
|
for schema in schemas:
|
||||||
security_manager.add_permission_view_menu(
|
security_manager.add_permission_view_menu(
|
||||||
"schema_access", security_manager.get_schema_perm(database, schema)
|
"schema_access", security_manager.get_schema_perm(database, schema)
|
||||||
)
|
)
|
||||||
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
except DAOCreateFailedError as ex:
|
except DAOCreateFailedError as ex:
|
||||||
db.session.rollback()
|
db.session.rollback()
|
||||||
|
|
|
@ -32,6 +32,7 @@ from superset.databases.commands.exceptions import (
|
||||||
DatabaseTestConnectionUnexpectedError,
|
DatabaseTestConnectionUnexpectedError,
|
||||||
)
|
)
|
||||||
from superset.databases.dao import DatabaseDAO
|
from superset.databases.dao import DatabaseDAO
|
||||||
|
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||||
from superset.databases.utils import make_url_safe
|
from superset.databases.utils import make_url_safe
|
||||||
from superset.errors import ErrorLevel, SupersetErrorType
|
from superset.errors import ErrorLevel, SupersetErrorType
|
||||||
from superset.exceptions import (
|
from superset.exceptions import (
|
||||||
|
@ -90,6 +91,10 @@ class TestConnectionDatabaseCommand(BaseCommand):
|
||||||
database.set_sqlalchemy_uri(uri)
|
database.set_sqlalchemy_uri(uri)
|
||||||
database.db_engine_spec.mutate_db_for_connection_test(database)
|
database.db_engine_spec.mutate_db_for_connection_test(database)
|
||||||
|
|
||||||
|
# Generate tunnel if present in the properties
|
||||||
|
if ssh_tunnel := self._properties.get("ssh_tunnel"):
|
||||||
|
ssh_tunnel = SSHTunnel(**ssh_tunnel)
|
||||||
|
|
||||||
event_logger.log_with_context(
|
event_logger.log_with_context(
|
||||||
action="test_connection_attempt",
|
action="test_connection_attempt",
|
||||||
engine=database.db_engine_spec.__name__,
|
engine=database.db_engine_spec.__name__,
|
||||||
|
@ -99,7 +104,9 @@ class TestConnectionDatabaseCommand(BaseCommand):
|
||||||
with closing(engine.raw_connection()) as conn:
|
with closing(engine.raw_connection()) as conn:
|
||||||
return engine.dialect.do_ping(conn)
|
return engine.dialect.do_ping(conn)
|
||||||
|
|
||||||
with database.get_sqla_engine_with_context() as engine:
|
with database.get_sqla_engine_with_context(
|
||||||
|
override_ssh_tunnel=ssh_tunnel
|
||||||
|
) as engine:
|
||||||
try:
|
try:
|
||||||
alive = func_timeout(
|
alive = func_timeout(
|
||||||
app.config["TEST_DATABASE_CONNECTION_TIMEOUT"].total_seconds(),
|
app.config["TEST_DATABASE_CONNECTION_TIMEOUT"].total_seconds(),
|
||||||
|
|
|
@ -21,7 +21,7 @@ from flask_appbuilder.models.sqla import Model
|
||||||
from marshmallow import ValidationError
|
from marshmallow import ValidationError
|
||||||
|
|
||||||
from superset.commands.base import BaseCommand
|
from superset.commands.base import BaseCommand
|
||||||
from superset.dao.exceptions import DAOUpdateFailedError
|
from superset.dao.exceptions import DAOCreateFailedError, DAOUpdateFailedError
|
||||||
from superset.databases.commands.exceptions import (
|
from superset.databases.commands.exceptions import (
|
||||||
DatabaseConnectionFailedError,
|
DatabaseConnectionFailedError,
|
||||||
DatabaseExistsValidationError,
|
DatabaseExistsValidationError,
|
||||||
|
@ -30,6 +30,12 @@ from superset.databases.commands.exceptions import (
|
||||||
DatabaseUpdateFailedError,
|
DatabaseUpdateFailedError,
|
||||||
)
|
)
|
||||||
from superset.databases.dao import DatabaseDAO
|
from superset.databases.dao import DatabaseDAO
|
||||||
|
from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand
|
||||||
|
from superset.databases.ssh_tunnel.commands.exceptions import (
|
||||||
|
SSHTunnelCreateFailedError,
|
||||||
|
SSHTunnelInvalidError,
|
||||||
|
)
|
||||||
|
from superset.databases.ssh_tunnel.commands.update import UpdateSSHTunnelCommand
|
||||||
from superset.extensions import db, security_manager
|
from superset.extensions import db, security_manager
|
||||||
from superset.models.core import Database
|
from superset.models.core import Database
|
||||||
from superset.utils.core import DatasourceType
|
from superset.utils.core import DatasourceType
|
||||||
|
@ -94,10 +100,33 @@ class UpdateDatabaseCommand(BaseCommand):
|
||||||
security_manager.add_permission_view_menu(
|
security_manager.add_permission_view_menu(
|
||||||
"schema_access", security_manager.get_schema_perm(database, schema)
|
"schema_access", security_manager.get_schema_perm(database, schema)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if ssh_tunnel_properties := self._properties.get("ssh_tunnel"):
|
||||||
|
existing_ssh_tunnel_model = DatabaseDAO.get_ssh_tunnel(database.id)
|
||||||
|
if existing_ssh_tunnel_model is None:
|
||||||
|
# We couldn't found an existing tunnel so we need to create one
|
||||||
|
try:
|
||||||
|
CreateSSHTunnelCommand(database.id, ssh_tunnel_properties).run()
|
||||||
|
except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex:
|
||||||
|
# So we can show the original message
|
||||||
|
raise ex
|
||||||
|
except Exception as ex:
|
||||||
|
raise DatabaseUpdateFailedError() from ex
|
||||||
|
else:
|
||||||
|
# We found an existing tunnel so we need to update it
|
||||||
|
try:
|
||||||
|
UpdateSSHTunnelCommand(
|
||||||
|
existing_ssh_tunnel_model.id, ssh_tunnel_properties
|
||||||
|
).run()
|
||||||
|
except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex:
|
||||||
|
# So we can show the original message
|
||||||
|
raise ex
|
||||||
|
except Exception as ex:
|
||||||
|
raise DatabaseUpdateFailedError() from ex
|
||||||
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
except DAOUpdateFailedError as ex:
|
except (DAOUpdateFailedError, DAOCreateFailedError) as ex:
|
||||||
logger.exception(ex.exception)
|
|
||||||
raise DatabaseUpdateFailedError() from ex
|
raise DatabaseUpdateFailedError() from ex
|
||||||
return database
|
return database
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,7 @@ from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from superset.dao.base import BaseDAO
|
from superset.dao.base import BaseDAO
|
||||||
from superset.databases.filters import DatabaseFilter
|
from superset.databases.filters import DatabaseFilter
|
||||||
|
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||||
from superset.extensions import db
|
from superset.extensions import db
|
||||||
from superset.models.core import Database
|
from superset.models.core import Database
|
||||||
from superset.models.dashboard import Dashboard
|
from superset.models.dashboard import Dashboard
|
||||||
|
@ -124,3 +125,13 @@ class DatabaseDAO(BaseDAO):
|
||||||
return dict(
|
return dict(
|
||||||
charts=charts, dashboards=dashboards, sqllab_tab_states=sqllab_tab_states
|
charts=charts, dashboards=dashboards, sqllab_tab_states=sqllab_tab_states
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_ssh_tunnel(cls, database_id: int) -> Optional[SSHTunnel]:
|
||||||
|
ssh_tunnel = (
|
||||||
|
db.session.query(SSHTunnel)
|
||||||
|
.filter(SSHTunnel.database_id == database_id)
|
||||||
|
.one_or_none()
|
||||||
|
)
|
||||||
|
|
||||||
|
return ssh_tunnel
|
||||||
|
|
|
@ -365,6 +365,19 @@ class DatabaseValidateParametersSchema(Schema):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseSSHTunnel(Schema):
|
||||||
|
server_address = fields.String()
|
||||||
|
server_port = fields.Integer()
|
||||||
|
username = fields.String()
|
||||||
|
|
||||||
|
# Basic Authentication
|
||||||
|
password = fields.String(required=False)
|
||||||
|
|
||||||
|
# password protected private key authentication
|
||||||
|
private_key = fields.String(required=False)
|
||||||
|
private_key_password = fields.String(required=False)
|
||||||
|
|
||||||
|
|
||||||
class DatabasePostSchema(Schema, DatabaseParametersSchemaMixin):
|
class DatabasePostSchema(Schema, DatabaseParametersSchemaMixin):
|
||||||
class Meta: # pylint: disable=too-few-public-methods
|
class Meta: # pylint: disable=too-few-public-methods
|
||||||
unknown = EXCLUDE
|
unknown = EXCLUDE
|
||||||
|
@ -409,6 +422,7 @@ class DatabasePostSchema(Schema, DatabaseParametersSchemaMixin):
|
||||||
is_managed_externally = fields.Boolean(allow_none=True, default=False)
|
is_managed_externally = fields.Boolean(allow_none=True, default=False)
|
||||||
external_url = fields.String(allow_none=True)
|
external_url = fields.String(allow_none=True)
|
||||||
uuid = fields.String(required=False)
|
uuid = fields.String(required=False)
|
||||||
|
ssh_tunnel = fields.Nested(DatabaseSSHTunnel, allow_none=True)
|
||||||
|
|
||||||
|
|
||||||
class DatabasePutSchema(Schema, DatabaseParametersSchemaMixin):
|
class DatabasePutSchema(Schema, DatabaseParametersSchemaMixin):
|
||||||
|
@ -454,6 +468,7 @@ class DatabasePutSchema(Schema, DatabaseParametersSchemaMixin):
|
||||||
)
|
)
|
||||||
is_managed_externally = fields.Boolean(allow_none=True, default=False)
|
is_managed_externally = fields.Boolean(allow_none=True, default=False)
|
||||||
external_url = fields.String(allow_none=True)
|
external_url = fields.String(allow_none=True)
|
||||||
|
ssh_tunnel = fields.Nested(DatabaseSSHTunnel, allow_none=True)
|
||||||
|
|
||||||
|
|
||||||
class DatabaseTestConnectionSchema(Schema, DatabaseParametersSchemaMixin):
|
class DatabaseTestConnectionSchema(Schema, DatabaseParametersSchemaMixin):
|
||||||
|
@ -482,6 +497,8 @@ class DatabaseTestConnectionSchema(Schema, DatabaseParametersSchemaMixin):
|
||||||
validate=[Length(1, 1024), sqlalchemy_uri_validator],
|
validate=[Length(1, 1024), sqlalchemy_uri_validator],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ssh_tunnel = fields.Nested(DatabaseSSHTunnel, allow_none=True)
|
||||||
|
|
||||||
|
|
||||||
class TableMetadataOptionsResponseSchema(Schema):
|
class TableMetadataOptionsResponseSchema(Schema):
|
||||||
deferrable = fields.Bool()
|
deferrable = fields.Bool()
|
||||||
|
|
|
@ -0,0 +1,16 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
|
@ -0,0 +1,16 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
|
@ -0,0 +1,92 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from flask_appbuilder.models.sqla import Model
|
||||||
|
from marshmallow import ValidationError
|
||||||
|
|
||||||
|
from superset.commands.base import BaseCommand
|
||||||
|
from superset.dao.exceptions import DAOCreateFailedError
|
||||||
|
from superset.databases.ssh_tunnel.commands.exceptions import (
|
||||||
|
SSHTunnelCreateFailedError,
|
||||||
|
SSHTunnelInvalidError,
|
||||||
|
SSHTunnelRequiredFieldValidationError,
|
||||||
|
)
|
||||||
|
from superset.databases.ssh_tunnel.dao import SSHTunnelDAO
|
||||||
|
from superset.extensions import db, event_logger
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CreateSSHTunnelCommand(BaseCommand):
|
||||||
|
def __init__(self, database_id: int, data: Dict[str, Any]):
|
||||||
|
self._properties = data.copy()
|
||||||
|
self._properties["database_id"] = database_id
|
||||||
|
|
||||||
|
def run(self) -> Model:
|
||||||
|
try:
|
||||||
|
# Start nested transaction since we are always creating the tunnel
|
||||||
|
# through a DB command (Create or Update). Without this, we cannot
|
||||||
|
# safely rollback changes to databases if any, i.e, things like
|
||||||
|
# test_do_not_create_database_if_ssh_tunnel_creation_fails test will fail
|
||||||
|
db.session.begin_nested()
|
||||||
|
self.validate()
|
||||||
|
tunnel = SSHTunnelDAO.create(self._properties, commit=False)
|
||||||
|
except DAOCreateFailedError as ex:
|
||||||
|
# Rollback nested transaction
|
||||||
|
db.session.rollback()
|
||||||
|
raise SSHTunnelCreateFailedError() from ex
|
||||||
|
except SSHTunnelInvalidError as ex:
|
||||||
|
# Rollback nested transaction
|
||||||
|
db.session.rollback()
|
||||||
|
raise ex
|
||||||
|
|
||||||
|
return tunnel
|
||||||
|
|
||||||
|
def validate(self) -> None:
|
||||||
|
# TODO(hughhh): check to make sure the server port is not localhost
|
||||||
|
# using the config.SSH_TUNNEL_MANAGER
|
||||||
|
exceptions: List[ValidationError] = []
|
||||||
|
database_id: Optional[int] = self._properties.get("database_id")
|
||||||
|
server_address: Optional[str] = self._properties.get("server_address")
|
||||||
|
server_port: Optional[int] = self._properties.get("server_port")
|
||||||
|
username: Optional[str] = self._properties.get("username")
|
||||||
|
private_key: Optional[str] = self._properties.get("private_key")
|
||||||
|
private_key_password: Optional[str] = self._properties.get(
|
||||||
|
"private_key_password"
|
||||||
|
)
|
||||||
|
if not database_id:
|
||||||
|
exceptions.append(SSHTunnelRequiredFieldValidationError("database_id"))
|
||||||
|
if not server_address:
|
||||||
|
exceptions.append(SSHTunnelRequiredFieldValidationError("server_address"))
|
||||||
|
if not server_port:
|
||||||
|
exceptions.append(SSHTunnelRequiredFieldValidationError("server_port"))
|
||||||
|
if not username:
|
||||||
|
exceptions.append(SSHTunnelRequiredFieldValidationError("username"))
|
||||||
|
if private_key_password and private_key is None:
|
||||||
|
exceptions.append(SSHTunnelRequiredFieldValidationError("private_key"))
|
||||||
|
if exceptions:
|
||||||
|
exception = SSHTunnelInvalidError()
|
||||||
|
exception.add_list(exceptions)
|
||||||
|
event_logger.log_with_context(
|
||||||
|
action="ssh_tunnel_creation_failed.{}.{}".format(
|
||||||
|
exception.__class__.__name__,
|
||||||
|
".".join(exception.get_list_classnames()),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
raise exception
|
|
@ -0,0 +1,51 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from flask_appbuilder.models.sqla import Model
|
||||||
|
|
||||||
|
from superset.commands.base import BaseCommand
|
||||||
|
from superset.dao.exceptions import DAODeleteFailedError
|
||||||
|
from superset.databases.ssh_tunnel.commands.exceptions import (
|
||||||
|
SSHTunnelDeleteFailedError,
|
||||||
|
SSHTunnelNotFoundError,
|
||||||
|
)
|
||||||
|
from superset.databases.ssh_tunnel.dao import SSHTunnelDAO
|
||||||
|
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DeleteSSHTunnelCommand(BaseCommand):
|
||||||
|
def __init__(self, model_id: int):
|
||||||
|
self._model_id = model_id
|
||||||
|
self._model: Optional[SSHTunnel] = None
|
||||||
|
|
||||||
|
def run(self) -> Model:
|
||||||
|
self.validate()
|
||||||
|
try:
|
||||||
|
ssh_tunnel = SSHTunnelDAO.delete(self._model)
|
||||||
|
except DAODeleteFailedError as ex:
|
||||||
|
raise SSHTunnelDeleteFailedError() from ex
|
||||||
|
return ssh_tunnel
|
||||||
|
|
||||||
|
def validate(self) -> None:
|
||||||
|
# Validate/populate model exists
|
||||||
|
self._model = SSHTunnelDAO.find_by_id(self._model_id)
|
||||||
|
if not self._model:
|
||||||
|
raise SSHTunnelNotFoundError()
|
|
@ -0,0 +1,54 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
from flask_babel import lazy_gettext as _
|
||||||
|
from marshmallow import ValidationError
|
||||||
|
|
||||||
|
from superset.commands.exceptions import (
|
||||||
|
CommandException,
|
||||||
|
CommandInvalidError,
|
||||||
|
DeleteFailedError,
|
||||||
|
UpdateFailedError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SSHTunnelDeleteFailedError(DeleteFailedError):
|
||||||
|
message = _("SSH Tunnel could not be deleted.")
|
||||||
|
|
||||||
|
|
||||||
|
class SSHTunnelNotFoundError(CommandException):
|
||||||
|
status = 404
|
||||||
|
message = _("SSH Tunnel not found.")
|
||||||
|
|
||||||
|
|
||||||
|
class SSHTunnelInvalidError(CommandInvalidError):
|
||||||
|
message = _("SSH Tunnel parameters are invalid.")
|
||||||
|
|
||||||
|
|
||||||
|
class SSHTunnelUpdateFailedError(UpdateFailedError):
|
||||||
|
message = _("SSH Tunnel could not be updated.")
|
||||||
|
|
||||||
|
|
||||||
|
class SSHTunnelCreateFailedError(CommandException):
|
||||||
|
message = _("Creating SSH Tunnel failed for an unknown reason")
|
||||||
|
|
||||||
|
|
||||||
|
class SSHTunnelRequiredFieldValidationError(ValidationError):
|
||||||
|
def __init__(self, field_name: str) -> None:
|
||||||
|
super().__init__(
|
||||||
|
[_("Field is required")],
|
||||||
|
field_name=field_name,
|
||||||
|
)
|
|
@ -0,0 +1,62 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from flask_appbuilder.models.sqla import Model
|
||||||
|
|
||||||
|
from superset.commands.base import BaseCommand
|
||||||
|
from superset.dao.exceptions import DAOUpdateFailedError
|
||||||
|
from superset.databases.ssh_tunnel.commands.exceptions import (
|
||||||
|
SSHTunnelInvalidError,
|
||||||
|
SSHTunnelNotFoundError,
|
||||||
|
SSHTunnelRequiredFieldValidationError,
|
||||||
|
SSHTunnelUpdateFailedError,
|
||||||
|
)
|
||||||
|
from superset.databases.ssh_tunnel.dao import SSHTunnelDAO
|
||||||
|
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateSSHTunnelCommand(BaseCommand):
|
||||||
|
def __init__(self, model_id: int, data: Dict[str, Any]):
|
||||||
|
self._properties = data.copy()
|
||||||
|
self._model_id = model_id
|
||||||
|
self._model: Optional[SSHTunnel] = None
|
||||||
|
|
||||||
|
def run(self) -> Model:
|
||||||
|
self.validate()
|
||||||
|
try:
|
||||||
|
tunnel = SSHTunnelDAO.update(self._model, self._properties)
|
||||||
|
except DAOUpdateFailedError as ex:
|
||||||
|
raise SSHTunnelUpdateFailedError() from ex
|
||||||
|
return tunnel
|
||||||
|
|
||||||
|
def validate(self) -> None:
|
||||||
|
# Validate/populate model exists
|
||||||
|
self._model = SSHTunnelDAO.find_by_id(self._model_id)
|
||||||
|
if not self._model:
|
||||||
|
raise SSHTunnelNotFoundError()
|
||||||
|
private_key: Optional[str] = self._properties.get("private_key")
|
||||||
|
private_key_password: Optional[str] = self._properties.get(
|
||||||
|
"private_key_password"
|
||||||
|
)
|
||||||
|
if private_key_password and private_key is None:
|
||||||
|
exception = SSHTunnelInvalidError()
|
||||||
|
exception.add(SSHTunnelRequiredFieldValidationError("private_key"))
|
||||||
|
raise exception
|
|
@ -0,0 +1,26 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from superset.dao.base import BaseDAO
|
||||||
|
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SSHTunnelDAO(BaseDAO):
|
||||||
|
model_cls = SSHTunnel
|
|
@ -0,0 +1,76 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from flask import current_app
|
||||||
|
from flask_appbuilder import Model
|
||||||
|
from sqlalchemy.orm import backref, relationship
|
||||||
|
from sqlalchemy_utils import EncryptedType
|
||||||
|
|
||||||
|
from superset.models.core import Database
|
||||||
|
from superset.models.helpers import (
|
||||||
|
AuditMixinNullable,
|
||||||
|
ExtraJSONMixin,
|
||||||
|
ImportExportMixin,
|
||||||
|
)
|
||||||
|
|
||||||
|
app_config = current_app.config
|
||||||
|
|
||||||
|
|
||||||
|
class SSHTunnel(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
|
||||||
|
"""
|
||||||
|
A ssh tunnel configuration in a database.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "ssh_tunnels"
|
||||||
|
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
database_id = sa.Column(
|
||||||
|
sa.Integer, sa.ForeignKey("dbs.id"), nullable=False, unique=True
|
||||||
|
)
|
||||||
|
database: Database = relationship(
|
||||||
|
"Database",
|
||||||
|
backref=backref("ssh_tunnels", uselist=False, cascade="all, delete-orphan"),
|
||||||
|
foreign_keys=[database_id],
|
||||||
|
)
|
||||||
|
|
||||||
|
server_address = sa.Column(sa.Text)
|
||||||
|
server_port = sa.Column(sa.Integer)
|
||||||
|
username = sa.Column(EncryptedType(sa.String, app_config["SECRET_KEY"]))
|
||||||
|
|
||||||
|
# basic authentication
|
||||||
|
password = sa.Column(
|
||||||
|
EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# password protected pkey authentication
|
||||||
|
private_key = sa.Column(
|
||||||
|
EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True
|
||||||
|
)
|
||||||
|
private_key_password = sa.Column(
|
||||||
|
EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"server_address": self.server_address,
|
||||||
|
"server_port": self.server_port,
|
||||||
|
"username": self.username,
|
||||||
|
}
|
|
@ -193,6 +193,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||||
engine_aliases: Set[str] = set()
|
engine_aliases: Set[str] = set()
|
||||||
drivers: Dict[str, str] = {}
|
drivers: Dict[str, str] = {}
|
||||||
default_driver: Optional[str] = None
|
default_driver: Optional[str] = None
|
||||||
|
allow_ssh_tunneling = False
|
||||||
|
|
||||||
_date_trunc_functions: Dict[str, str] = {}
|
_date_trunc_functions: Dict[str, str] = {}
|
||||||
_time_grain_expressions: Dict[Optional[str], str] = {}
|
_time_grain_expressions: Dict[Optional[str], str] = {}
|
||||||
|
|
|
@ -166,6 +166,7 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
|
||||||
class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
|
class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
|
||||||
engine = "postgresql"
|
engine = "postgresql"
|
||||||
engine_aliases = {"postgres"}
|
engine_aliases = {"postgres"}
|
||||||
|
allow_ssh_tunneling = True
|
||||||
|
|
||||||
default_driver = "psycopg2"
|
default_driver = "psycopg2"
|
||||||
sqlalchemy_uri_placeholder = (
|
sqlalchemy_uri_placeholder = (
|
||||||
|
|
|
@ -28,6 +28,7 @@ from flask_talisman import Talisman
|
||||||
from flask_wtf.csrf import CSRFProtect
|
from flask_wtf.csrf import CSRFProtect
|
||||||
from werkzeug.local import LocalProxy
|
from werkzeug.local import LocalProxy
|
||||||
|
|
||||||
|
from superset.extensions.ssh import SSHManagerFactory
|
||||||
from superset.utils.async_query_manager import AsyncQueryManager
|
from superset.utils.async_query_manager import AsyncQueryManager
|
||||||
from superset.utils.cache_manager import CacheManager
|
from superset.utils.cache_manager import CacheManager
|
||||||
from superset.utils.encrypt import EncryptedFieldFactory
|
from superset.utils.encrypt import EncryptedFieldFactory
|
||||||
|
@ -127,3 +128,4 @@ profiling = ProfilingExtension()
|
||||||
results_backend_manager = ResultsBackendManager()
|
results_backend_manager = ResultsBackendManager()
|
||||||
security_manager = LocalProxy(lambda: appbuilder.sm)
|
security_manager = LocalProxy(lambda: appbuilder.sm)
|
||||||
talisman = Talisman()
|
talisman = Talisman()
|
||||||
|
ssh_manager_factory = SSHManagerFactory()
|
||||||
|
|
|
@ -0,0 +1,88 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from flask import Flask
|
||||||
|
from sshtunnel import open_tunnel, SSHTunnelForwarder
|
||||||
|
|
||||||
|
from superset.databases.utils import make_url_safe
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||||
|
|
||||||
|
|
||||||
|
class SSHManager:
|
||||||
|
def __init__(self, app: Flask) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.local_bind_address = app.config["SSH_TUNNEL_LOCAL_BIND_ADDRESS"]
|
||||||
|
|
||||||
|
def build_sqla_url( # pylint: disable=no-self-use
|
||||||
|
self, sqlalchemy_url: str, server: SSHTunnelForwarder
|
||||||
|
) -> str:
|
||||||
|
# override any ssh tunnel configuration object
|
||||||
|
url = make_url_safe(sqlalchemy_url)
|
||||||
|
return url.set(
|
||||||
|
host=server.local_bind_address[0],
|
||||||
|
port=server.local_bind_port,
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_tunnel(
|
||||||
|
self,
|
||||||
|
ssh_tunnel: "SSHTunnel",
|
||||||
|
sqlalchemy_database_uri: str,
|
||||||
|
) -> SSHTunnelForwarder:
|
||||||
|
url = make_url_safe(sqlalchemy_database_uri)
|
||||||
|
params = {
|
||||||
|
"ssh_address_or_host": ssh_tunnel.server_address,
|
||||||
|
"ssh_port": ssh_tunnel.server_port,
|
||||||
|
"ssh_username": ssh_tunnel.username,
|
||||||
|
"remote_bind_address": (url.host, url.port), # bind_port, bind_host
|
||||||
|
"local_bind_address": (self.local_bind_address,),
|
||||||
|
}
|
||||||
|
|
||||||
|
if ssh_tunnel.password:
|
||||||
|
params["ssh_password"] = ssh_tunnel.password
|
||||||
|
elif ssh_tunnel.private_key:
|
||||||
|
params["private_key"] = ssh_tunnel.private_key
|
||||||
|
params["private_key_password"] = ssh_tunnel.private_key_password
|
||||||
|
|
||||||
|
return open_tunnel(**params)
|
||||||
|
|
||||||
|
|
||||||
|
class SSHManagerFactory:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._ssh_manager = None
|
||||||
|
|
||||||
|
def init_app(self, app: Flask) -> None:
|
||||||
|
ssh_manager_fqclass = app.config["SSH_TUNNEL_MANAGER_CLASS"]
|
||||||
|
ssh_manager_classname = ssh_manager_fqclass[
|
||||||
|
ssh_manager_fqclass.rfind(".") + 1 :
|
||||||
|
]
|
||||||
|
ssh_manager_module_name = ssh_manager_fqclass[
|
||||||
|
0 : ssh_manager_fqclass.rfind(".")
|
||||||
|
]
|
||||||
|
ssh_manager_class = getattr(
|
||||||
|
importlib.import_module(ssh_manager_module_name), ssh_manager_classname
|
||||||
|
)
|
||||||
|
|
||||||
|
self._ssh_manager = ssh_manager_class(app)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def instance(self) -> SSHManager:
|
||||||
|
return self._ssh_manager # type: ignore
|
|
@ -45,6 +45,7 @@ from superset.extensions import (
|
||||||
migrate,
|
migrate,
|
||||||
profiling,
|
profiling,
|
||||||
results_backend_manager,
|
results_backend_manager,
|
||||||
|
ssh_manager_factory,
|
||||||
talisman,
|
talisman,
|
||||||
)
|
)
|
||||||
from superset.security import SupersetSecurityManager
|
from superset.security import SupersetSecurityManager
|
||||||
|
@ -417,6 +418,7 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
|
||||||
self.configure_data_sources()
|
self.configure_data_sources()
|
||||||
self.configure_auth_provider()
|
self.configure_auth_provider()
|
||||||
self.configure_async_queries()
|
self.configure_async_queries()
|
||||||
|
self.configure_ssh_manager()
|
||||||
|
|
||||||
# Hook that provides administrators a handle on the Flask APP
|
# Hook that provides administrators a handle on the Flask APP
|
||||||
# after initialization
|
# after initialization
|
||||||
|
@ -474,6 +476,9 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
|
||||||
def configure_auth_provider(self) -> None:
|
def configure_auth_provider(self) -> None:
|
||||||
machine_auth_provider_factory.init_app(self.superset_app)
|
machine_auth_provider_factory.init_app(self.superset_app)
|
||||||
|
|
||||||
|
def configure_ssh_manager(self) -> None:
|
||||||
|
ssh_manager_factory.init_app(self.superset_app)
|
||||||
|
|
||||||
def setup_event_logger(self) -> None:
|
def setup_event_logger(self) -> None:
|
||||||
_event_logger["event_logger"] = get_event_logger_from_cfg_value(
|
_event_logger["event_logger"] = get_event_logger_from_cfg_value(
|
||||||
self.superset_app.config.get("EVENT_LOGGER", DBEventLogger())
|
self.superset_app.config.get("EVENT_LOGGER", DBEventLogger())
|
||||||
|
|
|
@ -0,0 +1,89 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
"""create_ssh_tunnel_credentials_tbl
|
||||||
|
|
||||||
|
Revision ID: f3c2d8ec8595
|
||||||
|
Revises: 4ce1d9b25135
|
||||||
|
Create Date: 2022-10-20 10:48:08.722861
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = "f3c2d8ec8595"
|
||||||
|
down_revision = "4ce1d9b25135"
|
||||||
|
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy_utils import UUIDType
|
||||||
|
|
||||||
|
from superset import app
|
||||||
|
from superset.extensions import encrypted_field_factory
|
||||||
|
|
||||||
|
app_config = app.config
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
op.create_table(
|
||||||
|
"ssh_tunnels",
|
||||||
|
# AuditMixinNullable
|
||||||
|
sa.Column("created_on", sa.DateTime(), nullable=True),
|
||||||
|
sa.Column("changed_on", sa.DateTime(), nullable=True),
|
||||||
|
sa.Column("created_by_fk", sa.Integer(), nullable=True),
|
||||||
|
sa.Column("changed_by_fk", sa.Integer(), nullable=True),
|
||||||
|
# ExtraJSONMixin
|
||||||
|
sa.Column("extra_json", sa.Text(), nullable=True),
|
||||||
|
# ImportExportMixin
|
||||||
|
sa.Column(
|
||||||
|
"uuid",
|
||||||
|
UUIDType(binary=True),
|
||||||
|
primary_key=False,
|
||||||
|
default=uuid4,
|
||||||
|
unique=True,
|
||||||
|
index=True,
|
||||||
|
),
|
||||||
|
# SSHTunnelCredentials
|
||||||
|
sa.Column("id", sa.Integer(), primary_key=True),
|
||||||
|
sa.Column(
|
||||||
|
"database_id",
|
||||||
|
sa.INTEGER(),
|
||||||
|
sa.ForeignKey("dbs.id"),
|
||||||
|
unique=True,
|
||||||
|
index=True,
|
||||||
|
),
|
||||||
|
sa.Column("server_address", sa.String(256)),
|
||||||
|
sa.Column("server_port", sa.INTEGER()),
|
||||||
|
sa.Column("username", encrypted_field_factory.create(sa.String(256))),
|
||||||
|
sa.Column(
|
||||||
|
"password", encrypted_field_factory.create(sa.String(256)), nullable=True
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"private_key",
|
||||||
|
encrypted_field_factory.create(sa.String(1024)),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"private_key_password",
|
||||||
|
encrypted_field_factory.create(sa.String(256)),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
op.drop_table("ssh_tunnels")
|
|
@ -21,10 +21,10 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import textwrap
|
import textwrap
|
||||||
from ast import literal_eval
|
from ast import literal_eval
|
||||||
from contextlib import closing, contextmanager
|
from contextlib import closing, contextmanager, nullcontext
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TYPE_CHECKING
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
@ -57,7 +57,12 @@ from superset import app, db_engine_specs
|
||||||
from superset.constants import PASSWORD_MASK
|
from superset.constants import PASSWORD_MASK
|
||||||
from superset.databases.utils import make_url_safe
|
from superset.databases.utils import make_url_safe
|
||||||
from superset.db_engine_specs.base import MetricType, TimeGrain
|
from superset.db_engine_specs.base import MetricType, TimeGrain
|
||||||
from superset.extensions import cache_manager, encrypted_field_factory, security_manager
|
from superset.extensions import (
|
||||||
|
cache_manager,
|
||||||
|
encrypted_field_factory,
|
||||||
|
security_manager,
|
||||||
|
ssh_manager_factory,
|
||||||
|
)
|
||||||
from superset.models.helpers import AuditMixinNullable, ImportExportMixin
|
from superset.models.helpers import AuditMixinNullable, ImportExportMixin
|
||||||
from superset.result_set import SupersetResultSet
|
from superset.result_set import SupersetResultSet
|
||||||
from superset.utils import cache as cache_util, core as utils
|
from superset.utils import cache as cache_util, core as utils
|
||||||
|
@ -71,6 +76,9 @@ log_query = config["QUERY_LOGGER"]
|
||||||
metadata = Model.metadata # pylint: disable=no-member
|
metadata = Model.metadata # pylint: disable=no-member
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||||
|
|
||||||
DB_CONNECTION_MUTATOR = config["DB_CONNECTION_MUTATOR"]
|
DB_CONNECTION_MUTATOR = config["DB_CONNECTION_MUTATOR"]
|
||||||
|
|
||||||
|
|
||||||
|
@ -373,17 +381,48 @@ class Database(
|
||||||
schema: Optional[str] = None,
|
schema: Optional[str] = None,
|
||||||
nullpool: bool = True,
|
nullpool: bool = True,
|
||||||
source: Optional[utils.QuerySource] = None,
|
source: Optional[utils.QuerySource] = None,
|
||||||
|
override_ssh_tunnel: Optional["SSHTunnel"] = None,
|
||||||
) -> Engine:
|
) -> Engine:
|
||||||
yield self._get_sqla_engine(schema=schema, nullpool=nullpool, source=source)
|
from superset.databases.dao import ( # pylint: disable=import-outside-toplevel
|
||||||
|
DatabaseDAO,
|
||||||
|
)
|
||||||
|
|
||||||
|
sqlalchemy_uri = self.sqlalchemy_uri_decrypted
|
||||||
|
engine_context = nullcontext()
|
||||||
|
ssh_tunnel = override_ssh_tunnel or DatabaseDAO.get_ssh_tunnel(
|
||||||
|
database_id=self.id
|
||||||
|
)
|
||||||
|
|
||||||
|
if ssh_tunnel:
|
||||||
|
# if ssh_tunnel is available build engine with information
|
||||||
|
engine_context = ssh_manager_factory.instance.create_tunnel(
|
||||||
|
ssh_tunnel=ssh_tunnel,
|
||||||
|
sqlalchemy_database_uri=self.sqlalchemy_uri_decrypted,
|
||||||
|
)
|
||||||
|
|
||||||
|
with engine_context as server_context:
|
||||||
|
if ssh_tunnel:
|
||||||
|
sqlalchemy_uri = ssh_manager_factory.instance.build_sqla_url(
|
||||||
|
sqlalchemy_uri, server_context
|
||||||
|
)
|
||||||
|
yield self._get_sqla_engine(
|
||||||
|
schema=schema,
|
||||||
|
nullpool=nullpool,
|
||||||
|
source=source,
|
||||||
|
sqlalchemy_uri=sqlalchemy_uri,
|
||||||
|
)
|
||||||
|
|
||||||
def _get_sqla_engine(
|
def _get_sqla_engine(
|
||||||
self,
|
self,
|
||||||
schema: Optional[str] = None,
|
schema: Optional[str] = None,
|
||||||
nullpool: bool = True,
|
nullpool: bool = True,
|
||||||
source: Optional[utils.QuerySource] = None,
|
source: Optional[utils.QuerySource] = None,
|
||||||
|
sqlalchemy_uri: Optional[str] = None,
|
||||||
) -> Engine:
|
) -> Engine:
|
||||||
extra = self.get_extra()
|
extra = self.get_extra()
|
||||||
sqlalchemy_url = make_url_safe(self.sqlalchemy_uri_decrypted)
|
sqlalchemy_url = make_url_safe(
|
||||||
|
sqlalchemy_uri if sqlalchemy_uri else self.sqlalchemy_uri_decrypted
|
||||||
|
)
|
||||||
sqlalchemy_url = self.db_engine_spec.adjust_database_uri(sqlalchemy_url, schema)
|
sqlalchemy_url = self.db_engine_spec.adjust_database_uri(sqlalchemy_url, schema)
|
||||||
effective_username = self.get_effective_user(sqlalchemy_url)
|
effective_username = self.get_effective_user(sqlalchemy_url)
|
||||||
# If using MySQL or Presto for example, will set url.username
|
# If using MySQL or Presto for example, will set url.username
|
||||||
|
@ -423,7 +462,6 @@ class Database(
|
||||||
sqlalchemy_url, params = DB_CONNECTION_MUTATOR(
|
sqlalchemy_url, params = DB_CONNECTION_MUTATOR(
|
||||||
sqlalchemy_url, params, effective_username, security_manager, source
|
sqlalchemy_url, params, effective_username, security_manager, source
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return create_engine(sqlalchemy_url, **params)
|
return create_engine(sqlalchemy_url, **params)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
|
@ -477,7 +515,7 @@ class Database(
|
||||||
security_manager,
|
security_manager,
|
||||||
)
|
)
|
||||||
|
|
||||||
with closing(engine.raw_connection()) as conn:
|
with self.get_raw_connection(schema=schema) as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
for sql_ in sqls[:-1]:
|
for sql_ in sqls[:-1]:
|
||||||
_log_query(sql_)
|
_log_query(sql_)
|
||||||
|
@ -574,14 +612,16 @@ class Database(
|
||||||
:return: The table/schema pairs
|
:return: The table/schema pairs
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
return {
|
with self.get_inspector_with_context() as inspector:
|
||||||
|
tables = {
|
||||||
(table, schema)
|
(table, schema)
|
||||||
for table in self.db_engine_spec.get_table_names(
|
for table in self.db_engine_spec.get_table_names(
|
||||||
database=self,
|
database=self,
|
||||||
inspector=self.inspector,
|
inspector=inspector,
|
||||||
schema=schema,
|
schema=schema,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
return tables
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise self.db_engine_spec.get_dbapi_mapped_exception(ex)
|
raise self.db_engine_spec.get_dbapi_mapped_exception(ex)
|
||||||
|
|
||||||
|
@ -608,17 +648,27 @@ class Database(
|
||||||
:return: set of views
|
:return: set of views
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
with self.get_inspector_with_context() as inspector:
|
||||||
return {
|
return {
|
||||||
(view, schema)
|
(view, schema)
|
||||||
for view in self.db_engine_spec.get_view_names(
|
for view in self.db_engine_spec.get_view_names(
|
||||||
database=self,
|
database=self,
|
||||||
inspector=self.inspector,
|
inspector=inspector,
|
||||||
schema=schema,
|
schema=schema,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise self.db_engine_spec.get_dbapi_mapped_exception(ex)
|
raise self.db_engine_spec.get_dbapi_mapped_exception(ex)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def get_inspector_with_context(
|
||||||
|
self, ssh_tunnel: Optional["SSHTunnel"] = None
|
||||||
|
) -> Inspector:
|
||||||
|
with self.get_sqla_engine_with_context(
|
||||||
|
override_ssh_tunnel=ssh_tunnel
|
||||||
|
) as engine:
|
||||||
|
yield sqla.inspect(engine)
|
||||||
|
|
||||||
@cache_util.memoized_func(
|
@cache_util.memoized_func(
|
||||||
key="db:{self.id}:schema_list",
|
key="db:{self.id}:schema_list",
|
||||||
cache=cache_manager.cache,
|
cache=cache_manager.cache,
|
||||||
|
@ -628,6 +678,7 @@ class Database(
|
||||||
cache: bool = False,
|
cache: bool = False,
|
||||||
cache_timeout: Optional[int] = None,
|
cache_timeout: Optional[int] = None,
|
||||||
force: bool = False,
|
force: bool = False,
|
||||||
|
ssh_tunnel: Optional["SSHTunnel"] = None,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""Parameters need to be passed as keyword arguments.
|
"""Parameters need to be passed as keyword arguments.
|
||||||
|
|
||||||
|
@ -640,7 +691,8 @@ class Database(
|
||||||
:return: schema list
|
:return: schema list
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
return self.db_engine_spec.get_schema_names(self.inspector)
|
with self.get_inspector_with_context(ssh_tunnel=ssh_tunnel) as inspector:
|
||||||
|
return self.db_engine_spec.get_schema_names(inspector)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
|
raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
|
||||||
|
|
||||||
|
@ -703,30 +755,35 @@ class Database(
|
||||||
def get_table_comment(
|
def get_table_comment(
|
||||||
self, table_name: str, schema: Optional[str] = None
|
self, table_name: str, schema: Optional[str] = None
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
return self.db_engine_spec.get_table_comment(self.inspector, table_name, schema)
|
with self.get_inspector_with_context() as inspector:
|
||||||
|
return self.db_engine_spec.get_table_comment(inspector, table_name, schema)
|
||||||
|
|
||||||
def get_columns(
|
def get_columns(
|
||||||
self, table_name: str, schema: Optional[str] = None
|
self, table_name: str, schema: Optional[str] = None
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
return self.db_engine_spec.get_columns(self.inspector, table_name, schema)
|
with self.get_inspector_with_context() as inspector:
|
||||||
|
return self.db_engine_spec.get_columns(inspector, table_name, schema)
|
||||||
|
|
||||||
def get_metrics(
|
def get_metrics(
|
||||||
self,
|
self,
|
||||||
table_name: str,
|
table_name: str,
|
||||||
schema: Optional[str] = None,
|
schema: Optional[str] = None,
|
||||||
) -> List[MetricType]:
|
) -> List[MetricType]:
|
||||||
return self.db_engine_spec.get_metrics(self, self.inspector, table_name, schema)
|
with self.get_inspector_with_context() as inspector:
|
||||||
|
return self.db_engine_spec.get_metrics(self, inspector, table_name, schema)
|
||||||
|
|
||||||
def get_indexes(
|
def get_indexes(
|
||||||
self, table_name: str, schema: Optional[str] = None
|
self, table_name: str, schema: Optional[str] = None
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
indexes = self.inspector.get_indexes(table_name, schema)
|
with self.get_inspector_with_context() as inspector:
|
||||||
|
indexes = inspector.get_indexes(table_name, schema)
|
||||||
return self.db_engine_spec.normalize_indexes(indexes)
|
return self.db_engine_spec.normalize_indexes(indexes)
|
||||||
|
|
||||||
def get_pk_constraint(
|
def get_pk_constraint(
|
||||||
self, table_name: str, schema: Optional[str] = None
|
self, table_name: str, schema: Optional[str] = None
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
pk_constraint = self.inspector.get_pk_constraint(table_name, schema) or {}
|
with self.get_inspector_with_context() as inspector:
|
||||||
|
pk_constraint = inspector.get_pk_constraint(table_name, schema) or {}
|
||||||
|
|
||||||
def _convert(value: Any) -> Any:
|
def _convert(value: Any) -> Any:
|
||||||
try:
|
try:
|
||||||
|
@ -739,7 +796,8 @@ class Database(
|
||||||
def get_foreign_keys(
|
def get_foreign_keys(
|
||||||
self, table_name: str, schema: Optional[str] = None
|
self, table_name: str, schema: Optional[str] = None
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
return self.inspector.get_foreign_keys(table_name, schema)
|
with self.get_inspector_with_context() as inspector:
|
||||||
|
return inspector.get_foreign_keys(table_name, schema)
|
||||||
|
|
||||||
def get_schema_access_for_file_upload( # pylint: disable=invalid-name
|
def get_schema_access_for_file_upload( # pylint: disable=invalid-name
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from superset.constants import PASSWORD_MASK
|
||||||
|
|
||||||
|
|
||||||
|
def mask_password_info(ssh_tunnel: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
if ssh_tunnel.pop("password", None) is not None:
|
||||||
|
ssh_tunnel["password"] = PASSWORD_MASK
|
||||||
|
if ssh_tunnel.pop("private_key", None) is not None:
|
||||||
|
ssh_tunnel["private_key"] = PASSWORD_MASK
|
||||||
|
if ssh_tunnel.pop("private_key_password", None) is not None:
|
||||||
|
ssh_tunnel["private_key_password"] = PASSWORD_MASK
|
||||||
|
return ssh_tunnel
|
|
@ -28,7 +28,7 @@ from __future__ import annotations
|
||||||
from typing import Callable, TYPE_CHECKING
|
from typing import Callable, TYPE_CHECKING
|
||||||
from unittest.mock import MagicMock, Mock, PropertyMock
|
from unittest.mock import MagicMock, Mock, PropertyMock
|
||||||
|
|
||||||
from flask import Flask
|
from flask import current_app, Flask
|
||||||
from flask.ctx import AppContext
|
from flask.ctx import AppContext
|
||||||
from pytest import fixture
|
from pytest import fixture
|
||||||
|
|
||||||
|
@ -40,6 +40,7 @@ from tests.example_data.data_loading.pandas.pands_data_loading_conf import (
|
||||||
from tests.example_data.data_loading.pandas.table_df_convertor import (
|
from tests.example_data.data_loading.pandas.table_df_convertor import (
|
||||||
TableToDfConvertorImpl,
|
TableToDfConvertorImpl,
|
||||||
)
|
)
|
||||||
|
from tests.integration_tests.test_app import app
|
||||||
|
|
||||||
SUPPORT_DATETIME_TYPE = "support_datetime_type"
|
SUPPORT_DATETIME_TYPE = "support_datetime_type"
|
||||||
|
|
||||||
|
@ -70,6 +71,7 @@ def example_db_provider() -> Callable[[], Database]:
|
||||||
|
|
||||||
@fixture(scope="session")
|
@fixture(scope="session")
|
||||||
def example_db_engine(example_db_provider: Callable[[], Database]) -> Engine:
|
def example_db_engine(example_db_provider: Callable[[], Database]) -> Engine:
|
||||||
|
with app.app_context():
|
||||||
with example_db_provider().get_sqla_engine_with_context() as engine:
|
with example_db_provider().get_sqla_engine_with_context() as engine:
|
||||||
return engine
|
return engine
|
||||||
|
|
||||||
|
|
|
@ -35,6 +35,8 @@ from sqlalchemy.sql import func
|
||||||
|
|
||||||
from superset import db, security_manager
|
from superset import db, security_manager
|
||||||
from superset.connectors.sqla.models import SqlaTable
|
from superset.connectors.sqla.models import SqlaTable
|
||||||
|
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||||
|
from superset.databases.utils import make_url_safe
|
||||||
from superset.db_engine_specs.mysql import MySQLEngineSpec
|
from superset.db_engine_specs.mysql import MySQLEngineSpec
|
||||||
from superset.db_engine_specs.postgres import PostgresEngineSpec
|
from superset.db_engine_specs.postgres import PostgresEngineSpec
|
||||||
from superset.db_engine_specs.redshift import RedshiftEngineSpec
|
from superset.db_engine_specs.redshift import RedshiftEngineSpec
|
||||||
|
@ -280,6 +282,314 @@ class TestDatabaseApi(SupersetTestCase):
|
||||||
db.session.delete(model)
|
db.session.delete(model)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
|
@mock.patch(
|
||||||
|
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
|
||||||
|
)
|
||||||
|
@mock.patch(
|
||||||
|
"superset.models.core.Database.get_all_schema_names",
|
||||||
|
)
|
||||||
|
def test_create_database_with_ssh_tunnel(
|
||||||
|
self, mock_test_connection_database_command_run, mock_get_all_schema_names
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Database API: Test create with SSH Tunnel
|
||||||
|
"""
|
||||||
|
self.login(username="admin")
|
||||||
|
example_db = get_example_database()
|
||||||
|
if example_db.backend == "sqlite":
|
||||||
|
return
|
||||||
|
ssh_tunnel_properties = {
|
||||||
|
"server_address": "123.132.123.1",
|
||||||
|
"server_port": 8080,
|
||||||
|
"username": "foo",
|
||||||
|
"password": "bar",
|
||||||
|
}
|
||||||
|
database_data = {
|
||||||
|
"database_name": "test-db-with-ssh-tunnel",
|
||||||
|
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||||
|
"ssh_tunnel": ssh_tunnel_properties,
|
||||||
|
}
|
||||||
|
uri = "api/v1/database/"
|
||||||
|
rv = self.client.post(uri, json=database_data)
|
||||||
|
response = json.loads(rv.data.decode("utf-8"))
|
||||||
|
self.assertEqual(rv.status_code, 201)
|
||||||
|
model_ssh_tunnel = (
|
||||||
|
db.session.query(SSHTunnel)
|
||||||
|
.filter(SSHTunnel.database_id == response.get("id"))
|
||||||
|
.one()
|
||||||
|
)
|
||||||
|
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
|
||||||
|
# Cleanup
|
||||||
|
model = db.session.query(Database).get(response.get("id"))
|
||||||
|
db.session.delete(model)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
@mock.patch(
|
||||||
|
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
|
||||||
|
)
|
||||||
|
@mock.patch(
|
||||||
|
"superset.models.core.Database.get_all_schema_names",
|
||||||
|
)
|
||||||
|
def test_update_database_with_ssh_tunnel(
|
||||||
|
self, mock_test_connection_database_command_run, mock_get_all_schema_names
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Database API: Test update with SSH Tunnel
|
||||||
|
"""
|
||||||
|
self.login(username="admin")
|
||||||
|
example_db = get_example_database()
|
||||||
|
if example_db.backend == "sqlite":
|
||||||
|
return
|
||||||
|
ssh_tunnel_properties = {
|
||||||
|
"server_address": "123.132.123.1",
|
||||||
|
"server_port": 8080,
|
||||||
|
"username": "foo",
|
||||||
|
"password": "bar",
|
||||||
|
}
|
||||||
|
database_data = {
|
||||||
|
"database_name": "test-db-with-ssh-tunnel",
|
||||||
|
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||||
|
}
|
||||||
|
database_data_with_ssh_tunnel = {
|
||||||
|
"database_name": "test-db-with-ssh-tunnel",
|
||||||
|
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||||
|
"ssh_tunnel": ssh_tunnel_properties,
|
||||||
|
}
|
||||||
|
|
||||||
|
uri = "api/v1/database/"
|
||||||
|
rv = self.client.post(uri, json=database_data)
|
||||||
|
response = json.loads(rv.data.decode("utf-8"))
|
||||||
|
self.assertEqual(rv.status_code, 201)
|
||||||
|
|
||||||
|
uri = "api/v1/database/{}".format(response.get("id"))
|
||||||
|
rv = self.client.put(uri, json=database_data_with_ssh_tunnel)
|
||||||
|
response_update = json.loads(rv.data.decode("utf-8"))
|
||||||
|
self.assertEqual(rv.status_code, 200)
|
||||||
|
|
||||||
|
model_ssh_tunnel = (
|
||||||
|
db.session.query(SSHTunnel)
|
||||||
|
.filter(SSHTunnel.database_id == response_update.get("id"))
|
||||||
|
.one()
|
||||||
|
)
|
||||||
|
self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id"))
|
||||||
|
# Cleanup
|
||||||
|
model = db.session.query(Database).get(response.get("id"))
|
||||||
|
db.session.delete(model)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
@mock.patch(
|
||||||
|
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
|
||||||
|
)
|
||||||
|
@mock.patch(
|
||||||
|
"superset.models.core.Database.get_all_schema_names",
|
||||||
|
)
|
||||||
|
def test_update_ssh_tunnel_via_database_api(
|
||||||
|
self, mock_test_connection_database_command_run, mock_get_all_schema_names
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Database API: Test update with SSH Tunnel
|
||||||
|
"""
|
||||||
|
self.login(username="admin")
|
||||||
|
example_db = get_example_database()
|
||||||
|
|
||||||
|
if example_db.backend == "sqlite":
|
||||||
|
return
|
||||||
|
initial_ssh_tunnel_properties = {
|
||||||
|
"server_address": "123.132.123.1",
|
||||||
|
"server_port": 8080,
|
||||||
|
"username": "foo",
|
||||||
|
"password": "bar",
|
||||||
|
}
|
||||||
|
updated_ssh_tunnel_properties = {
|
||||||
|
"server_address": "123.132.123.1",
|
||||||
|
"server_port": 8080,
|
||||||
|
"username": "Test",
|
||||||
|
}
|
||||||
|
database_data_with_ssh_tunnel = {
|
||||||
|
"database_name": "test-db-with-ssh-tunnel",
|
||||||
|
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||||
|
"ssh_tunnel": initial_ssh_tunnel_properties,
|
||||||
|
}
|
||||||
|
database_data_with_ssh_tunnel_update = {
|
||||||
|
"database_name": "test-db-with-ssh-tunnel",
|
||||||
|
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||||
|
"ssh_tunnel": updated_ssh_tunnel_properties,
|
||||||
|
}
|
||||||
|
|
||||||
|
uri = "api/v1/database/"
|
||||||
|
rv = self.client.post(uri, json=database_data_with_ssh_tunnel)
|
||||||
|
response = json.loads(rv.data.decode("utf-8"))
|
||||||
|
self.assertEqual(rv.status_code, 201)
|
||||||
|
model_ssh_tunnel = (
|
||||||
|
db.session.query(SSHTunnel)
|
||||||
|
.filter(SSHTunnel.database_id == response.get("id"))
|
||||||
|
.one()
|
||||||
|
)
|
||||||
|
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
|
||||||
|
self.assertEqual(model_ssh_tunnel.username, "foo")
|
||||||
|
uri = "api/v1/database/{}".format(response.get("id"))
|
||||||
|
rv = self.client.put(uri, json=database_data_with_ssh_tunnel_update)
|
||||||
|
response_update = json.loads(rv.data.decode("utf-8"))
|
||||||
|
self.assertEqual(rv.status_code, 200)
|
||||||
|
model_ssh_tunnel = (
|
||||||
|
db.session.query(SSHTunnel)
|
||||||
|
.filter(SSHTunnel.database_id == response_update.get("id"))
|
||||||
|
.one()
|
||||||
|
)
|
||||||
|
self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id"))
|
||||||
|
self.assertEqual(model_ssh_tunnel.username, "Test")
|
||||||
|
self.assertEqual(model_ssh_tunnel.server_address, "123.132.123.1")
|
||||||
|
self.assertEqual(model_ssh_tunnel.server_port, 8080)
|
||||||
|
# Cleanup
|
||||||
|
model = db.session.query(Database).get(response.get("id"))
|
||||||
|
db.session.delete(model)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
@mock.patch(
|
||||||
|
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
|
||||||
|
)
|
||||||
|
@mock.patch(
|
||||||
|
"superset.models.core.Database.get_all_schema_names",
|
||||||
|
)
|
||||||
|
def test_cascade_delete_ssh_tunnel(
|
||||||
|
self, mock_test_connection_database_command_run, mock_get_all_schema_names
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Database API: Test create with SSH Tunnel
|
||||||
|
"""
|
||||||
|
self.login(username="admin")
|
||||||
|
example_db = get_example_database()
|
||||||
|
if example_db.backend == "sqlite":
|
||||||
|
return
|
||||||
|
ssh_tunnel_properties = {
|
||||||
|
"server_address": "123.132.123.1",
|
||||||
|
"server_port": 8080,
|
||||||
|
"username": "foo",
|
||||||
|
"password": "bar",
|
||||||
|
}
|
||||||
|
database_data = {
|
||||||
|
"database_name": "test-db-with-ssh-tunnel",
|
||||||
|
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||||
|
"ssh_tunnel": ssh_tunnel_properties,
|
||||||
|
}
|
||||||
|
|
||||||
|
uri = "api/v1/database/"
|
||||||
|
rv = self.client.post(uri, json=database_data)
|
||||||
|
response = json.loads(rv.data.decode("utf-8"))
|
||||||
|
self.assertEqual(rv.status_code, 201)
|
||||||
|
model_ssh_tunnel = (
|
||||||
|
db.session.query(SSHTunnel)
|
||||||
|
.filter(SSHTunnel.database_id == response.get("id"))
|
||||||
|
.one()
|
||||||
|
)
|
||||||
|
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
|
||||||
|
# Cleanup
|
||||||
|
model = db.session.query(Database).get(response.get("id"))
|
||||||
|
db.session.delete(model)
|
||||||
|
db.session.commit()
|
||||||
|
model_ssh_tunnel = (
|
||||||
|
db.session.query(SSHTunnel)
|
||||||
|
.filter(SSHTunnel.database_id == response.get("id"))
|
||||||
|
.one_or_none()
|
||||||
|
)
|
||||||
|
assert model_ssh_tunnel is None
|
||||||
|
|
||||||
|
@mock.patch(
|
||||||
|
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
|
||||||
|
)
|
||||||
|
@mock.patch(
|
||||||
|
"superset.models.core.Database.get_all_schema_names",
|
||||||
|
)
|
||||||
|
def test_do_not_create_database_if_ssh_tunnel_creation_fails(
|
||||||
|
self, mock_test_connection_database_command_run, mock_get_all_schema_names
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Database API: Test create with SSH Tunnel
|
||||||
|
"""
|
||||||
|
self.login(username="admin")
|
||||||
|
example_db = get_example_database()
|
||||||
|
if example_db.backend == "sqlite":
|
||||||
|
return
|
||||||
|
ssh_tunnel_properties = {
|
||||||
|
"server_address": "123.132.123.1",
|
||||||
|
}
|
||||||
|
database_data = {
|
||||||
|
"database_name": "test-db-failure-ssh-tunnel",
|
||||||
|
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||||
|
"ssh_tunnel": ssh_tunnel_properties,
|
||||||
|
}
|
||||||
|
fail_message = {"message": "SSH Tunnel parameters are invalid."}
|
||||||
|
|
||||||
|
uri = "api/v1/database/"
|
||||||
|
rv = self.client.post(uri, json=database_data)
|
||||||
|
response = json.loads(rv.data.decode("utf-8"))
|
||||||
|
self.assertEqual(rv.status_code, 422)
|
||||||
|
model_ssh_tunnel = (
|
||||||
|
db.session.query(SSHTunnel)
|
||||||
|
.filter(SSHTunnel.database_id == response.get("id"))
|
||||||
|
.one_or_none()
|
||||||
|
)
|
||||||
|
assert model_ssh_tunnel is None
|
||||||
|
self.assertEqual(response, fail_message)
|
||||||
|
# Cleanup
|
||||||
|
model = (
|
||||||
|
db.session.query(Database)
|
||||||
|
.filter(Database.database_name == "test-db-failure-ssh-tunnel")
|
||||||
|
.one_or_none()
|
||||||
|
)
|
||||||
|
# the DB should not be created
|
||||||
|
assert model is None
|
||||||
|
|
||||||
|
@mock.patch(
|
||||||
|
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
|
||||||
|
)
|
||||||
|
@mock.patch(
|
||||||
|
"superset.models.core.Database.get_all_schema_names",
|
||||||
|
)
|
||||||
|
def test_get_database_returns_related_ssh_tunnel(
|
||||||
|
self, mock_test_connection_database_command_run, mock_get_all_schema_names
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Database API: Test GET Database returns its related SSH Tunnel
|
||||||
|
"""
|
||||||
|
self.login(username="admin")
|
||||||
|
example_db = get_example_database()
|
||||||
|
if example_db.backend == "sqlite":
|
||||||
|
return
|
||||||
|
ssh_tunnel_properties = {
|
||||||
|
"server_address": "123.132.123.1",
|
||||||
|
"server_port": 8080,
|
||||||
|
"username": "foo",
|
||||||
|
"password": "bar",
|
||||||
|
}
|
||||||
|
database_data = {
|
||||||
|
"database_name": "test-db-with-ssh-tunnel",
|
||||||
|
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||||
|
"ssh_tunnel": ssh_tunnel_properties,
|
||||||
|
}
|
||||||
|
response_ssh_tunnel = {
|
||||||
|
"server_address": "123.132.123.1",
|
||||||
|
"server_port": 8080,
|
||||||
|
"username": "foo",
|
||||||
|
"password": "XXXXXXXXXX",
|
||||||
|
}
|
||||||
|
|
||||||
|
uri = "api/v1/database/"
|
||||||
|
rv = self.client.post(uri, json=database_data)
|
||||||
|
response = json.loads(rv.data.decode("utf-8"))
|
||||||
|
self.assertEqual(rv.status_code, 201)
|
||||||
|
model_ssh_tunnel = (
|
||||||
|
db.session.query(SSHTunnel)
|
||||||
|
.filter(SSHTunnel.database_id == response.get("id"))
|
||||||
|
.one()
|
||||||
|
)
|
||||||
|
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
|
||||||
|
self.assertEqual(response.get("result")["ssh_tunnel"], response_ssh_tunnel)
|
||||||
|
# Cleanup
|
||||||
|
model = db.session.query(Database).get(response.get("id"))
|
||||||
|
db.session.delete(model)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
def test_create_database_invalid_configuration_method(self):
|
def test_create_database_invalid_configuration_method(self):
|
||||||
"""
|
"""
|
||||||
Database API: Test create with an invalid configuration method.
|
Database API: Test create with an invalid configuration method.
|
||||||
|
|
|
@ -0,0 +1,16 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
|
@ -0,0 +1,16 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
|
@ -0,0 +1,76 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
from unittest import mock, skip
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from superset import security_manager
|
||||||
|
from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand
|
||||||
|
from superset.databases.ssh_tunnel.commands.delete import DeleteSSHTunnelCommand
|
||||||
|
from superset.databases.ssh_tunnel.commands.exceptions import (
|
||||||
|
SSHTunnelInvalidError,
|
||||||
|
SSHTunnelNotFoundError,
|
||||||
|
)
|
||||||
|
from superset.databases.ssh_tunnel.commands.update import UpdateSSHTunnelCommand
|
||||||
|
from tests.integration_tests.base_tests import SupersetTestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateSSHTunnelCommand(SupersetTestCase):
|
||||||
|
@mock.patch("superset.utils.core.g")
|
||||||
|
def test_create_invalid_database_id(self, mock_g):
|
||||||
|
mock_g.user = security_manager.find_user("admin")
|
||||||
|
command = CreateSSHTunnelCommand(
|
||||||
|
None,
|
||||||
|
{
|
||||||
|
"server_address": "127.0.0.1",
|
||||||
|
"server_port": 5432,
|
||||||
|
"username": "test_user",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
with pytest.raises(SSHTunnelInvalidError) as excinfo:
|
||||||
|
command.run()
|
||||||
|
assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.")
|
||||||
|
|
||||||
|
|
||||||
|
class TestUpdateSSHTunnelCommand(SupersetTestCase):
|
||||||
|
@mock.patch("superset.utils.core.g")
|
||||||
|
def test_update_ssh_tunnel_not_found(self, mock_g):
|
||||||
|
mock_g.user = security_manager.find_user("admin")
|
||||||
|
# We have not created a SSH Tunnel yet so id = 1 is invalid
|
||||||
|
command = UpdateSSHTunnelCommand(
|
||||||
|
1,
|
||||||
|
{
|
||||||
|
"server_address": "127.0.0.1",
|
||||||
|
"server_port": 5432,
|
||||||
|
"username": "test_user",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
with pytest.raises(SSHTunnelNotFoundError) as excinfo:
|
||||||
|
command.run()
|
||||||
|
assert str(excinfo.value) == ("SSH Tunnel not found.")
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeleteSSHTunnelCommand(SupersetTestCase):
|
||||||
|
@mock.patch("superset.utils.core.g")
|
||||||
|
def test_delete_ssh_tunnel_not_found(self, mock_g):
|
||||||
|
mock_g.user = security_manager.find_user("admin")
|
||||||
|
# We have not created a SSH Tunnel yet so id = 1 is invalid
|
||||||
|
command = DeleteSSHTunnelCommand(1)
|
||||||
|
with pytest.raises(SSHTunnelNotFoundError) as excinfo:
|
||||||
|
command.run()
|
||||||
|
assert str(excinfo.value) == ("SSH Tunnel not found.")
|
|
@ -191,3 +191,147 @@ def test_non_zip_import(client: Any, full_api_access: None) -> None:
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_ssh_tunnel(
|
||||||
|
mocker: MockFixture,
|
||||||
|
app: Any,
|
||||||
|
session: Session,
|
||||||
|
client: Any,
|
||||||
|
full_api_access: None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Test that we can delete SSH Tunnel
|
||||||
|
"""
|
||||||
|
with app.app_context():
|
||||||
|
from superset.databases.api import DatabaseRestApi
|
||||||
|
from superset.databases.dao import DatabaseDAO
|
||||||
|
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||||
|
from superset.models.core import Database
|
||||||
|
|
||||||
|
DatabaseRestApi.datamodel.session = session
|
||||||
|
|
||||||
|
# create table for databases
|
||||||
|
Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member
|
||||||
|
|
||||||
|
# Create our Database
|
||||||
|
database = Database(
|
||||||
|
database_name="my_database",
|
||||||
|
sqlalchemy_uri="gsheets://",
|
||||||
|
encrypted_extra=json.dumps(
|
||||||
|
{
|
||||||
|
"service_account_info": {
|
||||||
|
"type": "service_account",
|
||||||
|
"project_id": "black-sanctum-314419",
|
||||||
|
"private_key_id": "259b0d419a8f840056158763ff54d8b08f7b8173",
|
||||||
|
"private_key": "SECRET",
|
||||||
|
"client_email": "google-spreadsheets-demo-servi@black-sanctum-314419.iam.gserviceaccount.com",
|
||||||
|
"client_id": "SSH_TUNNEL_CREDENTIALS_CLIENT",
|
||||||
|
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
||||||
|
"token_uri": "https://oauth2.googleapis.com/token",
|
||||||
|
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
|
||||||
|
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/google-spreadsheets-demo-servi%40black-sanctum-314419.iam.gserviceaccount.com",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
session.add(database)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
# mock the lookup so that we don't need to include the driver
|
||||||
|
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
|
||||||
|
mocker.patch("superset.utils.log.DBEventLogger.log")
|
||||||
|
|
||||||
|
# Create our SSHTunnel
|
||||||
|
tunnel = SSHTunnel(
|
||||||
|
database_id=1,
|
||||||
|
database=database,
|
||||||
|
)
|
||||||
|
|
||||||
|
session.add(tunnel)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
# Get our recently created SSHTunnel
|
||||||
|
response_tunnel = DatabaseDAO.get_ssh_tunnel(1)
|
||||||
|
assert response_tunnel
|
||||||
|
assert isinstance(response_tunnel, SSHTunnel)
|
||||||
|
assert 1 == response_tunnel.database_id
|
||||||
|
|
||||||
|
# Delete the recently created SSHTunnel
|
||||||
|
response_delete_tunnel = client.delete("/api/v1/database/1/ssh_tunnel/")
|
||||||
|
assert response_delete_tunnel.json["message"] == "OK"
|
||||||
|
|
||||||
|
response_tunnel = DatabaseDAO.get_ssh_tunnel(1)
|
||||||
|
assert response_tunnel is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_ssh_tunnel_not_found(
|
||||||
|
mocker: MockFixture,
|
||||||
|
app: Any,
|
||||||
|
session: Session,
|
||||||
|
client: Any,
|
||||||
|
full_api_access: None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Test that we cannot delete a tunnel that does not exist
|
||||||
|
"""
|
||||||
|
with app.app_context():
|
||||||
|
from superset.databases.api import DatabaseRestApi
|
||||||
|
from superset.databases.dao import DatabaseDAO
|
||||||
|
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||||
|
from superset.models.core import Database
|
||||||
|
|
||||||
|
DatabaseRestApi.datamodel.session = session
|
||||||
|
|
||||||
|
# create table for databases
|
||||||
|
Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member
|
||||||
|
|
||||||
|
# Create our Database
|
||||||
|
database = Database(
|
||||||
|
database_name="my_database",
|
||||||
|
sqlalchemy_uri="gsheets://",
|
||||||
|
encrypted_extra=json.dumps(
|
||||||
|
{
|
||||||
|
"service_account_info": {
|
||||||
|
"type": "service_account",
|
||||||
|
"project_id": "black-sanctum-314419",
|
||||||
|
"private_key_id": "259b0d419a8f840056158763ff54d8b08f7b8173",
|
||||||
|
"private_key": "SECRET",
|
||||||
|
"client_email": "google-spreadsheets-demo-servi@black-sanctum-314419.iam.gserviceaccount.com",
|
||||||
|
"client_id": "SSH_TUNNEL_CREDENTIALS_CLIENT",
|
||||||
|
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
||||||
|
"token_uri": "https://oauth2.googleapis.com/token",
|
||||||
|
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
|
||||||
|
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/google-spreadsheets-demo-servi%40black-sanctum-314419.iam.gserviceaccount.com",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
session.add(database)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
# mock the lookup so that we don't need to include the driver
|
||||||
|
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
|
||||||
|
mocker.patch("superset.utils.log.DBEventLogger.log")
|
||||||
|
|
||||||
|
# Create our SSHTunnel
|
||||||
|
tunnel = SSHTunnel(
|
||||||
|
database_id=1,
|
||||||
|
database=database,
|
||||||
|
)
|
||||||
|
|
||||||
|
session.add(tunnel)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
# Delete the recently created SSHTunnel
|
||||||
|
response_delete_tunnel = client.delete("/api/v1/database/2/ssh_tunnel/")
|
||||||
|
assert response_delete_tunnel.json["message"] == "Not found"
|
||||||
|
|
||||||
|
# Get our recently created SSHTunnel
|
||||||
|
response_tunnel = DatabaseDAO.get_ssh_tunnel(1)
|
||||||
|
assert response_tunnel
|
||||||
|
assert isinstance(response_tunnel, SSHTunnel)
|
||||||
|
assert 1 == response_tunnel.database_id
|
||||||
|
|
||||||
|
response_tunnel = DatabaseDAO.get_ssh_tunnel(2)
|
||||||
|
assert response_tunnel is None
|
||||||
|
|
|
@ -0,0 +1,16 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
|
@ -0,0 +1,69 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
from typing import Iterator
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy.orm.session import Session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def session_with_data(session: Session) -> Iterator[Session]:
|
||||||
|
from superset.connectors.sqla.models import SqlaTable
|
||||||
|
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||||
|
from superset.models.core import Database
|
||||||
|
|
||||||
|
engine = session.get_bind()
|
||||||
|
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
||||||
|
|
||||||
|
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||||
|
sqla_table = SqlaTable(
|
||||||
|
table_name="my_sqla_table",
|
||||||
|
columns=[],
|
||||||
|
metrics=[],
|
||||||
|
database=db,
|
||||||
|
)
|
||||||
|
ssh_tunnel = SSHTunnel(
|
||||||
|
database_id=db.id,
|
||||||
|
database=db,
|
||||||
|
)
|
||||||
|
|
||||||
|
session.add(db)
|
||||||
|
session.add(sqla_table)
|
||||||
|
session.add(ssh_tunnel)
|
||||||
|
session.flush()
|
||||||
|
yield session
|
||||||
|
session.rollback()
|
||||||
|
|
||||||
|
|
||||||
|
def test_database_get_ssh_tunnel(session_with_data: Session) -> None:
|
||||||
|
from superset.databases.dao import DatabaseDAO
|
||||||
|
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||||
|
|
||||||
|
result = DatabaseDAO.get_ssh_tunnel(1)
|
||||||
|
|
||||||
|
assert result
|
||||||
|
assert isinstance(result, SSHTunnel)
|
||||||
|
assert 1 == result.database_id
|
||||||
|
|
||||||
|
|
||||||
|
def test_database_get_ssh_tunnel_not_found(session_with_data: Session) -> None:
|
||||||
|
from superset.databases.dao import DatabaseDAO
|
||||||
|
|
||||||
|
result = DatabaseDAO.get_ssh_tunnel(2)
|
||||||
|
|
||||||
|
assert result is None
|
|
@ -0,0 +1,16 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
|
@ -0,0 +1,16 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
|
@ -0,0 +1,68 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
from typing import Iterator
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy.orm.session import Session
|
||||||
|
|
||||||
|
from superset.databases.ssh_tunnel.commands.exceptions import SSHTunnelInvalidError
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_ssh_tunnel_command() -> None:
|
||||||
|
from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand
|
||||||
|
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||||
|
from superset.models.core import Database
|
||||||
|
|
||||||
|
db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||||
|
|
||||||
|
properties = {
|
||||||
|
"database_id": db.id,
|
||||||
|
"server_address": "123.132.123.1",
|
||||||
|
"server_port": "3005",
|
||||||
|
"username": "foo",
|
||||||
|
"password": "bar",
|
||||||
|
}
|
||||||
|
|
||||||
|
result = CreateSSHTunnelCommand(db.id, properties).run()
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert isinstance(result, SSHTunnel)
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_ssh_tunnel_command_invalid_params() -> None:
|
||||||
|
from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand
|
||||||
|
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||||
|
from superset.models.core import Database
|
||||||
|
|
||||||
|
db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||||
|
|
||||||
|
# If we are trying to create a tunnel with a private_key_password
|
||||||
|
# then a private_key is mandatory
|
||||||
|
properties = {
|
||||||
|
"database_id": db.id,
|
||||||
|
"server_address": "123.132.123.1",
|
||||||
|
"server_port": "3005",
|
||||||
|
"username": "foo",
|
||||||
|
"private_key_password": "bar",
|
||||||
|
}
|
||||||
|
|
||||||
|
command = CreateSSHTunnelCommand(db.id, properties)
|
||||||
|
|
||||||
|
with pytest.raises(SSHTunnelInvalidError) as excinfo:
|
||||||
|
command.run()
|
||||||
|
assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.")
|
|
@ -0,0 +1,68 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
from typing import Iterator
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy.orm.session import Session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def session_with_data(session: Session) -> Iterator[Session]:
|
||||||
|
from superset.connectors.sqla.models import SqlaTable
|
||||||
|
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||||
|
from superset.models.core import Database
|
||||||
|
|
||||||
|
engine = session.get_bind()
|
||||||
|
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
||||||
|
|
||||||
|
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||||
|
sqla_table = SqlaTable(
|
||||||
|
table_name="my_sqla_table",
|
||||||
|
columns=[],
|
||||||
|
metrics=[],
|
||||||
|
database=db,
|
||||||
|
)
|
||||||
|
ssh_tunnel = SSHTunnel(
|
||||||
|
database_id=db.id,
|
||||||
|
database=db,
|
||||||
|
)
|
||||||
|
|
||||||
|
session.add(db)
|
||||||
|
session.add(sqla_table)
|
||||||
|
session.add(ssh_tunnel)
|
||||||
|
session.flush()
|
||||||
|
yield session
|
||||||
|
session.rollback()
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_ssh_tunnel_command(session_with_data: Session) -> None:
|
||||||
|
from superset.databases.dao import DatabaseDAO
|
||||||
|
from superset.databases.ssh_tunnel.commands.delete import DeleteSSHTunnelCommand
|
||||||
|
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||||
|
|
||||||
|
result = DatabaseDAO.get_ssh_tunnel(1)
|
||||||
|
|
||||||
|
assert result
|
||||||
|
assert isinstance(result, SSHTunnel)
|
||||||
|
assert 1 == result.database_id
|
||||||
|
|
||||||
|
DeleteSSHTunnelCommand(1).run()
|
||||||
|
|
||||||
|
result = DatabaseDAO.get_ssh_tunnel(1)
|
||||||
|
|
||||||
|
assert result is None
|
|
@ -0,0 +1,93 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
from typing import Iterator
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy.orm.session import Session
|
||||||
|
|
||||||
|
from superset.databases.ssh_tunnel.commands.exceptions import SSHTunnelInvalidError
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def session_with_data(session: Session) -> Iterator[Session]:
|
||||||
|
from superset.connectors.sqla.models import SqlaTable
|
||||||
|
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||||
|
from superset.models.core import Database
|
||||||
|
|
||||||
|
engine = session.get_bind()
|
||||||
|
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
||||||
|
|
||||||
|
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||||
|
sqla_table = SqlaTable(
|
||||||
|
table_name="my_sqla_table",
|
||||||
|
columns=[],
|
||||||
|
metrics=[],
|
||||||
|
database=db,
|
||||||
|
)
|
||||||
|
ssh_tunnel = SSHTunnel(database_id=db.id, database=db, server_address="Test")
|
||||||
|
|
||||||
|
session.add(db)
|
||||||
|
session.add(sqla_table)
|
||||||
|
session.add(ssh_tunnel)
|
||||||
|
session.flush()
|
||||||
|
yield session
|
||||||
|
session.rollback()
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_shh_tunnel_command(session_with_data: Session) -> None:
|
||||||
|
from superset.databases.dao import DatabaseDAO
|
||||||
|
from superset.databases.ssh_tunnel.commands.update import UpdateSSHTunnelCommand
|
||||||
|
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||||
|
|
||||||
|
result = DatabaseDAO.get_ssh_tunnel(1)
|
||||||
|
|
||||||
|
assert result
|
||||||
|
assert isinstance(result, SSHTunnel)
|
||||||
|
assert 1 == result.database_id
|
||||||
|
assert "Test" == result.server_address
|
||||||
|
|
||||||
|
update_payload = {"server_address": "Test2"}
|
||||||
|
UpdateSSHTunnelCommand(1, update_payload).run()
|
||||||
|
|
||||||
|
result = DatabaseDAO.get_ssh_tunnel(1)
|
||||||
|
|
||||||
|
assert result
|
||||||
|
assert isinstance(result, SSHTunnel)
|
||||||
|
assert "Test2" == result.server_address
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_shh_tunnel_invalid_params(session_with_data: Session) -> None:
|
||||||
|
from superset.databases.dao import DatabaseDAO
|
||||||
|
from superset.databases.ssh_tunnel.commands.update import UpdateSSHTunnelCommand
|
||||||
|
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||||
|
|
||||||
|
result = DatabaseDAO.get_ssh_tunnel(1)
|
||||||
|
|
||||||
|
assert result
|
||||||
|
assert isinstance(result, SSHTunnel)
|
||||||
|
assert 1 == result.database_id
|
||||||
|
assert "Test" == result.server_address
|
||||||
|
|
||||||
|
# If we are trying to update a tunnel with a private_key_password
|
||||||
|
# then a private_key is mandatory
|
||||||
|
update_payload = {"private_key_password": "pass"}
|
||||||
|
command = UpdateSSHTunnelCommand(1, update_payload)
|
||||||
|
|
||||||
|
with pytest.raises(SSHTunnelInvalidError) as excinfo:
|
||||||
|
command.run()
|
||||||
|
assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.")
|
|
@ -0,0 +1,43 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
from typing import Iterator
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy.orm.session import Session
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_ssh_tunnel():
|
||||||
|
from superset.databases.dao import DatabaseDAO
|
||||||
|
from superset.databases.ssh_tunnel.dao import SSHTunnelDAO
|
||||||
|
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||||
|
from superset.models.core import Database
|
||||||
|
|
||||||
|
db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||||
|
|
||||||
|
properties = {
|
||||||
|
"database_id": db.id,
|
||||||
|
"server_address": "123.132.123.1",
|
||||||
|
"server_port": "3005",
|
||||||
|
"username": "foo",
|
||||||
|
"password": "bar",
|
||||||
|
}
|
||||||
|
|
||||||
|
result = SSHTunnelDAO.create(properties)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert isinstance(result, SSHTunnel)
|
Loading…
Reference in New Issue