feat(db): Adding DB_SQLA_URI_VALIDATOR (#27847)

This commit is contained in:
Craig Rueda 2024-04-02 09:00:32 -07:00 committed by GitHub
parent 9fece4f811
commit 8bdf457dfa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 36 additions and 1 deletions

View File

@ -44,6 +44,7 @@ from flask_appbuilder.security.manager import AUTH_DB
from flask_caching.backends.base import BaseCache from flask_caching.backends.base import BaseCache
from pandas import Series from pandas import Series
from pandas._libs.parsers import STR_NA_VALUES # pylint: disable=no-name-in-module from pandas._libs.parsers import STR_NA_VALUES # pylint: disable=no-name-in-module
from sqlalchemy.engine.url import URL
from sqlalchemy.orm.query import Query from sqlalchemy.orm.query import Query
from superset.advanced_data_type.plugins.internet_address import internet_address from superset.advanced_data_type.plugins.internet_address import internet_address
@ -1207,6 +1208,17 @@ DASHBOARD_TEMPLATE_ID = None
DB_CONNECTION_MUTATOR = None DB_CONNECTION_MUTATOR = None
# A callable that is invoked for every invocation of DB Engine Specs
# which allows for custom validation of the engine URI.
# See: superset.db_engine_specs.base.BaseEngineSpec.validate_database_uri
# Example:
# def DB_ENGINE_URI_VALIDATOR(sqlalchemy_uri: URL):
# if not <some condition>:
# raise Exception("URI invalid")
#
DB_SQLA_URI_VALIDATOR: Callable[[URL], None] | None = None
# A function that intercepts the SQL to be executed and can alter it. # A function that intercepts the SQL to be executed and can alter it.
# The use case is can be around adding some sort of comment header # The use case is can be around adding some sort of comment header
# with information such as the username and worker node information # with information such as the username and worker node information

View File

@ -1956,6 +1956,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
:param sqlalchemy_uri: :param sqlalchemy_uri:
""" """
if db_engine_uri_validator := current_app.config["DB_SQLA_URI_VALIDATOR"]:
db_engine_uri_validator(sqlalchemy_uri)
if existing_disallowed := cls.disallow_uri_query_params.get( if existing_disallowed := cls.disallow_uri_query_params.get(
sqlalchemy_uri.get_driver_name(), set() sqlalchemy_uri.get_driver_name(), set()
).intersection(sqlalchemy_uri.query): ).intersection(sqlalchemy_uri.query):

View File

@ -23,6 +23,7 @@ import pytest
from pytest_mock import MockFixture from pytest_mock import MockFixture
from sqlalchemy import types from sqlalchemy import types
from sqlalchemy.dialects import sqlite from sqlalchemy.dialects import sqlite
from sqlalchemy.engine.url import URL
from sqlalchemy.sql import sqltypes from sqlalchemy.sql import sqltypes
from superset.superset_typing import ResultSetColumnType, SQLAColumnType from superset.superset_typing import ResultSetColumnType, SQLAColumnType
@ -69,6 +70,25 @@ def test_parse_sql_multi_statement() -> None:
] ]
def test_validate_db_uri(mocker: MockFixture) -> None:
"""
Ensures that the `validate_database_uri` method invokes the validator correctly
"""
def mock_validate(sqlalchemy_uri: URL) -> None:
raise ValueError("Invalid URI")
mocker.patch(
"superset.db_engine_specs.base.current_app.config",
{"DB_SQLA_URI_VALIDATOR": mock_validate},
)
from superset.db_engine_specs.base import BaseEngineSpec
with pytest.raises(ValueError):
BaseEngineSpec.validate_database_uri(URL.create("sqlite"))
@pytest.mark.parametrize( @pytest.mark.parametrize(
"original,expected", "original,expected",
[ [

View File

@ -124,7 +124,7 @@ def test_superset_limit(mocker: MockFixture, app_context: None, table1: None) ->
""" """
mocker.patch( mocker.patch(
"superset.extensions.metadb.current_app.config", "superset.extensions.metadb.current_app.config",
{"SUPERSET_META_DB_LIMIT": 1}, {"DB_SQLA_URI_VALIDATOR": None, "SUPERSET_META_DB_LIMIT": 1},
) )
mocker.patch("superset.extensions.metadb.security_manager") mocker.patch("superset.extensions.metadb.security_manager")