diff --git a/pyproject.toml b/pyproject.toml index 2e20fae77b..234af86bef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -120,7 +120,7 @@ db2 = ["ibm-db-sa>0.3.8, <=0.4.0"] dremio = ["sqlalchemy-dremio>=1.1.5, <1.3"] drill = ["sqlalchemy-drill>=1.1.4, <2"] druid = ["pydruid>=0.6.5,<0.7"] -duckdb = ["duckdb-engine>=0.9.5, <0.10"] +duckdb = ["duckdb-engine>=0.12.1, <0.13"] dynamodb = ["pydynamodb>=0.4.2"] solr = ["sqlalchemy-solr >= 0.2.0"] elasticsearch = ["elasticsearch-dbapi>=0.2.9, <0.3.0"] @@ -141,6 +141,7 @@ hive = [ impala = ["impyla>0.16.2, <0.17"] kusto = ["sqlalchemy-kusto>=2.0.0, <3"] kylin = ["kylinpy>=2.8.1, <2.9"] +motherduck = ["duckdb==0.10.2", "duckdb-engine>=0.12.1, <0.13"] mssql = ["pymssql>=2.2.8, <3"] mysql = ["mysqlclient>=2.1.0, <3"] ocient = [ diff --git a/superset/db_engine_specs/duckdb.py b/superset/db_engine_specs/duckdb.py index 89c45fb572..4f2e8fb611 100644 --- a/superset/db_engine_specs/duckdb.py +++ b/superset/db_engine_specs/duckdb.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + from __future__ import annotations import re @@ -259,6 +260,9 @@ class MotherDuckEngineSpec(DuckDBEngineSpec): engine_name = "MotherDuck" engine_aliases: set[str] = {"duckdb"} + supports_catalog = True + supports_dynamic_catalog = True + sqlalchemy_uri_placeholder = ( "duckdb:///md:{database_name}?motherduck_token={SERVICE_TOKEN}" ) @@ -293,3 +297,33 @@ class MotherDuckEngineSpec(DuckDBEngineSpec): return str( URL(drivername=DuckDBEngineSpec.engine, database=database, query=query) ) + + @classmethod + def adjust_engine_params( + cls, + uri: URL, + connect_args: dict[str, Any], + catalog: str | None = None, + schema: str | None = None, + ) -> tuple[URL, dict[str, Any]]: + if catalog: + uri = uri.set(database=f"md:{catalog}") + + return uri, connect_args + + @classmethod + def get_default_catalog(cls, database: Database) -> str | None: + return database.url_object.database.split(":", 1)[1] + + @classmethod + def get_catalog_names( + cls, + database: Database, + inspector: Inspector, + ) -> set[str]: + return { + catalog + for (catalog,) in inspector.bind.execute( + "SELECT alias FROM MD_ALL_DATABASES() WHERE is_attached;" + ) + }