feat(database): POST, PUT, DELETE API endpoints (#10741)

* feat(database): POST, PUT, DELETE API endpoints

* post tests

* more tests

* lint

* lint

* debug ci

* fix test

* fix test

* fix test

* fix test

* fix test

* fix test

* cleanup

* handle db connection failures

* lint

* skip hive and presto for connection fail test

* fix typo
This commit is contained in:
Daniel Vaz Gaspar 2020-09-02 17:58:08 +01:00 committed by GitHub
parent b5aecaff5c
commit 77a3167412
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 1676 additions and 367 deletions

View File

@ -14,126 +14,78 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
from typing import Any, Dict, List, Optional import logging
from typing import Any, Optional
from flask import g, request, Response
from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.api import expose, protect, rison, safe
from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_appbuilder.models.sqla.interface import SQLAInterface
from marshmallow import ValidationError
from sqlalchemy.exc import NoSuchTableError, OperationalError, SQLAlchemyError from sqlalchemy.exc import NoSuchTableError, OperationalError, SQLAlchemyError
from superset import event_logger from superset import event_logger
from superset.constants import RouteMethod
from superset.databases.commands.create import CreateDatabaseCommand
from superset.databases.commands.delete import DeleteDatabaseCommand
from superset.databases.commands.exceptions import (
DatabaseCreateFailedError,
DatabaseDeleteDatasetsExistFailedError,
DatabaseDeleteFailedError,
DatabaseInvalidError,
DatabaseNotFoundError,
DatabaseUpdateFailedError,
)
from superset.databases.commands.update import UpdateDatabaseCommand
from superset.databases.decorators import check_datasource_access from superset.databases.decorators import check_datasource_access
from superset.databases.filters import DatabaseFilter
from superset.databases.schemas import ( from superset.databases.schemas import (
database_schemas_query_schema, database_schemas_query_schema,
DatabasePostSchema,
DatabasePutSchema,
SchemasResponseSchema, SchemasResponseSchema,
SelectStarResponseSchema, SelectStarResponseSchema,
TableMetadataResponseSchema, TableMetadataResponseSchema,
) )
from superset.databases.utils import get_table_metadata
from superset.extensions import security_manager from superset.extensions import security_manager
from superset.models.core import Database from superset.models.core import Database
from superset.typing import FlaskResponse from superset.typing import FlaskResponse
from superset.utils.core import error_msg_from_exception from superset.utils.core import error_msg_from_exception
from superset.views.base_api import BaseSupersetModelRestApi, statsd_metrics from superset.views.base_api import BaseSupersetModelRestApi, statsd_metrics
from superset.views.database.filters import DatabaseFilter
from superset.views.database.validators import sqlalchemy_uri_validator
logger = logging.getLogger(__name__)
def get_foreign_keys_metadata(
database: Database, table_name: str, schema_name: Optional[str]
) -> List[Dict[str, Any]]:
foreign_keys = database.get_foreign_keys(table_name, schema_name)
for fk in foreign_keys:
fk["column_names"] = fk.pop("constrained_columns")
fk["type"] = "fk"
return foreign_keys
def get_indexes_metadata(
database: Database, table_name: str, schema_name: Optional[str]
) -> List[Dict[str, Any]]:
indexes = database.get_indexes(table_name, schema_name)
for idx in indexes:
idx["type"] = "index"
return indexes
def get_col_type(col: Dict[Any, Any]) -> str:
try:
dtype = f"{col['type']}"
except Exception: # pylint: disable=broad-except
# sqla.types.JSON __str__ has a bug, so using __class__.
dtype = col["type"].__class__.__name__
return dtype
def get_table_metadata(
database: Database, table_name: str, schema_name: Optional[str]
) -> Dict[str, Any]:
"""
Get table metadata information, including type, pk, fks.
This function raises SQLAlchemyError when a schema is not found.
:param database: The database model
:param table_name: Table name
:param schema_name: schema name
:return: Dict table metadata ready for API response
"""
keys = []
columns = database.get_columns(table_name, schema_name)
primary_key = database.get_pk_constraint(table_name, schema_name)
if primary_key and primary_key.get("constrained_columns"):
primary_key["column_names"] = primary_key.pop("constrained_columns")
primary_key["type"] = "pk"
keys += [primary_key]
foreign_keys = get_foreign_keys_metadata(database, table_name, schema_name)
indexes = get_indexes_metadata(database, table_name, schema_name)
keys += foreign_keys + indexes
payload_columns: List[Dict[str, Any]] = []
for col in columns:
dtype = get_col_type(col)
payload_columns.append(
{
"name": col["name"],
"type": dtype.split("(")[0] if "(" in dtype else dtype,
"longType": dtype,
"keys": [k for k in keys if col["name"] in k["column_names"]],
}
)
return {
"name": table_name,
"columns": payload_columns,
"selectStar": database.select_star(
table_name,
schema=schema_name,
show_cols=True,
indent=True,
cols=columns,
latest_partition=True,
),
"primaryKey": primary_key,
"foreignKeys": foreign_keys,
"indexes": keys,
}
class DatabaseRestApi(BaseSupersetModelRestApi): class DatabaseRestApi(BaseSupersetModelRestApi):
datamodel = SQLAInterface(Database) datamodel = SQLAInterface(Database)
include_route_methods = { include_route_methods = RouteMethod.REST_MODEL_VIEW_CRUD_SET | {
"get_list",
"table_metadata", "table_metadata",
"select_star", "select_star",
"schemas", "schemas",
} }
class_permission_name = "DatabaseView" class_permission_name = "DatabaseView"
method_permission_name = {
"get_list": "list",
"table_metadata": "list",
"select_star": "list",
"schemas": "list",
}
resource_name = "database" resource_name = "database"
allow_browser_login = True allow_browser_login = True
base_filters = [["id", DatabaseFilter, lambda: []]] base_filters = [["id", DatabaseFilter, lambda: []]]
show_columns = [
"id",
"database_name",
"cache_timeout",
"expose_in_sqllab",
"allow_run_async",
"allow_csv_upload",
"allow_ctas",
"allow_cvas",
"allow_dml",
"force_ctas_schema",
"allow_multi_schema_metadata_fetch",
"impersonate_user",
"encrypted_extra",
"extra",
"server_cert",
"sqlalchemy_uri",
]
list_columns = [ list_columns = [
"id", "id",
"database_name", "database_name",
@ -152,10 +104,30 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
"backend", "backend",
"function_names", "function_names",
] ]
add_columns = [
"database_name",
"sqlalchemy_uri",
"cache_timeout",
"expose_in_sqllab",
"allow_run_async",
"allow_csv_upload",
"allow_ctas",
"allow_cvas",
"allow_dml",
"force_ctas_schema",
"impersonate_user",
"allow_multi_schema_metadata_fetch",
"extra",
"encrypted_extra",
"server_cert",
]
edit_columns = add_columns
list_select_columns = list_columns + ["extra", "sqlalchemy_uri", "password"] list_select_columns = list_columns + ["extra", "sqlalchemy_uri", "password"]
# Removes the local limit for the page size # Removes the local limit for the page size
max_page_size = -1 max_page_size = -1
validators_columns = {"sqlalchemy_uri": sqlalchemy_uri_validator} add_model_schema = DatabasePostSchema()
edit_model_schema = DatabasePutSchema()
apispec_parameter_schemas = { apispec_parameter_schemas = {
"database_schemas_query_schema": database_schemas_query_schema, "database_schemas_query_schema": database_schemas_query_schema,
@ -167,6 +139,186 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
SchemasResponseSchema, SchemasResponseSchema,
) )
@expose("/", methods=["POST"])
@protect()
@safe
@statsd_metrics
def post(self) -> Response:
"""Creates a new Database
---
post:
description: >-
Create a new Database.
requestBody:
description: Database schema
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/{{self.__class__.__name__}}.post'
responses:
201:
description: Database added
content:
application/json:
schema:
type: object
properties:
id:
type: number
result:
$ref: '#/components/schemas/{{self.__class__.__name__}}.post'
302:
description: Redirects to the current digest
400:
$ref: '#/components/responses/400'
401:
$ref: '#/components/responses/401'
404:
$ref: '#/components/responses/404'
500:
$ref: '#/components/responses/500'
"""
if not request.is_json:
return self.response_400(message="Request is not JSON")
try:
item = self.add_model_schema.load(request.json)
# This validates custom Schema with custom validations
except ValidationError as error:
return self.response_400(message=error.messages)
try:
new_model = CreateDatabaseCommand(g.user, item).run()
# Return censored version for sqlalchemy URI
item["sqlalchemy_uri"] = new_model.sqlalchemy_uri
return self.response(201, id=new_model.id, result=item)
except DatabaseInvalidError as ex:
return self.response_422(message=ex.normalized_messages())
except DatabaseCreateFailedError as ex:
logger.error(
"Error creating model %s: %s", self.__class__.__name__, str(ex)
)
return self.response_422(message=str(ex))
@expose("/<int:pk>", methods=["PUT"])
@protect()
@safe
@statsd_metrics
def put( # pylint: disable=too-many-return-statements, arguments-differ
self, pk: int
) -> Response:
"""Changes a Database
---
put:
description: >-
Changes a Database.
parameters:
- in: path
schema:
type: integer
name: pk
requestBody:
description: Database schema
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/{{self.__class__.__name__}}.put'
responses:
200:
description: Database changed
content:
application/json:
schema:
type: object
properties:
id:
type: number
result:
$ref: '#/components/schemas/{{self.__class__.__name__}}.put'
400:
$ref: '#/components/responses/400'
401:
$ref: '#/components/responses/401'
403:
$ref: '#/components/responses/403'
404:
$ref: '#/components/responses/404'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
if not request.is_json:
return self.response_400(message="Request is not JSON")
try:
item = self.edit_model_schema.load(request.json)
# This validates custom Schema with custom validations
except ValidationError as error:
return self.response_400(message=error.messages)
try:
changed_model = UpdateDatabaseCommand(g.user, pk, item).run()
# Return censored version for sqlalchemy URI
item["sqlalchemy_uri"] = changed_model.sqlalchemy_uri
return self.response(200, id=changed_model.id, result=item)
except DatabaseNotFoundError:
return self.response_404()
except DatabaseInvalidError as ex:
return self.response_422(message=ex.normalized_messages())
except DatabaseUpdateFailedError as ex:
logger.error(
"Error updating model %s: %s", self.__class__.__name__, str(ex)
)
return self.response_422(message=str(ex))
@expose("/<int:pk>", methods=["DELETE"])
@protect()
@safe
@statsd_metrics
def delete(self, pk: int) -> Response: # pylint: disable=arguments-differ
"""Deletes a Database
---
delete:
description: >-
Deletes a Database.
parameters:
- in: path
schema:
type: integer
name: pk
responses:
200:
description: Database deleted
content:
application/json:
schema:
type: object
properties:
message:
type: string
401:
$ref: '#/components/responses/401'
403:
$ref: '#/components/responses/403'
404:
$ref: '#/components/responses/404'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
try:
DeleteDatabaseCommand(g.user, pk).run()
return self.response(200, message="OK")
except DatabaseNotFoundError:
return self.response_404()
except DatabaseDeleteDatasetsExistFailedError as ex:
return self.response_422(message=str(ex))
except DatabaseDeleteFailedError as ex:
logger.error(
"Error deleting model %s: %s", self.__class__.__name__, str(ex)
)
return self.response_422(message=str(ex))
@expose("/<int:pk>/schemas/") @expose("/<int:pk>/schemas/")
@protect() @protect()
@safe @safe

View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@ -0,0 +1,84 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Dict, List, Optional
from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
from marshmallow import ValidationError
from superset.commands.base import BaseCommand
from superset.dao.exceptions import DAOCreateFailedError
from superset.databases.commands.exceptions import (
DatabaseConnectionFailedError,
DatabaseCreateFailedError,
DatabaseExistsValidationError,
DatabaseInvalidError,
DatabaseRequiredFieldValidationError,
)
from superset.databases.dao import DatabaseDAO
from superset.extensions import db, security_manager
logger = logging.getLogger(__name__)
class CreateDatabaseCommand(BaseCommand):
def __init__(self, user: User, data: Dict[str, Any]):
self._actor = user
self._properties = data.copy()
def run(self) -> Model:
self.validate()
try:
database = DatabaseDAO.create(self._properties, commit=False)
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
# adding a new database we always want to force refresh schema list
# TODO Improve this simplistic implementation for catching DB conn fails
try:
schemas = database.get_all_schema_names()
except Exception:
db.session.rollback()
raise DatabaseConnectionFailedError()
for schema in schemas:
security_manager.add_permission_view_menu(
"schema_access", security_manager.get_schema_perm(database, schema)
)
security_manager.add_permission_view_menu("database_access", database.perm)
db.session.commit()
except DAOCreateFailedError as ex:
logger.exception(ex.exception)
raise DatabaseCreateFailedError()
return database
def validate(self) -> None:
exceptions: List[ValidationError] = list()
sqlalchemy_uri: Optional[str] = self._properties.get("sqlalchemy_uri")
database_name: Optional[str] = self._properties.get("database_name")
if not sqlalchemy_uri:
exceptions.append(DatabaseRequiredFieldValidationError("sqlalchemy_uri"))
if not database_name:
exceptions.append(DatabaseRequiredFieldValidationError("database_name"))
else:
# Check database_name uniqueness
if not DatabaseDAO.validate_uniqueness(database_name):
exceptions.append(DatabaseExistsValidationError())
if exceptions:
exception = DatabaseInvalidError()
exception.add_list(exceptions)
raise exception

View File

@ -0,0 +1,58 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Optional
from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
from superset.commands.base import BaseCommand
from superset.dao.exceptions import DAODeleteFailedError
from superset.databases.commands.exceptions import (
DatabaseDeleteDatasetsExistFailedError,
DatabaseDeleteFailedError,
DatabaseNotFoundError,
)
from superset.databases.dao import DatabaseDAO
from superset.models.core import Database
logger = logging.getLogger(__name__)
class DeleteDatabaseCommand(BaseCommand):
def __init__(self, user: User, model_id: int):
self._actor = user
self._model_id = model_id
self._model: Optional[Database] = None
def run(self) -> Model:
self.validate()
try:
database = DatabaseDAO.delete(self._model)
except DAODeleteFailedError as ex:
logger.exception(ex.exception)
raise DatabaseDeleteFailedError()
return database
def validate(self) -> None:
# Validate/populate model exists
self._model = DatabaseDAO.find_by_id(self._model_id)
if not self._model:
raise DatabaseNotFoundError()
# Check if there are datasets for this database
if self._model.tables:
raise DatabaseDeleteDatasetsExistFailedError()

View File

@ -0,0 +1,111 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from flask_babel import lazy_gettext as _
from marshmallow.validate import ValidationError
from superset.commands.exceptions import (
CommandException,
CommandInvalidError,
CreateFailedError,
DeleteFailedError,
UpdateFailedError,
)
class DatabaseInvalidError(CommandInvalidError):
message = _("Dashboard parameters are invalid.")
class DatabaseExistsValidationError(ValidationError):
"""
Marshmallow validation error for dataset already exists
"""
def __init__(self) -> None:
super().__init__(
_("A database with the same name already exists"),
field_name="database_name",
)
class DatabaseRequiredFieldValidationError(ValidationError):
def __init__(self, field_name: str) -> None:
super().__init__(
[_("Field is required")], field_name=field_name,
)
class DatabaseExtraJSONValidationError(ValidationError):
"""
Marshmallow validation error for database encrypted extra must be a valid JSON
"""
def __init__(self, json_error: str = "") -> None:
super().__init__(
[
_(
"Field cannot be decoded by JSON. %{json_error}s",
json_error=json_error,
)
],
field_name="extra",
)
class DatabaseExtraValidationError(ValidationError):
"""
Marshmallow validation error for database encrypted extra must be a valid JSON
"""
def __init__(self, key: str = "") -> None:
super().__init__(
[
_(
"The metadata_params in Extra field "
"is not configured correctly. The key "
"%{key}s is invalid.",
key=key,
)
],
field_name="extra",
)
class DatabaseNotFoundError(CommandException):
message = _("Database not found.")
class DatabaseCreateFailedError(CreateFailedError):
message = _("Database could not be created.")
class DatabaseUpdateFailedError(UpdateFailedError):
message = _("Database could not be updated.")
class DatabaseConnectionFailedError( # pylint: disable=too-many-ancestors
DatabaseCreateFailedError, DatabaseUpdateFailedError,
):
message = _("Could not connect to database.")
class DatabaseDeleteDatasetsExistFailedError(DeleteFailedError):
message = _("Cannot delete a database that has tables attached")
class DatabaseDeleteFailedError(DeleteFailedError):
message = _("Database could not be deleted.")

View File

@ -0,0 +1,87 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Dict, List, Optional
from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
from marshmallow import ValidationError
from superset.commands.base import BaseCommand
from superset.dao.exceptions import DAOUpdateFailedError
from superset.databases.commands.exceptions import (
DatabaseConnectionFailedError,
DatabaseExistsValidationError,
DatabaseInvalidError,
DatabaseNotFoundError,
DatabaseUpdateFailedError,
)
from superset.databases.dao import DatabaseDAO
from superset.extensions import db, security_manager
from superset.models.core import Database
logger = logging.getLogger(__name__)
class UpdateDatabaseCommand(BaseCommand):
def __init__(self, user: User, model_id: int, data: Dict[str, Any]):
self._actor = user
self._properties = data.copy()
self._model_id = model_id
self._model: Optional[Database] = None
def run(self) -> Model:
self.validate()
try:
database = DatabaseDAO.update(self._model, self._properties, commit=False)
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
security_manager.add_permission_view_menu("database_access", database.perm)
# adding a new database we always want to force refresh schema list
# TODO Improve this simplistic implementation for catching DB conn fails
try:
schemas = database.get_all_schema_names()
except Exception:
db.session.rollback()
raise DatabaseConnectionFailedError()
for schema in schemas:
security_manager.add_permission_view_menu(
"schema_access", security_manager.get_schema_perm(database, schema)
)
db.session.commit()
except DAOUpdateFailedError as ex:
logger.exception(ex.exception)
raise DatabaseUpdateFailedError()
return database
def validate(self) -> None:
exceptions: List[ValidationError] = list()
# Validate/populate model exists
self._model = DatabaseDAO.find_by_id(self._model_id)
if not self._model:
raise DatabaseNotFoundError()
database_name: Optional[str] = self._properties.get("database_name")
if database_name:
# Check database_name uniqueness
if not DatabaseDAO.validate_update_uniqueness(
self._model_id, database_name
):
exceptions.append(DatabaseExistsValidationError())
if exceptions:
exception = DatabaseInvalidError()
exception.add_list(exceptions)
raise exception

43
superset/databases/dao.py Normal file
View File

@ -0,0 +1,43 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from superset.dao.base import BaseDAO
from superset.databases.filters import DatabaseFilter
from superset.extensions import db
from superset.models.core import Database
logger = logging.getLogger(__name__)
class DatabaseDAO(BaseDAO):
model_cls = Database
base_filter = DatabaseFilter
@staticmethod
def validate_uniqueness(database_name: str) -> bool:
database_query = db.session.query(Database).filter(
Database.database_name == database_name
)
return not db.session.query(database_query.exists()).scalar()
@staticmethod
def validate_update_uniqueness(database_id: int, database_name: str) -> bool:
database_query = db.session.query(Database).filter(
Database.database_name == database_name, Database.id != database_id,
)
return not db.session.query(database_query.exists()).scalar()

View File

@ -14,13 +14,266 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import inspect
import json
from flask_babel import lazy_gettext as _
from marshmallow import fields, Schema from marshmallow import fields, Schema
from marshmallow.validate import Length, ValidationError
from sqlalchemy import MetaData
from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import ArgumentError
from superset import app
from superset.exceptions import CertificateException
from superset.utils.core import markdown, parse_ssl_cert
database_schemas_query_schema = { database_schemas_query_schema = {
"type": "object", "type": "object",
"properties": {"force": {"type": "boolean"}}, "properties": {"force": {"type": "boolean"}},
} }
database_name_description = "A database name to identify this connection."
cache_timeout_description = (
"Duration (in seconds) of the caching timeout for charts of this database. "
"A timeout of 0 indicates that the cache never expires. "
"Note this defaults to the global timeout if undefined."
)
expose_in_sqllab_description = "Expose this database to SQLLab"
allow_run_async_description = (
"Operate the database in asynchronous mode, meaning "
"that the queries are executed on remote workers as opposed "
"to on the web server itself. "
"This assumes that you have a Celery worker setup as well "
"as a results backend. Refer to the installation docs "
"for more information."
)
allow_csv_upload_description = (
"Allow to upload CSV file data into this database"
"If selected, please set the schemas allowed for csv upload in Extra."
)
allow_ctas_description = "Allow CREATE TABLE AS option in SQL Lab"
allow_cvas_description = "Allow CREATE VIEW AS option in SQL Lab"
allow_dml_description = (
"Allow users to run non-SELECT statements "
"(UPDATE, DELETE, CREATE, ...) "
"in SQL Lab"
)
allow_multi_schema_metadata_fetch_description = (
"Allow SQL Lab to fetch a list of all tables and all views across "
"all database schemas. For large data warehouse with thousands of "
"tables, this can be expensive and put strain on the system."
) # pylint: disable=invalid-name
impersonate_user_description = (
"If Presto, all the queries in SQL Lab are going to be executed as the "
"currently logged on user who must have permission to run them.<br/>"
"If Hive and hive.server2.enable.doAs is enabled, will run the queries as "
"service account, but impersonate the currently logged on user "
"via hive.server2.proxy.user property."
)
force_ctas_schema_description = (
"When allowing CREATE TABLE AS option in SQL Lab, "
"this option forces the table to be created in this schema"
)
encrypted_extra_description = markdown(
"JSON string containing additional connection configuration.<br/>"
"This is used to provide connection information for systems like "
"Hive, Presto, and BigQuery, which do not conform to the username:password "
"syntax normally used by SQLAlchemy.",
True,
)
extra_description = markdown(
"JSON string containing extra configuration elements.<br/>"
"1. The ``engine_params`` object gets unpacked into the "
"[sqlalchemy.create_engine]"
"(https://docs.sqlalchemy.org/en/latest/core/engines.html#"
"sqlalchemy.create_engine) call, while the ``metadata_params`` "
"gets unpacked into the [sqlalchemy.MetaData]"
"(https://docs.sqlalchemy.org/en/rel_1_0/core/metadata.html"
"#sqlalchemy.schema.MetaData) call.<br/>"
"2. The ``metadata_cache_timeout`` is a cache timeout setting "
"in seconds for metadata fetch of this database. Specify it as "
'**"metadata_cache_timeout": {"schema_cache_timeout": 600, '
'"table_cache_timeout": 600}**. '
"If unset, cache will not be enabled for the functionality. "
"A timeout of 0 indicates that the cache never expires.<br/>"
"3. The ``schemas_allowed_for_csv_upload`` is a comma separated list "
"of schemas that CSVs are allowed to upload to. "
'Specify it as **"schemas_allowed_for_csv_upload": '
'["public", "csv_upload"]**. '
"If database flavor does not support schema or any schema is allowed "
"to be accessed, just leave the list empty<br/>"
"4. the ``version`` field is a string specifying the this db's version. "
"This should be used with Presto DBs so that the syntax is correct<br/>"
"5. The ``allows_virtual_table_explore`` field is a boolean specifying "
"whether or not the Explore button in SQL Lab results is shown.",
True,
)
sqlalchemy_uri_description = markdown(
"Refer to the "
"[SqlAlchemy docs]"
"(https://docs.sqlalchemy.org/en/rel_1_2/core/engines.html#"
"database-urls) "
"for more information on how to structure your URI.",
True,
)
server_cert_description = markdown(
"Optional CA_BUNDLE contents to validate HTTPS requests. Only available "
"on certain database engines.",
True,
)
def sqlalchemy_uri_validator(value: str) -> str:
"""
Validate if it's a valid SQLAlchemy URI and refuse SQLLite by default
"""
try:
make_url(value.strip())
except (ArgumentError, AttributeError):
raise ValidationError(
[
_(
"Invalid connection string, a valid string usually follows:"
"'DRIVER://USER:PASSWORD@DB-HOST/DATABASE-NAME'"
"<p>"
"Example:'postgresql://user:password@your-postgres-db/database'"
"</p>"
)
]
)
if app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] and value:
if value.startswith("sqlite"):
raise ValidationError(
[
_(
"SQLite database cannot be used as a data source for "
"security reasons."
)
]
)
return value
def server_cert_validator(value: str) -> str:
"""
Validate the server certificate
"""
if value:
try:
parse_ssl_cert(value)
except CertificateException:
raise ValidationError([_("Invalid certificate")])
return value
def encrypted_extra_validator(value: str) -> str:
"""
Validate that encrypted extra is a valid JSON string
"""
if value:
try:
json.loads(value)
except json.JSONDecodeError as ex:
raise ValidationError(
[_("Field cannot be decoded by JSON. %(msg)s", msg=str(ex))]
)
return value
def extra_validator(value: str) -> str:
"""
Validate that extra is a valid JSON string, and that metadata_params
keys are on the call signature for SQLAlchemy Metadata
"""
if value:
try:
extra_ = json.loads(value)
except json.JSONDecodeError as ex:
raise ValidationError(
[_("Field cannot be decoded by JSON. %(msg)s", msg=str(ex))]
)
else:
metadata_signature = inspect.signature(MetaData)
for key in extra_.get("metadata_params", {}):
if key not in metadata_signature.parameters:
raise ValidationError(
[
_(
"The metadata_params in Extra field "
"is not configured correctly. The key "
"%(key)s is invalid.",
key=key,
)
]
)
return value
class DatabasePostSchema(Schema):
database_name = fields.String(
description=database_name_description, required=True, validate=Length(1, 250),
)
cache_timeout = fields.Integer(description=cache_timeout_description)
expose_in_sqllab = fields.Boolean(description=expose_in_sqllab_description)
allow_run_async = fields.Boolean(description=allow_run_async_description)
allow_csv_upload = fields.Boolean(description=allow_csv_upload_description)
allow_ctas = fields.Boolean(description=allow_ctas_description)
allow_cvas = fields.Boolean(description=allow_cvas_description)
allow_dml = fields.Boolean(description=allow_dml_description)
force_ctas_schema = fields.String(
description=force_ctas_schema_description, validate=Length(0, 250)
)
allow_multi_schema_metadata_fetch = fields.Boolean(
description=allow_multi_schema_metadata_fetch_description,
)
impersonate_user = fields.Boolean(description=impersonate_user_description)
encrypted_extra = fields.String(
description=encrypted_extra_description, validate=encrypted_extra_validator
)
extra = fields.String(description=extra_description, validate=extra_validator)
server_cert = fields.String(
description=server_cert_description, validate=server_cert_validator
)
sqlalchemy_uri = fields.String(
description=sqlalchemy_uri_description,
required=True,
validate=[Length(1, 1024), sqlalchemy_uri_validator],
)
class DatabasePutSchema(Schema):
database_name = fields.String(
description=database_name_description, allow_none=True, validate=Length(1, 250),
)
cache_timeout = fields.Integer(description=cache_timeout_description)
expose_in_sqllab = fields.Boolean(description=expose_in_sqllab_description)
allow_run_async = fields.Boolean(description=allow_run_async_description)
allow_csv_upload = fields.Boolean(description=allow_csv_upload_description)
allow_ctas = fields.Boolean(description=allow_ctas_description)
allow_cvas = fields.Boolean(description=allow_cvas_description)
allow_dml = fields.Boolean(description=allow_dml_description)
force_ctas_schema = fields.String(
description=force_ctas_schema_description, validate=Length(0, 250)
)
allow_multi_schema_metadata_fetch = fields.Boolean(
description=allow_multi_schema_metadata_fetch_description
)
impersonate_user = fields.Boolean(description=impersonate_user_description)
encrypted_extra = fields.String(
description=encrypted_extra_description, validate=encrypted_extra_validator
)
extra = fields.String(description=extra_description, validate=extra_validator)
server_cert = fields.String(
description=server_cert_description, validate=server_cert_validator
)
sqlalchemy_uri = fields.String(
description=sqlalchemy_uri_description,
allow_none=True,
validate=[Length(0, 1024), sqlalchemy_uri_validator],
)
class TableMetadataOptionsResponseSchema(Schema): class TableMetadataOptionsResponseSchema(Schema):
deferrable = fields.Bool() deferrable = fields.Bool()

100
superset/databases/utils.py Normal file
View File

@ -0,0 +1,100 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict, List, Optional
from superset import app
from superset.models.core import Database
custom_password_store = app.config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"]
def get_foreign_keys_metadata(
database: Database, table_name: str, schema_name: Optional[str]
) -> List[Dict[str, Any]]:
foreign_keys = database.get_foreign_keys(table_name, schema_name)
for fk in foreign_keys:
fk["column_names"] = fk.pop("constrained_columns")
fk["type"] = "fk"
return foreign_keys
def get_indexes_metadata(
database: Database, table_name: str, schema_name: Optional[str]
) -> List[Dict[str, Any]]:
indexes = database.get_indexes(table_name, schema_name)
for idx in indexes:
idx["type"] = "index"
return indexes
def get_col_type(col: Dict[Any, Any]) -> str:
try:
dtype = f"{col['type']}"
except Exception: # pylint: disable=broad-except
# sqla.types.JSON __str__ has a bug, so using __class__.
dtype = col["type"].__class__.__name__
return dtype
def get_table_metadata(
database: Database, table_name: str, schema_name: Optional[str]
) -> Dict[str, Any]:
"""
Get table metadata information, including type, pk, fks.
This function raises SQLAlchemyError when a schema is not found.
:param database: The database model
:param table_name: Table name
:param schema_name: schema name
:return: Dict table metadata ready for API response
"""
keys = []
columns = database.get_columns(table_name, schema_name)
primary_key = database.get_pk_constraint(table_name, schema_name)
if primary_key and primary_key.get("constrained_columns"):
primary_key["column_names"] = primary_key.pop("constrained_columns")
primary_key["type"] = "pk"
keys += [primary_key]
foreign_keys = get_foreign_keys_metadata(database, table_name, schema_name)
indexes = get_indexes_metadata(database, table_name, schema_name)
keys += foreign_keys + indexes
payload_columns: List[Dict[str, Any]] = []
for col in columns:
dtype = get_col_type(col)
payload_columns.append(
{
"name": col["name"],
"type": dtype.split("(")[0] if "(" in dtype else dtype,
"longType": dtype,
"keys": [k for k in keys if col["name"] in k["column_names"]],
}
)
return {
"name": table_name,
"columns": payload_columns,
"selectStar": database.select_star(
table_name,
schema=schema_name,
show_cols=True,
indent=True,
cols=columns,
latest_partition=True,
),
"primaryKey": primary_key,
"foreignKeys": foreign_keys,
"indexes": keys,
}

View File

@ -25,6 +25,7 @@ from marshmallow import ValidationError
from superset.connectors.sqla.models import SqlaTable from superset.connectors.sqla.models import SqlaTable
from superset.constants import RouteMethod from superset.constants import RouteMethod
from superset.databases.filters import DatabaseFilter
from superset.datasets.commands.create import CreateDatasetCommand from superset.datasets.commands.create import CreateDatasetCommand
from superset.datasets.commands.delete import DeleteDatasetCommand from superset.datasets.commands.delete import DeleteDatasetCommand
from superset.datasets.commands.exceptions import ( from superset.datasets.commands.exceptions import (
@ -51,7 +52,6 @@ from superset.views.base_api import (
RelatedFieldFilter, RelatedFieldFilter,
statsd_metrics, statsd_metrics,
) )
from superset.views.database.filters import DatabaseFilter
from superset.views.filters import FilterRelatedOwners from superset.views.filters import FilterRelatedOwners
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -106,6 +106,9 @@ class BaseSupersetModelRestApi(ModelRestApi):
"data": "list", "data": "list",
"viz_types": "list", "viz_types": "list",
"related_objects": "list", "related_objects": "list",
"table_metadata": "list",
"select_star": "list",
"schemas": "list",
} }
order_rel_fields: Dict[str, Tuple[str, str]] = {} order_rel_fields: Dict[str, Tuple[str, str]] = {}

View File

@ -68,6 +68,7 @@ from superset.connectors.sqla.models import (
TableColumn, TableColumn,
) )
from superset.dashboards.dao import DashboardDAO from superset.dashboards.dao import DashboardDAO
from superset.databases.filters import DatabaseFilter
from superset.exceptions import ( from superset.exceptions import (
CertificateException, CertificateException,
DatabaseNotFound, DatabaseNotFound,
@ -109,7 +110,6 @@ from superset.views.base import (
json_success, json_success,
validate_sqlatable, validate_sqlatable,
) )
from superset.views.database.filters import DatabaseFilter
from superset.views.utils import ( from superset.views.utils import (
_deserialize_results_payload, _deserialize_results_payload,
apply_display_max_row_limit, apply_display_max_row_limit,

View File

@ -21,11 +21,11 @@ from flask_babel import lazy_gettext as _
from sqlalchemy import MetaData from sqlalchemy import MetaData
from superset import app, security_manager from superset import app, security_manager
from superset.databases.filters import DatabaseFilter
from superset.exceptions import SupersetException from superset.exceptions import SupersetException
from superset.models.core import Database from superset.models.core import Database
from superset.security.analytics_db_safety import check_sqlalchemy_uri from superset.security.analytics_db_safety import check_sqlalchemy_uri
from superset.utils import core as utils from superset.utils import core as utils
from superset.views.database.filters import DatabaseFilter
class DatabaseMixin: class DatabaseMixin:
@ -240,7 +240,7 @@ class DatabaseMixin:
extra = database.get_extra() extra = database.get_extra()
except Exception as ex: except Exception as ex:
raise Exception( raise Exception(
_("Extra field cannot be decoded by JSON. %{msg}s", msg=str(ex)) _("Extra field cannot be decoded by JSON. %(msg)s", msg=str(ex))
) )
# this will check whether 'metadata_params' is configured correctly # this will check whether 'metadata_params' is configured correctly
@ -264,5 +264,5 @@ class DatabaseMixin:
database.get_encrypted_extra() database.get_encrypted_extra()
except Exception as ex: except Exception as ex:
raise Exception( raise Exception(
_("Extra field cannot be decoded by JSON. %{msg}s", msg=str(ex)) _("Extra field cannot be decoded by JSON. %(msg)s", msg=str(ex))
) )

View File

@ -36,11 +36,15 @@ def sqlalchemy_uri_validator(
make_url(uri.strip()) make_url(uri.strip())
except (ArgumentError, AttributeError): except (ArgumentError, AttributeError):
raise exception( raise exception(
_( [
"Invalid connection string, a valid string usually follows:" _(
"'DRIVER://USER:PASSWORD@DB-HOST/DATABASE-NAME'" "Invalid connection string, a valid string usually follows:"
"<p>Example:'postgresql://user:password@your-postgres-db/database'</p>" "'DRIVER://USER:PASSWORD@DB-HOST/DATABASE-NAME'"
) "<p>"
"Example:'postgresql://user:password@your-postgres-db/database'"
"</p>"
)
]
) )

View File

@ -1,268 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
"""Unit tests for Superset"""
import json
from unittest import mock
import prison
from sqlalchemy.sql import func
import tests.test_app
from superset import db, security_manager
from superset.connectors.sqla.models import SqlaTable
from superset.models.core import Database
from superset.utils.core import get_example_database, get_main_database
from .base_tests import SupersetTestCase
class TestDatabaseApi(SupersetTestCase):
def test_get_items(self):
"""
Database API: Test get items
"""
self.login(username="admin")
uri = "api/v1/database/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
response = json.loads(rv.data.decode("utf-8"))
expected_columns = [
"allow_csv_upload",
"allow_ctas",
"allow_cvas",
"allow_dml",
"allow_multi_schema_metadata_fetch",
"allow_run_async",
"allows_cost_estimate",
"allows_subquery",
"allows_virtual_table_explore",
"backend",
"database_name",
"explore_database_id",
"expose_in_sqllab",
"force_ctas_schema",
"function_names",
"id",
]
self.assertEqual(response["count"], 2)
self.assertEqual(list(response["result"][0].keys()), expected_columns)
def test_get_items_filter(self):
"""
Database API: Test get items with filter
"""
fake_db = (
db.session.query(Database).filter_by(database_name="fake_db_100").one()
)
old_expose_in_sqllab = fake_db.expose_in_sqllab
fake_db.expose_in_sqllab = False
db.session.commit()
self.login(username="admin")
arguments = {
"keys": ["none"],
"filters": [{"col": "expose_in_sqllab", "opr": "eq", "value": True}],
"order_columns": "database_name",
"order_direction": "asc",
"page": 0,
"page_size": -1,
}
uri = f"api/v1/database/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
self.assertEqual(response["count"], 1)
fake_db = (
db.session.query(Database).filter_by(database_name="fake_db_100").one()
)
fake_db.expose_in_sqllab = old_expose_in_sqllab
db.session.commit()
def test_get_items_not_allowed(self):
"""
Database API: Test get items not allowed
"""
self.login(username="gamma")
uri = f"api/v1/database/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(response["count"], 0)
def test_get_table_metadata(self):
"""
Database API: Test get table metadata info
"""
example_db = get_example_database()
self.login(username="admin")
uri = f"api/v1/database/{example_db.id}/table/birth_names/null/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(response["name"], "birth_names")
self.assertTrue(len(response["columns"]) > 5)
self.assertTrue(response.get("selectStar").startswith("SELECT"))
def test_get_invalid_database_table_metadata(self):
"""
Database API: Test get invalid database from table metadata
"""
database_id = 1000
self.login(username="admin")
uri = f"api/v1/database/{database_id}/table/some_table/some_schema/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
uri = f"api/v1/database/some_database/table/some_table/some_schema/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
def test_get_invalid_table_table_metadata(self):
"""
Database API: Test get invalid table from table metadata
"""
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/wrong_table/null/"
self.login(username="admin")
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
def test_get_table_metadata_no_db_permission(self):
"""
Database API: Test get table metadata from not permitted db
"""
self.login(username="gamma")
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/birth_names/null/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
def test_get_select_star(self):
"""
Database API: Test get select star
"""
self.login(username="admin")
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/select_star/birth_names/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
response = json.loads(rv.data.decode("utf-8"))
self.assertIn("gender", response["result"])
def test_get_select_star_not_allowed(self):
"""
Database API: Test get select star not allowed
"""
self.login(username="gamma")
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/select_star/birth_names/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
def test_get_select_star_datasource_access(self):
"""
Database API: Test get select star with datasource access
"""
session = db.session
table = SqlaTable(
schema="main", table_name="ab_permission", database=get_main_database()
)
session.add(table)
session.commit()
tmp_table_perm = security_manager.find_permission_view_menu(
"datasource_access", table.get_perm()
)
gamma_role = security_manager.find_role("Gamma")
security_manager.add_permission_role(gamma_role, tmp_table_perm)
self.login(username="gamma")
main_db = get_main_database()
uri = f"api/v1/database/{main_db.id}/select_star/ab_permission/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
# rollback changes
security_manager.del_permission_role(gamma_role, tmp_table_perm)
db.session.delete(table)
db.session.delete(main_db)
db.session.commit()
def test_get_select_star_not_found_database(self):
"""
Database API: Test get select star not found database
"""
self.login(username="admin")
max_id = db.session.query(func.max(Database.id)).scalar()
uri = f"api/v1/database/{max_id + 1}/select_star/birth_names/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
def test_get_select_star_not_found_table(self):
"""
Database API: Test get select star not found database
"""
self.login(username="admin")
example_db = get_example_database()
# sqllite will not raise a NoSuchTableError
if example_db.backend == "sqlite":
return
uri = f"api/v1/database/{example_db.id}/select_star/table_does_not_exist/"
rv = self.client.get(uri)
# TODO(bkyryliuk): investigate why presto returns 500
self.assertEqual(rv.status_code, 404 if example_db.backend != "presto" else 500)
def test_database_schemas(self):
"""
Database API: Test database schemas
"""
self.login("admin")
database = db.session.query(Database).first()
schemas = database.get_all_schema_names()
rv = self.client.get(f"api/v1/database/{database.id}/schemas/")
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(schemas, response["result"])
rv = self.client.get(
f"api/v1/database/{database.id}/schemas/?q={prison.dumps({'force': True})}"
)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(schemas, response["result"])
def test_database_schemas_not_found(self):
"""
Database API: Test database schemas not found
"""
self.logout()
self.login(username="gamma")
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/schemas/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
def test_database_schemas_invalid_query(self):
"""
Database API: Test database schemas with invalid query
"""
self.login("admin")
database = db.session.query(Database).first()
rv = self.client.get(
f"api/v1/database/{database.id}/schemas/?q={prison.dumps({'force': 'nop'})}"
)
self.assertEqual(rv.status_code, 400)

View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@ -0,0 +1,650 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
"""Unit tests for Superset"""
import json
import prison
from sqlalchemy.sql import func
import tests.test_app
from superset import db, security_manager
from superset.connectors.sqla.models import SqlaTable
from superset.models.core import Database
from superset.utils.core import get_example_database, get_main_database
from tests.base_tests import SupersetTestCase
from tests.fixtures.certificates import ssl_certificate
class TestDatabaseApi(SupersetTestCase):
def insert_database(
self,
database_name: str,
sqlalchemy_uri: str,
extra: str = "",
encrypted_extra: str = "",
server_cert: str = "",
expose_in_sqllab: bool = False,
) -> Database:
database = Database(
database_name=database_name,
sqlalchemy_uri=sqlalchemy_uri,
extra=extra,
encrypted_extra=encrypted_extra,
server_cert=server_cert,
expose_in_sqllab=expose_in_sqllab,
)
db.session.add(database)
db.session.commit()
return database
def test_get_items(self):
"""
Database API: Test get items
"""
self.login(username="admin")
uri = "api/v1/database/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
response = json.loads(rv.data.decode("utf-8"))
expected_columns = [
"allow_csv_upload",
"allow_ctas",
"allow_cvas",
"allow_dml",
"allow_multi_schema_metadata_fetch",
"allow_run_async",
"allows_cost_estimate",
"allows_subquery",
"allows_virtual_table_explore",
"backend",
"database_name",
"explore_database_id",
"expose_in_sqllab",
"force_ctas_schema",
"function_names",
"id",
]
self.assertEqual(response["count"], 2)
self.assertEqual(list(response["result"][0].keys()), expected_columns)
def test_get_items_filter(self):
"""
Database API: Test get items with filter
"""
example_db = get_example_database()
test_database = self.insert_database(
"test-database", example_db.sqlalchemy_uri_decrypted, expose_in_sqllab=True
)
dbs = db.session.query(Database).filter_by(expose_in_sqllab=True).all()
self.login(username="admin")
arguments = {
"keys": ["none"],
"filters": [{"col": "expose_in_sqllab", "opr": "eq", "value": True}],
"order_columns": "database_name",
"order_direction": "asc",
"page": 0,
"page_size": -1,
}
uri = f"api/v1/database/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
self.assertEqual(response["count"], len(dbs))
# Cleanup
db.session.delete(test_database)
db.session.commit()
def test_get_items_not_allowed(self):
"""
Database API: Test get items not allowed
"""
self.login(username="gamma")
uri = f"api/v1/database/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(response["count"], 0)
def test_create_database(self):
"""
Database API: Test create
"""
extra = {
"metadata_params": {},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_csv_upload": [],
}
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
database_data = {
"database_name": "test-database",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"server_cert": ssl_certificate,
"extra": json.dumps(extra),
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
db.session.commit()
def test_create_database_server_cert_validate(self):
"""
Database API: Test create server cert validation
"""
example_db = get_example_database()
if example_db.backend == "sqlite":
return
self.login(username="admin")
database_data = {
"database_name": "test-database",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"server_cert": "INVALID CERT",
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {"message": {"server_cert": ["Invalid certificate"]}}
self.assertEqual(rv.status_code, 400)
self.assertEqual(response, expected_response)
def test_create_database_json_validate(self):
"""
Database API: Test create encrypted extra and extra validation
"""
example_db = get_example_database()
if example_db.backend == "sqlite":
return
self.login(username="admin")
database_data = {
"database_name": "test-database",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"encrypted_extra": '{"A": "a", "B", "C"}',
"extra": '["A": "a", "B", "C"]',
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"message": {
"encrypted_extra": [
"Field cannot be decoded by JSON. Expecting ':' "
"delimiter: line 1 column 15 (char 14)"
],
"extra": [
"Field cannot be decoded by JSON. Expecting ','"
" delimiter: line 1 column 5 (char 4)"
],
}
}
self.assertEqual(rv.status_code, 400)
self.assertEqual(response, expected_response)
def test_create_database_extra_metadata_validate(self):
"""
Database API: Test create extra metadata_params validation
"""
example_db = get_example_database()
if example_db.backend == "sqlite":
return
extra = {
"metadata_params": {"wrong_param": "some_value"},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_csv_upload": [],
}
self.login(username="admin")
database_data = {
"database_name": "test-database",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"extra": json.dumps(extra),
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"message": {
"extra": [
"The metadata_params in Extra field is not configured correctly."
" The key wrong_param is invalid."
]
}
}
self.assertEqual(rv.status_code, 400)
self.assertEqual(response, expected_response)
def test_create_database_unique_validate(self):
"""
Database API: Test create database_name already exists
"""
example_db = get_example_database()
if example_db.backend == "sqlite":
return
self.login(username="admin")
database_data = {
"database_name": "examples",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"message": {"database_name": "A database with the same name already exists"}
}
self.assertEqual(rv.status_code, 422)
self.assertEqual(response, expected_response)
def test_create_database_uri_validate(self):
"""
Database API: Test create fail validate sqlalchemy uri
"""
self.login(username="admin")
database_data = {
"database_name": "test-database",
"sqlalchemy_uri": "wrong_uri",
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 400)
expected_response = {
"message": {
"sqlalchemy_uri": [
"Invalid connection string, a valid string usually "
"follows:'DRIVER://USER:PASSWORD@DB-HOST/DATABASE-NAME'"
"<p>Example:'postgresql://user:password@your-postgres-db/database'"
"</p>"
]
}
}
self.assertEqual(response, expected_response)
def test_create_database_fail_sqllite(self):
"""
Database API: Test create fail with sqllite
"""
database_data = {
"database_name": "test-database",
"sqlalchemy_uri": "sqlite:////some.db",
}
uri = "api/v1/database/"
self.login(username="admin")
response = self.client.post(uri, json=database_data)
response_data = json.loads(response.data.decode("utf-8"))
expected_response = {
"message": {
"sqlalchemy_uri": [
"SQLite database cannot be used as a data source "
"for security reasons."
]
}
}
self.assertEqual(response.status_code, 400)
self.assertEqual(response_data, expected_response)
def test_create_database_conn_fail(self):
"""
Database API: Test create fails connection
"""
example_db = get_example_database()
if example_db.backend in ("sqlite", "hive", "presto"):
return
example_db.password = "wrong_password"
database_data = {
"database_name": "test-database",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
}
uri = "api/v1/database/"
self.login(username="admin")
response = self.client.post(uri, json=database_data)
response_data = json.loads(response.data.decode("utf-8"))
expected_response = {"message": "Could not connect to database."}
self.assertEqual(response.status_code, 422)
self.assertEqual(response_data, expected_response)
def test_update_database(self):
"""
Database API: Test update
"""
example_db = get_example_database()
test_database = self.insert_database(
"test-database", example_db.sqlalchemy_uri_decrypted
)
self.login(username="admin")
database_data = {"database_name": "test-database-updated"}
uri = f"api/v1/database/{test_database.id}"
rv = self.client.put(uri, json=database_data)
self.assertEqual(rv.status_code, 200)
# Cleanup
model = db.session.query(Database).get(test_database.id)
db.session.delete(model)
db.session.commit()
def test_update_database_conn_fail(self):
"""
Database API: Test update fails connection
"""
example_db = get_example_database()
if example_db.backend in ("sqlite", "hive", "presto"):
return
test_database = self.insert_database(
"test-database1", example_db.sqlalchemy_uri_decrypted
)
example_db.password = "wrong_password"
database_data = {
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
}
uri = f"api/v1/database/{test_database.id}"
self.login(username="admin")
rv = self.client.put(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {"message": "Could not connect to database."}
self.assertEqual(rv.status_code, 422)
self.assertEqual(response, expected_response)
# Cleanup
model = db.session.query(Database).get(test_database.id)
db.session.delete(model)
db.session.commit()
def test_update_database_uniqueness(self):
"""
Database API: Test update uniqueness
"""
example_db = get_example_database()
test_database1 = self.insert_database(
"test-database1", example_db.sqlalchemy_uri_decrypted
)
test_database2 = self.insert_database(
"test-database2", example_db.sqlalchemy_uri_decrypted
)
self.login(username="admin")
database_data = {"database_name": "test-database2"}
uri = f"api/v1/database/{test_database1.id}"
rv = self.client.put(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"message": {"database_name": "A database with the same name already exists"}
}
self.assertEqual(rv.status_code, 422)
self.assertEqual(response, expected_response)
# Cleanup
db.session.delete(test_database1)
db.session.delete(test_database2)
db.session.commit()
def test_update_database_invalid(self):
"""
Database API: Test update invalid request
"""
self.login(username="admin")
database_data = {"database_name": "test-database-updated"}
uri = f"api/v1/database/invalid"
rv = self.client.put(uri, json=database_data)
self.assertEqual(rv.status_code, 404)
def test_update_database_uri_validate(self):
"""
Database API: Test update sqlalchemy_uri validate
"""
example_db = get_example_database()
test_database = self.insert_database(
"test-database", example_db.sqlalchemy_uri_decrypted
)
self.login(username="admin")
database_data = {
"database_name": "test-database-updated",
"sqlalchemy_uri": "wrong_uri",
}
uri = f"api/v1/database/{test_database.id}"
rv = self.client.put(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 400)
expected_response = {
"message": {
"sqlalchemy_uri": [
"Invalid connection string, a valid string usually "
"follows:'DRIVER://USER:PASSWORD@DB-HOST/DATABASE-NAME'"
"<p>Example:'postgresql://user:password@your-postgres-db/database'"
"</p>"
]
}
}
self.assertEqual(response, expected_response)
def test_delete_database(self):
"""
Database API: Test delete
"""
database_id = self.insert_database("test-database", "test_uri").id
self.login(username="admin")
uri = f"api/v1/database/{database_id}"
rv = self.delete_assert_metric(uri, "delete")
self.assertEqual(rv.status_code, 200)
model = db.session.query(Database).get(database_id)
self.assertEqual(model, None)
def test_delete_database_not_found(self):
"""
Database API: Test delete not found
"""
max_id = db.session.query(func.max(Database.id)).scalar()
self.login(username="admin")
uri = f"api/v1/database/{max_id + 1}"
rv = self.delete_assert_metric(uri, "delete")
self.assertEqual(rv.status_code, 404)
def test_delete_database_with_datasets(self):
"""
Database API: Test delete fails because it has depending datasets
"""
database_id = (
db.session.query(Database).filter_by(database_name="examples").one()
).id
self.login(username="admin")
uri = f"api/v1/database/{database_id}"
rv = self.delete_assert_metric(uri, "delete")
self.assertEqual(rv.status_code, 422)
def test_get_table_metadata(self):
"""
Database API: Test get table metadata info
"""
example_db = get_example_database()
self.login(username="admin")
uri = f"api/v1/database/{example_db.id}/table/birth_names/null/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(response["name"], "birth_names")
self.assertTrue(len(response["columns"]) > 5)
self.assertTrue(response.get("selectStar").startswith("SELECT"))
def test_get_invalid_database_table_metadata(self):
"""
Database API: Test get invalid database from table metadata
"""
database_id = 1000
self.login(username="admin")
uri = f"api/v1/database/{database_id}/table/some_table/some_schema/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
uri = f"api/v1/database/some_database/table/some_table/some_schema/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
def test_get_invalid_table_table_metadata(self):
"""
Database API: Test get invalid table from table metadata
"""
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/wrong_table/null/"
self.login(username="admin")
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
def test_get_table_metadata_no_db_permission(self):
"""
Database API: Test get table metadata from not permitted db
"""
self.login(username="gamma")
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/birth_names/null/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
def test_get_select_star(self):
"""
Database API: Test get select star
"""
self.login(username="admin")
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/select_star/birth_names/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
response = json.loads(rv.data.decode("utf-8"))
self.assertIn("gender", response["result"])
def test_get_select_star_not_allowed(self):
"""
Database API: Test get select star not allowed
"""
self.login(username="gamma")
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/select_star/birth_names/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
def test_get_select_star_datasource_access(self):
"""
Database API: Test get select star with datasource access
"""
session = db.session
table = SqlaTable(
schema="main", table_name="ab_permission", database=get_main_database()
)
session.add(table)
session.commit()
tmp_table_perm = security_manager.find_permission_view_menu(
"datasource_access", table.get_perm()
)
gamma_role = security_manager.find_role("Gamma")
security_manager.add_permission_role(gamma_role, tmp_table_perm)
self.login(username="gamma")
main_db = get_main_database()
uri = f"api/v1/database/{main_db.id}/select_star/ab_permission/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
# rollback changes
security_manager.del_permission_role(gamma_role, tmp_table_perm)
db.session.delete(table)
db.session.delete(main_db)
db.session.commit()
def test_get_select_star_not_found_database(self):
"""
Database API: Test get select star not found database
"""
self.login(username="admin")
max_id = db.session.query(func.max(Database.id)).scalar()
uri = f"api/v1/database/{max_id + 1}/select_star/birth_names/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
def test_get_select_star_not_found_table(self):
"""
Database API: Test get select star not found database
"""
self.login(username="admin")
example_db = get_example_database()
# sqllite will not raise a NoSuchTableError
if example_db.backend == "sqlite":
return
uri = f"api/v1/database/{example_db.id}/select_star/table_does_not_exist/"
rv = self.client.get(uri)
# TODO(bkyryliuk): investigate why presto returns 500
self.assertEqual(rv.status_code, 404 if example_db.backend != "presto" else 500)
def test_database_schemas(self):
"""
Database API: Test database schemas
"""
self.login("admin")
database = db.session.query(Database).first()
schemas = database.get_all_schema_names()
rv = self.client.get(f"api/v1/database/{database.id}/schemas/")
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(schemas, response["result"])
rv = self.client.get(
f"api/v1/database/{database.id}/schemas/?q={prison.dumps({'force': True})}"
)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(schemas, response["result"])
def test_database_schemas_not_found(self):
"""
Database API: Test database schemas not found
"""
self.logout()
self.login(username="gamma")
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/schemas/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
def test_database_schemas_invalid_query(self):
"""
Database API: Test database schemas with invalid query
"""
self.login("admin")
database = db.session.query(Database).first()
rv = self.client.get(
f"api/v1/database/{database.id}/schemas/?q={prison.dumps({'force': 'nop'})}"
)
self.assertEqual(rv.status_code, 400)