diff --git a/superset/config.py b/superset/config.py index 7f26ac27c4..b01b0ccab0 100644 --- a/superset/config.py +++ b/superset/config.py @@ -1073,13 +1073,13 @@ SQL_VALIDATORS_BY_ENGINE = { # A list of preferred databases, in order. These databases will be # displayed prominently in the "Add Database" dialog. You should -# use the "engine" attribute of the corresponding DB engine spec in -# `superset/db_engine_specs/`. +# use the "engine_name" attribute of the corresponding DB engine spec +# in `superset/db_engine_specs/`. PREFERRED_DATABASES: List[str] = [ - # "postgresql", - # "presto", - # "mysql", - # "sqlite", + # "PostgreSQL", + # "Presto", + # "MySQL", + # "SQLite", # etc. ] diff --git a/superset/databases/api.py b/superset/databases/api.py index c313f64c8a..75a218741b 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -886,6 +886,17 @@ class DatabaseRestApi(BaseSupersetModelRestApi): name: description: Name of the database type: string + engine: + description: Name of the SQLAlchemy engine + type: string + available_drivers: + description: Installed drivers for the engine + type: array + items: + type: string + default_driver: + description: Default driver for the engine + type: string preferred: description: Is the database preferred? type: boolean @@ -894,6 +905,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi): type: string parameters: description: JSON schema defining the needed parameters + type: object 400: $ref: '#/components/responses/400' 500: @@ -901,15 +913,22 @@ class DatabaseRestApi(BaseSupersetModelRestApi): """ preferred_databases: List[str] = app.config.get("PREFERRED_DATABASES", []) available_databases = [] - for engine_spec in get_available_engine_specs(): + for engine_spec, drivers in get_available_engine_specs().items(): payload: Dict[str, Any] = { "name": engine_spec.engine_name, "engine": engine_spec.engine, - "preferred": engine_spec.engine in preferred_databases, + "available_drivers": sorted(drivers), + "preferred": engine_spec.engine_name in preferred_databases, } - if hasattr(engine_spec, "parameters_json_schema") and hasattr( - engine_spec, "sqlalchemy_uri_placeholder" + if hasattr(engine_spec, "default_driver"): + payload["default_driver"] = engine_spec.default_driver # type: ignore + + # show configuration parameters for DBs that support it + if ( + hasattr(engine_spec, "parameters_json_schema") + and hasattr(engine_spec, "sqlalchemy_uri_placeholder") + and getattr(engine_spec, "default_driver") in drivers ): payload[ "parameters" @@ -920,13 +939,25 @@ class DatabaseRestApi(BaseSupersetModelRestApi): available_databases.append(payload) - available_databases.sort( - key=lambda payload: preferred_databases.index(payload["engine"]) - if payload["engine"] in preferred_databases - else len(preferred_databases) + # sort preferred first + response = sorted( + (payload for payload in available_databases if payload["preferred"]), + key=lambda payload: preferred_databases.index(payload["name"]), ) - return self.response(200, databases=available_databases) + # add others + response.extend( + sorted( + ( + payload + for payload in available_databases + if not payload["preferred"] + ), + key=lambda payload: payload["name"], + ) + ) + + return self.response(200, databases=response) @expose("/validate_parameters", methods=["POST"]) @protect() diff --git a/superset/db_engine_specs/__init__.py b/superset/db_engine_specs/__init__.py index a4e083cf6e..f4ced6f323 100644 --- a/superset/db_engine_specs/__init__.py +++ b/superset/db_engine_specs/__init__.py @@ -30,12 +30,15 @@ The general idea is to use static classes and an inheritance scheme. import inspect import logging import pkgutil +from collections import defaultdict from importlib import import_module from pathlib import Path from typing import Any, Dict, List, Set, Type import sqlalchemy.databases +import sqlalchemy.dialects from pkg_resources import iter_entry_points +from sqlalchemy.engine.default import DefaultDialect from superset.db_engine_specs.base import BaseEngineSpec @@ -85,12 +88,31 @@ def get_engine_specs() -> Dict[str, Type[BaseEngineSpec]]: return engine_specs_map -def get_available_engine_specs() -> List[Type[BaseEngineSpec]]: +def get_available_engine_specs() -> Dict[Type[BaseEngineSpec], Set[str]]: + """ + Return available engine specs and installed drivers for them. + """ + drivers: Dict[str, Set[str]] = defaultdict(set) + # native SQLAlchemy dialects - backends: Set[str] = { - getattr(sqlalchemy.databases, attr).dialect.name - for attr in sqlalchemy.databases.__all__ - } + for attr in sqlalchemy.databases.__all__: + dialect = getattr(sqlalchemy.dialects, attr) + for attribute in dialect.__dict__.values(): + if ( + hasattr(attribute, "dialect") + and inspect.isclass(attribute.dialect) + and issubclass(attribute.dialect, DefaultDialect) + ): + try: + attribute.dialect.dbapi() + except ModuleNotFoundError: + continue + except Exception as ex: # pylint: disable=broad-except + logger.warning( + "Unable to load dialect %s: %s", attribute.dialect, ex + ) + continue + drivers[attr].add(attribute.dialect.driver) # installed 3rd-party dialects for ep in iter_entry_points("sqlalchemy.dialects"): @@ -99,7 +121,11 @@ def get_available_engine_specs() -> List[Type[BaseEngineSpec]]: except Exception: # pylint: disable=broad-except logger.warning("Unable to load SQLAlchemy dialect: %s", dialect) else: - backends.add(dialect.name) + drivers[dialect.name].add(dialect.driver) engine_specs = get_engine_specs() - return [engine_specs[backend] for backend in backends if backend in engine_specs] + return { + engine_specs[backend]: drivers + for backend, drivers in drivers.items() + if backend in engine_specs + } diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 97321b8b9a..a7463a48e3 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -1328,7 +1328,7 @@ class BasicParametersMixin: individual parameters, instead of the full SQLAlchemy URI. This mixin is for the most common pattern of URI: - drivername://user:password@host:port/dbname[?key=value&key=value...] + engine+driver://user:password@host:port/dbname[?key=value&key=value...] """ @@ -1336,11 +1336,11 @@ class BasicParametersMixin: parameters_schema = BasicParametersSchema() # recommended driver name for the DB engine spec - drivername = "" + default_driver = "" # placeholder with the SQLAlchemy URI template sqlalchemy_uri_placeholder = ( - "drivername://user:password@host:port/dbname[?key=value&key=value...]" + "engine+driver://user:password@host:port/dbname[?key=value&key=value...]" ) # query parameter to enable encryption in the database connection @@ -1361,7 +1361,7 @@ class BasicParametersMixin: return str( URL( - cls.drivername, + f"{cls.engine}+{cls.default_driver}".rstrip("+"), # type: ignore username=parameters.get("username"), password=parameters.get("password"), host=parameters["host"], diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py index 5f46bb3cfa..0fca504e7e 100644 --- a/superset/db_engine_specs/bigquery.py +++ b/superset/db_engine_specs/bigquery.py @@ -67,7 +67,7 @@ class BigQueryEngineSpec(BaseEngineSpec): max_column_name_length = 128 parameters_schema = BigQueryParametersSchema() - drivername = engine + default_driver = "bigquery" sqlalchemy_uri_placeholder = "bigquery://{project_id}" # BigQuery doesn't maintain context when running multiple statements in the @@ -313,7 +313,7 @@ class BigQueryEngineSpec(BaseEngineSpec): project_id = encrypted_extra.get("credentials_info", {}).get("project_id") if project_id: - return f"{cls.drivername}://{project_id}" + return f"{cls.engine}+{cls.default_driver}://{project_id}" raise SupersetGenericDBErrorException( message="Big Query encrypted_extra is not available.", diff --git a/superset/db_engine_specs/cockroachdb.py b/superset/db_engine_specs/cockroachdb.py index 80b547bd8d..8c83bd793d 100644 --- a/superset/db_engine_specs/cockroachdb.py +++ b/superset/db_engine_specs/cockroachdb.py @@ -20,4 +20,4 @@ from superset.db_engine_specs.postgres import PostgresEngineSpec class CockroachDbEngineSpec(PostgresEngineSpec): engine = "cockroachdb" engine_name = "CockroachDB" - drivername = "cockroach" + default_driver = "" diff --git a/superset/db_engine_specs/mysql.py b/superset/db_engine_specs/mysql.py index 870dbe853d..4bb5979706 100644 --- a/superset/db_engine_specs/mysql.py +++ b/superset/db_engine_specs/mysql.py @@ -58,11 +58,10 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin): engine_name = "MySQL" max_column_name_length = 64 - drivername = "mysql+mysqldb" + default_driver = "mysqldb" sqlalchemy_uri_placeholder = ( "mysql://user:password@host:port/dbname[?key=value&key=value...]" ) - encryption_parameters = {"ssl": "1"} column_type_mappings: Tuple[ diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index 5c8ff40365..513e01eb0c 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -159,9 +159,9 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin): engine = "postgresql" engine_aliases = {"postgres"} - drivername = "postgresql+psycopg2" + default_driver = "psycopg2" sqlalchemy_uri_placeholder = ( - "postgresql+psycopg2://user:password@host:port/dbname[?key=value&key=value...]" + "postgresql://user:password@host:port/dbname[?key=value&key=value...]" ) # https://www.postgresql.org/docs/9.1/libpq-ssl.html#LIBQ-SSL-CERTIFICATES encryption_parameters = {"sslmode": "verify-ca"} diff --git a/tests/databases/api_tests.py b/tests/databases/api_tests.py index 3e82a7eb47..aea98d3f2c 100644 --- a/tests/databases/api_tests.py +++ b/tests/databases/api_tests.py @@ -1372,16 +1372,14 @@ class TestDatabaseApi(SupersetTestCase): @mock.patch("superset.databases.api.get_available_engine_specs") @mock.patch("superset.databases.api.app") def test_available(self, app, get_available_engine_specs): - app.config = { - "PREFERRED_DATABASES": ["postgresql", "biqquery", "mysql", "redshift"] + app.config = {"PREFERRED_DATABASES": ["PostgreSQL", "Google BigQuery"]} + get_available_engine_specs.return_value = { + PostgresEngineSpec: {"psycopg2"}, + BigQueryEngineSpec: {"bigquery"}, + MySQLEngineSpec: {"mysqlconnector", "mysqldb"}, + RedshiftEngineSpec: {"psycopg2"}, + HanaEngineSpec: {""}, } - get_available_engine_specs.return_value = [ - PostgresEngineSpec, - BigQueryEngineSpec, - MySQLEngineSpec, - RedshiftEngineSpec, - HanaEngineSpec, - ] self.login(username="admin") uri = "api/v1/database/available/" @@ -1392,6 +1390,8 @@ class TestDatabaseApi(SupersetTestCase): assert response == { "databases": [ { + "available_drivers": ["psycopg2"], + "default_driver": "psycopg2", "engine": "postgresql", "name": "PostgreSQL", "parameters": { @@ -1433,9 +1433,36 @@ class TestDatabaseApi(SupersetTestCase): "type": "object", }, "preferred": True, - "sqlalchemy_uri_placeholder": "postgresql+psycopg2://user:password@host:port/dbname[?key=value&key=value...]", + "sqlalchemy_uri_placeholder": "postgresql://user:password@host:port/dbname[?key=value&key=value...]", }, { + "available_drivers": ["bigquery"], + "default_driver": "bigquery", + "engine": "bigquery", + "name": "Google BigQuery", + "parameters": { + "properties": { + "credentials_info": { + "description": "Contents of BigQuery JSON credentials.", + "type": "string", + "x-encrypted-extra": True, + } + }, + "type": "object", + }, + "preferred": True, + "sqlalchemy_uri_placeholder": "bigquery://{project_id}", + }, + { + "available_drivers": ["psycopg2"], + "default_driver": "", + "engine": "redshift", + "name": "Amazon Redshift", + "preferred": False, + }, + { + "available_drivers": ["mysqlconnector", "mysqldb"], + "default_driver": "mysqldb", "engine": "mysql", "name": "MySQL", "parameters": { @@ -1476,70 +1503,48 @@ class TestDatabaseApi(SupersetTestCase): "required": ["database", "host", "port", "username"], "type": "object", }, - "preferred": True, + "preferred": False, "sqlalchemy_uri_placeholder": "mysql://user:password@host:port/dbname[?key=value&key=value...]", }, { - "engine": "redshift", - "name": "Amazon Redshift", - "parameters": { - "properties": { - "database": { - "description": "Database name", - "type": "string", - }, - "encryption": { - "description": "Use an encrypted connection to the database", - "type": "boolean", - }, - "host": { - "description": "Hostname or IP address", - "type": "string", - }, - "password": { - "description": "Password", - "nullable": True, - "type": "string", - }, - "port": { - "description": "Database port", - "format": "int32", - "type": "integer", - }, - "query": { - "additionalProperties": {}, - "description": "Additional parameters", - "type": "object", - }, - "username": { - "description": "Username", - "nullable": True, - "type": "string", - }, - }, - "required": ["database", "host", "port", "username"], - "type": "object", - }, + "available_drivers": [""], + "engine": "hana", + "name": "SAP HANA", + "preferred": False, + }, + ] + } + + @mock.patch("superset.databases.api.get_available_engine_specs") + @mock.patch("superset.databases.api.app") + def test_available_no_default(self, app, get_available_engine_specs): + app.config = {"PREFERRED_DATABASES": ["MySQL"]} + get_available_engine_specs.return_value = { + MySQLEngineSpec: {"mysqlconnector"}, + HanaEngineSpec: {""}, + } + + self.login(username="admin") + uri = "api/v1/database/available/" + + rv = self.client.get(uri) + response = json.loads(rv.data.decode("utf-8")) + assert rv.status_code == 200 + assert response == { + "databases": [ + { + "available_drivers": ["mysqlconnector"], + "default_driver": "mysqldb", + "engine": "mysql", + "name": "MySQL", "preferred": True, - "sqlalchemy_uri_placeholder": "redshift+psycopg2://user:password@host:port/dbname[?key=value&key=value...]", }, { - "engine": "bigquery", - "name": "Google BigQuery", - "parameters": { - "properties": { - "credentials_info": { - "description": "Contents of BigQuery JSON credentials.", - "type": "string", - "x-encrypted-extra": True, - } - }, - "type": "object", - }, + "available_drivers": [""], + "engine": "hana", + "name": "SAP HANA", "preferred": False, - "sqlalchemy_uri_placeholder": "bigquery://{project_id}", }, - {"engine": "hana", "name": "SAP HANA", "preferred": False}, ] } diff --git a/tests/databases/schema_tests.py b/tests/databases/schema_tests.py index 80859a7496..021ac578b0 100644 --- a/tests/databases/schema_tests.py +++ b/tests/databases/schema_tests.py @@ -29,7 +29,8 @@ class DummySchema(Schema, DatabaseParametersSchemaMixin): class DummyEngine(BasicParametersMixin): - drivername = "dummy" + engine = "dummy" + default_driver = "dummy" class InvalidEngine: @@ -54,7 +55,7 @@ def test_database_parameters_schema_mixin(get_engine_specs): result = schema.load(payload) assert result == { "configuration_method": ConfigurationMethod.DYNAMIC_FORM, - "sqlalchemy_uri": "dummy://username:password@localhost:12345/dbname", + "sqlalchemy_uri": "dummy+dummy://username:password@localhost:12345/dbname", }