feat: Expose different hooks dynamically to inject different database connection logics.

This commit is contained in:
wugeer 2024-06-18 11:05:19 +08:00
parent ddc9f06786
commit d1c6152c62
4 changed files with 121 additions and 0 deletions

View File

@ -63,6 +63,7 @@ from superset.utils.core import is_test, NO_TIME_RANGE, parse_boolean_string
from superset.utils.encrypt import SQLAlchemyUtilsAdapter
from superset.utils.log import DBEventLogger
from superset.utils.logging_configurator import DefaultLoggingConfigurator
from superset.utils.database_connect_modifier import BaseDBConnectModifier
logger = logging.getLogger(__name__)
@ -1258,6 +1259,42 @@ DASHBOARD_TEMPLATE_ID = None
DB_CONNECTION_MUTATOR = None
# Whether to enable the DB_CONNECTION_MODIFIER feature
DB_CONNECTION_MODIFIER_ENABLED = False
# A dictionary of database connect modifiers (by engine) that allows altering
# the database connection URL and params on the fly, at runtime. This allows for things
# like impersonation or arbitrary logic. For instance you can wire different users to
# use different connection parameters, or pass their email address as the
# username. The function receives the connection uri object, connection
# params, the username, and returns the mutated uri and params objects.
# Example:
# class PostgresDBConnectModifier(BaseDBConnectModifier):
# # When connecting to a postgres data source,
# # replace the default connection username and password
#
# @classmethod
# def run(cls, sqlalchemy_url: URL, params: dict[str, Any], username: str, *args: Any,
# **kwargs: Any) -> (URL, dict[str, Any]):
# new_password = cls._get_new_password(username)
# sqlalchemy_url.username = username
# sqlalchemy_url.password = new_password
# return sqlalchemy_url, params
#
# @staticmethod
# def _get_new_password(username):
# # 实现密码生成逻辑
# return 'new_password_' + username
#
# DB_CONNECTION_MODIFIER: dict[str, type[BaseDBConnectModifier]] = {
# "postgresql": PostgresDBConnectModifier,
# }
#
# Note that the returned uri and params are passed directly to sqlalchemy's
# as such `create_engine(url, **params)`
DB_CONNECTION_MODIFIER: dict[str, type[BaseDBConnectModifier]] = {}
# A callable that is invoked for every invocation of DB Engine Specs
# which allows for custom validation of the engine URI.
# See: superset.db_engine_specs.base.BaseEngineSpec.validate_database_uri

View File

@ -93,6 +93,9 @@ if TYPE_CHECKING:
DB_CONNECTION_MUTATOR = config["DB_CONNECTION_MUTATOR"]
DB_CONNECTION_MODIFIER_ENABLED = config["DB_CONNECTION_MODIFIER_ENABLED"]
DB_CONNECTION_MODIFIER = config["DB_CONNECTION_MODIFIER"]
class KeyValue(Model): # pylint: disable=too-few-public-methods
"""Used for any type of key-value store"""
@ -535,6 +538,16 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable
security_manager,
source,
)
is_db_connect_modify = (
DB_CONNECTION_MODIFIER_ENABLED and DB_CONNECTION_MODIFIER
and sqlalchemy_url.drivername in DB_CONNECTION_MODIFIER
)
if is_db_connect_modify:
url_modified = DB_CONNECTION_MODIFIER[sqlalchemy_url.drivername]
sqlalchemy_url, params = url_modified.run(
sqlalchemy_url, params, effective_username)
try:
return create_engine(sqlalchemy_url, **params)
except Exception as ex:

View File

@ -0,0 +1,12 @@
from typing import Any
from sqlalchemy.engine.url import URL
class BaseDBConnectModifier:
name = "BaseURLModifier"
@classmethod
def run(cls, sqlalchemy_url: URL, params: dict[str, Any], username: str, *args: Any,
**kwargs: Any) -> (URL, dict[str, Any]):
raise NotImplementedError

View File

@ -1156,6 +1156,35 @@ class TestCore(SupersetTestCase):
data = self.get_resp(url)
self.assertIn("Error message", data)
@pytest.mark.skip(
"TODO This test was wrong - 'Error message' was in the language pack"
)
@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")
@mock.patch("superset.models.core.DB_CONNECTION_MODIFIER")
def test_explore_with_modifier_injected_exceptions(self, mock_db_connection_modifier):
"""
Handle injected exceptions from the db modifier
"""
# Assert we can handle a custom exception at the modifier level
exception = SupersetException("Error message")
mock_db_connection_modifier.side_effect = exception
slice = db.session.query(Slice).first()
url = f"/explore/?form_data=%7B%22slice_id%22%3A%20{slice.id}%7D"
self.login(ADMIN_USERNAME)
data = self.get_resp(url)
self.assertIn("Error message", data)
# Assert we can handle a driver exception at the modifier level
exception = SQLAlchemyError("Error message")
mock_db_connection_modifier.side_effect = exception
slice = db.session.query(Slice).first()
url = f"/explore/?form_data=%7B%22slice_id%22%3A%20{slice.id}%7D"
self.login(ADMIN_USERNAME)
data = self.get_resp(url)
self.assertIn("Error message", data)
@pytest.mark.skip(
"TODO This test was wrong - 'Error message' was in the language pack"
)
@ -1186,6 +1215,36 @@ class TestCore(SupersetTestCase):
data = self.get_resp(url)
self.assertIn("Error message", data)
@pytest.mark.skip(
"TODO This test was wrong - 'Error message' was in the language pack"
)
@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")
@mock.patch("superset.models.core.DB_CONNECTION_MODIFIER")
def test_dashboard_with_modifier_injected_exceptions(self, mock_db_connection_modifier):
"""
Handle injected exceptions from the db modifier
"""
# Assert we can handle a custom exception at the modifier level
exception = SupersetException("Error message")
mock_db_connection_modifier.side_effect = exception
dash = db.session.query(Dashboard).first()
url = f"/superset/dashboard/{dash.id}/"
self.login(ADMIN_USERNAME)
data = self.get_resp(url)
self.assertIn("Error message", data)
# Assert we can handle a driver exception at the modifier level
exception = SQLAlchemyError("Error message")
mock_db_connection_modifier.side_effect = exception
dash = db.session.query(Dashboard).first()
url = f"/superset/dashboard/{dash.id}/"
self.login(ADMIN_USERNAME)
data = self.get_resp(url)
self.assertIn("Error message", data)
@pytest.mark.usefixtures("load_energy_table_with_slice")
@mock.patch("superset.commands.explore.form_data.create.CreateFormDataCommand.run")
def test_explore_redirect(self, mock_command: mock.Mock):