diff --git a/superset/databases/api.py b/superset/databases/api.py index 0a0dc3c989..f9e2c4150d 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -14,126 +14,78 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Optional +import logging +from typing import Any, Optional +from flask import g, request, Response from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.models.sqla.interface import SQLAInterface +from marshmallow import ValidationError from sqlalchemy.exc import NoSuchTableError, OperationalError, SQLAlchemyError from superset import event_logger +from superset.constants import RouteMethod +from superset.databases.commands.create import CreateDatabaseCommand +from superset.databases.commands.delete import DeleteDatabaseCommand +from superset.databases.commands.exceptions import ( + DatabaseCreateFailedError, + DatabaseDeleteDatasetsExistFailedError, + DatabaseDeleteFailedError, + DatabaseInvalidError, + DatabaseNotFoundError, + DatabaseUpdateFailedError, +) +from superset.databases.commands.update import UpdateDatabaseCommand from superset.databases.decorators import check_datasource_access +from superset.databases.filters import DatabaseFilter from superset.databases.schemas import ( database_schemas_query_schema, + DatabasePostSchema, + DatabasePutSchema, SchemasResponseSchema, SelectStarResponseSchema, TableMetadataResponseSchema, ) +from superset.databases.utils import get_table_metadata from superset.extensions import security_manager from superset.models.core import Database from superset.typing import FlaskResponse from superset.utils.core import error_msg_from_exception from superset.views.base_api import BaseSupersetModelRestApi, statsd_metrics -from superset.views.database.filters import DatabaseFilter -from superset.views.database.validators import sqlalchemy_uri_validator - -def get_foreign_keys_metadata( - database: Database, table_name: str, schema_name: Optional[str] -) -> List[Dict[str, Any]]: - foreign_keys = database.get_foreign_keys(table_name, schema_name) - for fk in foreign_keys: - fk["column_names"] = fk.pop("constrained_columns") - fk["type"] = "fk" - return foreign_keys - - -def get_indexes_metadata( - database: Database, table_name: str, schema_name: Optional[str] -) -> List[Dict[str, Any]]: - indexes = database.get_indexes(table_name, schema_name) - for idx in indexes: - idx["type"] = "index" - return indexes - - -def get_col_type(col: Dict[Any, Any]) -> str: - try: - dtype = f"{col['type']}" - except Exception: # pylint: disable=broad-except - # sqla.types.JSON __str__ has a bug, so using __class__. - dtype = col["type"].__class__.__name__ - return dtype - - -def get_table_metadata( - database: Database, table_name: str, schema_name: Optional[str] -) -> Dict[str, Any]: - """ - Get table metadata information, including type, pk, fks. - This function raises SQLAlchemyError when a schema is not found. - - :param database: The database model - :param table_name: Table name - :param schema_name: schema name - :return: Dict table metadata ready for API response - """ - keys = [] - columns = database.get_columns(table_name, schema_name) - primary_key = database.get_pk_constraint(table_name, schema_name) - if primary_key and primary_key.get("constrained_columns"): - primary_key["column_names"] = primary_key.pop("constrained_columns") - primary_key["type"] = "pk" - keys += [primary_key] - foreign_keys = get_foreign_keys_metadata(database, table_name, schema_name) - indexes = get_indexes_metadata(database, table_name, schema_name) - keys += foreign_keys + indexes - payload_columns: List[Dict[str, Any]] = [] - for col in columns: - dtype = get_col_type(col) - payload_columns.append( - { - "name": col["name"], - "type": dtype.split("(")[0] if "(" in dtype else dtype, - "longType": dtype, - "keys": [k for k in keys if col["name"] in k["column_names"]], - } - ) - return { - "name": table_name, - "columns": payload_columns, - "selectStar": database.select_star( - table_name, - schema=schema_name, - show_cols=True, - indent=True, - cols=columns, - latest_partition=True, - ), - "primaryKey": primary_key, - "foreignKeys": foreign_keys, - "indexes": keys, - } +logger = logging.getLogger(__name__) class DatabaseRestApi(BaseSupersetModelRestApi): datamodel = SQLAInterface(Database) - include_route_methods = { - "get_list", + include_route_methods = RouteMethod.REST_MODEL_VIEW_CRUD_SET | { "table_metadata", "select_star", "schemas", } class_permission_name = "DatabaseView" - method_permission_name = { - "get_list": "list", - "table_metadata": "list", - "select_star": "list", - "schemas": "list", - } resource_name = "database" allow_browser_login = True base_filters = [["id", DatabaseFilter, lambda: []]] + show_columns = [ + "id", + "database_name", + "cache_timeout", + "expose_in_sqllab", + "allow_run_async", + "allow_csv_upload", + "allow_ctas", + "allow_cvas", + "allow_dml", + "force_ctas_schema", + "allow_multi_schema_metadata_fetch", + "impersonate_user", + "encrypted_extra", + "extra", + "server_cert", + "sqlalchemy_uri", + ] list_columns = [ "id", "database_name", @@ -152,10 +104,30 @@ class DatabaseRestApi(BaseSupersetModelRestApi): "backend", "function_names", ] + add_columns = [ + "database_name", + "sqlalchemy_uri", + "cache_timeout", + "expose_in_sqllab", + "allow_run_async", + "allow_csv_upload", + "allow_ctas", + "allow_cvas", + "allow_dml", + "force_ctas_schema", + "impersonate_user", + "allow_multi_schema_metadata_fetch", + "extra", + "encrypted_extra", + "server_cert", + ] + edit_columns = add_columns + list_select_columns = list_columns + ["extra", "sqlalchemy_uri", "password"] # Removes the local limit for the page size max_page_size = -1 - validators_columns = {"sqlalchemy_uri": sqlalchemy_uri_validator} + add_model_schema = DatabasePostSchema() + edit_model_schema = DatabasePutSchema() apispec_parameter_schemas = { "database_schemas_query_schema": database_schemas_query_schema, @@ -167,6 +139,186 @@ class DatabaseRestApi(BaseSupersetModelRestApi): SchemasResponseSchema, ) + @expose("/", methods=["POST"]) + @protect() + @safe + @statsd_metrics + def post(self) -> Response: + """Creates a new Database + --- + post: + description: >- + Create a new Database. + requestBody: + description: Database schema + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/{{self.__class__.__name__}}.post' + responses: + 201: + description: Database added + content: + application/json: + schema: + type: object + properties: + id: + type: number + result: + $ref: '#/components/schemas/{{self.__class__.__name__}}.post' + 302: + description: Redirects to the current digest + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 404: + $ref: '#/components/responses/404' + 500: + $ref: '#/components/responses/500' + """ + if not request.is_json: + return self.response_400(message="Request is not JSON") + try: + item = self.add_model_schema.load(request.json) + # This validates custom Schema with custom validations + except ValidationError as error: + return self.response_400(message=error.messages) + try: + new_model = CreateDatabaseCommand(g.user, item).run() + # Return censored version for sqlalchemy URI + item["sqlalchemy_uri"] = new_model.sqlalchemy_uri + return self.response(201, id=new_model.id, result=item) + except DatabaseInvalidError as ex: + return self.response_422(message=ex.normalized_messages()) + except DatabaseCreateFailedError as ex: + logger.error( + "Error creating model %s: %s", self.__class__.__name__, str(ex) + ) + return self.response_422(message=str(ex)) + + @expose("/", methods=["PUT"]) + @protect() + @safe + @statsd_metrics + def put( # pylint: disable=too-many-return-statements, arguments-differ + self, pk: int + ) -> Response: + """Changes a Database + --- + put: + description: >- + Changes a Database. + parameters: + - in: path + schema: + type: integer + name: pk + requestBody: + description: Database schema + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/{{self.__class__.__name__}}.put' + responses: + 200: + description: Database changed + content: + application/json: + schema: + type: object + properties: + id: + type: number + result: + $ref: '#/components/schemas/{{self.__class__.__name__}}.put' + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 403: + $ref: '#/components/responses/403' + 404: + $ref: '#/components/responses/404' + 422: + $ref: '#/components/responses/422' + 500: + $ref: '#/components/responses/500' + """ + if not request.is_json: + return self.response_400(message="Request is not JSON") + try: + item = self.edit_model_schema.load(request.json) + # This validates custom Schema with custom validations + except ValidationError as error: + return self.response_400(message=error.messages) + try: + changed_model = UpdateDatabaseCommand(g.user, pk, item).run() + # Return censored version for sqlalchemy URI + item["sqlalchemy_uri"] = changed_model.sqlalchemy_uri + return self.response(200, id=changed_model.id, result=item) + except DatabaseNotFoundError: + return self.response_404() + except DatabaseInvalidError as ex: + return self.response_422(message=ex.normalized_messages()) + except DatabaseUpdateFailedError as ex: + logger.error( + "Error updating model %s: %s", self.__class__.__name__, str(ex) + ) + return self.response_422(message=str(ex)) + + @expose("/", methods=["DELETE"]) + @protect() + @safe + @statsd_metrics + def delete(self, pk: int) -> Response: # pylint: disable=arguments-differ + """Deletes a Database + --- + delete: + description: >- + Deletes a Database. + parameters: + - in: path + schema: + type: integer + name: pk + responses: + 200: + description: Database deleted + content: + application/json: + schema: + type: object + properties: + message: + type: string + 401: + $ref: '#/components/responses/401' + 403: + $ref: '#/components/responses/403' + 404: + $ref: '#/components/responses/404' + 422: + $ref: '#/components/responses/422' + 500: + $ref: '#/components/responses/500' + """ + try: + DeleteDatabaseCommand(g.user, pk).run() + return self.response(200, message="OK") + except DatabaseNotFoundError: + return self.response_404() + except DatabaseDeleteDatasetsExistFailedError as ex: + return self.response_422(message=str(ex)) + except DatabaseDeleteFailedError as ex: + logger.error( + "Error deleting model %s: %s", self.__class__.__name__, str(ex) + ) + return self.response_422(message=str(ex)) + @expose("//schemas/") @protect() @safe diff --git a/superset/databases/commands/__init__.py b/superset/databases/commands/__init__.py new file mode 100644 index 0000000000..13a83393a9 --- /dev/null +++ b/superset/databases/commands/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/superset/databases/commands/create.py b/superset/databases/commands/create.py new file mode 100644 index 0000000000..115652bc69 --- /dev/null +++ b/superset/databases/commands/create.py @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +from typing import Any, Dict, List, Optional + +from flask_appbuilder.models.sqla import Model +from flask_appbuilder.security.sqla.models import User +from marshmallow import ValidationError + +from superset.commands.base import BaseCommand +from superset.dao.exceptions import DAOCreateFailedError +from superset.databases.commands.exceptions import ( + DatabaseConnectionFailedError, + DatabaseCreateFailedError, + DatabaseExistsValidationError, + DatabaseInvalidError, + DatabaseRequiredFieldValidationError, +) +from superset.databases.dao import DatabaseDAO +from superset.extensions import db, security_manager + +logger = logging.getLogger(__name__) + + +class CreateDatabaseCommand(BaseCommand): + def __init__(self, user: User, data: Dict[str, Any]): + self._actor = user + self._properties = data.copy() + + def run(self) -> Model: + self.validate() + try: + database = DatabaseDAO.create(self._properties, commit=False) + database.set_sqlalchemy_uri(database.sqlalchemy_uri) + # adding a new database we always want to force refresh schema list + # TODO Improve this simplistic implementation for catching DB conn fails + try: + schemas = database.get_all_schema_names() + except Exception: + db.session.rollback() + raise DatabaseConnectionFailedError() + for schema in schemas: + security_manager.add_permission_view_menu( + "schema_access", security_manager.get_schema_perm(database, schema) + ) + security_manager.add_permission_view_menu("database_access", database.perm) + db.session.commit() + except DAOCreateFailedError as ex: + logger.exception(ex.exception) + raise DatabaseCreateFailedError() + return database + + def validate(self) -> None: + exceptions: List[ValidationError] = list() + sqlalchemy_uri: Optional[str] = self._properties.get("sqlalchemy_uri") + database_name: Optional[str] = self._properties.get("database_name") + + if not sqlalchemy_uri: + exceptions.append(DatabaseRequiredFieldValidationError("sqlalchemy_uri")) + if not database_name: + exceptions.append(DatabaseRequiredFieldValidationError("database_name")) + else: + # Check database_name uniqueness + if not DatabaseDAO.validate_uniqueness(database_name): + exceptions.append(DatabaseExistsValidationError()) + + if exceptions: + exception = DatabaseInvalidError() + exception.add_list(exceptions) + raise exception diff --git a/superset/databases/commands/delete.py b/superset/databases/commands/delete.py new file mode 100644 index 0000000000..5ea8de3731 --- /dev/null +++ b/superset/databases/commands/delete.py @@ -0,0 +1,58 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +from typing import Optional + +from flask_appbuilder.models.sqla import Model +from flask_appbuilder.security.sqla.models import User + +from superset.commands.base import BaseCommand +from superset.dao.exceptions import DAODeleteFailedError +from superset.databases.commands.exceptions import ( + DatabaseDeleteDatasetsExistFailedError, + DatabaseDeleteFailedError, + DatabaseNotFoundError, +) +from superset.databases.dao import DatabaseDAO +from superset.models.core import Database + +logger = logging.getLogger(__name__) + + +class DeleteDatabaseCommand(BaseCommand): + def __init__(self, user: User, model_id: int): + self._actor = user + self._model_id = model_id + self._model: Optional[Database] = None + + def run(self) -> Model: + self.validate() + try: + database = DatabaseDAO.delete(self._model) + except DAODeleteFailedError as ex: + logger.exception(ex.exception) + raise DatabaseDeleteFailedError() + return database + + def validate(self) -> None: + # Validate/populate model exists + self._model = DatabaseDAO.find_by_id(self._model_id) + if not self._model: + raise DatabaseNotFoundError() + # Check if there are datasets for this database + if self._model.tables: + raise DatabaseDeleteDatasetsExistFailedError() diff --git a/superset/databases/commands/exceptions.py b/superset/databases/commands/exceptions.py new file mode 100644 index 0000000000..66a3245b17 --- /dev/null +++ b/superset/databases/commands/exceptions.py @@ -0,0 +1,111 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from flask_babel import lazy_gettext as _ +from marshmallow.validate import ValidationError + +from superset.commands.exceptions import ( + CommandException, + CommandInvalidError, + CreateFailedError, + DeleteFailedError, + UpdateFailedError, +) + + +class DatabaseInvalidError(CommandInvalidError): + message = _("Dashboard parameters are invalid.") + + +class DatabaseExistsValidationError(ValidationError): + """ + Marshmallow validation error for dataset already exists + """ + + def __init__(self) -> None: + super().__init__( + _("A database with the same name already exists"), + field_name="database_name", + ) + + +class DatabaseRequiredFieldValidationError(ValidationError): + def __init__(self, field_name: str) -> None: + super().__init__( + [_("Field is required")], field_name=field_name, + ) + + +class DatabaseExtraJSONValidationError(ValidationError): + """ + Marshmallow validation error for database encrypted extra must be a valid JSON + """ + + def __init__(self, json_error: str = "") -> None: + super().__init__( + [ + _( + "Field cannot be decoded by JSON. %{json_error}s", + json_error=json_error, + ) + ], + field_name="extra", + ) + + +class DatabaseExtraValidationError(ValidationError): + """ + Marshmallow validation error for database encrypted extra must be a valid JSON + """ + + def __init__(self, key: str = "") -> None: + super().__init__( + [ + _( + "The metadata_params in Extra field " + "is not configured correctly. The key " + "%{key}s is invalid.", + key=key, + ) + ], + field_name="extra", + ) + + +class DatabaseNotFoundError(CommandException): + message = _("Database not found.") + + +class DatabaseCreateFailedError(CreateFailedError): + message = _("Database could not be created.") + + +class DatabaseUpdateFailedError(UpdateFailedError): + message = _("Database could not be updated.") + + +class DatabaseConnectionFailedError( # pylint: disable=too-many-ancestors + DatabaseCreateFailedError, DatabaseUpdateFailedError, +): + message = _("Could not connect to database.") + + +class DatabaseDeleteDatasetsExistFailedError(DeleteFailedError): + message = _("Cannot delete a database that has tables attached") + + +class DatabaseDeleteFailedError(DeleteFailedError): + message = _("Database could not be deleted.") diff --git a/superset/databases/commands/update.py b/superset/databases/commands/update.py new file mode 100644 index 0000000000..5d33c08bbd --- /dev/null +++ b/superset/databases/commands/update.py @@ -0,0 +1,87 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +from typing import Any, Dict, List, Optional + +from flask_appbuilder.models.sqla import Model +from flask_appbuilder.security.sqla.models import User +from marshmallow import ValidationError + +from superset.commands.base import BaseCommand +from superset.dao.exceptions import DAOUpdateFailedError +from superset.databases.commands.exceptions import ( + DatabaseConnectionFailedError, + DatabaseExistsValidationError, + DatabaseInvalidError, + DatabaseNotFoundError, + DatabaseUpdateFailedError, +) +from superset.databases.dao import DatabaseDAO +from superset.extensions import db, security_manager +from superset.models.core import Database + +logger = logging.getLogger(__name__) + + +class UpdateDatabaseCommand(BaseCommand): + def __init__(self, user: User, model_id: int, data: Dict[str, Any]): + self._actor = user + self._properties = data.copy() + self._model_id = model_id + self._model: Optional[Database] = None + + def run(self) -> Model: + self.validate() + try: + database = DatabaseDAO.update(self._model, self._properties, commit=False) + database.set_sqlalchemy_uri(database.sqlalchemy_uri) + security_manager.add_permission_view_menu("database_access", database.perm) + # adding a new database we always want to force refresh schema list + # TODO Improve this simplistic implementation for catching DB conn fails + try: + schemas = database.get_all_schema_names() + except Exception: + db.session.rollback() + raise DatabaseConnectionFailedError() + for schema in schemas: + security_manager.add_permission_view_menu( + "schema_access", security_manager.get_schema_perm(database, schema) + ) + db.session.commit() + + except DAOUpdateFailedError as ex: + logger.exception(ex.exception) + raise DatabaseUpdateFailedError() + return database + + def validate(self) -> None: + exceptions: List[ValidationError] = list() + # Validate/populate model exists + self._model = DatabaseDAO.find_by_id(self._model_id) + if not self._model: + raise DatabaseNotFoundError() + database_name: Optional[str] = self._properties.get("database_name") + if database_name: + # Check database_name uniqueness + if not DatabaseDAO.validate_update_uniqueness( + self._model_id, database_name + ): + exceptions.append(DatabaseExistsValidationError()) + if exceptions: + exception = DatabaseInvalidError() + exception.add_list(exceptions) + raise exception diff --git a/superset/databases/dao.py b/superset/databases/dao.py new file mode 100644 index 0000000000..88009800ed --- /dev/null +++ b/superset/databases/dao.py @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging + +from superset.dao.base import BaseDAO +from superset.databases.filters import DatabaseFilter +from superset.extensions import db +from superset.models.core import Database + +logger = logging.getLogger(__name__) + + +class DatabaseDAO(BaseDAO): + model_cls = Database + base_filter = DatabaseFilter + + @staticmethod + def validate_uniqueness(database_name: str) -> bool: + database_query = db.session.query(Database).filter( + Database.database_name == database_name + ) + return not db.session.query(database_query.exists()).scalar() + + @staticmethod + def validate_update_uniqueness(database_id: int, database_name: str) -> bool: + database_query = db.session.query(Database).filter( + Database.database_name == database_name, Database.id != database_id, + ) + return not db.session.query(database_query.exists()).scalar() diff --git a/superset/views/database/filters.py b/superset/databases/filters.py similarity index 100% rename from superset/views/database/filters.py rename to superset/databases/filters.py diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py index b322a47946..dff09f9c01 100644 --- a/superset/databases/schemas.py +++ b/superset/databases/schemas.py @@ -14,13 +14,266 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import inspect +import json + +from flask_babel import lazy_gettext as _ from marshmallow import fields, Schema +from marshmallow.validate import Length, ValidationError +from sqlalchemy import MetaData +from sqlalchemy.engine.url import make_url +from sqlalchemy.exc import ArgumentError + +from superset import app +from superset.exceptions import CertificateException +from superset.utils.core import markdown, parse_ssl_cert database_schemas_query_schema = { "type": "object", "properties": {"force": {"type": "boolean"}}, } +database_name_description = "A database name to identify this connection." +cache_timeout_description = ( + "Duration (in seconds) of the caching timeout for charts of this database. " + "A timeout of 0 indicates that the cache never expires. " + "Note this defaults to the global timeout if undefined." +) +expose_in_sqllab_description = "Expose this database to SQLLab" +allow_run_async_description = ( + "Operate the database in asynchronous mode, meaning " + "that the queries are executed on remote workers as opposed " + "to on the web server itself. " + "This assumes that you have a Celery worker setup as well " + "as a results backend. Refer to the installation docs " + "for more information." +) +allow_csv_upload_description = ( + "Allow to upload CSV file data into this database" + "If selected, please set the schemas allowed for csv upload in Extra." +) +allow_ctas_description = "Allow CREATE TABLE AS option in SQL Lab" +allow_cvas_description = "Allow CREATE VIEW AS option in SQL Lab" +allow_dml_description = ( + "Allow users to run non-SELECT statements " + "(UPDATE, DELETE, CREATE, ...) " + "in SQL Lab" +) +allow_multi_schema_metadata_fetch_description = ( + "Allow SQL Lab to fetch a list of all tables and all views across " + "all database schemas. For large data warehouse with thousands of " + "tables, this can be expensive and put strain on the system." +) # pylint: disable=invalid-name +impersonate_user_description = ( + "If Presto, all the queries in SQL Lab are going to be executed as the " + "currently logged on user who must have permission to run them.
" + "If Hive and hive.server2.enable.doAs is enabled, will run the queries as " + "service account, but impersonate the currently logged on user " + "via hive.server2.proxy.user property." +) +force_ctas_schema_description = ( + "When allowing CREATE TABLE AS option in SQL Lab, " + "this option forces the table to be created in this schema" +) +encrypted_extra_description = markdown( + "JSON string containing additional connection configuration.
" + "This is used to provide connection information for systems like " + "Hive, Presto, and BigQuery, which do not conform to the username:password " + "syntax normally used by SQLAlchemy.", + True, +) +extra_description = markdown( + "JSON string containing extra configuration elements.
" + "1. The ``engine_params`` object gets unpacked into the " + "[sqlalchemy.create_engine]" + "(https://docs.sqlalchemy.org/en/latest/core/engines.html#" + "sqlalchemy.create_engine) call, while the ``metadata_params`` " + "gets unpacked into the [sqlalchemy.MetaData]" + "(https://docs.sqlalchemy.org/en/rel_1_0/core/metadata.html" + "#sqlalchemy.schema.MetaData) call.
" + "2. The ``metadata_cache_timeout`` is a cache timeout setting " + "in seconds for metadata fetch of this database. Specify it as " + '**"metadata_cache_timeout": {"schema_cache_timeout": 600, ' + '"table_cache_timeout": 600}**. ' + "If unset, cache will not be enabled for the functionality. " + "A timeout of 0 indicates that the cache never expires.
" + "3. The ``schemas_allowed_for_csv_upload`` is a comma separated list " + "of schemas that CSVs are allowed to upload to. " + 'Specify it as **"schemas_allowed_for_csv_upload": ' + '["public", "csv_upload"]**. ' + "If database flavor does not support schema or any schema is allowed " + "to be accessed, just leave the list empty
" + "4. the ``version`` field is a string specifying the this db's version. " + "This should be used with Presto DBs so that the syntax is correct
" + "5. The ``allows_virtual_table_explore`` field is a boolean specifying " + "whether or not the Explore button in SQL Lab results is shown.", + True, +) +sqlalchemy_uri_description = markdown( + "Refer to the " + "[SqlAlchemy docs]" + "(https://docs.sqlalchemy.org/en/rel_1_2/core/engines.html#" + "database-urls) " + "for more information on how to structure your URI.", + True, +) +server_cert_description = markdown( + "Optional CA_BUNDLE contents to validate HTTPS requests. Only available " + "on certain database engines.", + True, +) + + +def sqlalchemy_uri_validator(value: str) -> str: + """ + Validate if it's a valid SQLAlchemy URI and refuse SQLLite by default + """ + try: + make_url(value.strip()) + except (ArgumentError, AttributeError): + raise ValidationError( + [ + _( + "Invalid connection string, a valid string usually follows:" + "'DRIVER://USER:PASSWORD@DB-HOST/DATABASE-NAME'" + "

" + "Example:'postgresql://user:password@your-postgres-db/database'" + "

" + ) + ] + ) + if app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] and value: + if value.startswith("sqlite"): + raise ValidationError( + [ + _( + "SQLite database cannot be used as a data source for " + "security reasons." + ) + ] + ) + + return value + + +def server_cert_validator(value: str) -> str: + """ + Validate the server certificate + """ + if value: + try: + parse_ssl_cert(value) + except CertificateException: + raise ValidationError([_("Invalid certificate")]) + return value + + +def encrypted_extra_validator(value: str) -> str: + """ + Validate that encrypted extra is a valid JSON string + """ + if value: + try: + json.loads(value) + except json.JSONDecodeError as ex: + raise ValidationError( + [_("Field cannot be decoded by JSON. %(msg)s", msg=str(ex))] + ) + return value + + +def extra_validator(value: str) -> str: + """ + Validate that extra is a valid JSON string, and that metadata_params + keys are on the call signature for SQLAlchemy Metadata + """ + if value: + try: + extra_ = json.loads(value) + except json.JSONDecodeError as ex: + raise ValidationError( + [_("Field cannot be decoded by JSON. %(msg)s", msg=str(ex))] + ) + else: + metadata_signature = inspect.signature(MetaData) + for key in extra_.get("metadata_params", {}): + if key not in metadata_signature.parameters: + raise ValidationError( + [ + _( + "The metadata_params in Extra field " + "is not configured correctly. The key " + "%(key)s is invalid.", + key=key, + ) + ] + ) + return value + + +class DatabasePostSchema(Schema): + database_name = fields.String( + description=database_name_description, required=True, validate=Length(1, 250), + ) + cache_timeout = fields.Integer(description=cache_timeout_description) + expose_in_sqllab = fields.Boolean(description=expose_in_sqllab_description) + allow_run_async = fields.Boolean(description=allow_run_async_description) + allow_csv_upload = fields.Boolean(description=allow_csv_upload_description) + allow_ctas = fields.Boolean(description=allow_ctas_description) + allow_cvas = fields.Boolean(description=allow_cvas_description) + allow_dml = fields.Boolean(description=allow_dml_description) + force_ctas_schema = fields.String( + description=force_ctas_schema_description, validate=Length(0, 250) + ) + allow_multi_schema_metadata_fetch = fields.Boolean( + description=allow_multi_schema_metadata_fetch_description, + ) + impersonate_user = fields.Boolean(description=impersonate_user_description) + encrypted_extra = fields.String( + description=encrypted_extra_description, validate=encrypted_extra_validator + ) + extra = fields.String(description=extra_description, validate=extra_validator) + server_cert = fields.String( + description=server_cert_description, validate=server_cert_validator + ) + sqlalchemy_uri = fields.String( + description=sqlalchemy_uri_description, + required=True, + validate=[Length(1, 1024), sqlalchemy_uri_validator], + ) + + +class DatabasePutSchema(Schema): + database_name = fields.String( + description=database_name_description, allow_none=True, validate=Length(1, 250), + ) + cache_timeout = fields.Integer(description=cache_timeout_description) + expose_in_sqllab = fields.Boolean(description=expose_in_sqllab_description) + allow_run_async = fields.Boolean(description=allow_run_async_description) + allow_csv_upload = fields.Boolean(description=allow_csv_upload_description) + allow_ctas = fields.Boolean(description=allow_ctas_description) + allow_cvas = fields.Boolean(description=allow_cvas_description) + allow_dml = fields.Boolean(description=allow_dml_description) + force_ctas_schema = fields.String( + description=force_ctas_schema_description, validate=Length(0, 250) + ) + allow_multi_schema_metadata_fetch = fields.Boolean( + description=allow_multi_schema_metadata_fetch_description + ) + impersonate_user = fields.Boolean(description=impersonate_user_description) + encrypted_extra = fields.String( + description=encrypted_extra_description, validate=encrypted_extra_validator + ) + extra = fields.String(description=extra_description, validate=extra_validator) + server_cert = fields.String( + description=server_cert_description, validate=server_cert_validator + ) + sqlalchemy_uri = fields.String( + description=sqlalchemy_uri_description, + allow_none=True, + validate=[Length(0, 1024), sqlalchemy_uri_validator], + ) + class TableMetadataOptionsResponseSchema(Schema): deferrable = fields.Bool() diff --git a/superset/databases/utils.py b/superset/databases/utils.py new file mode 100644 index 0000000000..c28c3e85cc --- /dev/null +++ b/superset/databases/utils.py @@ -0,0 +1,100 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Any, Dict, List, Optional + +from superset import app +from superset.models.core import Database + +custom_password_store = app.config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"] + + +def get_foreign_keys_metadata( + database: Database, table_name: str, schema_name: Optional[str] +) -> List[Dict[str, Any]]: + foreign_keys = database.get_foreign_keys(table_name, schema_name) + for fk in foreign_keys: + fk["column_names"] = fk.pop("constrained_columns") + fk["type"] = "fk" + return foreign_keys + + +def get_indexes_metadata( + database: Database, table_name: str, schema_name: Optional[str] +) -> List[Dict[str, Any]]: + indexes = database.get_indexes(table_name, schema_name) + for idx in indexes: + idx["type"] = "index" + return indexes + + +def get_col_type(col: Dict[Any, Any]) -> str: + try: + dtype = f"{col['type']}" + except Exception: # pylint: disable=broad-except + # sqla.types.JSON __str__ has a bug, so using __class__. + dtype = col["type"].__class__.__name__ + return dtype + + +def get_table_metadata( + database: Database, table_name: str, schema_name: Optional[str] +) -> Dict[str, Any]: + """ + Get table metadata information, including type, pk, fks. + This function raises SQLAlchemyError when a schema is not found. + + :param database: The database model + :param table_name: Table name + :param schema_name: schema name + :return: Dict table metadata ready for API response + """ + keys = [] + columns = database.get_columns(table_name, schema_name) + primary_key = database.get_pk_constraint(table_name, schema_name) + if primary_key and primary_key.get("constrained_columns"): + primary_key["column_names"] = primary_key.pop("constrained_columns") + primary_key["type"] = "pk" + keys += [primary_key] + foreign_keys = get_foreign_keys_metadata(database, table_name, schema_name) + indexes = get_indexes_metadata(database, table_name, schema_name) + keys += foreign_keys + indexes + payload_columns: List[Dict[str, Any]] = [] + for col in columns: + dtype = get_col_type(col) + payload_columns.append( + { + "name": col["name"], + "type": dtype.split("(")[0] if "(" in dtype else dtype, + "longType": dtype, + "keys": [k for k in keys if col["name"] in k["column_names"]], + } + ) + return { + "name": table_name, + "columns": payload_columns, + "selectStar": database.select_star( + table_name, + schema=schema_name, + show_cols=True, + indent=True, + cols=columns, + latest_partition=True, + ), + "primaryKey": primary_key, + "foreignKeys": foreign_keys, + "indexes": keys, + } diff --git a/superset/datasets/api.py b/superset/datasets/api.py index b04cef8184..93c2b92496 100644 --- a/superset/datasets/api.py +++ b/superset/datasets/api.py @@ -25,6 +25,7 @@ from marshmallow import ValidationError from superset.connectors.sqla.models import SqlaTable from superset.constants import RouteMethod +from superset.databases.filters import DatabaseFilter from superset.datasets.commands.create import CreateDatasetCommand from superset.datasets.commands.delete import DeleteDatasetCommand from superset.datasets.commands.exceptions import ( @@ -51,7 +52,6 @@ from superset.views.base_api import ( RelatedFieldFilter, statsd_metrics, ) -from superset.views.database.filters import DatabaseFilter from superset.views.filters import FilterRelatedOwners logger = logging.getLogger(__name__) diff --git a/superset/views/base_api.py b/superset/views/base_api.py index 4d7fe7d267..a4707876dc 100644 --- a/superset/views/base_api.py +++ b/superset/views/base_api.py @@ -106,6 +106,9 @@ class BaseSupersetModelRestApi(ModelRestApi): "data": "list", "viz_types": "list", "related_objects": "list", + "table_metadata": "list", + "select_star": "list", + "schemas": "list", } order_rel_fields: Dict[str, Tuple[str, str]] = {} diff --git a/superset/views/core.py b/superset/views/core.py index 7ddd801bdb..a96ce15d2a 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -68,6 +68,7 @@ from superset.connectors.sqla.models import ( TableColumn, ) from superset.dashboards.dao import DashboardDAO +from superset.databases.filters import DatabaseFilter from superset.exceptions import ( CertificateException, DatabaseNotFound, @@ -109,7 +110,6 @@ from superset.views.base import ( json_success, validate_sqlatable, ) -from superset.views.database.filters import DatabaseFilter from superset.views.utils import ( _deserialize_results_payload, apply_display_max_row_limit, diff --git a/superset/views/database/mixins.py b/superset/views/database/mixins.py index 49c003b456..d5d561247f 100644 --- a/superset/views/database/mixins.py +++ b/superset/views/database/mixins.py @@ -21,11 +21,11 @@ from flask_babel import lazy_gettext as _ from sqlalchemy import MetaData from superset import app, security_manager +from superset.databases.filters import DatabaseFilter from superset.exceptions import SupersetException from superset.models.core import Database from superset.security.analytics_db_safety import check_sqlalchemy_uri from superset.utils import core as utils -from superset.views.database.filters import DatabaseFilter class DatabaseMixin: @@ -240,7 +240,7 @@ class DatabaseMixin: extra = database.get_extra() except Exception as ex: raise Exception( - _("Extra field cannot be decoded by JSON. %{msg}s", msg=str(ex)) + _("Extra field cannot be decoded by JSON. %(msg)s", msg=str(ex)) ) # this will check whether 'metadata_params' is configured correctly @@ -264,5 +264,5 @@ class DatabaseMixin: database.get_encrypted_extra() except Exception as ex: raise Exception( - _("Extra field cannot be decoded by JSON. %{msg}s", msg=str(ex)) + _("Extra field cannot be decoded by JSON. %(msg)s", msg=str(ex)) ) diff --git a/superset/views/database/validators.py b/superset/views/database/validators.py index a1592a01a8..a54b853b43 100644 --- a/superset/views/database/validators.py +++ b/superset/views/database/validators.py @@ -36,11 +36,15 @@ def sqlalchemy_uri_validator( make_url(uri.strip()) except (ArgumentError, AttributeError): raise exception( - _( - "Invalid connection string, a valid string usually follows:" - "'DRIVER://USER:PASSWORD@DB-HOST/DATABASE-NAME'" - "

Example:'postgresql://user:password@your-postgres-db/database'

" - ) + [ + _( + "Invalid connection string, a valid string usually follows:" + "'DRIVER://USER:PASSWORD@DB-HOST/DATABASE-NAME'" + "

" + "Example:'postgresql://user:password@your-postgres-db/database'" + "

" + ) + ] ) diff --git a/tests/database_api_tests.py b/tests/database_api_tests.py deleted file mode 100644 index 49ede7b233..0000000000 --- a/tests/database_api_tests.py +++ /dev/null @@ -1,268 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# isort:skip_file -"""Unit tests for Superset""" -import json -from unittest import mock - -import prison -from sqlalchemy.sql import func - -import tests.test_app -from superset import db, security_manager -from superset.connectors.sqla.models import SqlaTable -from superset.models.core import Database -from superset.utils.core import get_example_database, get_main_database - -from .base_tests import SupersetTestCase - - -class TestDatabaseApi(SupersetTestCase): - def test_get_items(self): - """ - Database API: Test get items - """ - self.login(username="admin") - uri = "api/v1/database/" - rv = self.client.get(uri) - self.assertEqual(rv.status_code, 200) - response = json.loads(rv.data.decode("utf-8")) - expected_columns = [ - "allow_csv_upload", - "allow_ctas", - "allow_cvas", - "allow_dml", - "allow_multi_schema_metadata_fetch", - "allow_run_async", - "allows_cost_estimate", - "allows_subquery", - "allows_virtual_table_explore", - "backend", - "database_name", - "explore_database_id", - "expose_in_sqllab", - "force_ctas_schema", - "function_names", - "id", - ] - self.assertEqual(response["count"], 2) - self.assertEqual(list(response["result"][0].keys()), expected_columns) - - def test_get_items_filter(self): - """ - Database API: Test get items with filter - """ - fake_db = ( - db.session.query(Database).filter_by(database_name="fake_db_100").one() - ) - old_expose_in_sqllab = fake_db.expose_in_sqllab - fake_db.expose_in_sqllab = False - db.session.commit() - self.login(username="admin") - arguments = { - "keys": ["none"], - "filters": [{"col": "expose_in_sqllab", "opr": "eq", "value": True}], - "order_columns": "database_name", - "order_direction": "asc", - "page": 0, - "page_size": -1, - } - uri = f"api/v1/database/?q={prison.dumps(arguments)}" - rv = self.client.get(uri) - response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(rv.status_code, 200) - self.assertEqual(response["count"], 1) - - fake_db = ( - db.session.query(Database).filter_by(database_name="fake_db_100").one() - ) - fake_db.expose_in_sqllab = old_expose_in_sqllab - db.session.commit() - - def test_get_items_not_allowed(self): - """ - Database API: Test get items not allowed - """ - self.login(username="gamma") - uri = f"api/v1/database/" - rv = self.client.get(uri) - self.assertEqual(rv.status_code, 200) - response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(response["count"], 0) - - def test_get_table_metadata(self): - """ - Database API: Test get table metadata info - """ - example_db = get_example_database() - self.login(username="admin") - uri = f"api/v1/database/{example_db.id}/table/birth_names/null/" - rv = self.client.get(uri) - self.assertEqual(rv.status_code, 200) - response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(response["name"], "birth_names") - self.assertTrue(len(response["columns"]) > 5) - self.assertTrue(response.get("selectStar").startswith("SELECT")) - - def test_get_invalid_database_table_metadata(self): - """ - Database API: Test get invalid database from table metadata - """ - database_id = 1000 - self.login(username="admin") - uri = f"api/v1/database/{database_id}/table/some_table/some_schema/" - rv = self.client.get(uri) - self.assertEqual(rv.status_code, 404) - - uri = f"api/v1/database/some_database/table/some_table/some_schema/" - rv = self.client.get(uri) - self.assertEqual(rv.status_code, 404) - - def test_get_invalid_table_table_metadata(self): - """ - Database API: Test get invalid table from table metadata - """ - example_db = get_example_database() - uri = f"api/v1/database/{example_db.id}/wrong_table/null/" - self.login(username="admin") - rv = self.client.get(uri) - self.assertEqual(rv.status_code, 404) - - def test_get_table_metadata_no_db_permission(self): - """ - Database API: Test get table metadata from not permitted db - """ - self.login(username="gamma") - example_db = get_example_database() - uri = f"api/v1/database/{example_db.id}/birth_names/null/" - rv = self.client.get(uri) - self.assertEqual(rv.status_code, 404) - - def test_get_select_star(self): - """ - Database API: Test get select star - """ - self.login(username="admin") - example_db = get_example_database() - uri = f"api/v1/database/{example_db.id}/select_star/birth_names/" - rv = self.client.get(uri) - self.assertEqual(rv.status_code, 200) - response = json.loads(rv.data.decode("utf-8")) - self.assertIn("gender", response["result"]) - - def test_get_select_star_not_allowed(self): - """ - Database API: Test get select star not allowed - """ - self.login(username="gamma") - example_db = get_example_database() - uri = f"api/v1/database/{example_db.id}/select_star/birth_names/" - rv = self.client.get(uri) - self.assertEqual(rv.status_code, 404) - - def test_get_select_star_datasource_access(self): - """ - Database API: Test get select star with datasource access - """ - session = db.session - table = SqlaTable( - schema="main", table_name="ab_permission", database=get_main_database() - ) - session.add(table) - session.commit() - - tmp_table_perm = security_manager.find_permission_view_menu( - "datasource_access", table.get_perm() - ) - gamma_role = security_manager.find_role("Gamma") - security_manager.add_permission_role(gamma_role, tmp_table_perm) - - self.login(username="gamma") - main_db = get_main_database() - uri = f"api/v1/database/{main_db.id}/select_star/ab_permission/" - rv = self.client.get(uri) - self.assertEqual(rv.status_code, 200) - - # rollback changes - security_manager.del_permission_role(gamma_role, tmp_table_perm) - db.session.delete(table) - db.session.delete(main_db) - db.session.commit() - - def test_get_select_star_not_found_database(self): - """ - Database API: Test get select star not found database - """ - self.login(username="admin") - max_id = db.session.query(func.max(Database.id)).scalar() - uri = f"api/v1/database/{max_id + 1}/select_star/birth_names/" - rv = self.client.get(uri) - self.assertEqual(rv.status_code, 404) - - def test_get_select_star_not_found_table(self): - """ - Database API: Test get select star not found database - """ - self.login(username="admin") - example_db = get_example_database() - # sqllite will not raise a NoSuchTableError - if example_db.backend == "sqlite": - return - uri = f"api/v1/database/{example_db.id}/select_star/table_does_not_exist/" - rv = self.client.get(uri) - # TODO(bkyryliuk): investigate why presto returns 500 - self.assertEqual(rv.status_code, 404 if example_db.backend != "presto" else 500) - - def test_database_schemas(self): - """ - Database API: Test database schemas - """ - self.login("admin") - database = db.session.query(Database).first() - schemas = database.get_all_schema_names() - - rv = self.client.get(f"api/v1/database/{database.id}/schemas/") - response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(schemas, response["result"]) - - rv = self.client.get( - f"api/v1/database/{database.id}/schemas/?q={prison.dumps({'force': True})}" - ) - response = json.loads(rv.data.decode("utf-8")) - self.assertEqual(schemas, response["result"]) - - def test_database_schemas_not_found(self): - """ - Database API: Test database schemas not found - """ - self.logout() - self.login(username="gamma") - example_db = get_example_database() - uri = f"api/v1/database/{example_db.id}/schemas/" - rv = self.client.get(uri) - self.assertEqual(rv.status_code, 404) - - def test_database_schemas_invalid_query(self): - """ - Database API: Test database schemas with invalid query - """ - self.login("admin") - database = db.session.query(Database).first() - rv = self.client.get( - f"api/v1/database/{database.id}/schemas/?q={prison.dumps({'force': 'nop'})}" - ) - self.assertEqual(rv.status_code, 400) diff --git a/tests/databases/__init__.py b/tests/databases/__init__.py new file mode 100644 index 0000000000..13a83393a9 --- /dev/null +++ b/tests/databases/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/databases/api_tests.py b/tests/databases/api_tests.py new file mode 100644 index 0000000000..503387af39 --- /dev/null +++ b/tests/databases/api_tests.py @@ -0,0 +1,650 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# isort:skip_file +"""Unit tests for Superset""" +import json + +import prison +from sqlalchemy.sql import func + +import tests.test_app +from superset import db, security_manager +from superset.connectors.sqla.models import SqlaTable +from superset.models.core import Database +from superset.utils.core import get_example_database, get_main_database +from tests.base_tests import SupersetTestCase +from tests.fixtures.certificates import ssl_certificate + + +class TestDatabaseApi(SupersetTestCase): + def insert_database( + self, + database_name: str, + sqlalchemy_uri: str, + extra: str = "", + encrypted_extra: str = "", + server_cert: str = "", + expose_in_sqllab: bool = False, + ) -> Database: + database = Database( + database_name=database_name, + sqlalchemy_uri=sqlalchemy_uri, + extra=extra, + encrypted_extra=encrypted_extra, + server_cert=server_cert, + expose_in_sqllab=expose_in_sqllab, + ) + db.session.add(database) + db.session.commit() + return database + + def test_get_items(self): + """ + Database API: Test get items + """ + self.login(username="admin") + uri = "api/v1/database/" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 200) + response = json.loads(rv.data.decode("utf-8")) + expected_columns = [ + "allow_csv_upload", + "allow_ctas", + "allow_cvas", + "allow_dml", + "allow_multi_schema_metadata_fetch", + "allow_run_async", + "allows_cost_estimate", + "allows_subquery", + "allows_virtual_table_explore", + "backend", + "database_name", + "explore_database_id", + "expose_in_sqllab", + "force_ctas_schema", + "function_names", + "id", + ] + self.assertEqual(response["count"], 2) + self.assertEqual(list(response["result"][0].keys()), expected_columns) + + def test_get_items_filter(self): + """ + Database API: Test get items with filter + """ + example_db = get_example_database() + test_database = self.insert_database( + "test-database", example_db.sqlalchemy_uri_decrypted, expose_in_sqllab=True + ) + dbs = db.session.query(Database).filter_by(expose_in_sqllab=True).all() + + self.login(username="admin") + arguments = { + "keys": ["none"], + "filters": [{"col": "expose_in_sqllab", "opr": "eq", "value": True}], + "order_columns": "database_name", + "order_direction": "asc", + "page": 0, + "page_size": -1, + } + uri = f"api/v1/database/?q={prison.dumps(arguments)}" + rv = self.client.get(uri) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 200) + self.assertEqual(response["count"], len(dbs)) + + # Cleanup + db.session.delete(test_database) + db.session.commit() + + def test_get_items_not_allowed(self): + """ + Database API: Test get items not allowed + """ + self.login(username="gamma") + uri = f"api/v1/database/" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 200) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(response["count"], 0) + + def test_create_database(self): + """ + Database API: Test create + """ + extra = { + "metadata_params": {}, + "engine_params": {}, + "metadata_cache_timeout": {}, + "schemas_allowed_for_csv_upload": [], + } + + self.login(username="admin") + example_db = get_example_database() + if example_db.backend == "sqlite": + return + database_data = { + "database_name": "test-database", + "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, + "server_cert": ssl_certificate, + "extra": json.dumps(extra), + } + + uri = "api/v1/database/" + rv = self.client.post(uri, json=database_data) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 201) + # Cleanup + model = db.session.query(Database).get(response.get("id")) + db.session.delete(model) + db.session.commit() + + def test_create_database_server_cert_validate(self): + """ + Database API: Test create server cert validation + """ + example_db = get_example_database() + if example_db.backend == "sqlite": + return + + self.login(username="admin") + database_data = { + "database_name": "test-database", + "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, + "server_cert": "INVALID CERT", + } + + uri = "api/v1/database/" + rv = self.client.post(uri, json=database_data) + response = json.loads(rv.data.decode("utf-8")) + expected_response = {"message": {"server_cert": ["Invalid certificate"]}} + self.assertEqual(rv.status_code, 400) + self.assertEqual(response, expected_response) + + def test_create_database_json_validate(self): + """ + Database API: Test create encrypted extra and extra validation + """ + example_db = get_example_database() + if example_db.backend == "sqlite": + return + + self.login(username="admin") + database_data = { + "database_name": "test-database", + "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, + "encrypted_extra": '{"A": "a", "B", "C"}', + "extra": '["A": "a", "B", "C"]', + } + + uri = "api/v1/database/" + rv = self.client.post(uri, json=database_data) + response = json.loads(rv.data.decode("utf-8")) + expected_response = { + "message": { + "encrypted_extra": [ + "Field cannot be decoded by JSON. Expecting ':' " + "delimiter: line 1 column 15 (char 14)" + ], + "extra": [ + "Field cannot be decoded by JSON. Expecting ','" + " delimiter: line 1 column 5 (char 4)" + ], + } + } + self.assertEqual(rv.status_code, 400) + self.assertEqual(response, expected_response) + + def test_create_database_extra_metadata_validate(self): + """ + Database API: Test create extra metadata_params validation + """ + example_db = get_example_database() + if example_db.backend == "sqlite": + return + + extra = { + "metadata_params": {"wrong_param": "some_value"}, + "engine_params": {}, + "metadata_cache_timeout": {}, + "schemas_allowed_for_csv_upload": [], + } + self.login(username="admin") + database_data = { + "database_name": "test-database", + "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, + "extra": json.dumps(extra), + } + + uri = "api/v1/database/" + rv = self.client.post(uri, json=database_data) + response = json.loads(rv.data.decode("utf-8")) + expected_response = { + "message": { + "extra": [ + "The metadata_params in Extra field is not configured correctly." + " The key wrong_param is invalid." + ] + } + } + self.assertEqual(rv.status_code, 400) + self.assertEqual(response, expected_response) + + def test_create_database_unique_validate(self): + """ + Database API: Test create database_name already exists + """ + example_db = get_example_database() + if example_db.backend == "sqlite": + return + + self.login(username="admin") + database_data = { + "database_name": "examples", + "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, + } + + uri = "api/v1/database/" + rv = self.client.post(uri, json=database_data) + response = json.loads(rv.data.decode("utf-8")) + expected_response = { + "message": {"database_name": "A database with the same name already exists"} + } + self.assertEqual(rv.status_code, 422) + self.assertEqual(response, expected_response) + + def test_create_database_uri_validate(self): + """ + Database API: Test create fail validate sqlalchemy uri + """ + self.login(username="admin") + database_data = { + "database_name": "test-database", + "sqlalchemy_uri": "wrong_uri", + } + + uri = "api/v1/database/" + rv = self.client.post(uri, json=database_data) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 400) + expected_response = { + "message": { + "sqlalchemy_uri": [ + "Invalid connection string, a valid string usually " + "follows:'DRIVER://USER:PASSWORD@DB-HOST/DATABASE-NAME'" + "

Example:'postgresql://user:password@your-postgres-db/database'" + "

" + ] + } + } + self.assertEqual(response, expected_response) + + def test_create_database_fail_sqllite(self): + """ + Database API: Test create fail with sqllite + """ + database_data = { + "database_name": "test-database", + "sqlalchemy_uri": "sqlite:////some.db", + } + + uri = "api/v1/database/" + self.login(username="admin") + response = self.client.post(uri, json=database_data) + response_data = json.loads(response.data.decode("utf-8")) + expected_response = { + "message": { + "sqlalchemy_uri": [ + "SQLite database cannot be used as a data source " + "for security reasons." + ] + } + } + self.assertEqual(response.status_code, 400) + self.assertEqual(response_data, expected_response) + + def test_create_database_conn_fail(self): + """ + Database API: Test create fails connection + """ + example_db = get_example_database() + if example_db.backend in ("sqlite", "hive", "presto"): + return + example_db.password = "wrong_password" + database_data = { + "database_name": "test-database", + "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, + } + + uri = "api/v1/database/" + self.login(username="admin") + response = self.client.post(uri, json=database_data) + response_data = json.loads(response.data.decode("utf-8")) + expected_response = {"message": "Could not connect to database."} + self.assertEqual(response.status_code, 422) + self.assertEqual(response_data, expected_response) + + def test_update_database(self): + """ + Database API: Test update + """ + example_db = get_example_database() + test_database = self.insert_database( + "test-database", example_db.sqlalchemy_uri_decrypted + ) + + self.login(username="admin") + database_data = {"database_name": "test-database-updated"} + uri = f"api/v1/database/{test_database.id}" + rv = self.client.put(uri, json=database_data) + self.assertEqual(rv.status_code, 200) + # Cleanup + model = db.session.query(Database).get(test_database.id) + db.session.delete(model) + db.session.commit() + + def test_update_database_conn_fail(self): + """ + Database API: Test update fails connection + """ + example_db = get_example_database() + if example_db.backend in ("sqlite", "hive", "presto"): + return + + test_database = self.insert_database( + "test-database1", example_db.sqlalchemy_uri_decrypted + ) + example_db.password = "wrong_password" + database_data = { + "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, + } + + uri = f"api/v1/database/{test_database.id}" + self.login(username="admin") + rv = self.client.put(uri, json=database_data) + response = json.loads(rv.data.decode("utf-8")) + expected_response = {"message": "Could not connect to database."} + self.assertEqual(rv.status_code, 422) + self.assertEqual(response, expected_response) + # Cleanup + model = db.session.query(Database).get(test_database.id) + db.session.delete(model) + db.session.commit() + + def test_update_database_uniqueness(self): + """ + Database API: Test update uniqueness + """ + example_db = get_example_database() + test_database1 = self.insert_database( + "test-database1", example_db.sqlalchemy_uri_decrypted + ) + test_database2 = self.insert_database( + "test-database2", example_db.sqlalchemy_uri_decrypted + ) + + self.login(username="admin") + database_data = {"database_name": "test-database2"} + uri = f"api/v1/database/{test_database1.id}" + rv = self.client.put(uri, json=database_data) + response = json.loads(rv.data.decode("utf-8")) + expected_response = { + "message": {"database_name": "A database with the same name already exists"} + } + self.assertEqual(rv.status_code, 422) + self.assertEqual(response, expected_response) + # Cleanup + db.session.delete(test_database1) + db.session.delete(test_database2) + db.session.commit() + + def test_update_database_invalid(self): + """ + Database API: Test update invalid request + """ + self.login(username="admin") + database_data = {"database_name": "test-database-updated"} + uri = f"api/v1/database/invalid" + rv = self.client.put(uri, json=database_data) + self.assertEqual(rv.status_code, 404) + + def test_update_database_uri_validate(self): + """ + Database API: Test update sqlalchemy_uri validate + """ + example_db = get_example_database() + test_database = self.insert_database( + "test-database", example_db.sqlalchemy_uri_decrypted + ) + + self.login(username="admin") + database_data = { + "database_name": "test-database-updated", + "sqlalchemy_uri": "wrong_uri", + } + uri = f"api/v1/database/{test_database.id}" + rv = self.client.put(uri, json=database_data) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 400) + expected_response = { + "message": { + "sqlalchemy_uri": [ + "Invalid connection string, a valid string usually " + "follows:'DRIVER://USER:PASSWORD@DB-HOST/DATABASE-NAME'" + "

Example:'postgresql://user:password@your-postgres-db/database'" + "

" + ] + } + } + self.assertEqual(response, expected_response) + + def test_delete_database(self): + """ + Database API: Test delete + """ + database_id = self.insert_database("test-database", "test_uri").id + self.login(username="admin") + uri = f"api/v1/database/{database_id}" + rv = self.delete_assert_metric(uri, "delete") + self.assertEqual(rv.status_code, 200) + model = db.session.query(Database).get(database_id) + self.assertEqual(model, None) + + def test_delete_database_not_found(self): + """ + Database API: Test delete not found + """ + max_id = db.session.query(func.max(Database.id)).scalar() + self.login(username="admin") + uri = f"api/v1/database/{max_id + 1}" + rv = self.delete_assert_metric(uri, "delete") + self.assertEqual(rv.status_code, 404) + + def test_delete_database_with_datasets(self): + """ + Database API: Test delete fails because it has depending datasets + """ + database_id = ( + db.session.query(Database).filter_by(database_name="examples").one() + ).id + self.login(username="admin") + uri = f"api/v1/database/{database_id}" + rv = self.delete_assert_metric(uri, "delete") + self.assertEqual(rv.status_code, 422) + + def test_get_table_metadata(self): + """ + Database API: Test get table metadata info + """ + example_db = get_example_database() + self.login(username="admin") + uri = f"api/v1/database/{example_db.id}/table/birth_names/null/" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 200) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(response["name"], "birth_names") + self.assertTrue(len(response["columns"]) > 5) + self.assertTrue(response.get("selectStar").startswith("SELECT")) + + def test_get_invalid_database_table_metadata(self): + """ + Database API: Test get invalid database from table metadata + """ + database_id = 1000 + self.login(username="admin") + uri = f"api/v1/database/{database_id}/table/some_table/some_schema/" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 404) + + uri = f"api/v1/database/some_database/table/some_table/some_schema/" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 404) + + def test_get_invalid_table_table_metadata(self): + """ + Database API: Test get invalid table from table metadata + """ + example_db = get_example_database() + uri = f"api/v1/database/{example_db.id}/wrong_table/null/" + self.login(username="admin") + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 404) + + def test_get_table_metadata_no_db_permission(self): + """ + Database API: Test get table metadata from not permitted db + """ + self.login(username="gamma") + example_db = get_example_database() + uri = f"api/v1/database/{example_db.id}/birth_names/null/" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 404) + + def test_get_select_star(self): + """ + Database API: Test get select star + """ + self.login(username="admin") + example_db = get_example_database() + uri = f"api/v1/database/{example_db.id}/select_star/birth_names/" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 200) + response = json.loads(rv.data.decode("utf-8")) + self.assertIn("gender", response["result"]) + + def test_get_select_star_not_allowed(self): + """ + Database API: Test get select star not allowed + """ + self.login(username="gamma") + example_db = get_example_database() + uri = f"api/v1/database/{example_db.id}/select_star/birth_names/" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 404) + + def test_get_select_star_datasource_access(self): + """ + Database API: Test get select star with datasource access + """ + session = db.session + table = SqlaTable( + schema="main", table_name="ab_permission", database=get_main_database() + ) + session.add(table) + session.commit() + + tmp_table_perm = security_manager.find_permission_view_menu( + "datasource_access", table.get_perm() + ) + gamma_role = security_manager.find_role("Gamma") + security_manager.add_permission_role(gamma_role, tmp_table_perm) + + self.login(username="gamma") + main_db = get_main_database() + uri = f"api/v1/database/{main_db.id}/select_star/ab_permission/" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 200) + + # rollback changes + security_manager.del_permission_role(gamma_role, tmp_table_perm) + db.session.delete(table) + db.session.delete(main_db) + db.session.commit() + + def test_get_select_star_not_found_database(self): + """ + Database API: Test get select star not found database + """ + self.login(username="admin") + max_id = db.session.query(func.max(Database.id)).scalar() + uri = f"api/v1/database/{max_id + 1}/select_star/birth_names/" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 404) + + def test_get_select_star_not_found_table(self): + """ + Database API: Test get select star not found database + """ + self.login(username="admin") + example_db = get_example_database() + # sqllite will not raise a NoSuchTableError + if example_db.backend == "sqlite": + return + uri = f"api/v1/database/{example_db.id}/select_star/table_does_not_exist/" + rv = self.client.get(uri) + # TODO(bkyryliuk): investigate why presto returns 500 + self.assertEqual(rv.status_code, 404 if example_db.backend != "presto" else 500) + + def test_database_schemas(self): + """ + Database API: Test database schemas + """ + self.login("admin") + database = db.session.query(Database).first() + schemas = database.get_all_schema_names() + + rv = self.client.get(f"api/v1/database/{database.id}/schemas/") + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(schemas, response["result"]) + + rv = self.client.get( + f"api/v1/database/{database.id}/schemas/?q={prison.dumps({'force': True})}" + ) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(schemas, response["result"]) + + def test_database_schemas_not_found(self): + """ + Database API: Test database schemas not found + """ + self.logout() + self.login(username="gamma") + example_db = get_example_database() + uri = f"api/v1/database/{example_db.id}/schemas/" + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 404) + + def test_database_schemas_invalid_query(self): + """ + Database API: Test database schemas with invalid query + """ + self.login("admin") + database = db.session.query(Database).first() + rv = self.client.get( + f"api/v1/database/{database.id}/schemas/?q={prison.dumps({'force': 'nop'})}" + ) + self.assertEqual(rv.status_code, 400)