feat(db): Adding DB_SQLA_URI_VALIDATOR (#27847)

This commit is contained in:
Craig Rueda 2024-04-02 09:00:32 -07:00 committed by GitHub
parent 9fece4f811
commit 8bdf457dfa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 36 additions and 1 deletions

View File

@ -44,6 +44,7 @@ from flask_appbuilder.security.manager import AUTH_DB
from flask_caching.backends.base import BaseCache
from pandas import Series
from pandas._libs.parsers import STR_NA_VALUES # pylint: disable=no-name-in-module
from sqlalchemy.engine.url import URL
from sqlalchemy.orm.query import Query
from superset.advanced_data_type.plugins.internet_address import internet_address
@ -1207,6 +1208,17 @@ DASHBOARD_TEMPLATE_ID = None
DB_CONNECTION_MUTATOR = None
# A callable that is invoked for every invocation of DB Engine Specs
# which allows for custom validation of the engine URI.
# See: superset.db_engine_specs.base.BaseEngineSpec.validate_database_uri
# Example:
# def DB_ENGINE_URI_VALIDATOR(sqlalchemy_uri: URL):
# if not <some condition>:
# raise Exception("URI invalid")
#
DB_SQLA_URI_VALIDATOR: Callable[[URL], None] | None = None
# A function that intercepts the SQL to be executed and can alter it.
# The use case is can be around adding some sort of comment header
# with information such as the username and worker node information

View File

@ -1956,6 +1956,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
:param sqlalchemy_uri:
"""
if db_engine_uri_validator := current_app.config["DB_SQLA_URI_VALIDATOR"]:
db_engine_uri_validator(sqlalchemy_uri)
if existing_disallowed := cls.disallow_uri_query_params.get(
sqlalchemy_uri.get_driver_name(), set()
).intersection(sqlalchemy_uri.query):

View File

@ -23,6 +23,7 @@ import pytest
from pytest_mock import MockFixture
from sqlalchemy import types
from sqlalchemy.dialects import sqlite
from sqlalchemy.engine.url import URL
from sqlalchemy.sql import sqltypes
from superset.superset_typing import ResultSetColumnType, SQLAColumnType
@ -69,6 +70,25 @@ def test_parse_sql_multi_statement() -> None:
]
def test_validate_db_uri(mocker: MockFixture) -> None:
"""
Ensures that the `validate_database_uri` method invokes the validator correctly
"""
def mock_validate(sqlalchemy_uri: URL) -> None:
raise ValueError("Invalid URI")
mocker.patch(
"superset.db_engine_specs.base.current_app.config",
{"DB_SQLA_URI_VALIDATOR": mock_validate},
)
from superset.db_engine_specs.base import BaseEngineSpec
with pytest.raises(ValueError):
BaseEngineSpec.validate_database_uri(URL.create("sqlite"))
@pytest.mark.parametrize(
"original,expected",
[

View File

@ -124,7 +124,7 @@ def test_superset_limit(mocker: MockFixture, app_context: None, table1: None) ->
"""
mocker.patch(
"superset.extensions.metadb.current_app.config",
{"SUPERSET_META_DB_LIMIT": 1},
{"DB_SQLA_URI_VALIDATOR": None, "SUPERSET_META_DB_LIMIT": 1},
)
mocker.patch("superset.extensions.metadb.security_manager")