diff --git a/superset/config.py b/superset/config.py index 3dfeb3ad47..06002960d5 100644 --- a/superset/config.py +++ b/superset/config.py @@ -63,6 +63,7 @@ from superset.utils.core import is_test, NO_TIME_RANGE, parse_boolean_string from superset.utils.encrypt import SQLAlchemyUtilsAdapter from superset.utils.log import DBEventLogger from superset.utils.logging_configurator import DefaultLoggingConfigurator +from superset.utils.database_connect_modifier import BaseDBConnectModifier logger = logging.getLogger(__name__) @@ -1258,6 +1259,42 @@ DASHBOARD_TEMPLATE_ID = None DB_CONNECTION_MUTATOR = None +# Whether to enable the DB_CONNECTION_MODIFIER feature +DB_CONNECTION_MODIFIER_ENABLED = False + +# A dictionary of database connect modifiers (by engine) that allows altering +# the database connection URL and params on the fly, at runtime. This allows for things +# like impersonation or arbitrary logic. For instance you can wire different users to +# use different connection parameters, or pass their email address as the +# username. The function receives the connection uri object, connection +# params, the username, and returns the mutated uri and params objects. +# Example: +# class PostgresDBConnectModifier(BaseDBConnectModifier): +# # When connecting to a postgres data source, +# # replace the default connection username and password +# +# @classmethod +# def run(cls, sqlalchemy_url: URL, params: dict[str, Any], username: str, *args: Any, +# **kwargs: Any) -> (URL, dict[str, Any]): +# new_password = cls._get_new_password(username) +# sqlalchemy_url.username = username +# sqlalchemy_url.password = new_password +# return sqlalchemy_url, params +# +# @staticmethod +# def _get_new_password(username): +# # 实现密码生成逻辑 +# return 'new_password_' + username +# +# DB_CONNECTION_MODIFIER: dict[str, type[BaseDBConnectModifier]] = { +# "postgresql": PostgresDBConnectModifier, +# } +# +# Note that the returned uri and params are passed directly to sqlalchemy's +# as such `create_engine(url, **params)` +DB_CONNECTION_MODIFIER: dict[str, type[BaseDBConnectModifier]] = {} + + # A callable that is invoked for every invocation of DB Engine Specs # which allows for custom validation of the engine URI. # See: superset.db_engine_specs.base.BaseEngineSpec.validate_database_uri diff --git a/superset/models/core.py b/superset/models/core.py index 78bbf55cdf..552c4a7969 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -93,6 +93,9 @@ if TYPE_CHECKING: DB_CONNECTION_MUTATOR = config["DB_CONNECTION_MUTATOR"] +DB_CONNECTION_MODIFIER_ENABLED = config["DB_CONNECTION_MODIFIER_ENABLED"] +DB_CONNECTION_MODIFIER = config["DB_CONNECTION_MODIFIER"] + class KeyValue(Model): # pylint: disable=too-few-public-methods """Used for any type of key-value store""" @@ -535,6 +538,16 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable security_manager, source, ) + + is_db_connect_modify = ( + DB_CONNECTION_MODIFIER_ENABLED and DB_CONNECTION_MODIFIER + and sqlalchemy_url.drivername in DB_CONNECTION_MODIFIER + ) + if is_db_connect_modify: + url_modified = DB_CONNECTION_MODIFIER[sqlalchemy_url.drivername] + sqlalchemy_url, params = url_modified.run( + sqlalchemy_url, params, effective_username) + try: return create_engine(sqlalchemy_url, **params) except Exception as ex: diff --git a/superset/utils/database_connect_modifier.py b/superset/utils/database_connect_modifier.py new file mode 100644 index 0000000000..0ab4bf955b --- /dev/null +++ b/superset/utils/database_connect_modifier.py @@ -0,0 +1,12 @@ +from typing import Any + +from sqlalchemy.engine.url import URL + + +class BaseDBConnectModifier: + name = "BaseURLModifier" + + @classmethod + def run(cls, sqlalchemy_url: URL, params: dict[str, Any], username: str, *args: Any, + **kwargs: Any) -> (URL, dict[str, Any]): + raise NotImplementedError diff --git a/tests/integration_tests/core_tests.py b/tests/integration_tests/core_tests.py index d085beba78..6d49ff14df 100644 --- a/tests/integration_tests/core_tests.py +++ b/tests/integration_tests/core_tests.py @@ -1156,6 +1156,35 @@ class TestCore(SupersetTestCase): data = self.get_resp(url) self.assertIn("Error message", data) + @pytest.mark.skip( + "TODO This test was wrong - 'Error message' was in the language pack" + ) + @pytest.mark.usefixtures("load_world_bank_dashboard_with_slices") + @mock.patch("superset.models.core.DB_CONNECTION_MODIFIER") + def test_explore_with_modifier_injected_exceptions(self, mock_db_connection_modifier): + """ + Handle injected exceptions from the db modifier + """ + # Assert we can handle a custom exception at the modifier level + exception = SupersetException("Error message") + mock_db_connection_modifier.side_effect = exception + slice = db.session.query(Slice).first() + url = f"/explore/?form_data=%7B%22slice_id%22%3A%20{slice.id}%7D" + + self.login(ADMIN_USERNAME) + data = self.get_resp(url) + self.assertIn("Error message", data) + + # Assert we can handle a driver exception at the modifier level + exception = SQLAlchemyError("Error message") + mock_db_connection_modifier.side_effect = exception + slice = db.session.query(Slice).first() + url = f"/explore/?form_data=%7B%22slice_id%22%3A%20{slice.id}%7D" + + self.login(ADMIN_USERNAME) + data = self.get_resp(url) + self.assertIn("Error message", data) + @pytest.mark.skip( "TODO This test was wrong - 'Error message' was in the language pack" ) @@ -1186,6 +1215,36 @@ class TestCore(SupersetTestCase): data = self.get_resp(url) self.assertIn("Error message", data) + @pytest.mark.skip( + "TODO This test was wrong - 'Error message' was in the language pack" + ) + @pytest.mark.usefixtures("load_world_bank_dashboard_with_slices") + @mock.patch("superset.models.core.DB_CONNECTION_MODIFIER") + def test_dashboard_with_modifier_injected_exceptions(self, mock_db_connection_modifier): + """ + Handle injected exceptions from the db modifier + """ + + # Assert we can handle a custom exception at the modifier level + exception = SupersetException("Error message") + mock_db_connection_modifier.side_effect = exception + dash = db.session.query(Dashboard).first() + url = f"/superset/dashboard/{dash.id}/" + + self.login(ADMIN_USERNAME) + data = self.get_resp(url) + self.assertIn("Error message", data) + + # Assert we can handle a driver exception at the modifier level + exception = SQLAlchemyError("Error message") + mock_db_connection_modifier.side_effect = exception + dash = db.session.query(Dashboard).first() + url = f"/superset/dashboard/{dash.id}/" + + self.login(ADMIN_USERNAME) + data = self.get_resp(url) + self.assertIn("Error message", data) + @pytest.mark.usefixtures("load_energy_table_with_slice") @mock.patch("superset.commands.explore.form_data.create.CreateFormDataCommand.run") def test_explore_redirect(self, mock_command: mock.Mock):