feat(db_engine): Add custom_user_agent when connecting to MotherDuck (#27665)

This commit is contained in:
Guen Prawiroatmodjo 2024-03-28 18:05:28 -07:00 committed by GitHub
parent 8ae4662f17
commit fcf90dffa8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 58 additions and 2 deletions

View File

@ -25,7 +25,8 @@ from flask_babel import gettext as __
from sqlalchemy import types
from sqlalchemy.engine.reflection import Inspector
from superset.constants import TimeGrain
from superset.config import VERSION_STRING
from superset.constants import TimeGrain, USER_AGENT
from superset.db_engine_specs.base import BaseEngineSpec
from superset.errors import SupersetErrorType
@ -41,6 +42,8 @@ class DuckDBEngineSpec(BaseEngineSpec):
engine = "duckdb"
engine_name = "DuckDB"
sqlalchemy_uri_placeholder = "duckdb:////path/to/duck.db"
_time_grain_expressions = {
None: "{col}",
TimeGrain.SECOND: "DATE_TRUNC('second', {col})",
@ -81,9 +84,28 @@ class DuckDBEngineSpec(BaseEngineSpec):
) -> set[str]:
return set(inspector.get_table_names(schema))
@staticmethod
def get_extra_params(database: Database) -> dict[str, Any]:
"""
Add a user agent to be used in the requests.
"""
extra: dict[str, Any] = BaseEngineSpec.get_extra_params(database)
engine_params: dict[str, Any] = extra.setdefault("engine_params", {})
connect_args: dict[str, Any] = engine_params.setdefault("connect_args", {})
config: dict[str, Any] = connect_args.setdefault("config", {})
custom_user_agent = config.pop("custom_user_agent", "")
delim = " " if custom_user_agent else ""
user_agent = USER_AGENT.replace(" ", "-").lower()
user_agent = f"{user_agent}/{VERSION_STRING}{delim}{custom_user_agent}"
config.setdefault("custom_user_agent", user_agent)
return extra
class MotherDuckEngineSpec(DuckDBEngineSpec):
engine = "duckdb"
engine_name = "MotherDuck"
sqlalchemy_uri_placeholder = "duckdb:///md:{SERVICE_TOKEN}@{database_name}"
sqlalchemy_uri_placeholder = (
"duckdb:///md:{database_name}?motherduck_token={SERVICE_TOKEN}"
)

View File

@ -15,11 +15,14 @@
# specific language governing permissions and limitations
# under the License.
import json
from datetime import datetime
from typing import Optional
import pytest
from pytest_mock import MockerFixture
from superset.config import VERSION_STRING
from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm
from tests.unit_tests.fixtures.common import dttm
@ -38,3 +41,34 @@ def test_convert_dttm(
from superset.db_engine_specs.duckdb import DuckDBEngineSpec as spec
assert_convert_dttm(spec, target_type, expected_result, dttm)
def test_get_extra_params(mocker: MockerFixture) -> None:
"""
Test the ``get_extra_params`` method.
"""
from superset.db_engine_specs.duckdb import DuckDBEngineSpec
database = mocker.MagicMock()
database.extra = {}
assert DuckDBEngineSpec.get_extra_params(database) == {
"engine_params": {
"connect_args": {
"config": {"custom_user_agent": f"apache-superset/{VERSION_STRING}"}
}
}
}
database.extra = json.dumps(
{"engine_params": {"connect_args": {"config": {"custom_user_agent": "my-app"}}}}
)
assert DuckDBEngineSpec.get_extra_params(database) == {
"engine_params": {
"connect_args": {
"config": {
"custom_user_agent": f"apache-superset/{VERSION_STRING} my-app"
}
}
}
}