mirror of https://github.com/apache/superset.git
feat: Expose different hooks dynamically to inject different database connection logics.
This commit is contained in:
parent
ddc9f06786
commit
d1c6152c62
|
@ -63,6 +63,7 @@ from superset.utils.core import is_test, NO_TIME_RANGE, parse_boolean_string
|
||||||
from superset.utils.encrypt import SQLAlchemyUtilsAdapter
|
from superset.utils.encrypt import SQLAlchemyUtilsAdapter
|
||||||
from superset.utils.log import DBEventLogger
|
from superset.utils.log import DBEventLogger
|
||||||
from superset.utils.logging_configurator import DefaultLoggingConfigurator
|
from superset.utils.logging_configurator import DefaultLoggingConfigurator
|
||||||
|
from superset.utils.database_connect_modifier import BaseDBConnectModifier
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -1258,6 +1259,42 @@ DASHBOARD_TEMPLATE_ID = None
|
||||||
DB_CONNECTION_MUTATOR = None
|
DB_CONNECTION_MUTATOR = None
|
||||||
|
|
||||||
|
|
||||||
|
# Whether to enable the DB_CONNECTION_MODIFIER feature
|
||||||
|
DB_CONNECTION_MODIFIER_ENABLED = False
|
||||||
|
|
||||||
|
# A dictionary of database connect modifiers (by engine) that allows altering
|
||||||
|
# the database connection URL and params on the fly, at runtime. This allows for things
|
||||||
|
# like impersonation or arbitrary logic. For instance you can wire different users to
|
||||||
|
# use different connection parameters, or pass their email address as the
|
||||||
|
# username. The function receives the connection uri object, connection
|
||||||
|
# params, the username, and returns the mutated uri and params objects.
|
||||||
|
# Example:
|
||||||
|
# class PostgresDBConnectModifier(BaseDBConnectModifier):
|
||||||
|
# # When connecting to a postgres data source,
|
||||||
|
# # replace the default connection username and password
|
||||||
|
#
|
||||||
|
# @classmethod
|
||||||
|
# def run(cls, sqlalchemy_url: URL, params: dict[str, Any], username: str, *args: Any,
|
||||||
|
# **kwargs: Any) -> (URL, dict[str, Any]):
|
||||||
|
# new_password = cls._get_new_password(username)
|
||||||
|
# sqlalchemy_url.username = username
|
||||||
|
# sqlalchemy_url.password = new_password
|
||||||
|
# return sqlalchemy_url, params
|
||||||
|
#
|
||||||
|
# @staticmethod
|
||||||
|
# def _get_new_password(username):
|
||||||
|
# # 实现密码生成逻辑
|
||||||
|
# return 'new_password_' + username
|
||||||
|
#
|
||||||
|
# DB_CONNECTION_MODIFIER: dict[str, type[BaseDBConnectModifier]] = {
|
||||||
|
# "postgresql": PostgresDBConnectModifier,
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# Note that the returned uri and params are passed directly to sqlalchemy's
|
||||||
|
# as such `create_engine(url, **params)`
|
||||||
|
DB_CONNECTION_MODIFIER: dict[str, type[BaseDBConnectModifier]] = {}
|
||||||
|
|
||||||
|
|
||||||
# A callable that is invoked for every invocation of DB Engine Specs
|
# A callable that is invoked for every invocation of DB Engine Specs
|
||||||
# which allows for custom validation of the engine URI.
|
# which allows for custom validation of the engine URI.
|
||||||
# See: superset.db_engine_specs.base.BaseEngineSpec.validate_database_uri
|
# See: superset.db_engine_specs.base.BaseEngineSpec.validate_database_uri
|
||||||
|
|
|
@ -93,6 +93,9 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
DB_CONNECTION_MUTATOR = config["DB_CONNECTION_MUTATOR"]
|
DB_CONNECTION_MUTATOR = config["DB_CONNECTION_MUTATOR"]
|
||||||
|
|
||||||
|
DB_CONNECTION_MODIFIER_ENABLED = config["DB_CONNECTION_MODIFIER_ENABLED"]
|
||||||
|
DB_CONNECTION_MODIFIER = config["DB_CONNECTION_MODIFIER"]
|
||||||
|
|
||||||
|
|
||||||
class KeyValue(Model): # pylint: disable=too-few-public-methods
|
class KeyValue(Model): # pylint: disable=too-few-public-methods
|
||||||
"""Used for any type of key-value store"""
|
"""Used for any type of key-value store"""
|
||||||
|
@ -535,6 +538,16 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
|
||||||
security_manager,
|
security_manager,
|
||||||
source,
|
source,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
is_db_connect_modify = (
|
||||||
|
DB_CONNECTION_MODIFIER_ENABLED and DB_CONNECTION_MODIFIER
|
||||||
|
and sqlalchemy_url.drivername in DB_CONNECTION_MODIFIER
|
||||||
|
)
|
||||||
|
if is_db_connect_modify:
|
||||||
|
url_modified = DB_CONNECTION_MODIFIER[sqlalchemy_url.drivername]
|
||||||
|
sqlalchemy_url, params = url_modified.run(
|
||||||
|
sqlalchemy_url, params, effective_username)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return create_engine(sqlalchemy_url, **params)
|
return create_engine(sqlalchemy_url, **params)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
|
|
|
@ -0,0 +1,12 @@
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy.engine.url import URL
|
||||||
|
|
||||||
|
|
||||||
|
class BaseDBConnectModifier:
|
||||||
|
name = "BaseURLModifier"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def run(cls, sqlalchemy_url: URL, params: dict[str, Any], username: str, *args: Any,
|
||||||
|
**kwargs: Any) -> (URL, dict[str, Any]):
|
||||||
|
raise NotImplementedError
|
|
@ -1156,6 +1156,35 @@ class TestCore(SupersetTestCase):
|
||||||
data = self.get_resp(url)
|
data = self.get_resp(url)
|
||||||
self.assertIn("Error message", data)
|
self.assertIn("Error message", data)
|
||||||
|
|
||||||
|
@pytest.mark.skip(
|
||||||
|
"TODO This test was wrong - 'Error message' was in the language pack"
|
||||||
|
)
|
||||||
|
@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")
|
||||||
|
@mock.patch("superset.models.core.DB_CONNECTION_MODIFIER")
|
||||||
|
def test_explore_with_modifier_injected_exceptions(self, mock_db_connection_modifier):
|
||||||
|
"""
|
||||||
|
Handle injected exceptions from the db modifier
|
||||||
|
"""
|
||||||
|
# Assert we can handle a custom exception at the modifier level
|
||||||
|
exception = SupersetException("Error message")
|
||||||
|
mock_db_connection_modifier.side_effect = exception
|
||||||
|
slice = db.session.query(Slice).first()
|
||||||
|
url = f"/explore/?form_data=%7B%22slice_id%22%3A%20{slice.id}%7D"
|
||||||
|
|
||||||
|
self.login(ADMIN_USERNAME)
|
||||||
|
data = self.get_resp(url)
|
||||||
|
self.assertIn("Error message", data)
|
||||||
|
|
||||||
|
# Assert we can handle a driver exception at the modifier level
|
||||||
|
exception = SQLAlchemyError("Error message")
|
||||||
|
mock_db_connection_modifier.side_effect = exception
|
||||||
|
slice = db.session.query(Slice).first()
|
||||||
|
url = f"/explore/?form_data=%7B%22slice_id%22%3A%20{slice.id}%7D"
|
||||||
|
|
||||||
|
self.login(ADMIN_USERNAME)
|
||||||
|
data = self.get_resp(url)
|
||||||
|
self.assertIn("Error message", data)
|
||||||
|
|
||||||
@pytest.mark.skip(
|
@pytest.mark.skip(
|
||||||
"TODO This test was wrong - 'Error message' was in the language pack"
|
"TODO This test was wrong - 'Error message' was in the language pack"
|
||||||
)
|
)
|
||||||
|
@ -1186,6 +1215,36 @@ class TestCore(SupersetTestCase):
|
||||||
data = self.get_resp(url)
|
data = self.get_resp(url)
|
||||||
self.assertIn("Error message", data)
|
self.assertIn("Error message", data)
|
||||||
|
|
||||||
|
@pytest.mark.skip(
|
||||||
|
"TODO This test was wrong - 'Error message' was in the language pack"
|
||||||
|
)
|
||||||
|
@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")
|
||||||
|
@mock.patch("superset.models.core.DB_CONNECTION_MODIFIER")
|
||||||
|
def test_dashboard_with_modifier_injected_exceptions(self, mock_db_connection_modifier):
|
||||||
|
"""
|
||||||
|
Handle injected exceptions from the db modifier
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Assert we can handle a custom exception at the modifier level
|
||||||
|
exception = SupersetException("Error message")
|
||||||
|
mock_db_connection_modifier.side_effect = exception
|
||||||
|
dash = db.session.query(Dashboard).first()
|
||||||
|
url = f"/superset/dashboard/{dash.id}/"
|
||||||
|
|
||||||
|
self.login(ADMIN_USERNAME)
|
||||||
|
data = self.get_resp(url)
|
||||||
|
self.assertIn("Error message", data)
|
||||||
|
|
||||||
|
# Assert we can handle a driver exception at the modifier level
|
||||||
|
exception = SQLAlchemyError("Error message")
|
||||||
|
mock_db_connection_modifier.side_effect = exception
|
||||||
|
dash = db.session.query(Dashboard).first()
|
||||||
|
url = f"/superset/dashboard/{dash.id}/"
|
||||||
|
|
||||||
|
self.login(ADMIN_USERNAME)
|
||||||
|
data = self.get_resp(url)
|
||||||
|
self.assertIn("Error message", data)
|
||||||
|
|
||||||
@pytest.mark.usefixtures("load_energy_table_with_slice")
|
@pytest.mark.usefixtures("load_energy_table_with_slice")
|
||||||
@mock.patch("superset.commands.explore.form_data.create.CreateFormDataCommand.run")
|
@mock.patch("superset.commands.explore.form_data.create.CreateFormDataCommand.run")
|
||||||
def test_explore_redirect(self, mock_command: mock.Mock):
|
def test_explore_redirect(self, mock_command: mock.Mock):
|
||||||
|
|
Loading…
Reference in New Issue