mirror of https://github.com/apache/superset.git
feat(ssh-tunnelling): Setup SSH Tunneling Commands for Database Connections (#21912)
Co-authored-by: Antonio Rivero Martinez <38889534+Antonio-RiveroMartnez@users.noreply.github.com> Co-authored-by: Elizabeth Thompson <eschutho@gmail.com>
This commit is contained in:
parent
a7a4561550
commit
ebaad10d6c
|
@ -19,6 +19,8 @@ babel==2.9.1
|
|||
# via flask-babel
|
||||
backoff==1.11.1
|
||||
# via apache-superset
|
||||
bcrypt==4.0.1
|
||||
# via paramiko
|
||||
billiard==3.6.4.0
|
||||
# via celery
|
||||
bleach==3.3.1
|
||||
|
@ -57,7 +59,9 @@ cron-descriptor==1.2.24
|
|||
croniter==1.0.15
|
||||
# via apache-superset
|
||||
cryptography==3.4.7
|
||||
# via apache-superset
|
||||
# via
|
||||
# apache-superset
|
||||
# paramiko
|
||||
deprecation==2.1.0
|
||||
# via apache-superset
|
||||
dnspython==2.1.0
|
||||
|
@ -167,6 +171,8 @@ packaging==21.3
|
|||
# deprecation
|
||||
pandas==1.5.2
|
||||
# via apache-superset
|
||||
paramiko==2.11.0
|
||||
# via sshtunnel
|
||||
parsedatetime==2.6
|
||||
# via apache-superset
|
||||
pgsanity==0.2.9
|
||||
|
@ -188,6 +194,8 @@ pyjwt==2.4.0
|
|||
# flask-jwt-extended
|
||||
pymeeus==0.5.11
|
||||
# via convertdate
|
||||
pynacl==1.5.0
|
||||
# via paramiko
|
||||
pyparsing==3.0.6
|
||||
# via
|
||||
# apache-superset
|
||||
|
@ -231,6 +239,7 @@ six==1.16.0
|
|||
# flask-talisman
|
||||
# isodate
|
||||
# jsonschema
|
||||
# paramiko
|
||||
# polyline
|
||||
# prison
|
||||
# pyrsistent
|
||||
|
@ -252,6 +261,8 @@ sqlalchemy-utils==0.38.3
|
|||
# flask-appbuilder
|
||||
sqlparse==0.4.3
|
||||
# via apache-superset
|
||||
sshtunnel==0.4.0
|
||||
# via apache-superset
|
||||
tabulate==0.8.9
|
||||
# via apache-superset
|
||||
typing-extensions==4.4.0
|
||||
|
|
1
setup.py
1
setup.py
|
@ -113,6 +113,7 @@ setup(
|
|||
"PyJWT>=2.4.0, <3.0",
|
||||
"redis",
|
||||
"selenium>=3.141.0",
|
||||
"sshtunnel>=0.4.0, <0.5",
|
||||
"simplejson>=3.15.0",
|
||||
"slack_sdk>=3.1.1, <4",
|
||||
"sqlalchemy>=1.4, <2",
|
||||
|
|
|
@ -476,8 +476,30 @@ DEFAULT_FEATURE_FLAGS: Dict[str, bool] = {
|
|||
"DRILL_TO_DETAIL": False,
|
||||
"DATAPANEL_CLOSED_BY_DEFAULT": False,
|
||||
"HORIZONTAL_FILTER_BAR": False,
|
||||
# Allow users to enable ssh tunneling when creating a DB.
|
||||
# Users must check whether the DB engine supports SSH Tunnels
|
||||
# otherwise enabling this flag won't have any effect on the DB.
|
||||
"SSH_TUNNELING": False,
|
||||
}
|
||||
|
||||
# ------------------------------
|
||||
# SSH Tunnel
|
||||
# ------------------------------
|
||||
# Allow users to set the host used when connecting to the SSH Tunnel
|
||||
# as localhost and any other alias (0.0.0.0)
|
||||
# ----------------------------------------------------------------------
|
||||
# |
|
||||
# -------------+ | +----------+
|
||||
# LOCAL | | | REMOTE | :22 SSH
|
||||
# CLIENT | <== SSH ========> | SERVER | :8080 web service
|
||||
# -------------+ | +----------+
|
||||
# |
|
||||
# FIREWALL (only port 22 is open)
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
SSH_TUNNEL_MANAGER_CLASS = "superset.extensions.ssh.SSHManager"
|
||||
SSH_TUNNEL_LOCAL_BIND_ADDRESS = "127.0.0.1"
|
||||
|
||||
# Feature flags may also be set via 'SUPERSET_FEATURE_' prefixed environment vars.
|
||||
DEFAULT_FEATURE_FLAGS.update(
|
||||
{
|
||||
|
@ -1506,7 +1528,7 @@ elif importlib.util.find_spec("superset_config") and not is_test():
|
|||
try:
|
||||
# pylint: disable=import-error,wildcard-import,unused-wildcard-import
|
||||
import superset_config
|
||||
from superset_config import * # type:ignore
|
||||
from superset_config import * # type: ignore
|
||||
|
||||
print(f"Loaded your LOCAL configuration at [{superset_config.__file__}]")
|
||||
except Exception:
|
||||
|
|
|
@ -139,6 +139,7 @@ MODEL_API_RW_METHOD_PERMISSION_MAP = {
|
|||
"validate_sql": "read",
|
||||
"get_data": "read",
|
||||
"samples": "read",
|
||||
"delete_ssh_tunnel": "write",
|
||||
}
|
||||
|
||||
EXTRA_FORM_DATA_APPEND_KEYS = {
|
||||
|
|
|
@ -72,6 +72,11 @@ from superset.databases.schemas import (
|
|||
ValidateSQLRequest,
|
||||
ValidateSQLResponse,
|
||||
)
|
||||
from superset.databases.ssh_tunnel.commands.delete import DeleteSSHTunnelCommand
|
||||
from superset.databases.ssh_tunnel.commands.exceptions import (
|
||||
SSHTunnelDeleteFailedError,
|
||||
SSHTunnelNotFoundError,
|
||||
)
|
||||
from superset.databases.utils import get_table_metadata
|
||||
from superset.db_engine_specs import get_available_engine_specs
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
|
@ -80,6 +85,7 @@ from superset.extensions import security_manager
|
|||
from superset.models.core import Database
|
||||
from superset.superset_typing import FlaskResponse
|
||||
from superset.utils.core import error_msg_from_exception, parse_js_uri_path_item
|
||||
from superset.utils.ssh_tunnel import mask_password_info
|
||||
from superset.views.base import json_errors_response
|
||||
from superset.views.base_api import (
|
||||
BaseSupersetModelRestApi,
|
||||
|
@ -107,6 +113,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
|||
"available",
|
||||
"validate_parameters",
|
||||
"validate_sql",
|
||||
"delete_ssh_tunnel",
|
||||
}
|
||||
resource_name = "database"
|
||||
class_permission_name = "Database"
|
||||
|
@ -219,6 +226,47 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
|||
ValidateSQLResponse,
|
||||
)
|
||||
|
||||
@expose("/<int:pk>", methods=["GET"])
|
||||
@protect()
|
||||
@safe
|
||||
def get(self, pk: int, **kwargs: Any) -> Response:
|
||||
"""Get a database
|
||||
---
|
||||
get:
|
||||
description: >-
|
||||
Get a database
|
||||
parameters:
|
||||
- in: path
|
||||
schema:
|
||||
type: integer
|
||||
description: The database id
|
||||
name: pk
|
||||
responses:
|
||||
200:
|
||||
description: Database
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
400:
|
||||
$ref: '#/components/responses/400'
|
||||
401:
|
||||
$ref: '#/components/responses/401'
|
||||
422:
|
||||
$ref: '#/components/responses/422'
|
||||
500:
|
||||
$ref: '#/components/responses/500'
|
||||
"""
|
||||
data = self.get_headless(pk, **kwargs)
|
||||
try:
|
||||
if ssh_tunnel := DatabaseDAO.get_ssh_tunnel(pk):
|
||||
payload = data.json
|
||||
payload["result"]["ssh_tunnel"] = ssh_tunnel.data
|
||||
return payload
|
||||
return data
|
||||
except SupersetException as ex:
|
||||
return self.response(ex.status, message=ex.message)
|
||||
|
||||
@expose("/", methods=["POST"])
|
||||
@protect()
|
||||
@safe
|
||||
|
@ -280,6 +328,12 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
|||
if new_model.driver:
|
||||
item["driver"] = new_model.driver
|
||||
|
||||
# Return SSH Tunnel and hide passwords if any
|
||||
if item.get("ssh_tunnel"):
|
||||
item["ssh_tunnel"] = mask_password_info(
|
||||
new_model.ssh_tunnel # pylint: disable=no-member
|
||||
)
|
||||
|
||||
return self.response(201, id=new_model.id, result=item)
|
||||
except DatabaseInvalidError as ex:
|
||||
return self.response_422(message=ex.normalized_messages())
|
||||
|
@ -361,6 +415,9 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
|||
item["sqlalchemy_uri"] = changed_model.sqlalchemy_uri
|
||||
if changed_model.parameters:
|
||||
item["parameters"] = changed_model.parameters
|
||||
# Return SSH Tunnel and hide passwords if any
|
||||
if item.get("ssh_tunnel"):
|
||||
item["ssh_tunnel"] = mask_password_info(changed_model.ssh_tunnel)
|
||||
return self.response(200, id=changed_model.id, result=item)
|
||||
except DatabaseNotFoundError:
|
||||
return self.response_404()
|
||||
|
@ -1206,3 +1263,57 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
|||
command = ValidateDatabaseParametersCommand(payload)
|
||||
command.run()
|
||||
return self.response(200, message="OK")
|
||||
|
||||
@expose("/<int:pk>/ssh_tunnel/", methods=["DELETE"])
|
||||
@protect()
|
||||
@statsd_metrics
|
||||
@event_logger.log_this_with_context(
|
||||
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}"
|
||||
f".delete_ssh_tunnel",
|
||||
log_to_statsd=False,
|
||||
)
|
||||
def delete_ssh_tunnel(self, pk: int) -> Response:
|
||||
"""Deletes a SSH Tunnel
|
||||
---
|
||||
delete:
|
||||
description: >-
|
||||
Deletes a SSH Tunnel.
|
||||
parameters:
|
||||
- in: path
|
||||
schema:
|
||||
type: integer
|
||||
name: pk
|
||||
responses:
|
||||
200:
|
||||
description: SSH Tunnel deleted
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
message:
|
||||
type: string
|
||||
401:
|
||||
$ref: '#/components/responses/401'
|
||||
403:
|
||||
$ref: '#/components/responses/403'
|
||||
404:
|
||||
$ref: '#/components/responses/404'
|
||||
422:
|
||||
$ref: '#/components/responses/422'
|
||||
500:
|
||||
$ref: '#/components/responses/500'
|
||||
"""
|
||||
try:
|
||||
DeleteSSHTunnelCommand(pk).run()
|
||||
return self.response(200, message="OK")
|
||||
except SSHTunnelNotFoundError:
|
||||
return self.response_404()
|
||||
except SSHTunnelDeleteFailedError as ex:
|
||||
logger.error(
|
||||
"Error deleting SSH Tunnel %s: %s",
|
||||
self.__class__.__name__,
|
||||
str(ex),
|
||||
exc_info=True,
|
||||
)
|
||||
return self.response_422(message=str(ex))
|
||||
|
|
|
@ -31,6 +31,11 @@ from superset.databases.commands.exceptions import (
|
|||
)
|
||||
from superset.databases.commands.test_connection import TestConnectionDatabaseCommand
|
||||
from superset.databases.dao import DatabaseDAO
|
||||
from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand
|
||||
from superset.databases.ssh_tunnel.commands.exceptions import (
|
||||
SSHTunnelCreateFailedError,
|
||||
SSHTunnelInvalidError,
|
||||
)
|
||||
from superset.exceptions import SupersetErrorsException
|
||||
from superset.extensions import db, event_logger, security_manager
|
||||
|
||||
|
@ -71,12 +76,35 @@ class CreateDatabaseCommand(BaseCommand):
|
|||
database = DatabaseDAO.create(self._properties, commit=False)
|
||||
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
|
||||
|
||||
ssh_tunnel = None
|
||||
if ssh_tunnel_properties := self._properties.get("ssh_tunnel"):
|
||||
try:
|
||||
# So database.id is not None
|
||||
db.session.flush()
|
||||
ssh_tunnel = CreateSSHTunnelCommand(
|
||||
database.id, ssh_tunnel_properties
|
||||
).run()
|
||||
except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex:
|
||||
event_logger.log_with_context(
|
||||
action=f"db_creation_failed.{ex.__class__.__name__}",
|
||||
engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
|
||||
)
|
||||
# So we can show the original message
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
event_logger.log_with_context(
|
||||
action=f"db_creation_failed.{ex.__class__.__name__}",
|
||||
engine=self._properties.get("sqlalchemy_uri", "").split(":")[0],
|
||||
)
|
||||
raise DatabaseCreateFailedError() from ex
|
||||
|
||||
# adding a new database we always want to force refresh schema list
|
||||
schemas = database.get_all_schema_names(cache=False)
|
||||
schemas = database.get_all_schema_names(cache=False, ssh_tunnel=ssh_tunnel)
|
||||
for schema in schemas:
|
||||
security_manager.add_permission_view_menu(
|
||||
"schema_access", security_manager.get_schema_perm(database, schema)
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
except DAOCreateFailedError as ex:
|
||||
db.session.rollback()
|
||||
|
|
|
@ -32,6 +32,7 @@ from superset.databases.commands.exceptions import (
|
|||
DatabaseTestConnectionUnexpectedError,
|
||||
)
|
||||
from superset.databases.dao import DatabaseDAO
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.databases.utils import make_url_safe
|
||||
from superset.errors import ErrorLevel, SupersetErrorType
|
||||
from superset.exceptions import (
|
||||
|
@ -90,6 +91,10 @@ class TestConnectionDatabaseCommand(BaseCommand):
|
|||
database.set_sqlalchemy_uri(uri)
|
||||
database.db_engine_spec.mutate_db_for_connection_test(database)
|
||||
|
||||
# Generate tunnel if present in the properties
|
||||
if ssh_tunnel := self._properties.get("ssh_tunnel"):
|
||||
ssh_tunnel = SSHTunnel(**ssh_tunnel)
|
||||
|
||||
event_logger.log_with_context(
|
||||
action="test_connection_attempt",
|
||||
engine=database.db_engine_spec.__name__,
|
||||
|
@ -99,7 +104,9 @@ class TestConnectionDatabaseCommand(BaseCommand):
|
|||
with closing(engine.raw_connection()) as conn:
|
||||
return engine.dialect.do_ping(conn)
|
||||
|
||||
with database.get_sqla_engine_with_context() as engine:
|
||||
with database.get_sqla_engine_with_context(
|
||||
override_ssh_tunnel=ssh_tunnel
|
||||
) as engine:
|
||||
try:
|
||||
alive = func_timeout(
|
||||
app.config["TEST_DATABASE_CONNECTION_TIMEOUT"].total_seconds(),
|
||||
|
|
|
@ -21,7 +21,7 @@ from flask_appbuilder.models.sqla import Model
|
|||
from marshmallow import ValidationError
|
||||
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.dao.exceptions import DAOUpdateFailedError
|
||||
from superset.dao.exceptions import DAOCreateFailedError, DAOUpdateFailedError
|
||||
from superset.databases.commands.exceptions import (
|
||||
DatabaseConnectionFailedError,
|
||||
DatabaseExistsValidationError,
|
||||
|
@ -30,6 +30,12 @@ from superset.databases.commands.exceptions import (
|
|||
DatabaseUpdateFailedError,
|
||||
)
|
||||
from superset.databases.dao import DatabaseDAO
|
||||
from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand
|
||||
from superset.databases.ssh_tunnel.commands.exceptions import (
|
||||
SSHTunnelCreateFailedError,
|
||||
SSHTunnelInvalidError,
|
||||
)
|
||||
from superset.databases.ssh_tunnel.commands.update import UpdateSSHTunnelCommand
|
||||
from superset.extensions import db, security_manager
|
||||
from superset.models.core import Database
|
||||
from superset.utils.core import DatasourceType
|
||||
|
@ -94,10 +100,33 @@ class UpdateDatabaseCommand(BaseCommand):
|
|||
security_manager.add_permission_view_menu(
|
||||
"schema_access", security_manager.get_schema_perm(database, schema)
|
||||
)
|
||||
|
||||
if ssh_tunnel_properties := self._properties.get("ssh_tunnel"):
|
||||
existing_ssh_tunnel_model = DatabaseDAO.get_ssh_tunnel(database.id)
|
||||
if existing_ssh_tunnel_model is None:
|
||||
# We couldn't found an existing tunnel so we need to create one
|
||||
try:
|
||||
CreateSSHTunnelCommand(database.id, ssh_tunnel_properties).run()
|
||||
except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex:
|
||||
# So we can show the original message
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
raise DatabaseUpdateFailedError() from ex
|
||||
else:
|
||||
# We found an existing tunnel so we need to update it
|
||||
try:
|
||||
UpdateSSHTunnelCommand(
|
||||
existing_ssh_tunnel_model.id, ssh_tunnel_properties
|
||||
).run()
|
||||
except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex:
|
||||
# So we can show the original message
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
raise DatabaseUpdateFailedError() from ex
|
||||
|
||||
db.session.commit()
|
||||
|
||||
except DAOUpdateFailedError as ex:
|
||||
logger.exception(ex.exception)
|
||||
except (DAOUpdateFailedError, DAOCreateFailedError) as ex:
|
||||
raise DatabaseUpdateFailedError() from ex
|
||||
return database
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ from typing import Any, Dict, Optional
|
|||
|
||||
from superset.dao.base import BaseDAO
|
||||
from superset.databases.filters import DatabaseFilter
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.extensions import db
|
||||
from superset.models.core import Database
|
||||
from superset.models.dashboard import Dashboard
|
||||
|
@ -124,3 +125,13 @@ class DatabaseDAO(BaseDAO):
|
|||
return dict(
|
||||
charts=charts, dashboards=dashboards, sqllab_tab_states=sqllab_tab_states
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_ssh_tunnel(cls, database_id: int) -> Optional[SSHTunnel]:
|
||||
ssh_tunnel = (
|
||||
db.session.query(SSHTunnel)
|
||||
.filter(SSHTunnel.database_id == database_id)
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
return ssh_tunnel
|
||||
|
|
|
@ -365,6 +365,19 @@ class DatabaseValidateParametersSchema(Schema):
|
|||
)
|
||||
|
||||
|
||||
class DatabaseSSHTunnel(Schema):
|
||||
server_address = fields.String()
|
||||
server_port = fields.Integer()
|
||||
username = fields.String()
|
||||
|
||||
# Basic Authentication
|
||||
password = fields.String(required=False)
|
||||
|
||||
# password protected private key authentication
|
||||
private_key = fields.String(required=False)
|
||||
private_key_password = fields.String(required=False)
|
||||
|
||||
|
||||
class DatabasePostSchema(Schema, DatabaseParametersSchemaMixin):
|
||||
class Meta: # pylint: disable=too-few-public-methods
|
||||
unknown = EXCLUDE
|
||||
|
@ -409,6 +422,7 @@ class DatabasePostSchema(Schema, DatabaseParametersSchemaMixin):
|
|||
is_managed_externally = fields.Boolean(allow_none=True, default=False)
|
||||
external_url = fields.String(allow_none=True)
|
||||
uuid = fields.String(required=False)
|
||||
ssh_tunnel = fields.Nested(DatabaseSSHTunnel, allow_none=True)
|
||||
|
||||
|
||||
class DatabasePutSchema(Schema, DatabaseParametersSchemaMixin):
|
||||
|
@ -454,6 +468,7 @@ class DatabasePutSchema(Schema, DatabaseParametersSchemaMixin):
|
|||
)
|
||||
is_managed_externally = fields.Boolean(allow_none=True, default=False)
|
||||
external_url = fields.String(allow_none=True)
|
||||
ssh_tunnel = fields.Nested(DatabaseSSHTunnel, allow_none=True)
|
||||
|
||||
|
||||
class DatabaseTestConnectionSchema(Schema, DatabaseParametersSchemaMixin):
|
||||
|
@ -482,6 +497,8 @@ class DatabaseTestConnectionSchema(Schema, DatabaseParametersSchemaMixin):
|
|||
validate=[Length(1, 1024), sqlalchemy_uri_validator],
|
||||
)
|
||||
|
||||
ssh_tunnel = fields.Nested(DatabaseSSHTunnel, allow_none=True)
|
||||
|
||||
|
||||
class TableMetadataOptionsResponseSchema(Schema):
|
||||
deferrable = fields.Bool()
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
|
@ -0,0 +1,16 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
|
@ -0,0 +1,92 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
from marshmallow import ValidationError
|
||||
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.dao.exceptions import DAOCreateFailedError
|
||||
from superset.databases.ssh_tunnel.commands.exceptions import (
|
||||
SSHTunnelCreateFailedError,
|
||||
SSHTunnelInvalidError,
|
||||
SSHTunnelRequiredFieldValidationError,
|
||||
)
|
||||
from superset.databases.ssh_tunnel.dao import SSHTunnelDAO
|
||||
from superset.extensions import db, event_logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CreateSSHTunnelCommand(BaseCommand):
|
||||
def __init__(self, database_id: int, data: Dict[str, Any]):
|
||||
self._properties = data.copy()
|
||||
self._properties["database_id"] = database_id
|
||||
|
||||
def run(self) -> Model:
|
||||
try:
|
||||
# Start nested transaction since we are always creating the tunnel
|
||||
# through a DB command (Create or Update). Without this, we cannot
|
||||
# safely rollback changes to databases if any, i.e, things like
|
||||
# test_do_not_create_database_if_ssh_tunnel_creation_fails test will fail
|
||||
db.session.begin_nested()
|
||||
self.validate()
|
||||
tunnel = SSHTunnelDAO.create(self._properties, commit=False)
|
||||
except DAOCreateFailedError as ex:
|
||||
# Rollback nested transaction
|
||||
db.session.rollback()
|
||||
raise SSHTunnelCreateFailedError() from ex
|
||||
except SSHTunnelInvalidError as ex:
|
||||
# Rollback nested transaction
|
||||
db.session.rollback()
|
||||
raise ex
|
||||
|
||||
return tunnel
|
||||
|
||||
def validate(self) -> None:
|
||||
# TODO(hughhh): check to make sure the server port is not localhost
|
||||
# using the config.SSH_TUNNEL_MANAGER
|
||||
exceptions: List[ValidationError] = []
|
||||
database_id: Optional[int] = self._properties.get("database_id")
|
||||
server_address: Optional[str] = self._properties.get("server_address")
|
||||
server_port: Optional[int] = self._properties.get("server_port")
|
||||
username: Optional[str] = self._properties.get("username")
|
||||
private_key: Optional[str] = self._properties.get("private_key")
|
||||
private_key_password: Optional[str] = self._properties.get(
|
||||
"private_key_password"
|
||||
)
|
||||
if not database_id:
|
||||
exceptions.append(SSHTunnelRequiredFieldValidationError("database_id"))
|
||||
if not server_address:
|
||||
exceptions.append(SSHTunnelRequiredFieldValidationError("server_address"))
|
||||
if not server_port:
|
||||
exceptions.append(SSHTunnelRequiredFieldValidationError("server_port"))
|
||||
if not username:
|
||||
exceptions.append(SSHTunnelRequiredFieldValidationError("username"))
|
||||
if private_key_password and private_key is None:
|
||||
exceptions.append(SSHTunnelRequiredFieldValidationError("private_key"))
|
||||
if exceptions:
|
||||
exception = SSHTunnelInvalidError()
|
||||
exception.add_list(exceptions)
|
||||
event_logger.log_with_context(
|
||||
action="ssh_tunnel_creation_failed.{}.{}".format(
|
||||
exception.__class__.__name__,
|
||||
".".join(exception.get_list_classnames()),
|
||||
)
|
||||
)
|
||||
raise exception
|
|
@ -0,0 +1,51 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.dao.exceptions import DAODeleteFailedError
|
||||
from superset.databases.ssh_tunnel.commands.exceptions import (
|
||||
SSHTunnelDeleteFailedError,
|
||||
SSHTunnelNotFoundError,
|
||||
)
|
||||
from superset.databases.ssh_tunnel.dao import SSHTunnelDAO
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DeleteSSHTunnelCommand(BaseCommand):
|
||||
def __init__(self, model_id: int):
|
||||
self._model_id = model_id
|
||||
self._model: Optional[SSHTunnel] = None
|
||||
|
||||
def run(self) -> Model:
|
||||
self.validate()
|
||||
try:
|
||||
ssh_tunnel = SSHTunnelDAO.delete(self._model)
|
||||
except DAODeleteFailedError as ex:
|
||||
raise SSHTunnelDeleteFailedError() from ex
|
||||
return ssh_tunnel
|
||||
|
||||
def validate(self) -> None:
|
||||
# Validate/populate model exists
|
||||
self._model = SSHTunnelDAO.find_by_id(self._model_id)
|
||||
if not self._model:
|
||||
raise SSHTunnelNotFoundError()
|
|
@ -0,0 +1,54 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from flask_babel import lazy_gettext as _
|
||||
from marshmallow import ValidationError
|
||||
|
||||
from superset.commands.exceptions import (
|
||||
CommandException,
|
||||
CommandInvalidError,
|
||||
DeleteFailedError,
|
||||
UpdateFailedError,
|
||||
)
|
||||
|
||||
|
||||
class SSHTunnelDeleteFailedError(DeleteFailedError):
|
||||
message = _("SSH Tunnel could not be deleted.")
|
||||
|
||||
|
||||
class SSHTunnelNotFoundError(CommandException):
|
||||
status = 404
|
||||
message = _("SSH Tunnel not found.")
|
||||
|
||||
|
||||
class SSHTunnelInvalidError(CommandInvalidError):
|
||||
message = _("SSH Tunnel parameters are invalid.")
|
||||
|
||||
|
||||
class SSHTunnelUpdateFailedError(UpdateFailedError):
|
||||
message = _("SSH Tunnel could not be updated.")
|
||||
|
||||
|
||||
class SSHTunnelCreateFailedError(CommandException):
|
||||
message = _("Creating SSH Tunnel failed for an unknown reason")
|
||||
|
||||
|
||||
class SSHTunnelRequiredFieldValidationError(ValidationError):
|
||||
def __init__(self, field_name: str) -> None:
|
||||
super().__init__(
|
||||
[_("Field is required")],
|
||||
field_name=field_name,
|
||||
)
|
|
@ -0,0 +1,62 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.dao.exceptions import DAOUpdateFailedError
|
||||
from superset.databases.ssh_tunnel.commands.exceptions import (
|
||||
SSHTunnelInvalidError,
|
||||
SSHTunnelNotFoundError,
|
||||
SSHTunnelRequiredFieldValidationError,
|
||||
SSHTunnelUpdateFailedError,
|
||||
)
|
||||
from superset.databases.ssh_tunnel.dao import SSHTunnelDAO
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UpdateSSHTunnelCommand(BaseCommand):
|
||||
def __init__(self, model_id: int, data: Dict[str, Any]):
|
||||
self._properties = data.copy()
|
||||
self._model_id = model_id
|
||||
self._model: Optional[SSHTunnel] = None
|
||||
|
||||
def run(self) -> Model:
|
||||
self.validate()
|
||||
try:
|
||||
tunnel = SSHTunnelDAO.update(self._model, self._properties)
|
||||
except DAOUpdateFailedError as ex:
|
||||
raise SSHTunnelUpdateFailedError() from ex
|
||||
return tunnel
|
||||
|
||||
def validate(self) -> None:
|
||||
# Validate/populate model exists
|
||||
self._model = SSHTunnelDAO.find_by_id(self._model_id)
|
||||
if not self._model:
|
||||
raise SSHTunnelNotFoundError()
|
||||
private_key: Optional[str] = self._properties.get("private_key")
|
||||
private_key_password: Optional[str] = self._properties.get(
|
||||
"private_key_password"
|
||||
)
|
||||
if private_key_password and private_key is None:
|
||||
exception = SSHTunnelInvalidError()
|
||||
exception.add(SSHTunnelRequiredFieldValidationError("private_key"))
|
||||
raise exception
|
|
@ -0,0 +1,26 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
|
||||
from superset.dao.base import BaseDAO
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SSHTunnelDAO(BaseDAO):
|
||||
model_cls = SSHTunnel
|
|
@ -0,0 +1,76 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import current_app
|
||||
from flask_appbuilder import Model
|
||||
from sqlalchemy.orm import backref, relationship
|
||||
from sqlalchemy_utils import EncryptedType
|
||||
|
||||
from superset.models.core import Database
|
||||
from superset.models.helpers import (
|
||||
AuditMixinNullable,
|
||||
ExtraJSONMixin,
|
||||
ImportExportMixin,
|
||||
)
|
||||
|
||||
app_config = current_app.config
|
||||
|
||||
|
||||
class SSHTunnel(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
|
||||
"""
|
||||
A ssh tunnel configuration in a database.
|
||||
"""
|
||||
|
||||
__tablename__ = "ssh_tunnels"
|
||||
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
database_id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey("dbs.id"), nullable=False, unique=True
|
||||
)
|
||||
database: Database = relationship(
|
||||
"Database",
|
||||
backref=backref("ssh_tunnels", uselist=False, cascade="all, delete-orphan"),
|
||||
foreign_keys=[database_id],
|
||||
)
|
||||
|
||||
server_address = sa.Column(sa.Text)
|
||||
server_port = sa.Column(sa.Integer)
|
||||
username = sa.Column(EncryptedType(sa.String, app_config["SECRET_KEY"]))
|
||||
|
||||
# basic authentication
|
||||
password = sa.Column(
|
||||
EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True
|
||||
)
|
||||
|
||||
# password protected pkey authentication
|
||||
private_key = sa.Column(
|
||||
EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True
|
||||
)
|
||||
private_key_password = sa.Column(
|
||||
EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True
|
||||
)
|
||||
|
||||
@property
|
||||
def data(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"server_address": self.server_address,
|
||||
"server_port": self.server_port,
|
||||
"username": self.username,
|
||||
}
|
|
@ -193,6 +193,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
engine_aliases: Set[str] = set()
|
||||
drivers: Dict[str, str] = {}
|
||||
default_driver: Optional[str] = None
|
||||
allow_ssh_tunneling = False
|
||||
|
||||
_date_trunc_functions: Dict[str, str] = {}
|
||||
_time_grain_expressions: Dict[Optional[str], str] = {}
|
||||
|
|
|
@ -166,6 +166,7 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
|
|||
class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
|
||||
engine = "postgresql"
|
||||
engine_aliases = {"postgres"}
|
||||
allow_ssh_tunneling = True
|
||||
|
||||
default_driver = "psycopg2"
|
||||
sqlalchemy_uri_placeholder = (
|
||||
|
|
|
@ -28,6 +28,7 @@ from flask_talisman import Talisman
|
|||
from flask_wtf.csrf import CSRFProtect
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from superset.extensions.ssh import SSHManagerFactory
|
||||
from superset.utils.async_query_manager import AsyncQueryManager
|
||||
from superset.utils.cache_manager import CacheManager
|
||||
from superset.utils.encrypt import EncryptedFieldFactory
|
||||
|
@ -127,3 +128,4 @@ profiling = ProfilingExtension()
|
|||
results_backend_manager = ResultsBackendManager()
|
||||
security_manager = LocalProxy(lambda: appbuilder.sm)
|
||||
talisman = Talisman()
|
||||
ssh_manager_factory = SSHManagerFactory()
|
||||
|
|
|
@ -0,0 +1,88 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from flask import Flask
|
||||
from sshtunnel import open_tunnel, SSHTunnelForwarder
|
||||
|
||||
from superset.databases.utils import make_url_safe
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
|
||||
|
||||
class SSHManager:
|
||||
def __init__(self, app: Flask) -> None:
|
||||
super().__init__()
|
||||
self.local_bind_address = app.config["SSH_TUNNEL_LOCAL_BIND_ADDRESS"]
|
||||
|
||||
def build_sqla_url( # pylint: disable=no-self-use
|
||||
self, sqlalchemy_url: str, server: SSHTunnelForwarder
|
||||
) -> str:
|
||||
# override any ssh tunnel configuration object
|
||||
url = make_url_safe(sqlalchemy_url)
|
||||
return url.set(
|
||||
host=server.local_bind_address[0],
|
||||
port=server.local_bind_port,
|
||||
)
|
||||
|
||||
def create_tunnel(
|
||||
self,
|
||||
ssh_tunnel: "SSHTunnel",
|
||||
sqlalchemy_database_uri: str,
|
||||
) -> SSHTunnelForwarder:
|
||||
url = make_url_safe(sqlalchemy_database_uri)
|
||||
params = {
|
||||
"ssh_address_or_host": ssh_tunnel.server_address,
|
||||
"ssh_port": ssh_tunnel.server_port,
|
||||
"ssh_username": ssh_tunnel.username,
|
||||
"remote_bind_address": (url.host, url.port), # bind_port, bind_host
|
||||
"local_bind_address": (self.local_bind_address,),
|
||||
}
|
||||
|
||||
if ssh_tunnel.password:
|
||||
params["ssh_password"] = ssh_tunnel.password
|
||||
elif ssh_tunnel.private_key:
|
||||
params["private_key"] = ssh_tunnel.private_key
|
||||
params["private_key_password"] = ssh_tunnel.private_key_password
|
||||
|
||||
return open_tunnel(**params)
|
||||
|
||||
|
||||
class SSHManagerFactory:
|
||||
def __init__(self) -> None:
|
||||
self._ssh_manager = None
|
||||
|
||||
def init_app(self, app: Flask) -> None:
|
||||
ssh_manager_fqclass = app.config["SSH_TUNNEL_MANAGER_CLASS"]
|
||||
ssh_manager_classname = ssh_manager_fqclass[
|
||||
ssh_manager_fqclass.rfind(".") + 1 :
|
||||
]
|
||||
ssh_manager_module_name = ssh_manager_fqclass[
|
||||
0 : ssh_manager_fqclass.rfind(".")
|
||||
]
|
||||
ssh_manager_class = getattr(
|
||||
importlib.import_module(ssh_manager_module_name), ssh_manager_classname
|
||||
)
|
||||
|
||||
self._ssh_manager = ssh_manager_class(app)
|
||||
|
||||
@property
|
||||
def instance(self) -> SSHManager:
|
||||
return self._ssh_manager # type: ignore
|
|
@ -45,6 +45,7 @@ from superset.extensions import (
|
|||
migrate,
|
||||
profiling,
|
||||
results_backend_manager,
|
||||
ssh_manager_factory,
|
||||
talisman,
|
||||
)
|
||||
from superset.security import SupersetSecurityManager
|
||||
|
@ -417,6 +418,7 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
|
|||
self.configure_data_sources()
|
||||
self.configure_auth_provider()
|
||||
self.configure_async_queries()
|
||||
self.configure_ssh_manager()
|
||||
|
||||
# Hook that provides administrators a handle on the Flask APP
|
||||
# after initialization
|
||||
|
@ -474,6 +476,9 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
|
|||
def configure_auth_provider(self) -> None:
|
||||
machine_auth_provider_factory.init_app(self.superset_app)
|
||||
|
||||
def configure_ssh_manager(self) -> None:
|
||||
ssh_manager_factory.init_app(self.superset_app)
|
||||
|
||||
def setup_event_logger(self) -> None:
|
||||
_event_logger["event_logger"] = get_event_logger_from_cfg_value(
|
||||
self.superset_app.config.get("EVENT_LOGGER", DBEventLogger())
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""create_ssh_tunnel_credentials_tbl
|
||||
|
||||
Revision ID: f3c2d8ec8595
|
||||
Revises: 4ce1d9b25135
|
||||
Create Date: 2022-10-20 10:48:08.722861
|
||||
|
||||
"""
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f3c2d8ec8595"
|
||||
down_revision = "4ce1d9b25135"
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy_utils import UUIDType
|
||||
|
||||
from superset import app
|
||||
from superset.extensions import encrypted_field_factory
|
||||
|
||||
app_config = app.config
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.create_table(
|
||||
"ssh_tunnels",
|
||||
# AuditMixinNullable
|
||||
sa.Column("created_on", sa.DateTime(), nullable=True),
|
||||
sa.Column("changed_on", sa.DateTime(), nullable=True),
|
||||
sa.Column("created_by_fk", sa.Integer(), nullable=True),
|
||||
sa.Column("changed_by_fk", sa.Integer(), nullable=True),
|
||||
# ExtraJSONMixin
|
||||
sa.Column("extra_json", sa.Text(), nullable=True),
|
||||
# ImportExportMixin
|
||||
sa.Column(
|
||||
"uuid",
|
||||
UUIDType(binary=True),
|
||||
primary_key=False,
|
||||
default=uuid4,
|
||||
unique=True,
|
||||
index=True,
|
||||
),
|
||||
# SSHTunnelCredentials
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column(
|
||||
"database_id",
|
||||
sa.INTEGER(),
|
||||
sa.ForeignKey("dbs.id"),
|
||||
unique=True,
|
||||
index=True,
|
||||
),
|
||||
sa.Column("server_address", sa.String(256)),
|
||||
sa.Column("server_port", sa.INTEGER()),
|
||||
sa.Column("username", encrypted_field_factory.create(sa.String(256))),
|
||||
sa.Column(
|
||||
"password", encrypted_field_factory.create(sa.String(256)), nullable=True
|
||||
),
|
||||
sa.Column(
|
||||
"private_key",
|
||||
encrypted_field_factory.create(sa.String(1024)),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"private_key_password",
|
||||
encrypted_field_factory.create(sa.String(256)),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_table("ssh_tunnels")
|
|
@ -21,10 +21,10 @@ import json
|
|||
import logging
|
||||
import textwrap
|
||||
from ast import literal_eval
|
||||
from contextlib import closing, contextmanager
|
||||
from contextlib import closing, contextmanager, nullcontext
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TYPE_CHECKING
|
||||
|
||||
import numpy
|
||||
import pandas as pd
|
||||
|
@ -57,7 +57,12 @@ from superset import app, db_engine_specs
|
|||
from superset.constants import PASSWORD_MASK
|
||||
from superset.databases.utils import make_url_safe
|
||||
from superset.db_engine_specs.base import MetricType, TimeGrain
|
||||
from superset.extensions import cache_manager, encrypted_field_factory, security_manager
|
||||
from superset.extensions import (
|
||||
cache_manager,
|
||||
encrypted_field_factory,
|
||||
security_manager,
|
||||
ssh_manager_factory,
|
||||
)
|
||||
from superset.models.helpers import AuditMixinNullable, ImportExportMixin
|
||||
from superset.result_set import SupersetResultSet
|
||||
from superset.utils import cache as cache_util, core as utils
|
||||
|
@ -71,6 +76,9 @@ log_query = config["QUERY_LOGGER"]
|
|||
metadata = Model.metadata # pylint: disable=no-member
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
|
||||
DB_CONNECTION_MUTATOR = config["DB_CONNECTION_MUTATOR"]
|
||||
|
||||
|
||||
|
@ -373,17 +381,48 @@ class Database(
|
|||
schema: Optional[str] = None,
|
||||
nullpool: bool = True,
|
||||
source: Optional[utils.QuerySource] = None,
|
||||
override_ssh_tunnel: Optional["SSHTunnel"] = None,
|
||||
) -> Engine:
|
||||
yield self._get_sqla_engine(schema=schema, nullpool=nullpool, source=source)
|
||||
from superset.databases.dao import ( # pylint: disable=import-outside-toplevel
|
||||
DatabaseDAO,
|
||||
)
|
||||
|
||||
sqlalchemy_uri = self.sqlalchemy_uri_decrypted
|
||||
engine_context = nullcontext()
|
||||
ssh_tunnel = override_ssh_tunnel or DatabaseDAO.get_ssh_tunnel(
|
||||
database_id=self.id
|
||||
)
|
||||
|
||||
if ssh_tunnel:
|
||||
# if ssh_tunnel is available build engine with information
|
||||
engine_context = ssh_manager_factory.instance.create_tunnel(
|
||||
ssh_tunnel=ssh_tunnel,
|
||||
sqlalchemy_database_uri=self.sqlalchemy_uri_decrypted,
|
||||
)
|
||||
|
||||
with engine_context as server_context:
|
||||
if ssh_tunnel:
|
||||
sqlalchemy_uri = ssh_manager_factory.instance.build_sqla_url(
|
||||
sqlalchemy_uri, server_context
|
||||
)
|
||||
yield self._get_sqla_engine(
|
||||
schema=schema,
|
||||
nullpool=nullpool,
|
||||
source=source,
|
||||
sqlalchemy_uri=sqlalchemy_uri,
|
||||
)
|
||||
|
||||
def _get_sqla_engine(
|
||||
self,
|
||||
schema: Optional[str] = None,
|
||||
nullpool: bool = True,
|
||||
source: Optional[utils.QuerySource] = None,
|
||||
sqlalchemy_uri: Optional[str] = None,
|
||||
) -> Engine:
|
||||
extra = self.get_extra()
|
||||
sqlalchemy_url = make_url_safe(self.sqlalchemy_uri_decrypted)
|
||||
sqlalchemy_url = make_url_safe(
|
||||
sqlalchemy_uri if sqlalchemy_uri else self.sqlalchemy_uri_decrypted
|
||||
)
|
||||
sqlalchemy_url = self.db_engine_spec.adjust_database_uri(sqlalchemy_url, schema)
|
||||
effective_username = self.get_effective_user(sqlalchemy_url)
|
||||
# If using MySQL or Presto for example, will set url.username
|
||||
|
@ -423,7 +462,6 @@ class Database(
|
|||
sqlalchemy_url, params = DB_CONNECTION_MUTATOR(
|
||||
sqlalchemy_url, params, effective_username, security_manager, source
|
||||
)
|
||||
|
||||
try:
|
||||
return create_engine(sqlalchemy_url, **params)
|
||||
except Exception as ex:
|
||||
|
@ -477,7 +515,7 @@ class Database(
|
|||
security_manager,
|
||||
)
|
||||
|
||||
with closing(engine.raw_connection()) as conn:
|
||||
with self.get_raw_connection(schema=schema) as conn:
|
||||
cursor = conn.cursor()
|
||||
for sql_ in sqls[:-1]:
|
||||
_log_query(sql_)
|
||||
|
@ -574,14 +612,16 @@ class Database(
|
|||
:return: The table/schema pairs
|
||||
"""
|
||||
try:
|
||||
return {
|
||||
(table, schema)
|
||||
for table in self.db_engine_spec.get_table_names(
|
||||
database=self,
|
||||
inspector=self.inspector,
|
||||
schema=schema,
|
||||
)
|
||||
}
|
||||
with self.get_inspector_with_context() as inspector:
|
||||
tables = {
|
||||
(table, schema)
|
||||
for table in self.db_engine_spec.get_table_names(
|
||||
database=self,
|
||||
inspector=inspector,
|
||||
schema=schema,
|
||||
)
|
||||
}
|
||||
return tables
|
||||
except Exception as ex:
|
||||
raise self.db_engine_spec.get_dbapi_mapped_exception(ex)
|
||||
|
||||
|
@ -608,17 +648,27 @@ class Database(
|
|||
:return: set of views
|
||||
"""
|
||||
try:
|
||||
return {
|
||||
(view, schema)
|
||||
for view in self.db_engine_spec.get_view_names(
|
||||
database=self,
|
||||
inspector=self.inspector,
|
||||
schema=schema,
|
||||
)
|
||||
}
|
||||
with self.get_inspector_with_context() as inspector:
|
||||
return {
|
||||
(view, schema)
|
||||
for view in self.db_engine_spec.get_view_names(
|
||||
database=self,
|
||||
inspector=inspector,
|
||||
schema=schema,
|
||||
)
|
||||
}
|
||||
except Exception as ex:
|
||||
raise self.db_engine_spec.get_dbapi_mapped_exception(ex)
|
||||
|
||||
@contextmanager
|
||||
def get_inspector_with_context(
|
||||
self, ssh_tunnel: Optional["SSHTunnel"] = None
|
||||
) -> Inspector:
|
||||
with self.get_sqla_engine_with_context(
|
||||
override_ssh_tunnel=ssh_tunnel
|
||||
) as engine:
|
||||
yield sqla.inspect(engine)
|
||||
|
||||
@cache_util.memoized_func(
|
||||
key="db:{self.id}:schema_list",
|
||||
cache=cache_manager.cache,
|
||||
|
@ -628,6 +678,7 @@ class Database(
|
|||
cache: bool = False,
|
||||
cache_timeout: Optional[int] = None,
|
||||
force: bool = False,
|
||||
ssh_tunnel: Optional["SSHTunnel"] = None,
|
||||
) -> List[str]:
|
||||
"""Parameters need to be passed as keyword arguments.
|
||||
|
||||
|
@ -640,7 +691,8 @@ class Database(
|
|||
:return: schema list
|
||||
"""
|
||||
try:
|
||||
return self.db_engine_spec.get_schema_names(self.inspector)
|
||||
with self.get_inspector_with_context(ssh_tunnel=ssh_tunnel) as inspector:
|
||||
return self.db_engine_spec.get_schema_names(inspector)
|
||||
except Exception as ex:
|
||||
raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
|
||||
|
||||
|
@ -703,43 +755,49 @@ class Database(
|
|||
def get_table_comment(
|
||||
self, table_name: str, schema: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
return self.db_engine_spec.get_table_comment(self.inspector, table_name, schema)
|
||||
with self.get_inspector_with_context() as inspector:
|
||||
return self.db_engine_spec.get_table_comment(inspector, table_name, schema)
|
||||
|
||||
def get_columns(
|
||||
self, table_name: str, schema: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
return self.db_engine_spec.get_columns(self.inspector, table_name, schema)
|
||||
with self.get_inspector_with_context() as inspector:
|
||||
return self.db_engine_spec.get_columns(inspector, table_name, schema)
|
||||
|
||||
def get_metrics(
|
||||
self,
|
||||
table_name: str,
|
||||
schema: Optional[str] = None,
|
||||
) -> List[MetricType]:
|
||||
return self.db_engine_spec.get_metrics(self, self.inspector, table_name, schema)
|
||||
with self.get_inspector_with_context() as inspector:
|
||||
return self.db_engine_spec.get_metrics(self, inspector, table_name, schema)
|
||||
|
||||
def get_indexes(
|
||||
self, table_name: str, schema: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
indexes = self.inspector.get_indexes(table_name, schema)
|
||||
return self.db_engine_spec.normalize_indexes(indexes)
|
||||
with self.get_inspector_with_context() as inspector:
|
||||
indexes = inspector.get_indexes(table_name, schema)
|
||||
return self.db_engine_spec.normalize_indexes(indexes)
|
||||
|
||||
def get_pk_constraint(
|
||||
self, table_name: str, schema: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
pk_constraint = self.inspector.get_pk_constraint(table_name, schema) or {}
|
||||
with self.get_inspector_with_context() as inspector:
|
||||
pk_constraint = inspector.get_pk_constraint(table_name, schema) or {}
|
||||
|
||||
def _convert(value: Any) -> Any:
|
||||
try:
|
||||
return utils.base_json_conv(value)
|
||||
except TypeError:
|
||||
return None
|
||||
def _convert(value: Any) -> Any:
|
||||
try:
|
||||
return utils.base_json_conv(value)
|
||||
except TypeError:
|
||||
return None
|
||||
|
||||
return {key: _convert(value) for key, value in pk_constraint.items()}
|
||||
return {key: _convert(value) for key, value in pk_constraint.items()}
|
||||
|
||||
def get_foreign_keys(
|
||||
self, table_name: str, schema: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
return self.inspector.get_foreign_keys(table_name, schema)
|
||||
with self.get_inspector_with_context() as inspector:
|
||||
return inspector.get_foreign_keys(table_name, schema)
|
||||
|
||||
def get_schema_access_for_file_upload( # pylint: disable=invalid-name
|
||||
self,
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from superset.constants import PASSWORD_MASK
|
||||
|
||||
|
||||
def mask_password_info(ssh_tunnel: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if ssh_tunnel.pop("password", None) is not None:
|
||||
ssh_tunnel["password"] = PASSWORD_MASK
|
||||
if ssh_tunnel.pop("private_key", None) is not None:
|
||||
ssh_tunnel["private_key"] = PASSWORD_MASK
|
||||
if ssh_tunnel.pop("private_key_password", None) is not None:
|
||||
ssh_tunnel["private_key_password"] = PASSWORD_MASK
|
||||
return ssh_tunnel
|
|
@ -28,7 +28,7 @@ from __future__ import annotations
|
|||
from typing import Callable, TYPE_CHECKING
|
||||
from unittest.mock import MagicMock, Mock, PropertyMock
|
||||
|
||||
from flask import Flask
|
||||
from flask import current_app, Flask
|
||||
from flask.ctx import AppContext
|
||||
from pytest import fixture
|
||||
|
||||
|
@ -40,6 +40,7 @@ from tests.example_data.data_loading.pandas.pands_data_loading_conf import (
|
|||
from tests.example_data.data_loading.pandas.table_df_convertor import (
|
||||
TableToDfConvertorImpl,
|
||||
)
|
||||
from tests.integration_tests.test_app import app
|
||||
|
||||
SUPPORT_DATETIME_TYPE = "support_datetime_type"
|
||||
|
||||
|
@ -70,8 +71,9 @@ def example_db_provider() -> Callable[[], Database]:
|
|||
|
||||
@fixture(scope="session")
|
||||
def example_db_engine(example_db_provider: Callable[[], Database]) -> Engine:
|
||||
with example_db_provider().get_sqla_engine_with_context() as engine:
|
||||
return engine
|
||||
with app.app_context():
|
||||
with example_db_provider().get_sqla_engine_with_context() as engine:
|
||||
return engine
|
||||
|
||||
|
||||
@fixture(scope="session")
|
||||
|
|
|
@ -35,6 +35,8 @@ from sqlalchemy.sql import func
|
|||
|
||||
from superset import db, security_manager
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.databases.utils import make_url_safe
|
||||
from superset.db_engine_specs.mysql import MySQLEngineSpec
|
||||
from superset.db_engine_specs.postgres import PostgresEngineSpec
|
||||
from superset.db_engine_specs.redshift import RedshiftEngineSpec
|
||||
|
@ -280,6 +282,314 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
db.session.delete(model)
|
||||
db.session.commit()
|
||||
|
||||
@mock.patch(
|
||||
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
|
||||
)
|
||||
@mock.patch(
|
||||
"superset.models.core.Database.get_all_schema_names",
|
||||
)
|
||||
def test_create_database_with_ssh_tunnel(
|
||||
self, mock_test_connection_database_command_run, mock_get_all_schema_names
|
||||
):
|
||||
"""
|
||||
Database API: Test create with SSH Tunnel
|
||||
"""
|
||||
self.login(username="admin")
|
||||
example_db = get_example_database()
|
||||
if example_db.backend == "sqlite":
|
||||
return
|
||||
ssh_tunnel_properties = {
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": 8080,
|
||||
"username": "foo",
|
||||
"password": "bar",
|
||||
}
|
||||
database_data = {
|
||||
"database_name": "test-db-with-ssh-tunnel",
|
||||
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||
"ssh_tunnel": ssh_tunnel_properties,
|
||||
}
|
||||
uri = "api/v1/database/"
|
||||
rv = self.client.post(uri, json=database_data)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 201)
|
||||
model_ssh_tunnel = (
|
||||
db.session.query(SSHTunnel)
|
||||
.filter(SSHTunnel.database_id == response.get("id"))
|
||||
.one()
|
||||
)
|
||||
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
|
||||
# Cleanup
|
||||
model = db.session.query(Database).get(response.get("id"))
|
||||
db.session.delete(model)
|
||||
db.session.commit()
|
||||
|
||||
@mock.patch(
|
||||
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
|
||||
)
|
||||
@mock.patch(
|
||||
"superset.models.core.Database.get_all_schema_names",
|
||||
)
|
||||
def test_update_database_with_ssh_tunnel(
|
||||
self, mock_test_connection_database_command_run, mock_get_all_schema_names
|
||||
):
|
||||
"""
|
||||
Database API: Test update with SSH Tunnel
|
||||
"""
|
||||
self.login(username="admin")
|
||||
example_db = get_example_database()
|
||||
if example_db.backend == "sqlite":
|
||||
return
|
||||
ssh_tunnel_properties = {
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": 8080,
|
||||
"username": "foo",
|
||||
"password": "bar",
|
||||
}
|
||||
database_data = {
|
||||
"database_name": "test-db-with-ssh-tunnel",
|
||||
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||
}
|
||||
database_data_with_ssh_tunnel = {
|
||||
"database_name": "test-db-with-ssh-tunnel",
|
||||
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||
"ssh_tunnel": ssh_tunnel_properties,
|
||||
}
|
||||
|
||||
uri = "api/v1/database/"
|
||||
rv = self.client.post(uri, json=database_data)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 201)
|
||||
|
||||
uri = "api/v1/database/{}".format(response.get("id"))
|
||||
rv = self.client.put(uri, json=database_data_with_ssh_tunnel)
|
||||
response_update = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
|
||||
model_ssh_tunnel = (
|
||||
db.session.query(SSHTunnel)
|
||||
.filter(SSHTunnel.database_id == response_update.get("id"))
|
||||
.one()
|
||||
)
|
||||
self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id"))
|
||||
# Cleanup
|
||||
model = db.session.query(Database).get(response.get("id"))
|
||||
db.session.delete(model)
|
||||
db.session.commit()
|
||||
|
||||
@mock.patch(
|
||||
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
|
||||
)
|
||||
@mock.patch(
|
||||
"superset.models.core.Database.get_all_schema_names",
|
||||
)
|
||||
def test_update_ssh_tunnel_via_database_api(
|
||||
self, mock_test_connection_database_command_run, mock_get_all_schema_names
|
||||
):
|
||||
"""
|
||||
Database API: Test update with SSH Tunnel
|
||||
"""
|
||||
self.login(username="admin")
|
||||
example_db = get_example_database()
|
||||
|
||||
if example_db.backend == "sqlite":
|
||||
return
|
||||
initial_ssh_tunnel_properties = {
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": 8080,
|
||||
"username": "foo",
|
||||
"password": "bar",
|
||||
}
|
||||
updated_ssh_tunnel_properties = {
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": 8080,
|
||||
"username": "Test",
|
||||
}
|
||||
database_data_with_ssh_tunnel = {
|
||||
"database_name": "test-db-with-ssh-tunnel",
|
||||
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||
"ssh_tunnel": initial_ssh_tunnel_properties,
|
||||
}
|
||||
database_data_with_ssh_tunnel_update = {
|
||||
"database_name": "test-db-with-ssh-tunnel",
|
||||
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||
"ssh_tunnel": updated_ssh_tunnel_properties,
|
||||
}
|
||||
|
||||
uri = "api/v1/database/"
|
||||
rv = self.client.post(uri, json=database_data_with_ssh_tunnel)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 201)
|
||||
model_ssh_tunnel = (
|
||||
db.session.query(SSHTunnel)
|
||||
.filter(SSHTunnel.database_id == response.get("id"))
|
||||
.one()
|
||||
)
|
||||
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
|
||||
self.assertEqual(model_ssh_tunnel.username, "foo")
|
||||
uri = "api/v1/database/{}".format(response.get("id"))
|
||||
rv = self.client.put(uri, json=database_data_with_ssh_tunnel_update)
|
||||
response_update = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
model_ssh_tunnel = (
|
||||
db.session.query(SSHTunnel)
|
||||
.filter(SSHTunnel.database_id == response_update.get("id"))
|
||||
.one()
|
||||
)
|
||||
self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id"))
|
||||
self.assertEqual(model_ssh_tunnel.username, "Test")
|
||||
self.assertEqual(model_ssh_tunnel.server_address, "123.132.123.1")
|
||||
self.assertEqual(model_ssh_tunnel.server_port, 8080)
|
||||
# Cleanup
|
||||
model = db.session.query(Database).get(response.get("id"))
|
||||
db.session.delete(model)
|
||||
db.session.commit()
|
||||
|
||||
@mock.patch(
|
||||
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
|
||||
)
|
||||
@mock.patch(
|
||||
"superset.models.core.Database.get_all_schema_names",
|
||||
)
|
||||
def test_cascade_delete_ssh_tunnel(
|
||||
self, mock_test_connection_database_command_run, mock_get_all_schema_names
|
||||
):
|
||||
"""
|
||||
Database API: Test create with SSH Tunnel
|
||||
"""
|
||||
self.login(username="admin")
|
||||
example_db = get_example_database()
|
||||
if example_db.backend == "sqlite":
|
||||
return
|
||||
ssh_tunnel_properties = {
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": 8080,
|
||||
"username": "foo",
|
||||
"password": "bar",
|
||||
}
|
||||
database_data = {
|
||||
"database_name": "test-db-with-ssh-tunnel",
|
||||
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||
"ssh_tunnel": ssh_tunnel_properties,
|
||||
}
|
||||
|
||||
uri = "api/v1/database/"
|
||||
rv = self.client.post(uri, json=database_data)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 201)
|
||||
model_ssh_tunnel = (
|
||||
db.session.query(SSHTunnel)
|
||||
.filter(SSHTunnel.database_id == response.get("id"))
|
||||
.one()
|
||||
)
|
||||
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
|
||||
# Cleanup
|
||||
model = db.session.query(Database).get(response.get("id"))
|
||||
db.session.delete(model)
|
||||
db.session.commit()
|
||||
model_ssh_tunnel = (
|
||||
db.session.query(SSHTunnel)
|
||||
.filter(SSHTunnel.database_id == response.get("id"))
|
||||
.one_or_none()
|
||||
)
|
||||
assert model_ssh_tunnel is None
|
||||
|
||||
@mock.patch(
|
||||
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
|
||||
)
|
||||
@mock.patch(
|
||||
"superset.models.core.Database.get_all_schema_names",
|
||||
)
|
||||
def test_do_not_create_database_if_ssh_tunnel_creation_fails(
|
||||
self, mock_test_connection_database_command_run, mock_get_all_schema_names
|
||||
):
|
||||
"""
|
||||
Database API: Test create with SSH Tunnel
|
||||
"""
|
||||
self.login(username="admin")
|
||||
example_db = get_example_database()
|
||||
if example_db.backend == "sqlite":
|
||||
return
|
||||
ssh_tunnel_properties = {
|
||||
"server_address": "123.132.123.1",
|
||||
}
|
||||
database_data = {
|
||||
"database_name": "test-db-failure-ssh-tunnel",
|
||||
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||
"ssh_tunnel": ssh_tunnel_properties,
|
||||
}
|
||||
fail_message = {"message": "SSH Tunnel parameters are invalid."}
|
||||
|
||||
uri = "api/v1/database/"
|
||||
rv = self.client.post(uri, json=database_data)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 422)
|
||||
model_ssh_tunnel = (
|
||||
db.session.query(SSHTunnel)
|
||||
.filter(SSHTunnel.database_id == response.get("id"))
|
||||
.one_or_none()
|
||||
)
|
||||
assert model_ssh_tunnel is None
|
||||
self.assertEqual(response, fail_message)
|
||||
# Cleanup
|
||||
model = (
|
||||
db.session.query(Database)
|
||||
.filter(Database.database_name == "test-db-failure-ssh-tunnel")
|
||||
.one_or_none()
|
||||
)
|
||||
# the DB should not be created
|
||||
assert model is None
|
||||
|
||||
@mock.patch(
|
||||
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
|
||||
)
|
||||
@mock.patch(
|
||||
"superset.models.core.Database.get_all_schema_names",
|
||||
)
|
||||
def test_get_database_returns_related_ssh_tunnel(
|
||||
self, mock_test_connection_database_command_run, mock_get_all_schema_names
|
||||
):
|
||||
"""
|
||||
Database API: Test GET Database returns its related SSH Tunnel
|
||||
"""
|
||||
self.login(username="admin")
|
||||
example_db = get_example_database()
|
||||
if example_db.backend == "sqlite":
|
||||
return
|
||||
ssh_tunnel_properties = {
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": 8080,
|
||||
"username": "foo",
|
||||
"password": "bar",
|
||||
}
|
||||
database_data = {
|
||||
"database_name": "test-db-with-ssh-tunnel",
|
||||
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||
"ssh_tunnel": ssh_tunnel_properties,
|
||||
}
|
||||
response_ssh_tunnel = {
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": 8080,
|
||||
"username": "foo",
|
||||
"password": "XXXXXXXXXX",
|
||||
}
|
||||
|
||||
uri = "api/v1/database/"
|
||||
rv = self.client.post(uri, json=database_data)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 201)
|
||||
model_ssh_tunnel = (
|
||||
db.session.query(SSHTunnel)
|
||||
.filter(SSHTunnel.database_id == response.get("id"))
|
||||
.one()
|
||||
)
|
||||
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
|
||||
self.assertEqual(response.get("result")["ssh_tunnel"], response_ssh_tunnel)
|
||||
# Cleanup
|
||||
model = db.session.query(Database).get(response.get("id"))
|
||||
db.session.delete(model)
|
||||
db.session.commit()
|
||||
|
||||
def test_create_database_invalid_configuration_method(self):
|
||||
"""
|
||||
Database API: Test create with an invalid configuration method.
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
|
@ -0,0 +1,16 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
|
@ -0,0 +1,76 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from unittest import mock, skip
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from superset import security_manager
|
||||
from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand
|
||||
from superset.databases.ssh_tunnel.commands.delete import DeleteSSHTunnelCommand
|
||||
from superset.databases.ssh_tunnel.commands.exceptions import (
|
||||
SSHTunnelInvalidError,
|
||||
SSHTunnelNotFoundError,
|
||||
)
|
||||
from superset.databases.ssh_tunnel.commands.update import UpdateSSHTunnelCommand
|
||||
from tests.integration_tests.base_tests import SupersetTestCase
|
||||
|
||||
|
||||
class TestCreateSSHTunnelCommand(SupersetTestCase):
|
||||
@mock.patch("superset.utils.core.g")
|
||||
def test_create_invalid_database_id(self, mock_g):
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
command = CreateSSHTunnelCommand(
|
||||
None,
|
||||
{
|
||||
"server_address": "127.0.0.1",
|
||||
"server_port": 5432,
|
||||
"username": "test_user",
|
||||
},
|
||||
)
|
||||
with pytest.raises(SSHTunnelInvalidError) as excinfo:
|
||||
command.run()
|
||||
assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.")
|
||||
|
||||
|
||||
class TestUpdateSSHTunnelCommand(SupersetTestCase):
|
||||
@mock.patch("superset.utils.core.g")
|
||||
def test_update_ssh_tunnel_not_found(self, mock_g):
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
# We have not created a SSH Tunnel yet so id = 1 is invalid
|
||||
command = UpdateSSHTunnelCommand(
|
||||
1,
|
||||
{
|
||||
"server_address": "127.0.0.1",
|
||||
"server_port": 5432,
|
||||
"username": "test_user",
|
||||
},
|
||||
)
|
||||
with pytest.raises(SSHTunnelNotFoundError) as excinfo:
|
||||
command.run()
|
||||
assert str(excinfo.value) == ("SSH Tunnel not found.")
|
||||
|
||||
|
||||
class TestDeleteSSHTunnelCommand(SupersetTestCase):
|
||||
@mock.patch("superset.utils.core.g")
|
||||
def test_delete_ssh_tunnel_not_found(self, mock_g):
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
# We have not created a SSH Tunnel yet so id = 1 is invalid
|
||||
command = DeleteSSHTunnelCommand(1)
|
||||
with pytest.raises(SSHTunnelNotFoundError) as excinfo:
|
||||
command.run()
|
||||
assert str(excinfo.value) == ("SSH Tunnel not found.")
|
|
@ -191,3 +191,147 @@ def test_non_zip_import(client: Any, full_api_access: None) -> None:
|
|||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def test_delete_ssh_tunnel(
|
||||
mocker: MockFixture,
|
||||
app: Any,
|
||||
session: Session,
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
) -> None:
|
||||
"""
|
||||
Test that we can delete SSH Tunnel
|
||||
"""
|
||||
with app.app_context():
|
||||
from superset.databases.api import DatabaseRestApi
|
||||
from superset.databases.dao import DatabaseDAO
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.models.core import Database
|
||||
|
||||
DatabaseRestApi.datamodel.session = session
|
||||
|
||||
# create table for databases
|
||||
Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member
|
||||
|
||||
# Create our Database
|
||||
database = Database(
|
||||
database_name="my_database",
|
||||
sqlalchemy_uri="gsheets://",
|
||||
encrypted_extra=json.dumps(
|
||||
{
|
||||
"service_account_info": {
|
||||
"type": "service_account",
|
||||
"project_id": "black-sanctum-314419",
|
||||
"private_key_id": "259b0d419a8f840056158763ff54d8b08f7b8173",
|
||||
"private_key": "SECRET",
|
||||
"client_email": "google-spreadsheets-demo-servi@black-sanctum-314419.iam.gserviceaccount.com",
|
||||
"client_id": "SSH_TUNNEL_CREDENTIALS_CLIENT",
|
||||
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
||||
"token_uri": "https://oauth2.googleapis.com/token",
|
||||
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
|
||||
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/google-spreadsheets-demo-servi%40black-sanctum-314419.iam.gserviceaccount.com",
|
||||
},
|
||||
}
|
||||
),
|
||||
)
|
||||
session.add(database)
|
||||
session.commit()
|
||||
|
||||
# mock the lookup so that we don't need to include the driver
|
||||
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
|
||||
mocker.patch("superset.utils.log.DBEventLogger.log")
|
||||
|
||||
# Create our SSHTunnel
|
||||
tunnel = SSHTunnel(
|
||||
database_id=1,
|
||||
database=database,
|
||||
)
|
||||
|
||||
session.add(tunnel)
|
||||
session.commit()
|
||||
|
||||
# Get our recently created SSHTunnel
|
||||
response_tunnel = DatabaseDAO.get_ssh_tunnel(1)
|
||||
assert response_tunnel
|
||||
assert isinstance(response_tunnel, SSHTunnel)
|
||||
assert 1 == response_tunnel.database_id
|
||||
|
||||
# Delete the recently created SSHTunnel
|
||||
response_delete_tunnel = client.delete("/api/v1/database/1/ssh_tunnel/")
|
||||
assert response_delete_tunnel.json["message"] == "OK"
|
||||
|
||||
response_tunnel = DatabaseDAO.get_ssh_tunnel(1)
|
||||
assert response_tunnel is None
|
||||
|
||||
|
||||
def test_delete_ssh_tunnel_not_found(
|
||||
mocker: MockFixture,
|
||||
app: Any,
|
||||
session: Session,
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
) -> None:
|
||||
"""
|
||||
Test that we cannot delete a tunnel that does not exist
|
||||
"""
|
||||
with app.app_context():
|
||||
from superset.databases.api import DatabaseRestApi
|
||||
from superset.databases.dao import DatabaseDAO
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.models.core import Database
|
||||
|
||||
DatabaseRestApi.datamodel.session = session
|
||||
|
||||
# create table for databases
|
||||
Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member
|
||||
|
||||
# Create our Database
|
||||
database = Database(
|
||||
database_name="my_database",
|
||||
sqlalchemy_uri="gsheets://",
|
||||
encrypted_extra=json.dumps(
|
||||
{
|
||||
"service_account_info": {
|
||||
"type": "service_account",
|
||||
"project_id": "black-sanctum-314419",
|
||||
"private_key_id": "259b0d419a8f840056158763ff54d8b08f7b8173",
|
||||
"private_key": "SECRET",
|
||||
"client_email": "google-spreadsheets-demo-servi@black-sanctum-314419.iam.gserviceaccount.com",
|
||||
"client_id": "SSH_TUNNEL_CREDENTIALS_CLIENT",
|
||||
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
||||
"token_uri": "https://oauth2.googleapis.com/token",
|
||||
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
|
||||
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/google-spreadsheets-demo-servi%40black-sanctum-314419.iam.gserviceaccount.com",
|
||||
},
|
||||
}
|
||||
),
|
||||
)
|
||||
session.add(database)
|
||||
session.commit()
|
||||
|
||||
# mock the lookup so that we don't need to include the driver
|
||||
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
|
||||
mocker.patch("superset.utils.log.DBEventLogger.log")
|
||||
|
||||
# Create our SSHTunnel
|
||||
tunnel = SSHTunnel(
|
||||
database_id=1,
|
||||
database=database,
|
||||
)
|
||||
|
||||
session.add(tunnel)
|
||||
session.commit()
|
||||
|
||||
# Delete the recently created SSHTunnel
|
||||
response_delete_tunnel = client.delete("/api/v1/database/2/ssh_tunnel/")
|
||||
assert response_delete_tunnel.json["message"] == "Not found"
|
||||
|
||||
# Get our recently created SSHTunnel
|
||||
response_tunnel = DatabaseDAO.get_ssh_tunnel(1)
|
||||
assert response_tunnel
|
||||
assert isinstance(response_tunnel, SSHTunnel)
|
||||
assert 1 == response_tunnel.database_id
|
||||
|
||||
response_tunnel = DatabaseDAO.get_ssh_tunnel(2)
|
||||
assert response_tunnel is None
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
|
@ -0,0 +1,69 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from typing import Iterator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session_with_data(session: Session) -> Iterator[Session]:
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.models.core import Database
|
||||
|
||||
engine = session.get_bind()
|
||||
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
sqla_table = SqlaTable(
|
||||
table_name="my_sqla_table",
|
||||
columns=[],
|
||||
metrics=[],
|
||||
database=db,
|
||||
)
|
||||
ssh_tunnel = SSHTunnel(
|
||||
database_id=db.id,
|
||||
database=db,
|
||||
)
|
||||
|
||||
session.add(db)
|
||||
session.add(sqla_table)
|
||||
session.add(ssh_tunnel)
|
||||
session.flush()
|
||||
yield session
|
||||
session.rollback()
|
||||
|
||||
|
||||
def test_database_get_ssh_tunnel(session_with_data: Session) -> None:
|
||||
from superset.databases.dao import DatabaseDAO
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
|
||||
result = DatabaseDAO.get_ssh_tunnel(1)
|
||||
|
||||
assert result
|
||||
assert isinstance(result, SSHTunnel)
|
||||
assert 1 == result.database_id
|
||||
|
||||
|
||||
def test_database_get_ssh_tunnel_not_found(session_with_data: Session) -> None:
|
||||
from superset.databases.dao import DatabaseDAO
|
||||
|
||||
result = DatabaseDAO.get_ssh_tunnel(2)
|
||||
|
||||
assert result is None
|
|
@ -0,0 +1,16 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
|
@ -0,0 +1,16 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
|
@ -0,0 +1,68 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from typing import Iterator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from superset.databases.ssh_tunnel.commands.exceptions import SSHTunnelInvalidError
|
||||
|
||||
|
||||
def test_create_ssh_tunnel_command() -> None:
|
||||
from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.models.core import Database
|
||||
|
||||
db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
|
||||
properties = {
|
||||
"database_id": db.id,
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": "3005",
|
||||
"username": "foo",
|
||||
"password": "bar",
|
||||
}
|
||||
|
||||
result = CreateSSHTunnelCommand(db.id, properties).run()
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, SSHTunnel)
|
||||
|
||||
|
||||
def test_create_ssh_tunnel_command_invalid_params() -> None:
|
||||
from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.models.core import Database
|
||||
|
||||
db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
|
||||
# If we are trying to create a tunnel with a private_key_password
|
||||
# then a private_key is mandatory
|
||||
properties = {
|
||||
"database_id": db.id,
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": "3005",
|
||||
"username": "foo",
|
||||
"private_key_password": "bar",
|
||||
}
|
||||
|
||||
command = CreateSSHTunnelCommand(db.id, properties)
|
||||
|
||||
with pytest.raises(SSHTunnelInvalidError) as excinfo:
|
||||
command.run()
|
||||
assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.")
|
|
@ -0,0 +1,68 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from typing import Iterator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session_with_data(session: Session) -> Iterator[Session]:
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.models.core import Database
|
||||
|
||||
engine = session.get_bind()
|
||||
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
sqla_table = SqlaTable(
|
||||
table_name="my_sqla_table",
|
||||
columns=[],
|
||||
metrics=[],
|
||||
database=db,
|
||||
)
|
||||
ssh_tunnel = SSHTunnel(
|
||||
database_id=db.id,
|
||||
database=db,
|
||||
)
|
||||
|
||||
session.add(db)
|
||||
session.add(sqla_table)
|
||||
session.add(ssh_tunnel)
|
||||
session.flush()
|
||||
yield session
|
||||
session.rollback()
|
||||
|
||||
|
||||
def test_delete_ssh_tunnel_command(session_with_data: Session) -> None:
|
||||
from superset.databases.dao import DatabaseDAO
|
||||
from superset.databases.ssh_tunnel.commands.delete import DeleteSSHTunnelCommand
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
|
||||
result = DatabaseDAO.get_ssh_tunnel(1)
|
||||
|
||||
assert result
|
||||
assert isinstance(result, SSHTunnel)
|
||||
assert 1 == result.database_id
|
||||
|
||||
DeleteSSHTunnelCommand(1).run()
|
||||
|
||||
result = DatabaseDAO.get_ssh_tunnel(1)
|
||||
|
||||
assert result is None
|
|
@ -0,0 +1,93 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from typing import Iterator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from superset.databases.ssh_tunnel.commands.exceptions import SSHTunnelInvalidError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session_with_data(session: Session) -> Iterator[Session]:
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.models.core import Database
|
||||
|
||||
engine = session.get_bind()
|
||||
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
sqla_table = SqlaTable(
|
||||
table_name="my_sqla_table",
|
||||
columns=[],
|
||||
metrics=[],
|
||||
database=db,
|
||||
)
|
||||
ssh_tunnel = SSHTunnel(database_id=db.id, database=db, server_address="Test")
|
||||
|
||||
session.add(db)
|
||||
session.add(sqla_table)
|
||||
session.add(ssh_tunnel)
|
||||
session.flush()
|
||||
yield session
|
||||
session.rollback()
|
||||
|
||||
|
||||
def test_update_shh_tunnel_command(session_with_data: Session) -> None:
|
||||
from superset.databases.dao import DatabaseDAO
|
||||
from superset.databases.ssh_tunnel.commands.update import UpdateSSHTunnelCommand
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
|
||||
result = DatabaseDAO.get_ssh_tunnel(1)
|
||||
|
||||
assert result
|
||||
assert isinstance(result, SSHTunnel)
|
||||
assert 1 == result.database_id
|
||||
assert "Test" == result.server_address
|
||||
|
||||
update_payload = {"server_address": "Test2"}
|
||||
UpdateSSHTunnelCommand(1, update_payload).run()
|
||||
|
||||
result = DatabaseDAO.get_ssh_tunnel(1)
|
||||
|
||||
assert result
|
||||
assert isinstance(result, SSHTunnel)
|
||||
assert "Test2" == result.server_address
|
||||
|
||||
|
||||
def test_update_shh_tunnel_invalid_params(session_with_data: Session) -> None:
|
||||
from superset.databases.dao import DatabaseDAO
|
||||
from superset.databases.ssh_tunnel.commands.update import UpdateSSHTunnelCommand
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
|
||||
result = DatabaseDAO.get_ssh_tunnel(1)
|
||||
|
||||
assert result
|
||||
assert isinstance(result, SSHTunnel)
|
||||
assert 1 == result.database_id
|
||||
assert "Test" == result.server_address
|
||||
|
||||
# If we are trying to update a tunnel with a private_key_password
|
||||
# then a private_key is mandatory
|
||||
update_payload = {"private_key_password": "pass"}
|
||||
command = UpdateSSHTunnelCommand(1, update_payload)
|
||||
|
||||
with pytest.raises(SSHTunnelInvalidError) as excinfo:
|
||||
command.run()
|
||||
assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.")
|
|
@ -0,0 +1,43 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from typing import Iterator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
|
||||
def test_create_ssh_tunnel():
|
||||
from superset.databases.dao import DatabaseDAO
|
||||
from superset.databases.ssh_tunnel.dao import SSHTunnelDAO
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.models.core import Database
|
||||
|
||||
db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
|
||||
properties = {
|
||||
"database_id": db.id,
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": "3005",
|
||||
"username": "foo",
|
||||
"password": "bar",
|
||||
}
|
||||
|
||||
result = SSHTunnelDAO.create(properties)
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, SSHTunnel)
|
Loading…
Reference in New Issue