feat(sshtunnel): add configuration for SSH_TIMEOUT (#24369)

This commit is contained in:
Hugh A. Miles II 2023-06-13 12:29:40 -04:00 committed by GitHub
parent 1328c56aab
commit eb05225f0d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 11 additions and 2 deletions

View File

@ -498,7 +498,11 @@ DEFAULT_FEATURE_FLAGS: dict[str, bool] = {
# ---------------------------------------------------------------------- # ----------------------------------------------------------------------
SSH_TUNNEL_MANAGER_CLASS = "superset.extensions.ssh.SSHManager" SSH_TUNNEL_MANAGER_CLASS = "superset.extensions.ssh.SSHManager"
SSH_TUNNEL_LOCAL_BIND_ADDRESS = "127.0.0.1" SSH_TUNNEL_LOCAL_BIND_ADDRESS = "127.0.0.1"
#: Timeout (seconds) for tunnel connection (open_channel timeout)
SSH_TUNNEL_TIMEOUT_SEC = 10.0 SSH_TUNNEL_TIMEOUT_SEC = 10.0
#: Timeout (seconds) for transport socket (``socket.settimeout``)
SSH_TUNNEL_PACKET_TIMEOUT_SEC = 1.0
# Feature flags may also be set via 'SUPERSET_FEATURE_' prefixed environment vars. # Feature flags may also be set via 'SUPERSET_FEATURE_' prefixed environment vars.
DEFAULT_FEATURE_FLAGS.update( DEFAULT_FEATURE_FLAGS.update(

View File

@ -35,6 +35,7 @@ class SSHManager:
super().__init__() super().__init__()
self.local_bind_address = app.config["SSH_TUNNEL_LOCAL_BIND_ADDRESS"] self.local_bind_address = app.config["SSH_TUNNEL_LOCAL_BIND_ADDRESS"]
sshtunnel.TUNNEL_TIMEOUT = app.config["SSH_TUNNEL_TIMEOUT_SEC"] sshtunnel.TUNNEL_TIMEOUT = app.config["SSH_TUNNEL_TIMEOUT_SEC"]
sshtunnel.SSH_TIMEOUT = app.config["SSH_TUNNEL_PACKET_TIMEOUT_SEC"]
def build_sqla_url( # pylint: disable=no-self-use def build_sqla_url( # pylint: disable=no-self-use
self, sqlalchemy_url: str, server: sshtunnel.SSHTunnelForwarder self, sqlalchemy_url: str, server: sshtunnel.SSHTunnelForwarder

View File

@ -31,6 +31,7 @@ from typing import Any, Callable, Optional, TYPE_CHECKING
import numpy import numpy
import pandas as pd import pandas as pd
import sqlalchemy as sqla import sqlalchemy as sqla
import sshtunnel
from flask import g, request from flask import g, request
from flask_appbuilder import Model from flask_appbuilder import Model
from sqlalchemy import ( from sqlalchemy import (
@ -406,9 +407,10 @@ class Database(
with engine_context as server_context: with engine_context as server_context:
if ssh_tunnel and server_context: if ssh_tunnel and server_context:
logger.info( logger.info(
"[SSH] Successfully create tunnel at %s: %s", "[SSH] Successfully created tunnel w/ %s tunnel_timeout + %s ssh_timeout at %s",
sshtunnel.TUNNEL_TIMEOUT,
sshtunnel.SSH_TIMEOUT,
server_context.local_bind_address, server_context.local_bind_address,
server_context.local_bind_port,
) )
sqlalchemy_uri = ssh_manager_factory.instance.build_sqla_url( sqlalchemy_uri = ssh_manager_factory.instance.build_sqla_url(
sqlalchemy_uri, server_context sqlalchemy_uri, server_context

View File

@ -28,8 +28,10 @@ def test_ssh_tunnel_timeout_setting() -> None:
"SSH_TUNNEL_MAX_RETRIES": 2, "SSH_TUNNEL_MAX_RETRIES": 2,
"SSH_TUNNEL_LOCAL_BIND_ADDRESS": "test", "SSH_TUNNEL_LOCAL_BIND_ADDRESS": "test",
"SSH_TUNNEL_TIMEOUT_SEC": 123.0, "SSH_TUNNEL_TIMEOUT_SEC": 123.0,
"SSH_TUNNEL_PACKET_TIMEOUT_SEC": 321.0,
"SSH_TUNNEL_MANAGER_CLASS": "superset.extensions.ssh.SSHManager", "SSH_TUNNEL_MANAGER_CLASS": "superset.extensions.ssh.SSHManager",
} }
factory = SSHManagerFactory() factory = SSHManagerFactory()
factory.init_app(app) factory.init_app(app)
assert sshtunnel.TUNNEL_TIMEOUT == 123.0 assert sshtunnel.TUNNEL_TIMEOUT == 123.0
assert sshtunnel.SSH_TIMEOUT == 321.0