mirror of https://github.com/apache/superset.git
feat(database): POST, PUT, DELETE API endpoints (#10741)
* feat(database): POST, PUT, DELETE API endpoints * post tests * more tests * lint * lint * debug ci * fix test * fix test * fix test * fix test * fix test * fix test * cleanup * handle db connection failures * lint * skip hive and presto for connection fail test * fix typo
This commit is contained in:
parent
b5aecaff5c
commit
77a3167412
|
@ -14,126 +14,78 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import Any, Dict, List, Optional
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import g, request, Response
|
||||
from flask_appbuilder.api import expose, protect, rison, safe
|
||||
from flask_appbuilder.models.sqla.interface import SQLAInterface
|
||||
from marshmallow import ValidationError
|
||||
from sqlalchemy.exc import NoSuchTableError, OperationalError, SQLAlchemyError
|
||||
|
||||
from superset import event_logger
|
||||
from superset.constants import RouteMethod
|
||||
from superset.databases.commands.create import CreateDatabaseCommand
|
||||
from superset.databases.commands.delete import DeleteDatabaseCommand
|
||||
from superset.databases.commands.exceptions import (
|
||||
DatabaseCreateFailedError,
|
||||
DatabaseDeleteDatasetsExistFailedError,
|
||||
DatabaseDeleteFailedError,
|
||||
DatabaseInvalidError,
|
||||
DatabaseNotFoundError,
|
||||
DatabaseUpdateFailedError,
|
||||
)
|
||||
from superset.databases.commands.update import UpdateDatabaseCommand
|
||||
from superset.databases.decorators import check_datasource_access
|
||||
from superset.databases.filters import DatabaseFilter
|
||||
from superset.databases.schemas import (
|
||||
database_schemas_query_schema,
|
||||
DatabasePostSchema,
|
||||
DatabasePutSchema,
|
||||
SchemasResponseSchema,
|
||||
SelectStarResponseSchema,
|
||||
TableMetadataResponseSchema,
|
||||
)
|
||||
from superset.databases.utils import get_table_metadata
|
||||
from superset.extensions import security_manager
|
||||
from superset.models.core import Database
|
||||
from superset.typing import FlaskResponse
|
||||
from superset.utils.core import error_msg_from_exception
|
||||
from superset.views.base_api import BaseSupersetModelRestApi, statsd_metrics
|
||||
from superset.views.database.filters import DatabaseFilter
|
||||
from superset.views.database.validators import sqlalchemy_uri_validator
|
||||
|
||||
|
||||
def get_foreign_keys_metadata(
|
||||
database: Database, table_name: str, schema_name: Optional[str]
|
||||
) -> List[Dict[str, Any]]:
|
||||
foreign_keys = database.get_foreign_keys(table_name, schema_name)
|
||||
for fk in foreign_keys:
|
||||
fk["column_names"] = fk.pop("constrained_columns")
|
||||
fk["type"] = "fk"
|
||||
return foreign_keys
|
||||
|
||||
|
||||
def get_indexes_metadata(
|
||||
database: Database, table_name: str, schema_name: Optional[str]
|
||||
) -> List[Dict[str, Any]]:
|
||||
indexes = database.get_indexes(table_name, schema_name)
|
||||
for idx in indexes:
|
||||
idx["type"] = "index"
|
||||
return indexes
|
||||
|
||||
|
||||
def get_col_type(col: Dict[Any, Any]) -> str:
|
||||
try:
|
||||
dtype = f"{col['type']}"
|
||||
except Exception: # pylint: disable=broad-except
|
||||
# sqla.types.JSON __str__ has a bug, so using __class__.
|
||||
dtype = col["type"].__class__.__name__
|
||||
return dtype
|
||||
|
||||
|
||||
def get_table_metadata(
|
||||
database: Database, table_name: str, schema_name: Optional[str]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get table metadata information, including type, pk, fks.
|
||||
This function raises SQLAlchemyError when a schema is not found.
|
||||
|
||||
:param database: The database model
|
||||
:param table_name: Table name
|
||||
:param schema_name: schema name
|
||||
:return: Dict table metadata ready for API response
|
||||
"""
|
||||
keys = []
|
||||
columns = database.get_columns(table_name, schema_name)
|
||||
primary_key = database.get_pk_constraint(table_name, schema_name)
|
||||
if primary_key and primary_key.get("constrained_columns"):
|
||||
primary_key["column_names"] = primary_key.pop("constrained_columns")
|
||||
primary_key["type"] = "pk"
|
||||
keys += [primary_key]
|
||||
foreign_keys = get_foreign_keys_metadata(database, table_name, schema_name)
|
||||
indexes = get_indexes_metadata(database, table_name, schema_name)
|
||||
keys += foreign_keys + indexes
|
||||
payload_columns: List[Dict[str, Any]] = []
|
||||
for col in columns:
|
||||
dtype = get_col_type(col)
|
||||
payload_columns.append(
|
||||
{
|
||||
"name": col["name"],
|
||||
"type": dtype.split("(")[0] if "(" in dtype else dtype,
|
||||
"longType": dtype,
|
||||
"keys": [k for k in keys if col["name"] in k["column_names"]],
|
||||
}
|
||||
)
|
||||
return {
|
||||
"name": table_name,
|
||||
"columns": payload_columns,
|
||||
"selectStar": database.select_star(
|
||||
table_name,
|
||||
schema=schema_name,
|
||||
show_cols=True,
|
||||
indent=True,
|
||||
cols=columns,
|
||||
latest_partition=True,
|
||||
),
|
||||
"primaryKey": primary_key,
|
||||
"foreignKeys": foreign_keys,
|
||||
"indexes": keys,
|
||||
}
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatabaseRestApi(BaseSupersetModelRestApi):
|
||||
datamodel = SQLAInterface(Database)
|
||||
|
||||
include_route_methods = {
|
||||
"get_list",
|
||||
include_route_methods = RouteMethod.REST_MODEL_VIEW_CRUD_SET | {
|
||||
"table_metadata",
|
||||
"select_star",
|
||||
"schemas",
|
||||
}
|
||||
class_permission_name = "DatabaseView"
|
||||
method_permission_name = {
|
||||
"get_list": "list",
|
||||
"table_metadata": "list",
|
||||
"select_star": "list",
|
||||
"schemas": "list",
|
||||
}
|
||||
resource_name = "database"
|
||||
allow_browser_login = True
|
||||
base_filters = [["id", DatabaseFilter, lambda: []]]
|
||||
show_columns = [
|
||||
"id",
|
||||
"database_name",
|
||||
"cache_timeout",
|
||||
"expose_in_sqllab",
|
||||
"allow_run_async",
|
||||
"allow_csv_upload",
|
||||
"allow_ctas",
|
||||
"allow_cvas",
|
||||
"allow_dml",
|
||||
"force_ctas_schema",
|
||||
"allow_multi_schema_metadata_fetch",
|
||||
"impersonate_user",
|
||||
"encrypted_extra",
|
||||
"extra",
|
||||
"server_cert",
|
||||
"sqlalchemy_uri",
|
||||
]
|
||||
list_columns = [
|
||||
"id",
|
||||
"database_name",
|
||||
|
@ -152,10 +104,30 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
|||
"backend",
|
||||
"function_names",
|
||||
]
|
||||
add_columns = [
|
||||
"database_name",
|
||||
"sqlalchemy_uri",
|
||||
"cache_timeout",
|
||||
"expose_in_sqllab",
|
||||
"allow_run_async",
|
||||
"allow_csv_upload",
|
||||
"allow_ctas",
|
||||
"allow_cvas",
|
||||
"allow_dml",
|
||||
"force_ctas_schema",
|
||||
"impersonate_user",
|
||||
"allow_multi_schema_metadata_fetch",
|
||||
"extra",
|
||||
"encrypted_extra",
|
||||
"server_cert",
|
||||
]
|
||||
edit_columns = add_columns
|
||||
|
||||
list_select_columns = list_columns + ["extra", "sqlalchemy_uri", "password"]
|
||||
# Removes the local limit for the page size
|
||||
max_page_size = -1
|
||||
validators_columns = {"sqlalchemy_uri": sqlalchemy_uri_validator}
|
||||
add_model_schema = DatabasePostSchema()
|
||||
edit_model_schema = DatabasePutSchema()
|
||||
|
||||
apispec_parameter_schemas = {
|
||||
"database_schemas_query_schema": database_schemas_query_schema,
|
||||
|
@ -167,6 +139,186 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
|||
SchemasResponseSchema,
|
||||
)
|
||||
|
||||
@expose("/", methods=["POST"])
|
||||
@protect()
|
||||
@safe
|
||||
@statsd_metrics
|
||||
def post(self) -> Response:
|
||||
"""Creates a new Database
|
||||
---
|
||||
post:
|
||||
description: >-
|
||||
Create a new Database.
|
||||
requestBody:
|
||||
description: Database schema
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/{{self.__class__.__name__}}.post'
|
||||
responses:
|
||||
201:
|
||||
description: Database added
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: number
|
||||
result:
|
||||
$ref: '#/components/schemas/{{self.__class__.__name__}}.post'
|
||||
302:
|
||||
description: Redirects to the current digest
|
||||
400:
|
||||
$ref: '#/components/responses/400'
|
||||
401:
|
||||
$ref: '#/components/responses/401'
|
||||
404:
|
||||
$ref: '#/components/responses/404'
|
||||
500:
|
||||
$ref: '#/components/responses/500'
|
||||
"""
|
||||
if not request.is_json:
|
||||
return self.response_400(message="Request is not JSON")
|
||||
try:
|
||||
item = self.add_model_schema.load(request.json)
|
||||
# This validates custom Schema with custom validations
|
||||
except ValidationError as error:
|
||||
return self.response_400(message=error.messages)
|
||||
try:
|
||||
new_model = CreateDatabaseCommand(g.user, item).run()
|
||||
# Return censored version for sqlalchemy URI
|
||||
item["sqlalchemy_uri"] = new_model.sqlalchemy_uri
|
||||
return self.response(201, id=new_model.id, result=item)
|
||||
except DatabaseInvalidError as ex:
|
||||
return self.response_422(message=ex.normalized_messages())
|
||||
except DatabaseCreateFailedError as ex:
|
||||
logger.error(
|
||||
"Error creating model %s: %s", self.__class__.__name__, str(ex)
|
||||
)
|
||||
return self.response_422(message=str(ex))
|
||||
|
||||
@expose("/<int:pk>", methods=["PUT"])
|
||||
@protect()
|
||||
@safe
|
||||
@statsd_metrics
|
||||
def put( # pylint: disable=too-many-return-statements, arguments-differ
|
||||
self, pk: int
|
||||
) -> Response:
|
||||
"""Changes a Database
|
||||
---
|
||||
put:
|
||||
description: >-
|
||||
Changes a Database.
|
||||
parameters:
|
||||
- in: path
|
||||
schema:
|
||||
type: integer
|
||||
name: pk
|
||||
requestBody:
|
||||
description: Database schema
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/{{self.__class__.__name__}}.put'
|
||||
responses:
|
||||
200:
|
||||
description: Database changed
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: number
|
||||
result:
|
||||
$ref: '#/components/schemas/{{self.__class__.__name__}}.put'
|
||||
400:
|
||||
$ref: '#/components/responses/400'
|
||||
401:
|
||||
$ref: '#/components/responses/401'
|
||||
403:
|
||||
$ref: '#/components/responses/403'
|
||||
404:
|
||||
$ref: '#/components/responses/404'
|
||||
422:
|
||||
$ref: '#/components/responses/422'
|
||||
500:
|
||||
$ref: '#/components/responses/500'
|
||||
"""
|
||||
if not request.is_json:
|
||||
return self.response_400(message="Request is not JSON")
|
||||
try:
|
||||
item = self.edit_model_schema.load(request.json)
|
||||
# This validates custom Schema with custom validations
|
||||
except ValidationError as error:
|
||||
return self.response_400(message=error.messages)
|
||||
try:
|
||||
changed_model = UpdateDatabaseCommand(g.user, pk, item).run()
|
||||
# Return censored version for sqlalchemy URI
|
||||
item["sqlalchemy_uri"] = changed_model.sqlalchemy_uri
|
||||
return self.response(200, id=changed_model.id, result=item)
|
||||
except DatabaseNotFoundError:
|
||||
return self.response_404()
|
||||
except DatabaseInvalidError as ex:
|
||||
return self.response_422(message=ex.normalized_messages())
|
||||
except DatabaseUpdateFailedError as ex:
|
||||
logger.error(
|
||||
"Error updating model %s: %s", self.__class__.__name__, str(ex)
|
||||
)
|
||||
return self.response_422(message=str(ex))
|
||||
|
||||
@expose("/<int:pk>", methods=["DELETE"])
|
||||
@protect()
|
||||
@safe
|
||||
@statsd_metrics
|
||||
def delete(self, pk: int) -> Response: # pylint: disable=arguments-differ
|
||||
"""Deletes a Database
|
||||
---
|
||||
delete:
|
||||
description: >-
|
||||
Deletes a Database.
|
||||
parameters:
|
||||
- in: path
|
||||
schema:
|
||||
type: integer
|
||||
name: pk
|
||||
responses:
|
||||
200:
|
||||
description: Database deleted
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
message:
|
||||
type: string
|
||||
401:
|
||||
$ref: '#/components/responses/401'
|
||||
403:
|
||||
$ref: '#/components/responses/403'
|
||||
404:
|
||||
$ref: '#/components/responses/404'
|
||||
422:
|
||||
$ref: '#/components/responses/422'
|
||||
500:
|
||||
$ref: '#/components/responses/500'
|
||||
"""
|
||||
try:
|
||||
DeleteDatabaseCommand(g.user, pk).run()
|
||||
return self.response(200, message="OK")
|
||||
except DatabaseNotFoundError:
|
||||
return self.response_404()
|
||||
except DatabaseDeleteDatasetsExistFailedError as ex:
|
||||
return self.response_422(message=str(ex))
|
||||
except DatabaseDeleteFailedError as ex:
|
||||
logger.error(
|
||||
"Error deleting model %s: %s", self.__class__.__name__, str(ex)
|
||||
)
|
||||
return self.response_422(message=str(ex))
|
||||
|
||||
@expose("/<int:pk>/schemas/")
|
||||
@protect()
|
||||
@safe
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
|
@ -0,0 +1,84 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
from marshmallow import ValidationError
|
||||
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.dao.exceptions import DAOCreateFailedError
|
||||
from superset.databases.commands.exceptions import (
|
||||
DatabaseConnectionFailedError,
|
||||
DatabaseCreateFailedError,
|
||||
DatabaseExistsValidationError,
|
||||
DatabaseInvalidError,
|
||||
DatabaseRequiredFieldValidationError,
|
||||
)
|
||||
from superset.databases.dao import DatabaseDAO
|
||||
from superset.extensions import db, security_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CreateDatabaseCommand(BaseCommand):
|
||||
def __init__(self, user: User, data: Dict[str, Any]):
|
||||
self._actor = user
|
||||
self._properties = data.copy()
|
||||
|
||||
def run(self) -> Model:
|
||||
self.validate()
|
||||
try:
|
||||
database = DatabaseDAO.create(self._properties, commit=False)
|
||||
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
|
||||
# adding a new database we always want to force refresh schema list
|
||||
# TODO Improve this simplistic implementation for catching DB conn fails
|
||||
try:
|
||||
schemas = database.get_all_schema_names()
|
||||
except Exception:
|
||||
db.session.rollback()
|
||||
raise DatabaseConnectionFailedError()
|
||||
for schema in schemas:
|
||||
security_manager.add_permission_view_menu(
|
||||
"schema_access", security_manager.get_schema_perm(database, schema)
|
||||
)
|
||||
security_manager.add_permission_view_menu("database_access", database.perm)
|
||||
db.session.commit()
|
||||
except DAOCreateFailedError as ex:
|
||||
logger.exception(ex.exception)
|
||||
raise DatabaseCreateFailedError()
|
||||
return database
|
||||
|
||||
def validate(self) -> None:
|
||||
exceptions: List[ValidationError] = list()
|
||||
sqlalchemy_uri: Optional[str] = self._properties.get("sqlalchemy_uri")
|
||||
database_name: Optional[str] = self._properties.get("database_name")
|
||||
|
||||
if not sqlalchemy_uri:
|
||||
exceptions.append(DatabaseRequiredFieldValidationError("sqlalchemy_uri"))
|
||||
if not database_name:
|
||||
exceptions.append(DatabaseRequiredFieldValidationError("database_name"))
|
||||
else:
|
||||
# Check database_name uniqueness
|
||||
if not DatabaseDAO.validate_uniqueness(database_name):
|
||||
exceptions.append(DatabaseExistsValidationError())
|
||||
|
||||
if exceptions:
|
||||
exception = DatabaseInvalidError()
|
||||
exception.add_list(exceptions)
|
||||
raise exception
|
|
@ -0,0 +1,58 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.dao.exceptions import DAODeleteFailedError
|
||||
from superset.databases.commands.exceptions import (
|
||||
DatabaseDeleteDatasetsExistFailedError,
|
||||
DatabaseDeleteFailedError,
|
||||
DatabaseNotFoundError,
|
||||
)
|
||||
from superset.databases.dao import DatabaseDAO
|
||||
from superset.models.core import Database
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DeleteDatabaseCommand(BaseCommand):
|
||||
def __init__(self, user: User, model_id: int):
|
||||
self._actor = user
|
||||
self._model_id = model_id
|
||||
self._model: Optional[Database] = None
|
||||
|
||||
def run(self) -> Model:
|
||||
self.validate()
|
||||
try:
|
||||
database = DatabaseDAO.delete(self._model)
|
||||
except DAODeleteFailedError as ex:
|
||||
logger.exception(ex.exception)
|
||||
raise DatabaseDeleteFailedError()
|
||||
return database
|
||||
|
||||
def validate(self) -> None:
|
||||
# Validate/populate model exists
|
||||
self._model = DatabaseDAO.find_by_id(self._model_id)
|
||||
if not self._model:
|
||||
raise DatabaseNotFoundError()
|
||||
# Check if there are datasets for this database
|
||||
if self._model.tables:
|
||||
raise DatabaseDeleteDatasetsExistFailedError()
|
|
@ -0,0 +1,111 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from flask_babel import lazy_gettext as _
|
||||
from marshmallow.validate import ValidationError
|
||||
|
||||
from superset.commands.exceptions import (
|
||||
CommandException,
|
||||
CommandInvalidError,
|
||||
CreateFailedError,
|
||||
DeleteFailedError,
|
||||
UpdateFailedError,
|
||||
)
|
||||
|
||||
|
||||
class DatabaseInvalidError(CommandInvalidError):
|
||||
message = _("Dashboard parameters are invalid.")
|
||||
|
||||
|
||||
class DatabaseExistsValidationError(ValidationError):
|
||||
"""
|
||||
Marshmallow validation error for dataset already exists
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
_("A database with the same name already exists"),
|
||||
field_name="database_name",
|
||||
)
|
||||
|
||||
|
||||
class DatabaseRequiredFieldValidationError(ValidationError):
|
||||
def __init__(self, field_name: str) -> None:
|
||||
super().__init__(
|
||||
[_("Field is required")], field_name=field_name,
|
||||
)
|
||||
|
||||
|
||||
class DatabaseExtraJSONValidationError(ValidationError):
|
||||
"""
|
||||
Marshmallow validation error for database encrypted extra must be a valid JSON
|
||||
"""
|
||||
|
||||
def __init__(self, json_error: str = "") -> None:
|
||||
super().__init__(
|
||||
[
|
||||
_(
|
||||
"Field cannot be decoded by JSON. %{json_error}s",
|
||||
json_error=json_error,
|
||||
)
|
||||
],
|
||||
field_name="extra",
|
||||
)
|
||||
|
||||
|
||||
class DatabaseExtraValidationError(ValidationError):
|
||||
"""
|
||||
Marshmallow validation error for database encrypted extra must be a valid JSON
|
||||
"""
|
||||
|
||||
def __init__(self, key: str = "") -> None:
|
||||
super().__init__(
|
||||
[
|
||||
_(
|
||||
"The metadata_params in Extra field "
|
||||
"is not configured correctly. The key "
|
||||
"%{key}s is invalid.",
|
||||
key=key,
|
||||
)
|
||||
],
|
||||
field_name="extra",
|
||||
)
|
||||
|
||||
|
||||
class DatabaseNotFoundError(CommandException):
|
||||
message = _("Database not found.")
|
||||
|
||||
|
||||
class DatabaseCreateFailedError(CreateFailedError):
|
||||
message = _("Database could not be created.")
|
||||
|
||||
|
||||
class DatabaseUpdateFailedError(UpdateFailedError):
|
||||
message = _("Database could not be updated.")
|
||||
|
||||
|
||||
class DatabaseConnectionFailedError( # pylint: disable=too-many-ancestors
|
||||
DatabaseCreateFailedError, DatabaseUpdateFailedError,
|
||||
):
|
||||
message = _("Could not connect to database.")
|
||||
|
||||
|
||||
class DatabaseDeleteDatasetsExistFailedError(DeleteFailedError):
|
||||
message = _("Cannot delete a database that has tables attached")
|
||||
|
||||
|
||||
class DatabaseDeleteFailedError(DeleteFailedError):
|
||||
message = _("Database could not be deleted.")
|
|
@ -0,0 +1,87 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
from marshmallow import ValidationError
|
||||
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.dao.exceptions import DAOUpdateFailedError
|
||||
from superset.databases.commands.exceptions import (
|
||||
DatabaseConnectionFailedError,
|
||||
DatabaseExistsValidationError,
|
||||
DatabaseInvalidError,
|
||||
DatabaseNotFoundError,
|
||||
DatabaseUpdateFailedError,
|
||||
)
|
||||
from superset.databases.dao import DatabaseDAO
|
||||
from superset.extensions import db, security_manager
|
||||
from superset.models.core import Database
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UpdateDatabaseCommand(BaseCommand):
|
||||
def __init__(self, user: User, model_id: int, data: Dict[str, Any]):
|
||||
self._actor = user
|
||||
self._properties = data.copy()
|
||||
self._model_id = model_id
|
||||
self._model: Optional[Database] = None
|
||||
|
||||
def run(self) -> Model:
|
||||
self.validate()
|
||||
try:
|
||||
database = DatabaseDAO.update(self._model, self._properties, commit=False)
|
||||
database.set_sqlalchemy_uri(database.sqlalchemy_uri)
|
||||
security_manager.add_permission_view_menu("database_access", database.perm)
|
||||
# adding a new database we always want to force refresh schema list
|
||||
# TODO Improve this simplistic implementation for catching DB conn fails
|
||||
try:
|
||||
schemas = database.get_all_schema_names()
|
||||
except Exception:
|
||||
db.session.rollback()
|
||||
raise DatabaseConnectionFailedError()
|
||||
for schema in schemas:
|
||||
security_manager.add_permission_view_menu(
|
||||
"schema_access", security_manager.get_schema_perm(database, schema)
|
||||
)
|
||||
db.session.commit()
|
||||
|
||||
except DAOUpdateFailedError as ex:
|
||||
logger.exception(ex.exception)
|
||||
raise DatabaseUpdateFailedError()
|
||||
return database
|
||||
|
||||
def validate(self) -> None:
|
||||
exceptions: List[ValidationError] = list()
|
||||
# Validate/populate model exists
|
||||
self._model = DatabaseDAO.find_by_id(self._model_id)
|
||||
if not self._model:
|
||||
raise DatabaseNotFoundError()
|
||||
database_name: Optional[str] = self._properties.get("database_name")
|
||||
if database_name:
|
||||
# Check database_name uniqueness
|
||||
if not DatabaseDAO.validate_update_uniqueness(
|
||||
self._model_id, database_name
|
||||
):
|
||||
exceptions.append(DatabaseExistsValidationError())
|
||||
if exceptions:
|
||||
exception = DatabaseInvalidError()
|
||||
exception.add_list(exceptions)
|
||||
raise exception
|
|
@ -0,0 +1,43 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
|
||||
from superset.dao.base import BaseDAO
|
||||
from superset.databases.filters import DatabaseFilter
|
||||
from superset.extensions import db
|
||||
from superset.models.core import Database
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatabaseDAO(BaseDAO):
|
||||
model_cls = Database
|
||||
base_filter = DatabaseFilter
|
||||
|
||||
@staticmethod
|
||||
def validate_uniqueness(database_name: str) -> bool:
|
||||
database_query = db.session.query(Database).filter(
|
||||
Database.database_name == database_name
|
||||
)
|
||||
return not db.session.query(database_query.exists()).scalar()
|
||||
|
||||
@staticmethod
|
||||
def validate_update_uniqueness(database_id: int, database_name: str) -> bool:
|
||||
database_query = db.session.query(Database).filter(
|
||||
Database.database_name == database_name, Database.id != database_id,
|
||||
)
|
||||
return not db.session.query(database_query.exists()).scalar()
|
|
@ -14,13 +14,266 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import inspect
|
||||
import json
|
||||
|
||||
from flask_babel import lazy_gettext as _
|
||||
from marshmallow import fields, Schema
|
||||
from marshmallow.validate import Length, ValidationError
|
||||
from sqlalchemy import MetaData
|
||||
from sqlalchemy.engine.url import make_url
|
||||
from sqlalchemy.exc import ArgumentError
|
||||
|
||||
from superset import app
|
||||
from superset.exceptions import CertificateException
|
||||
from superset.utils.core import markdown, parse_ssl_cert
|
||||
|
||||
database_schemas_query_schema = {
|
||||
"type": "object",
|
||||
"properties": {"force": {"type": "boolean"}},
|
||||
}
|
||||
|
||||
database_name_description = "A database name to identify this connection."
|
||||
cache_timeout_description = (
|
||||
"Duration (in seconds) of the caching timeout for charts of this database. "
|
||||
"A timeout of 0 indicates that the cache never expires. "
|
||||
"Note this defaults to the global timeout if undefined."
|
||||
)
|
||||
expose_in_sqllab_description = "Expose this database to SQLLab"
|
||||
allow_run_async_description = (
|
||||
"Operate the database in asynchronous mode, meaning "
|
||||
"that the queries are executed on remote workers as opposed "
|
||||
"to on the web server itself. "
|
||||
"This assumes that you have a Celery worker setup as well "
|
||||
"as a results backend. Refer to the installation docs "
|
||||
"for more information."
|
||||
)
|
||||
allow_csv_upload_description = (
|
||||
"Allow to upload CSV file data into this database"
|
||||
"If selected, please set the schemas allowed for csv upload in Extra."
|
||||
)
|
||||
allow_ctas_description = "Allow CREATE TABLE AS option in SQL Lab"
|
||||
allow_cvas_description = "Allow CREATE VIEW AS option in SQL Lab"
|
||||
allow_dml_description = (
|
||||
"Allow users to run non-SELECT statements "
|
||||
"(UPDATE, DELETE, CREATE, ...) "
|
||||
"in SQL Lab"
|
||||
)
|
||||
allow_multi_schema_metadata_fetch_description = (
|
||||
"Allow SQL Lab to fetch a list of all tables and all views across "
|
||||
"all database schemas. For large data warehouse with thousands of "
|
||||
"tables, this can be expensive and put strain on the system."
|
||||
) # pylint: disable=invalid-name
|
||||
impersonate_user_description = (
|
||||
"If Presto, all the queries in SQL Lab are going to be executed as the "
|
||||
"currently logged on user who must have permission to run them.<br/>"
|
||||
"If Hive and hive.server2.enable.doAs is enabled, will run the queries as "
|
||||
"service account, but impersonate the currently logged on user "
|
||||
"via hive.server2.proxy.user property."
|
||||
)
|
||||
force_ctas_schema_description = (
|
||||
"When allowing CREATE TABLE AS option in SQL Lab, "
|
||||
"this option forces the table to be created in this schema"
|
||||
)
|
||||
encrypted_extra_description = markdown(
|
||||
"JSON string containing additional connection configuration.<br/>"
|
||||
"This is used to provide connection information for systems like "
|
||||
"Hive, Presto, and BigQuery, which do not conform to the username:password "
|
||||
"syntax normally used by SQLAlchemy.",
|
||||
True,
|
||||
)
|
||||
extra_description = markdown(
|
||||
"JSON string containing extra configuration elements.<br/>"
|
||||
"1. The ``engine_params`` object gets unpacked into the "
|
||||
"[sqlalchemy.create_engine]"
|
||||
"(https://docs.sqlalchemy.org/en/latest/core/engines.html#"
|
||||
"sqlalchemy.create_engine) call, while the ``metadata_params`` "
|
||||
"gets unpacked into the [sqlalchemy.MetaData]"
|
||||
"(https://docs.sqlalchemy.org/en/rel_1_0/core/metadata.html"
|
||||
"#sqlalchemy.schema.MetaData) call.<br/>"
|
||||
"2. The ``metadata_cache_timeout`` is a cache timeout setting "
|
||||
"in seconds for metadata fetch of this database. Specify it as "
|
||||
'**"metadata_cache_timeout": {"schema_cache_timeout": 600, '
|
||||
'"table_cache_timeout": 600}**. '
|
||||
"If unset, cache will not be enabled for the functionality. "
|
||||
"A timeout of 0 indicates that the cache never expires.<br/>"
|
||||
"3. The ``schemas_allowed_for_csv_upload`` is a comma separated list "
|
||||
"of schemas that CSVs are allowed to upload to. "
|
||||
'Specify it as **"schemas_allowed_for_csv_upload": '
|
||||
'["public", "csv_upload"]**. '
|
||||
"If database flavor does not support schema or any schema is allowed "
|
||||
"to be accessed, just leave the list empty<br/>"
|
||||
"4. the ``version`` field is a string specifying the this db's version. "
|
||||
"This should be used with Presto DBs so that the syntax is correct<br/>"
|
||||
"5. The ``allows_virtual_table_explore`` field is a boolean specifying "
|
||||
"whether or not the Explore button in SQL Lab results is shown.",
|
||||
True,
|
||||
)
|
||||
sqlalchemy_uri_description = markdown(
|
||||
"Refer to the "
|
||||
"[SqlAlchemy docs]"
|
||||
"(https://docs.sqlalchemy.org/en/rel_1_2/core/engines.html#"
|
||||
"database-urls) "
|
||||
"for more information on how to structure your URI.",
|
||||
True,
|
||||
)
|
||||
server_cert_description = markdown(
|
||||
"Optional CA_BUNDLE contents to validate HTTPS requests. Only available "
|
||||
"on certain database engines.",
|
||||
True,
|
||||
)
|
||||
|
||||
|
||||
def sqlalchemy_uri_validator(value: str) -> str:
|
||||
"""
|
||||
Validate if it's a valid SQLAlchemy URI and refuse SQLLite by default
|
||||
"""
|
||||
try:
|
||||
make_url(value.strip())
|
||||
except (ArgumentError, AttributeError):
|
||||
raise ValidationError(
|
||||
[
|
||||
_(
|
||||
"Invalid connection string, a valid string usually follows:"
|
||||
"'DRIVER://USER:PASSWORD@DB-HOST/DATABASE-NAME'"
|
||||
"<p>"
|
||||
"Example:'postgresql://user:password@your-postgres-db/database'"
|
||||
"</p>"
|
||||
)
|
||||
]
|
||||
)
|
||||
if app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] and value:
|
||||
if value.startswith("sqlite"):
|
||||
raise ValidationError(
|
||||
[
|
||||
_(
|
||||
"SQLite database cannot be used as a data source for "
|
||||
"security reasons."
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def server_cert_validator(value: str) -> str:
|
||||
"""
|
||||
Validate the server certificate
|
||||
"""
|
||||
if value:
|
||||
try:
|
||||
parse_ssl_cert(value)
|
||||
except CertificateException:
|
||||
raise ValidationError([_("Invalid certificate")])
|
||||
return value
|
||||
|
||||
|
||||
def encrypted_extra_validator(value: str) -> str:
|
||||
"""
|
||||
Validate that encrypted extra is a valid JSON string
|
||||
"""
|
||||
if value:
|
||||
try:
|
||||
json.loads(value)
|
||||
except json.JSONDecodeError as ex:
|
||||
raise ValidationError(
|
||||
[_("Field cannot be decoded by JSON. %(msg)s", msg=str(ex))]
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
def extra_validator(value: str) -> str:
|
||||
"""
|
||||
Validate that extra is a valid JSON string, and that metadata_params
|
||||
keys are on the call signature for SQLAlchemy Metadata
|
||||
"""
|
||||
if value:
|
||||
try:
|
||||
extra_ = json.loads(value)
|
||||
except json.JSONDecodeError as ex:
|
||||
raise ValidationError(
|
||||
[_("Field cannot be decoded by JSON. %(msg)s", msg=str(ex))]
|
||||
)
|
||||
else:
|
||||
metadata_signature = inspect.signature(MetaData)
|
||||
for key in extra_.get("metadata_params", {}):
|
||||
if key not in metadata_signature.parameters:
|
||||
raise ValidationError(
|
||||
[
|
||||
_(
|
||||
"The metadata_params in Extra field "
|
||||
"is not configured correctly. The key "
|
||||
"%(key)s is invalid.",
|
||||
key=key,
|
||||
)
|
||||
]
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
class DatabasePostSchema(Schema):
|
||||
database_name = fields.String(
|
||||
description=database_name_description, required=True, validate=Length(1, 250),
|
||||
)
|
||||
cache_timeout = fields.Integer(description=cache_timeout_description)
|
||||
expose_in_sqllab = fields.Boolean(description=expose_in_sqllab_description)
|
||||
allow_run_async = fields.Boolean(description=allow_run_async_description)
|
||||
allow_csv_upload = fields.Boolean(description=allow_csv_upload_description)
|
||||
allow_ctas = fields.Boolean(description=allow_ctas_description)
|
||||
allow_cvas = fields.Boolean(description=allow_cvas_description)
|
||||
allow_dml = fields.Boolean(description=allow_dml_description)
|
||||
force_ctas_schema = fields.String(
|
||||
description=force_ctas_schema_description, validate=Length(0, 250)
|
||||
)
|
||||
allow_multi_schema_metadata_fetch = fields.Boolean(
|
||||
description=allow_multi_schema_metadata_fetch_description,
|
||||
)
|
||||
impersonate_user = fields.Boolean(description=impersonate_user_description)
|
||||
encrypted_extra = fields.String(
|
||||
description=encrypted_extra_description, validate=encrypted_extra_validator
|
||||
)
|
||||
extra = fields.String(description=extra_description, validate=extra_validator)
|
||||
server_cert = fields.String(
|
||||
description=server_cert_description, validate=server_cert_validator
|
||||
)
|
||||
sqlalchemy_uri = fields.String(
|
||||
description=sqlalchemy_uri_description,
|
||||
required=True,
|
||||
validate=[Length(1, 1024), sqlalchemy_uri_validator],
|
||||
)
|
||||
|
||||
|
||||
class DatabasePutSchema(Schema):
|
||||
database_name = fields.String(
|
||||
description=database_name_description, allow_none=True, validate=Length(1, 250),
|
||||
)
|
||||
cache_timeout = fields.Integer(description=cache_timeout_description)
|
||||
expose_in_sqllab = fields.Boolean(description=expose_in_sqllab_description)
|
||||
allow_run_async = fields.Boolean(description=allow_run_async_description)
|
||||
allow_csv_upload = fields.Boolean(description=allow_csv_upload_description)
|
||||
allow_ctas = fields.Boolean(description=allow_ctas_description)
|
||||
allow_cvas = fields.Boolean(description=allow_cvas_description)
|
||||
allow_dml = fields.Boolean(description=allow_dml_description)
|
||||
force_ctas_schema = fields.String(
|
||||
description=force_ctas_schema_description, validate=Length(0, 250)
|
||||
)
|
||||
allow_multi_schema_metadata_fetch = fields.Boolean(
|
||||
description=allow_multi_schema_metadata_fetch_description
|
||||
)
|
||||
impersonate_user = fields.Boolean(description=impersonate_user_description)
|
||||
encrypted_extra = fields.String(
|
||||
description=encrypted_extra_description, validate=encrypted_extra_validator
|
||||
)
|
||||
extra = fields.String(description=extra_description, validate=extra_validator)
|
||||
server_cert = fields.String(
|
||||
description=server_cert_description, validate=server_cert_validator
|
||||
)
|
||||
sqlalchemy_uri = fields.String(
|
||||
description=sqlalchemy_uri_description,
|
||||
allow_none=True,
|
||||
validate=[Length(0, 1024), sqlalchemy_uri_validator],
|
||||
)
|
||||
|
||||
|
||||
class TableMetadataOptionsResponseSchema(Schema):
|
||||
deferrable = fields.Bool()
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from superset import app
|
||||
from superset.models.core import Database
|
||||
|
||||
custom_password_store = app.config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"]
|
||||
|
||||
|
||||
def get_foreign_keys_metadata(
|
||||
database: Database, table_name: str, schema_name: Optional[str]
|
||||
) -> List[Dict[str, Any]]:
|
||||
foreign_keys = database.get_foreign_keys(table_name, schema_name)
|
||||
for fk in foreign_keys:
|
||||
fk["column_names"] = fk.pop("constrained_columns")
|
||||
fk["type"] = "fk"
|
||||
return foreign_keys
|
||||
|
||||
|
||||
def get_indexes_metadata(
|
||||
database: Database, table_name: str, schema_name: Optional[str]
|
||||
) -> List[Dict[str, Any]]:
|
||||
indexes = database.get_indexes(table_name, schema_name)
|
||||
for idx in indexes:
|
||||
idx["type"] = "index"
|
||||
return indexes
|
||||
|
||||
|
||||
def get_col_type(col: Dict[Any, Any]) -> str:
|
||||
try:
|
||||
dtype = f"{col['type']}"
|
||||
except Exception: # pylint: disable=broad-except
|
||||
# sqla.types.JSON __str__ has a bug, so using __class__.
|
||||
dtype = col["type"].__class__.__name__
|
||||
return dtype
|
||||
|
||||
|
||||
def get_table_metadata(
|
||||
database: Database, table_name: str, schema_name: Optional[str]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get table metadata information, including type, pk, fks.
|
||||
This function raises SQLAlchemyError when a schema is not found.
|
||||
|
||||
:param database: The database model
|
||||
:param table_name: Table name
|
||||
:param schema_name: schema name
|
||||
:return: Dict table metadata ready for API response
|
||||
"""
|
||||
keys = []
|
||||
columns = database.get_columns(table_name, schema_name)
|
||||
primary_key = database.get_pk_constraint(table_name, schema_name)
|
||||
if primary_key and primary_key.get("constrained_columns"):
|
||||
primary_key["column_names"] = primary_key.pop("constrained_columns")
|
||||
primary_key["type"] = "pk"
|
||||
keys += [primary_key]
|
||||
foreign_keys = get_foreign_keys_metadata(database, table_name, schema_name)
|
||||
indexes = get_indexes_metadata(database, table_name, schema_name)
|
||||
keys += foreign_keys + indexes
|
||||
payload_columns: List[Dict[str, Any]] = []
|
||||
for col in columns:
|
||||
dtype = get_col_type(col)
|
||||
payload_columns.append(
|
||||
{
|
||||
"name": col["name"],
|
||||
"type": dtype.split("(")[0] if "(" in dtype else dtype,
|
||||
"longType": dtype,
|
||||
"keys": [k for k in keys if col["name"] in k["column_names"]],
|
||||
}
|
||||
)
|
||||
return {
|
||||
"name": table_name,
|
||||
"columns": payload_columns,
|
||||
"selectStar": database.select_star(
|
||||
table_name,
|
||||
schema=schema_name,
|
||||
show_cols=True,
|
||||
indent=True,
|
||||
cols=columns,
|
||||
latest_partition=True,
|
||||
),
|
||||
"primaryKey": primary_key,
|
||||
"foreignKeys": foreign_keys,
|
||||
"indexes": keys,
|
||||
}
|
|
@ -25,6 +25,7 @@ from marshmallow import ValidationError
|
|||
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.constants import RouteMethod
|
||||
from superset.databases.filters import DatabaseFilter
|
||||
from superset.datasets.commands.create import CreateDatasetCommand
|
||||
from superset.datasets.commands.delete import DeleteDatasetCommand
|
||||
from superset.datasets.commands.exceptions import (
|
||||
|
@ -51,7 +52,6 @@ from superset.views.base_api import (
|
|||
RelatedFieldFilter,
|
||||
statsd_metrics,
|
||||
)
|
||||
from superset.views.database.filters import DatabaseFilter
|
||||
from superset.views.filters import FilterRelatedOwners
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
|
@ -106,6 +106,9 @@ class BaseSupersetModelRestApi(ModelRestApi):
|
|||
"data": "list",
|
||||
"viz_types": "list",
|
||||
"related_objects": "list",
|
||||
"table_metadata": "list",
|
||||
"select_star": "list",
|
||||
"schemas": "list",
|
||||
}
|
||||
|
||||
order_rel_fields: Dict[str, Tuple[str, str]] = {}
|
||||
|
|
|
@ -68,6 +68,7 @@ from superset.connectors.sqla.models import (
|
|||
TableColumn,
|
||||
)
|
||||
from superset.dashboards.dao import DashboardDAO
|
||||
from superset.databases.filters import DatabaseFilter
|
||||
from superset.exceptions import (
|
||||
CertificateException,
|
||||
DatabaseNotFound,
|
||||
|
@ -109,7 +110,6 @@ from superset.views.base import (
|
|||
json_success,
|
||||
validate_sqlatable,
|
||||
)
|
||||
from superset.views.database.filters import DatabaseFilter
|
||||
from superset.views.utils import (
|
||||
_deserialize_results_payload,
|
||||
apply_display_max_row_limit,
|
||||
|
|
|
@ -21,11 +21,11 @@ from flask_babel import lazy_gettext as _
|
|||
from sqlalchemy import MetaData
|
||||
|
||||
from superset import app, security_manager
|
||||
from superset.databases.filters import DatabaseFilter
|
||||
from superset.exceptions import SupersetException
|
||||
from superset.models.core import Database
|
||||
from superset.security.analytics_db_safety import check_sqlalchemy_uri
|
||||
from superset.utils import core as utils
|
||||
from superset.views.database.filters import DatabaseFilter
|
||||
|
||||
|
||||
class DatabaseMixin:
|
||||
|
@ -240,7 +240,7 @@ class DatabaseMixin:
|
|||
extra = database.get_extra()
|
||||
except Exception as ex:
|
||||
raise Exception(
|
||||
_("Extra field cannot be decoded by JSON. %{msg}s", msg=str(ex))
|
||||
_("Extra field cannot be decoded by JSON. %(msg)s", msg=str(ex))
|
||||
)
|
||||
|
||||
# this will check whether 'metadata_params' is configured correctly
|
||||
|
@ -264,5 +264,5 @@ class DatabaseMixin:
|
|||
database.get_encrypted_extra()
|
||||
except Exception as ex:
|
||||
raise Exception(
|
||||
_("Extra field cannot be decoded by JSON. %{msg}s", msg=str(ex))
|
||||
_("Extra field cannot be decoded by JSON. %(msg)s", msg=str(ex))
|
||||
)
|
||||
|
|
|
@ -36,11 +36,15 @@ def sqlalchemy_uri_validator(
|
|||
make_url(uri.strip())
|
||||
except (ArgumentError, AttributeError):
|
||||
raise exception(
|
||||
_(
|
||||
"Invalid connection string, a valid string usually follows:"
|
||||
"'DRIVER://USER:PASSWORD@DB-HOST/DATABASE-NAME'"
|
||||
"<p>Example:'postgresql://user:password@your-postgres-db/database'</p>"
|
||||
)
|
||||
[
|
||||
_(
|
||||
"Invalid connection string, a valid string usually follows:"
|
||||
"'DRIVER://USER:PASSWORD@DB-HOST/DATABASE-NAME'"
|
||||
"<p>"
|
||||
"Example:'postgresql://user:password@your-postgres-db/database'"
|
||||
"</p>"
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -1,268 +0,0 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# isort:skip_file
|
||||
"""Unit tests for Superset"""
|
||||
import json
|
||||
from unittest import mock
|
||||
|
||||
import prison
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
import tests.test_app
|
||||
from superset import db, security_manager
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.models.core import Database
|
||||
from superset.utils.core import get_example_database, get_main_database
|
||||
|
||||
from .base_tests import SupersetTestCase
|
||||
|
||||
|
||||
class TestDatabaseApi(SupersetTestCase):
|
||||
def test_get_items(self):
|
||||
"""
|
||||
Database API: Test get items
|
||||
"""
|
||||
self.login(username="admin")
|
||||
uri = "api/v1/database/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
expected_columns = [
|
||||
"allow_csv_upload",
|
||||
"allow_ctas",
|
||||
"allow_cvas",
|
||||
"allow_dml",
|
||||
"allow_multi_schema_metadata_fetch",
|
||||
"allow_run_async",
|
||||
"allows_cost_estimate",
|
||||
"allows_subquery",
|
||||
"allows_virtual_table_explore",
|
||||
"backend",
|
||||
"database_name",
|
||||
"explore_database_id",
|
||||
"expose_in_sqllab",
|
||||
"force_ctas_schema",
|
||||
"function_names",
|
||||
"id",
|
||||
]
|
||||
self.assertEqual(response["count"], 2)
|
||||
self.assertEqual(list(response["result"][0].keys()), expected_columns)
|
||||
|
||||
def test_get_items_filter(self):
|
||||
"""
|
||||
Database API: Test get items with filter
|
||||
"""
|
||||
fake_db = (
|
||||
db.session.query(Database).filter_by(database_name="fake_db_100").one()
|
||||
)
|
||||
old_expose_in_sqllab = fake_db.expose_in_sqllab
|
||||
fake_db.expose_in_sqllab = False
|
||||
db.session.commit()
|
||||
self.login(username="admin")
|
||||
arguments = {
|
||||
"keys": ["none"],
|
||||
"filters": [{"col": "expose_in_sqllab", "opr": "eq", "value": True}],
|
||||
"order_columns": "database_name",
|
||||
"order_direction": "asc",
|
||||
"page": 0,
|
||||
"page_size": -1,
|
||||
}
|
||||
uri = f"api/v1/database/?q={prison.dumps(arguments)}"
|
||||
rv = self.client.get(uri)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
self.assertEqual(response["count"], 1)
|
||||
|
||||
fake_db = (
|
||||
db.session.query(Database).filter_by(database_name="fake_db_100").one()
|
||||
)
|
||||
fake_db.expose_in_sqllab = old_expose_in_sqllab
|
||||
db.session.commit()
|
||||
|
||||
def test_get_items_not_allowed(self):
|
||||
"""
|
||||
Database API: Test get items not allowed
|
||||
"""
|
||||
self.login(username="gamma")
|
||||
uri = f"api/v1/database/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(response["count"], 0)
|
||||
|
||||
def test_get_table_metadata(self):
|
||||
"""
|
||||
Database API: Test get table metadata info
|
||||
"""
|
||||
example_db = get_example_database()
|
||||
self.login(username="admin")
|
||||
uri = f"api/v1/database/{example_db.id}/table/birth_names/null/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(response["name"], "birth_names")
|
||||
self.assertTrue(len(response["columns"]) > 5)
|
||||
self.assertTrue(response.get("selectStar").startswith("SELECT"))
|
||||
|
||||
def test_get_invalid_database_table_metadata(self):
|
||||
"""
|
||||
Database API: Test get invalid database from table metadata
|
||||
"""
|
||||
database_id = 1000
|
||||
self.login(username="admin")
|
||||
uri = f"api/v1/database/{database_id}/table/some_table/some_schema/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
|
||||
uri = f"api/v1/database/some_database/table/some_table/some_schema/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
|
||||
def test_get_invalid_table_table_metadata(self):
|
||||
"""
|
||||
Database API: Test get invalid table from table metadata
|
||||
"""
|
||||
example_db = get_example_database()
|
||||
uri = f"api/v1/database/{example_db.id}/wrong_table/null/"
|
||||
self.login(username="admin")
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
|
||||
def test_get_table_metadata_no_db_permission(self):
|
||||
"""
|
||||
Database API: Test get table metadata from not permitted db
|
||||
"""
|
||||
self.login(username="gamma")
|
||||
example_db = get_example_database()
|
||||
uri = f"api/v1/database/{example_db.id}/birth_names/null/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
|
||||
def test_get_select_star(self):
|
||||
"""
|
||||
Database API: Test get select star
|
||||
"""
|
||||
self.login(username="admin")
|
||||
example_db = get_example_database()
|
||||
uri = f"api/v1/database/{example_db.id}/select_star/birth_names/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertIn("gender", response["result"])
|
||||
|
||||
def test_get_select_star_not_allowed(self):
|
||||
"""
|
||||
Database API: Test get select star not allowed
|
||||
"""
|
||||
self.login(username="gamma")
|
||||
example_db = get_example_database()
|
||||
uri = f"api/v1/database/{example_db.id}/select_star/birth_names/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
|
||||
def test_get_select_star_datasource_access(self):
|
||||
"""
|
||||
Database API: Test get select star with datasource access
|
||||
"""
|
||||
session = db.session
|
||||
table = SqlaTable(
|
||||
schema="main", table_name="ab_permission", database=get_main_database()
|
||||
)
|
||||
session.add(table)
|
||||
session.commit()
|
||||
|
||||
tmp_table_perm = security_manager.find_permission_view_menu(
|
||||
"datasource_access", table.get_perm()
|
||||
)
|
||||
gamma_role = security_manager.find_role("Gamma")
|
||||
security_manager.add_permission_role(gamma_role, tmp_table_perm)
|
||||
|
||||
self.login(username="gamma")
|
||||
main_db = get_main_database()
|
||||
uri = f"api/v1/database/{main_db.id}/select_star/ab_permission/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
|
||||
# rollback changes
|
||||
security_manager.del_permission_role(gamma_role, tmp_table_perm)
|
||||
db.session.delete(table)
|
||||
db.session.delete(main_db)
|
||||
db.session.commit()
|
||||
|
||||
def test_get_select_star_not_found_database(self):
|
||||
"""
|
||||
Database API: Test get select star not found database
|
||||
"""
|
||||
self.login(username="admin")
|
||||
max_id = db.session.query(func.max(Database.id)).scalar()
|
||||
uri = f"api/v1/database/{max_id + 1}/select_star/birth_names/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
|
||||
def test_get_select_star_not_found_table(self):
|
||||
"""
|
||||
Database API: Test get select star not found database
|
||||
"""
|
||||
self.login(username="admin")
|
||||
example_db = get_example_database()
|
||||
# sqllite will not raise a NoSuchTableError
|
||||
if example_db.backend == "sqlite":
|
||||
return
|
||||
uri = f"api/v1/database/{example_db.id}/select_star/table_does_not_exist/"
|
||||
rv = self.client.get(uri)
|
||||
# TODO(bkyryliuk): investigate why presto returns 500
|
||||
self.assertEqual(rv.status_code, 404 if example_db.backend != "presto" else 500)
|
||||
|
||||
def test_database_schemas(self):
|
||||
"""
|
||||
Database API: Test database schemas
|
||||
"""
|
||||
self.login("admin")
|
||||
database = db.session.query(Database).first()
|
||||
schemas = database.get_all_schema_names()
|
||||
|
||||
rv = self.client.get(f"api/v1/database/{database.id}/schemas/")
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(schemas, response["result"])
|
||||
|
||||
rv = self.client.get(
|
||||
f"api/v1/database/{database.id}/schemas/?q={prison.dumps({'force': True})}"
|
||||
)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(schemas, response["result"])
|
||||
|
||||
def test_database_schemas_not_found(self):
|
||||
"""
|
||||
Database API: Test database schemas not found
|
||||
"""
|
||||
self.logout()
|
||||
self.login(username="gamma")
|
||||
example_db = get_example_database()
|
||||
uri = f"api/v1/database/{example_db.id}/schemas/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
|
||||
def test_database_schemas_invalid_query(self):
|
||||
"""
|
||||
Database API: Test database schemas with invalid query
|
||||
"""
|
||||
self.login("admin")
|
||||
database = db.session.query(Database).first()
|
||||
rv = self.client.get(
|
||||
f"api/v1/database/{database.id}/schemas/?q={prison.dumps({'force': 'nop'})}"
|
||||
)
|
||||
self.assertEqual(rv.status_code, 400)
|
|
@ -0,0 +1,16 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
|
@ -0,0 +1,650 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# isort:skip_file
|
||||
"""Unit tests for Superset"""
|
||||
import json
|
||||
|
||||
import prison
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
import tests.test_app
|
||||
from superset import db, security_manager
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.models.core import Database
|
||||
from superset.utils.core import get_example_database, get_main_database
|
||||
from tests.base_tests import SupersetTestCase
|
||||
from tests.fixtures.certificates import ssl_certificate
|
||||
|
||||
|
||||
class TestDatabaseApi(SupersetTestCase):
|
||||
def insert_database(
|
||||
self,
|
||||
database_name: str,
|
||||
sqlalchemy_uri: str,
|
||||
extra: str = "",
|
||||
encrypted_extra: str = "",
|
||||
server_cert: str = "",
|
||||
expose_in_sqllab: bool = False,
|
||||
) -> Database:
|
||||
database = Database(
|
||||
database_name=database_name,
|
||||
sqlalchemy_uri=sqlalchemy_uri,
|
||||
extra=extra,
|
||||
encrypted_extra=encrypted_extra,
|
||||
server_cert=server_cert,
|
||||
expose_in_sqllab=expose_in_sqllab,
|
||||
)
|
||||
db.session.add(database)
|
||||
db.session.commit()
|
||||
return database
|
||||
|
||||
def test_get_items(self):
|
||||
"""
|
||||
Database API: Test get items
|
||||
"""
|
||||
self.login(username="admin")
|
||||
uri = "api/v1/database/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
expected_columns = [
|
||||
"allow_csv_upload",
|
||||
"allow_ctas",
|
||||
"allow_cvas",
|
||||
"allow_dml",
|
||||
"allow_multi_schema_metadata_fetch",
|
||||
"allow_run_async",
|
||||
"allows_cost_estimate",
|
||||
"allows_subquery",
|
||||
"allows_virtual_table_explore",
|
||||
"backend",
|
||||
"database_name",
|
||||
"explore_database_id",
|
||||
"expose_in_sqllab",
|
||||
"force_ctas_schema",
|
||||
"function_names",
|
||||
"id",
|
||||
]
|
||||
self.assertEqual(response["count"], 2)
|
||||
self.assertEqual(list(response["result"][0].keys()), expected_columns)
|
||||
|
||||
def test_get_items_filter(self):
|
||||
"""
|
||||
Database API: Test get items with filter
|
||||
"""
|
||||
example_db = get_example_database()
|
||||
test_database = self.insert_database(
|
||||
"test-database", example_db.sqlalchemy_uri_decrypted, expose_in_sqllab=True
|
||||
)
|
||||
dbs = db.session.query(Database).filter_by(expose_in_sqllab=True).all()
|
||||
|
||||
self.login(username="admin")
|
||||
arguments = {
|
||||
"keys": ["none"],
|
||||
"filters": [{"col": "expose_in_sqllab", "opr": "eq", "value": True}],
|
||||
"order_columns": "database_name",
|
||||
"order_direction": "asc",
|
||||
"page": 0,
|
||||
"page_size": -1,
|
||||
}
|
||||
uri = f"api/v1/database/?q={prison.dumps(arguments)}"
|
||||
rv = self.client.get(uri)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
self.assertEqual(response["count"], len(dbs))
|
||||
|
||||
# Cleanup
|
||||
db.session.delete(test_database)
|
||||
db.session.commit()
|
||||
|
||||
def test_get_items_not_allowed(self):
|
||||
"""
|
||||
Database API: Test get items not allowed
|
||||
"""
|
||||
self.login(username="gamma")
|
||||
uri = f"api/v1/database/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(response["count"], 0)
|
||||
|
||||
def test_create_database(self):
|
||||
"""
|
||||
Database API: Test create
|
||||
"""
|
||||
extra = {
|
||||
"metadata_params": {},
|
||||
"engine_params": {},
|
||||
"metadata_cache_timeout": {},
|
||||
"schemas_allowed_for_csv_upload": [],
|
||||
}
|
||||
|
||||
self.login(username="admin")
|
||||
example_db = get_example_database()
|
||||
if example_db.backend == "sqlite":
|
||||
return
|
||||
database_data = {
|
||||
"database_name": "test-database",
|
||||
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||
"server_cert": ssl_certificate,
|
||||
"extra": json.dumps(extra),
|
||||
}
|
||||
|
||||
uri = "api/v1/database/"
|
||||
rv = self.client.post(uri, json=database_data)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 201)
|
||||
# Cleanup
|
||||
model = db.session.query(Database).get(response.get("id"))
|
||||
db.session.delete(model)
|
||||
db.session.commit()
|
||||
|
||||
def test_create_database_server_cert_validate(self):
|
||||
"""
|
||||
Database API: Test create server cert validation
|
||||
"""
|
||||
example_db = get_example_database()
|
||||
if example_db.backend == "sqlite":
|
||||
return
|
||||
|
||||
self.login(username="admin")
|
||||
database_data = {
|
||||
"database_name": "test-database",
|
||||
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||
"server_cert": "INVALID CERT",
|
||||
}
|
||||
|
||||
uri = "api/v1/database/"
|
||||
rv = self.client.post(uri, json=database_data)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
expected_response = {"message": {"server_cert": ["Invalid certificate"]}}
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
self.assertEqual(response, expected_response)
|
||||
|
||||
def test_create_database_json_validate(self):
|
||||
"""
|
||||
Database API: Test create encrypted extra and extra validation
|
||||
"""
|
||||
example_db = get_example_database()
|
||||
if example_db.backend == "sqlite":
|
||||
return
|
||||
|
||||
self.login(username="admin")
|
||||
database_data = {
|
||||
"database_name": "test-database",
|
||||
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||
"encrypted_extra": '{"A": "a", "B", "C"}',
|
||||
"extra": '["A": "a", "B", "C"]',
|
||||
}
|
||||
|
||||
uri = "api/v1/database/"
|
||||
rv = self.client.post(uri, json=database_data)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
expected_response = {
|
||||
"message": {
|
||||
"encrypted_extra": [
|
||||
"Field cannot be decoded by JSON. Expecting ':' "
|
||||
"delimiter: line 1 column 15 (char 14)"
|
||||
],
|
||||
"extra": [
|
||||
"Field cannot be decoded by JSON. Expecting ','"
|
||||
" delimiter: line 1 column 5 (char 4)"
|
||||
],
|
||||
}
|
||||
}
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
self.assertEqual(response, expected_response)
|
||||
|
||||
def test_create_database_extra_metadata_validate(self):
|
||||
"""
|
||||
Database API: Test create extra metadata_params validation
|
||||
"""
|
||||
example_db = get_example_database()
|
||||
if example_db.backend == "sqlite":
|
||||
return
|
||||
|
||||
extra = {
|
||||
"metadata_params": {"wrong_param": "some_value"},
|
||||
"engine_params": {},
|
||||
"metadata_cache_timeout": {},
|
||||
"schemas_allowed_for_csv_upload": [],
|
||||
}
|
||||
self.login(username="admin")
|
||||
database_data = {
|
||||
"database_name": "test-database",
|
||||
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||
"extra": json.dumps(extra),
|
||||
}
|
||||
|
||||
uri = "api/v1/database/"
|
||||
rv = self.client.post(uri, json=database_data)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
expected_response = {
|
||||
"message": {
|
||||
"extra": [
|
||||
"The metadata_params in Extra field is not configured correctly."
|
||||
" The key wrong_param is invalid."
|
||||
]
|
||||
}
|
||||
}
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
self.assertEqual(response, expected_response)
|
||||
|
||||
def test_create_database_unique_validate(self):
|
||||
"""
|
||||
Database API: Test create database_name already exists
|
||||
"""
|
||||
example_db = get_example_database()
|
||||
if example_db.backend == "sqlite":
|
||||
return
|
||||
|
||||
self.login(username="admin")
|
||||
database_data = {
|
||||
"database_name": "examples",
|
||||
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||
}
|
||||
|
||||
uri = "api/v1/database/"
|
||||
rv = self.client.post(uri, json=database_data)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
expected_response = {
|
||||
"message": {"database_name": "A database with the same name already exists"}
|
||||
}
|
||||
self.assertEqual(rv.status_code, 422)
|
||||
self.assertEqual(response, expected_response)
|
||||
|
||||
def test_create_database_uri_validate(self):
|
||||
"""
|
||||
Database API: Test create fail validate sqlalchemy uri
|
||||
"""
|
||||
self.login(username="admin")
|
||||
database_data = {
|
||||
"database_name": "test-database",
|
||||
"sqlalchemy_uri": "wrong_uri",
|
||||
}
|
||||
|
||||
uri = "api/v1/database/"
|
||||
rv = self.client.post(uri, json=database_data)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
expected_response = {
|
||||
"message": {
|
||||
"sqlalchemy_uri": [
|
||||
"Invalid connection string, a valid string usually "
|
||||
"follows:'DRIVER://USER:PASSWORD@DB-HOST/DATABASE-NAME'"
|
||||
"<p>Example:'postgresql://user:password@your-postgres-db/database'"
|
||||
"</p>"
|
||||
]
|
||||
}
|
||||
}
|
||||
self.assertEqual(response, expected_response)
|
||||
|
||||
def test_create_database_fail_sqllite(self):
|
||||
"""
|
||||
Database API: Test create fail with sqllite
|
||||
"""
|
||||
database_data = {
|
||||
"database_name": "test-database",
|
||||
"sqlalchemy_uri": "sqlite:////some.db",
|
||||
}
|
||||
|
||||
uri = "api/v1/database/"
|
||||
self.login(username="admin")
|
||||
response = self.client.post(uri, json=database_data)
|
||||
response_data = json.loads(response.data.decode("utf-8"))
|
||||
expected_response = {
|
||||
"message": {
|
||||
"sqlalchemy_uri": [
|
||||
"SQLite database cannot be used as a data source "
|
||||
"for security reasons."
|
||||
]
|
||||
}
|
||||
}
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(response_data, expected_response)
|
||||
|
||||
def test_create_database_conn_fail(self):
|
||||
"""
|
||||
Database API: Test create fails connection
|
||||
"""
|
||||
example_db = get_example_database()
|
||||
if example_db.backend in ("sqlite", "hive", "presto"):
|
||||
return
|
||||
example_db.password = "wrong_password"
|
||||
database_data = {
|
||||
"database_name": "test-database",
|
||||
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||
}
|
||||
|
||||
uri = "api/v1/database/"
|
||||
self.login(username="admin")
|
||||
response = self.client.post(uri, json=database_data)
|
||||
response_data = json.loads(response.data.decode("utf-8"))
|
||||
expected_response = {"message": "Could not connect to database."}
|
||||
self.assertEqual(response.status_code, 422)
|
||||
self.assertEqual(response_data, expected_response)
|
||||
|
||||
def test_update_database(self):
|
||||
"""
|
||||
Database API: Test update
|
||||
"""
|
||||
example_db = get_example_database()
|
||||
test_database = self.insert_database(
|
||||
"test-database", example_db.sqlalchemy_uri_decrypted
|
||||
)
|
||||
|
||||
self.login(username="admin")
|
||||
database_data = {"database_name": "test-database-updated"}
|
||||
uri = f"api/v1/database/{test_database.id}"
|
||||
rv = self.client.put(uri, json=database_data)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
# Cleanup
|
||||
model = db.session.query(Database).get(test_database.id)
|
||||
db.session.delete(model)
|
||||
db.session.commit()
|
||||
|
||||
def test_update_database_conn_fail(self):
|
||||
"""
|
||||
Database API: Test update fails connection
|
||||
"""
|
||||
example_db = get_example_database()
|
||||
if example_db.backend in ("sqlite", "hive", "presto"):
|
||||
return
|
||||
|
||||
test_database = self.insert_database(
|
||||
"test-database1", example_db.sqlalchemy_uri_decrypted
|
||||
)
|
||||
example_db.password = "wrong_password"
|
||||
database_data = {
|
||||
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||
}
|
||||
|
||||
uri = f"api/v1/database/{test_database.id}"
|
||||
self.login(username="admin")
|
||||
rv = self.client.put(uri, json=database_data)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
expected_response = {"message": "Could not connect to database."}
|
||||
self.assertEqual(rv.status_code, 422)
|
||||
self.assertEqual(response, expected_response)
|
||||
# Cleanup
|
||||
model = db.session.query(Database).get(test_database.id)
|
||||
db.session.delete(model)
|
||||
db.session.commit()
|
||||
|
||||
def test_update_database_uniqueness(self):
|
||||
"""
|
||||
Database API: Test update uniqueness
|
||||
"""
|
||||
example_db = get_example_database()
|
||||
test_database1 = self.insert_database(
|
||||
"test-database1", example_db.sqlalchemy_uri_decrypted
|
||||
)
|
||||
test_database2 = self.insert_database(
|
||||
"test-database2", example_db.sqlalchemy_uri_decrypted
|
||||
)
|
||||
|
||||
self.login(username="admin")
|
||||
database_data = {"database_name": "test-database2"}
|
||||
uri = f"api/v1/database/{test_database1.id}"
|
||||
rv = self.client.put(uri, json=database_data)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
expected_response = {
|
||||
"message": {"database_name": "A database with the same name already exists"}
|
||||
}
|
||||
self.assertEqual(rv.status_code, 422)
|
||||
self.assertEqual(response, expected_response)
|
||||
# Cleanup
|
||||
db.session.delete(test_database1)
|
||||
db.session.delete(test_database2)
|
||||
db.session.commit()
|
||||
|
||||
def test_update_database_invalid(self):
|
||||
"""
|
||||
Database API: Test update invalid request
|
||||
"""
|
||||
self.login(username="admin")
|
||||
database_data = {"database_name": "test-database-updated"}
|
||||
uri = f"api/v1/database/invalid"
|
||||
rv = self.client.put(uri, json=database_data)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
|
||||
def test_update_database_uri_validate(self):
|
||||
"""
|
||||
Database API: Test update sqlalchemy_uri validate
|
||||
"""
|
||||
example_db = get_example_database()
|
||||
test_database = self.insert_database(
|
||||
"test-database", example_db.sqlalchemy_uri_decrypted
|
||||
)
|
||||
|
||||
self.login(username="admin")
|
||||
database_data = {
|
||||
"database_name": "test-database-updated",
|
||||
"sqlalchemy_uri": "wrong_uri",
|
||||
}
|
||||
uri = f"api/v1/database/{test_database.id}"
|
||||
rv = self.client.put(uri, json=database_data)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 400)
|
||||
expected_response = {
|
||||
"message": {
|
||||
"sqlalchemy_uri": [
|
||||
"Invalid connection string, a valid string usually "
|
||||
"follows:'DRIVER://USER:PASSWORD@DB-HOST/DATABASE-NAME'"
|
||||
"<p>Example:'postgresql://user:password@your-postgres-db/database'"
|
||||
"</p>"
|
||||
]
|
||||
}
|
||||
}
|
||||
self.assertEqual(response, expected_response)
|
||||
|
||||
def test_delete_database(self):
|
||||
"""
|
||||
Database API: Test delete
|
||||
"""
|
||||
database_id = self.insert_database("test-database", "test_uri").id
|
||||
self.login(username="admin")
|
||||
uri = f"api/v1/database/{database_id}"
|
||||
rv = self.delete_assert_metric(uri, "delete")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
model = db.session.query(Database).get(database_id)
|
||||
self.assertEqual(model, None)
|
||||
|
||||
def test_delete_database_not_found(self):
|
||||
"""
|
||||
Database API: Test delete not found
|
||||
"""
|
||||
max_id = db.session.query(func.max(Database.id)).scalar()
|
||||
self.login(username="admin")
|
||||
uri = f"api/v1/database/{max_id + 1}"
|
||||
rv = self.delete_assert_metric(uri, "delete")
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
|
||||
def test_delete_database_with_datasets(self):
|
||||
"""
|
||||
Database API: Test delete fails because it has depending datasets
|
||||
"""
|
||||
database_id = (
|
||||
db.session.query(Database).filter_by(database_name="examples").one()
|
||||
).id
|
||||
self.login(username="admin")
|
||||
uri = f"api/v1/database/{database_id}"
|
||||
rv = self.delete_assert_metric(uri, "delete")
|
||||
self.assertEqual(rv.status_code, 422)
|
||||
|
||||
def test_get_table_metadata(self):
|
||||
"""
|
||||
Database API: Test get table metadata info
|
||||
"""
|
||||
example_db = get_example_database()
|
||||
self.login(username="admin")
|
||||
uri = f"api/v1/database/{example_db.id}/table/birth_names/null/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(response["name"], "birth_names")
|
||||
self.assertTrue(len(response["columns"]) > 5)
|
||||
self.assertTrue(response.get("selectStar").startswith("SELECT"))
|
||||
|
||||
def test_get_invalid_database_table_metadata(self):
|
||||
"""
|
||||
Database API: Test get invalid database from table metadata
|
||||
"""
|
||||
database_id = 1000
|
||||
self.login(username="admin")
|
||||
uri = f"api/v1/database/{database_id}/table/some_table/some_schema/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
|
||||
uri = f"api/v1/database/some_database/table/some_table/some_schema/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
|
||||
def test_get_invalid_table_table_metadata(self):
|
||||
"""
|
||||
Database API: Test get invalid table from table metadata
|
||||
"""
|
||||
example_db = get_example_database()
|
||||
uri = f"api/v1/database/{example_db.id}/wrong_table/null/"
|
||||
self.login(username="admin")
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
|
||||
def test_get_table_metadata_no_db_permission(self):
|
||||
"""
|
||||
Database API: Test get table metadata from not permitted db
|
||||
"""
|
||||
self.login(username="gamma")
|
||||
example_db = get_example_database()
|
||||
uri = f"api/v1/database/{example_db.id}/birth_names/null/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
|
||||
def test_get_select_star(self):
|
||||
"""
|
||||
Database API: Test get select star
|
||||
"""
|
||||
self.login(username="admin")
|
||||
example_db = get_example_database()
|
||||
uri = f"api/v1/database/{example_db.id}/select_star/birth_names/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertIn("gender", response["result"])
|
||||
|
||||
def test_get_select_star_not_allowed(self):
|
||||
"""
|
||||
Database API: Test get select star not allowed
|
||||
"""
|
||||
self.login(username="gamma")
|
||||
example_db = get_example_database()
|
||||
uri = f"api/v1/database/{example_db.id}/select_star/birth_names/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
|
||||
def test_get_select_star_datasource_access(self):
|
||||
"""
|
||||
Database API: Test get select star with datasource access
|
||||
"""
|
||||
session = db.session
|
||||
table = SqlaTable(
|
||||
schema="main", table_name="ab_permission", database=get_main_database()
|
||||
)
|
||||
session.add(table)
|
||||
session.commit()
|
||||
|
||||
tmp_table_perm = security_manager.find_permission_view_menu(
|
||||
"datasource_access", table.get_perm()
|
||||
)
|
||||
gamma_role = security_manager.find_role("Gamma")
|
||||
security_manager.add_permission_role(gamma_role, tmp_table_perm)
|
||||
|
||||
self.login(username="gamma")
|
||||
main_db = get_main_database()
|
||||
uri = f"api/v1/database/{main_db.id}/select_star/ab_permission/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
|
||||
# rollback changes
|
||||
security_manager.del_permission_role(gamma_role, tmp_table_perm)
|
||||
db.session.delete(table)
|
||||
db.session.delete(main_db)
|
||||
db.session.commit()
|
||||
|
||||
def test_get_select_star_not_found_database(self):
|
||||
"""
|
||||
Database API: Test get select star not found database
|
||||
"""
|
||||
self.login(username="admin")
|
||||
max_id = db.session.query(func.max(Database.id)).scalar()
|
||||
uri = f"api/v1/database/{max_id + 1}/select_star/birth_names/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
|
||||
def test_get_select_star_not_found_table(self):
|
||||
"""
|
||||
Database API: Test get select star not found database
|
||||
"""
|
||||
self.login(username="admin")
|
||||
example_db = get_example_database()
|
||||
# sqllite will not raise a NoSuchTableError
|
||||
if example_db.backend == "sqlite":
|
||||
return
|
||||
uri = f"api/v1/database/{example_db.id}/select_star/table_does_not_exist/"
|
||||
rv = self.client.get(uri)
|
||||
# TODO(bkyryliuk): investigate why presto returns 500
|
||||
self.assertEqual(rv.status_code, 404 if example_db.backend != "presto" else 500)
|
||||
|
||||
def test_database_schemas(self):
|
||||
"""
|
||||
Database API: Test database schemas
|
||||
"""
|
||||
self.login("admin")
|
||||
database = db.session.query(Database).first()
|
||||
schemas = database.get_all_schema_names()
|
||||
|
||||
rv = self.client.get(f"api/v1/database/{database.id}/schemas/")
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(schemas, response["result"])
|
||||
|
||||
rv = self.client.get(
|
||||
f"api/v1/database/{database.id}/schemas/?q={prison.dumps({'force': True})}"
|
||||
)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(schemas, response["result"])
|
||||
|
||||
def test_database_schemas_not_found(self):
|
||||
"""
|
||||
Database API: Test database schemas not found
|
||||
"""
|
||||
self.logout()
|
||||
self.login(username="gamma")
|
||||
example_db = get_example_database()
|
||||
uri = f"api/v1/database/{example_db.id}/schemas/"
|
||||
rv = self.client.get(uri)
|
||||
self.assertEqual(rv.status_code, 404)
|
||||
|
||||
def test_database_schemas_invalid_query(self):
|
||||
"""
|
||||
Database API: Test database schemas with invalid query
|
||||
"""
|
||||
self.login("admin")
|
||||
database = db.session.query(Database).first()
|
||||
rv = self.client.get(
|
||||
f"api/v1/database/{database.id}/schemas/?q={prison.dumps({'force': 'nop'})}"
|
||||
)
|
||||
self.assertEqual(rv.status_code, 400)
|
Loading…
Reference in New Issue