superset/tests/databases/api_tests.py
Beto Dealmeida 9785667a0d
feat: add UUID column to ImportMixin (#11098)
* Add UUID column to ImportMixin

* Fix default value

* Fix lint

* Fix order of downgrade

* Add logging when downgrade fails

* Migrate position_json to contain UUIDs, and add schedule tables

* Save UUID when adding charts to dashboard

* Fix heads

* Rename migration file

* Fix dashboard serialization

* Fix migration script with Postgres

* Fix unique contraint name

* Handle UUID when exporting dashboard

* Fix Dataset PUT

* Add UUID JSON serialization

* Fix tests

* Simplify logic

* Try binary=True
2020-10-07 09:00:55 -07:00

804 lines
29 KiB
Python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
"""Unit tests for Superset"""
import datetime
import json
import pandas as pd
import prison
import pytest
import random
from sqlalchemy import String, Date, Float
from sqlalchemy.sql import func
from superset import db, security_manager, ConnectorRegistry
from superset.connectors.sqla.models import SqlaTable
from superset.models.core import Database
from superset.utils.core import get_example_database, get_main_database
from tests.base_tests import SupersetTestCase
from tests.dashboard_utils import (
create_table_for_dashboard,
create_dashboard,
)
from tests.fixtures.certificates import ssl_certificate
from tests.fixtures.unicode_dashboard import load_unicode_dashboard_with_position
from tests.test_app import app
class TestDatabaseApi(SupersetTestCase):
def insert_database(
self,
database_name: str,
sqlalchemy_uri: str,
extra: str = "",
encrypted_extra: str = "",
server_cert: str = "",
expose_in_sqllab: bool = False,
) -> Database:
database = Database(
database_name=database_name,
sqlalchemy_uri=sqlalchemy_uri,
extra=extra,
encrypted_extra=encrypted_extra,
server_cert=server_cert,
expose_in_sqllab=expose_in_sqllab,
)
db.session.add(database)
db.session.commit()
return database
def test_get_items(self):
"""
Database API: Test get items
"""
self.login(username="admin")
uri = "api/v1/database/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
response = json.loads(rv.data.decode("utf-8"))
expected_columns = [
"allow_csv_upload",
"allow_ctas",
"allow_cvas",
"allow_dml",
"allow_multi_schema_metadata_fetch",
"allow_run_async",
"allows_cost_estimate",
"allows_subquery",
"allows_virtual_table_explore",
"backend",
"changed_on",
"changed_on_delta_humanized",
"created_by",
"database_name",
"explore_database_id",
"expose_in_sqllab",
"force_ctas_schema",
"function_names",
"id",
]
self.assertEqual(response["count"], 2)
self.assertEqual(list(response["result"][0].keys()), expected_columns)
def test_get_items_filter(self):
"""
Database API: Test get items with filter
"""
example_db = get_example_database()
test_database = self.insert_database(
"test-database", example_db.sqlalchemy_uri_decrypted, expose_in_sqllab=True
)
dbs = db.session.query(Database).filter_by(expose_in_sqllab=True).all()
self.login(username="admin")
arguments = {
"keys": ["none"],
"filters": [{"col": "expose_in_sqllab", "opr": "eq", "value": True}],
"order_columns": "database_name",
"order_direction": "asc",
"page": 0,
"page_size": -1,
}
uri = f"api/v1/database/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
self.assertEqual(response["count"], len(dbs))
# Cleanup
db.session.delete(test_database)
db.session.commit()
def test_get_items_not_allowed(self):
"""
Database API: Test get items not allowed
"""
self.login(username="gamma")
uri = f"api/v1/database/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(response["count"], 0)
def test_create_database(self):
"""
Database API: Test create
"""
extra = {
"metadata_params": {},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_csv_upload": [],
}
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
database_data = {
"database_name": "test-database",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"server_cert": ssl_certificate,
"extra": json.dumps(extra),
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
db.session.commit()
def test_create_database_server_cert_validate(self):
"""
Database API: Test create server cert validation
"""
example_db = get_example_database()
if example_db.backend == "sqlite":
return
self.login(username="admin")
database_data = {
"database_name": "test-database",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"server_cert": "INVALID CERT",
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {"message": {"server_cert": ["Invalid certificate"]}}
self.assertEqual(rv.status_code, 400)
self.assertEqual(response, expected_response)
def test_create_database_json_validate(self):
"""
Database API: Test create encrypted extra and extra validation
"""
example_db = get_example_database()
if example_db.backend == "sqlite":
return
self.login(username="admin")
database_data = {
"database_name": "test-database",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"encrypted_extra": '{"A": "a", "B", "C"}',
"extra": '["A": "a", "B", "C"]',
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"message": {
"encrypted_extra": [
"Field cannot be decoded by JSON. Expecting ':' "
"delimiter: line 1 column 15 (char 14)"
],
"extra": [
"Field cannot be decoded by JSON. Expecting ','"
" delimiter: line 1 column 5 (char 4)"
],
}
}
self.assertEqual(rv.status_code, 400)
self.assertEqual(response, expected_response)
def test_create_database_extra_metadata_validate(self):
"""
Database API: Test create extra metadata_params validation
"""
example_db = get_example_database()
if example_db.backend == "sqlite":
return
extra = {
"metadata_params": {"wrong_param": "some_value"},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_csv_upload": [],
}
self.login(username="admin")
database_data = {
"database_name": "test-database",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"extra": json.dumps(extra),
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"message": {
"extra": [
"The metadata_params in Extra field is not configured correctly."
" The key wrong_param is invalid."
]
}
}
self.assertEqual(rv.status_code, 400)
self.assertEqual(response, expected_response)
def test_create_database_unique_validate(self):
"""
Database API: Test create database_name already exists
"""
example_db = get_example_database()
if example_db.backend == "sqlite":
return
self.login(username="admin")
database_data = {
"database_name": "examples",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"message": {"database_name": "A database with the same name already exists"}
}
self.assertEqual(rv.status_code, 422)
self.assertEqual(response, expected_response)
def test_create_database_uri_validate(self):
"""
Database API: Test create fail validate sqlalchemy uri
"""
self.login(username="admin")
database_data = {
"database_name": "test-database",
"sqlalchemy_uri": "wrong_uri",
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 400)
expected_response = {
"message": {
"sqlalchemy_uri": [
"Invalid connection string, a valid string usually "
"follows:'DRIVER://USER:PASSWORD@DB-HOST/DATABASE-NAME'"
"<p>Example:'postgresql://user:password@your-postgres-db/database'"
"</p>"
]
}
}
self.assertEqual(response, expected_response)
def test_create_database_fail_sqllite(self):
"""
Database API: Test create fail with sqllite
"""
database_data = {
"database_name": "test-database",
"sqlalchemy_uri": "sqlite:////some.db",
}
uri = "api/v1/database/"
self.login(username="admin")
response = self.client.post(uri, json=database_data)
response_data = json.loads(response.data.decode("utf-8"))
expected_response = {
"message": {
"sqlalchemy_uri": [
"SQLite database cannot be used as a data source "
"for security reasons."
]
}
}
self.assertEqual(response_data, expected_response)
self.assertEqual(response.status_code, 400)
def test_create_database_conn_fail(self):
"""
Database API: Test create fails connection
"""
example_db = get_example_database()
if example_db.backend in ("sqlite", "hive", "presto"):
return
example_db.password = "wrong_password"
database_data = {
"database_name": "test-database",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
}
uri = "api/v1/database/"
self.login(username="admin")
response = self.client.post(uri, json=database_data)
response_data = json.loads(response.data.decode("utf-8"))
expected_response = {"message": "Could not connect to database."}
self.assertEqual(response.status_code, 422)
self.assertEqual(response_data, expected_response)
def test_update_database(self):
"""
Database API: Test update
"""
example_db = get_example_database()
test_database = self.insert_database(
"test-database", example_db.sqlalchemy_uri_decrypted
)
self.login(username="admin")
database_data = {"database_name": "test-database-updated"}
uri = f"api/v1/database/{test_database.id}"
rv = self.client.put(uri, json=database_data)
self.assertEqual(rv.status_code, 200)
# Cleanup
model = db.session.query(Database).get(test_database.id)
db.session.delete(model)
db.session.commit()
def test_update_database_conn_fail(self):
"""
Database API: Test update fails connection
"""
example_db = get_example_database()
if example_db.backend in ("sqlite", "hive", "presto"):
return
test_database = self.insert_database(
"test-database1", example_db.sqlalchemy_uri_decrypted
)
example_db.password = "wrong_password"
database_data = {
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
}
uri = f"api/v1/database/{test_database.id}"
self.login(username="admin")
rv = self.client.put(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {"message": "Could not connect to database."}
self.assertEqual(rv.status_code, 422)
self.assertEqual(response, expected_response)
# Cleanup
model = db.session.query(Database).get(test_database.id)
db.session.delete(model)
db.session.commit()
def test_update_database_uniqueness(self):
"""
Database API: Test update uniqueness
"""
example_db = get_example_database()
test_database1 = self.insert_database(
"test-database1", example_db.sqlalchemy_uri_decrypted
)
test_database2 = self.insert_database(
"test-database2", example_db.sqlalchemy_uri_decrypted
)
self.login(username="admin")
database_data = {"database_name": "test-database2"}
uri = f"api/v1/database/{test_database1.id}"
rv = self.client.put(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"message": {"database_name": "A database with the same name already exists"}
}
self.assertEqual(rv.status_code, 422)
self.assertEqual(response, expected_response)
# Cleanup
db.session.delete(test_database1)
db.session.delete(test_database2)
db.session.commit()
def test_update_database_invalid(self):
"""
Database API: Test update invalid request
"""
self.login(username="admin")
database_data = {"database_name": "test-database-updated"}
uri = f"api/v1/database/invalid"
rv = self.client.put(uri, json=database_data)
self.assertEqual(rv.status_code, 404)
def test_update_database_uri_validate(self):
"""
Database API: Test update sqlalchemy_uri validate
"""
example_db = get_example_database()
test_database = self.insert_database(
"test-database", example_db.sqlalchemy_uri_decrypted
)
self.login(username="admin")
database_data = {
"database_name": "test-database-updated",
"sqlalchemy_uri": "wrong_uri",
}
uri = f"api/v1/database/{test_database.id}"
rv = self.client.put(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 400)
expected_response = {
"message": {
"sqlalchemy_uri": [
"Invalid connection string, a valid string usually "
"follows:'DRIVER://USER:PASSWORD@DB-HOST/DATABASE-NAME'"
"<p>Example:'postgresql://user:password@your-postgres-db/database'"
"</p>"
]
}
}
self.assertEqual(response, expected_response)
def test_delete_database(self):
"""
Database API: Test delete
"""
database_id = self.insert_database("test-database", "test_uri").id
self.login(username="admin")
uri = f"api/v1/database/{database_id}"
rv = self.delete_assert_metric(uri, "delete")
self.assertEqual(rv.status_code, 200)
model = db.session.query(Database).get(database_id)
self.assertEqual(model, None)
def test_delete_database_not_found(self):
"""
Database API: Test delete not found
"""
max_id = db.session.query(func.max(Database.id)).scalar()
self.login(username="admin")
uri = f"api/v1/database/{max_id + 1}"
rv = self.delete_assert_metric(uri, "delete")
self.assertEqual(rv.status_code, 404)
def test_delete_database_with_datasets(self):
"""
Database API: Test delete fails because it has depending datasets
"""
database_id = (
db.session.query(Database).filter_by(database_name="examples").one()
).id
self.login(username="admin")
uri = f"api/v1/database/{database_id}"
rv = self.delete_assert_metric(uri, "delete")
self.assertEqual(rv.status_code, 422)
def test_get_table_metadata(self):
"""
Database API: Test get table metadata info
"""
example_db = get_example_database()
self.login(username="admin")
uri = f"api/v1/database/{example_db.id}/table/birth_names/null/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(response["name"], "birth_names")
self.assertIsNone(response["comment"])
self.assertTrue(len(response["columns"]) > 5)
self.assertTrue(response.get("selectStar").startswith("SELECT"))
def test_get_invalid_database_table_metadata(self):
"""
Database API: Test get invalid database from table metadata
"""
database_id = 1000
self.login(username="admin")
uri = f"api/v1/database/{database_id}/table/some_table/some_schema/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
uri = f"api/v1/database/some_database/table/some_table/some_schema/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
def test_get_invalid_table_table_metadata(self):
"""
Database API: Test get invalid table from table metadata
"""
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/wrong_table/null/"
self.login(username="admin")
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
def test_get_table_metadata_no_db_permission(self):
"""
Database API: Test get table metadata from not permitted db
"""
self.login(username="gamma")
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/birth_names/null/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
def test_get_select_star(self):
"""
Database API: Test get select star
"""
self.login(username="admin")
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/select_star/birth_names/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
response = json.loads(rv.data.decode("utf-8"))
self.assertIn("gender", response["result"])
def test_get_select_star_not_allowed(self):
"""
Database API: Test get select star not allowed
"""
self.login(username="gamma")
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/select_star/birth_names/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
def test_get_select_star_datasource_access(self):
"""
Database API: Test get select star with datasource access
"""
session = db.session
table = SqlaTable(
schema="main", table_name="ab_permission", database=get_main_database()
)
session.add(table)
session.commit()
tmp_table_perm = security_manager.find_permission_view_menu(
"datasource_access", table.get_perm()
)
gamma_role = security_manager.find_role("Gamma")
security_manager.add_permission_role(gamma_role, tmp_table_perm)
self.login(username="gamma")
main_db = get_main_database()
uri = f"api/v1/database/{main_db.id}/select_star/ab_permission/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
# rollback changes
security_manager.del_permission_role(gamma_role, tmp_table_perm)
db.session.delete(table)
db.session.delete(main_db)
db.session.commit()
def test_get_select_star_not_found_database(self):
"""
Database API: Test get select star not found database
"""
self.login(username="admin")
max_id = db.session.query(func.max(Database.id)).scalar()
uri = f"api/v1/database/{max_id + 1}/select_star/birth_names/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
def test_get_select_star_not_found_table(self):
"""
Database API: Test get select star not found database
"""
self.login(username="admin")
example_db = get_example_database()
# sqllite will not raise a NoSuchTableError
if example_db.backend == "sqlite":
return
uri = f"api/v1/database/{example_db.id}/select_star/table_does_not_exist/"
rv = self.client.get(uri)
# TODO(bkyryliuk): investigate why presto returns 500
self.assertEqual(rv.status_code, 404 if example_db.backend != "presto" else 500)
def test_database_schemas(self):
"""
Database API: Test database schemas
"""
self.login("admin")
database = db.session.query(Database).first()
schemas = database.get_all_schema_names()
rv = self.client.get(f"api/v1/database/{database.id}/schemas/")
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(schemas, response["result"])
rv = self.client.get(
f"api/v1/database/{database.id}/schemas/?q={prison.dumps({'force': True})}"
)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(schemas, response["result"])
def test_database_schemas_not_found(self):
"""
Database API: Test database schemas not found
"""
self.logout()
self.login(username="gamma")
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/schemas/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
def test_database_schemas_invalid_query(self):
"""
Database API: Test database schemas with invalid query
"""
self.login("admin")
database = db.session.query(Database).first()
rv = self.client.get(
f"api/v1/database/{database.id}/schemas/?q={prison.dumps({'force': 'nop'})}"
)
self.assertEqual(rv.status_code, 400)
def test_test_connection(self):
"""
Database API: Test test connection
"""
extra = {
"metadata_params": {},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_csv_upload": [],
}
# need to temporarily allow sqlite dbs, teardown will undo this
app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = False
self.login("admin")
example_db = get_example_database()
# validate that the endpoint works with the password-masked sqlalchemy uri
data = {
"database_name": "examples",
"encrypted_extra": "{}",
"extra": json.dumps(extra),
"impersonate_user": False,
"sqlalchemy_uri": example_db.safe_sqlalchemy_uri(),
"server_cert": ssl_certificate,
}
url = f"api/v1/database/test_connection"
rv = self.post_assert_metric(url, data, "test_connection")
self.assertEqual(rv.status_code, 200)
self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8")
# validate that the endpoint works with the decrypted sqlalchemy uri
data = {
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"database_name": "examples",
"impersonate_user": False,
"extra": json.dumps(extra),
"server_cert": None,
}
rv = self.post_assert_metric(url, data, "test_connection")
self.assertEqual(rv.status_code, 200)
self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8")
def test_test_connection_failed(self):
"""
Database API: Test test connection failed
"""
self.login("admin")
data = {
"sqlalchemy_uri": "broken://url",
"database_name": "examples",
"impersonate_user": False,
"server_cert": None,
}
url = f"api/v1/database/test_connection"
rv = self.post_assert_metric(url, data, "test_connection")
self.assertEqual(rv.status_code, 400)
self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8")
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"driver_name": "broken",
"message": "Could not load database driver: broken",
}
self.assertEqual(response, expected_response)
data = {
"sqlalchemy_uri": "mssql+pymssql://url",
"database_name": "examples",
"impersonate_user": False,
"server_cert": None,
}
rv = self.post_assert_metric(url, data, "test_connection")
self.assertEqual(rv.status_code, 400)
self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8")
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"driver_name": "mssql+pymssql",
"message": "Could not load database driver: mssql+pymssql",
}
self.assertEqual(response, expected_response)
def test_test_connection_unsafe_uri(self):
"""
Database API: Test test connection with unsafe uri
"""
self.login("admin")
app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = True
data = {
"sqlalchemy_uri": "sqlite:///home/superset/unsafe.db",
"database_name": "unsafe",
"impersonate_user": False,
"server_cert": None,
}
url = f"api/v1/database/test_connection"
rv = self.post_assert_metric(url, data, "test_connection")
self.assertEqual(rv.status_code, 400)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"message": {
"sqlalchemy_uri": [
"SQLite database cannot be used as a data source for security reasons."
]
}
}
self.assertEqual(response, expected_response)
@pytest.mark.usefixtures("load_unicode_dashboard_with_position")
def test_get_database_related_objects(self):
"""
Database API: Test get chart and dashboard count related to a database
:return:
"""
self.login(username="admin")
database = get_example_database()
uri = f"api/v1/database/{database.id}/related_objects/"
rv = self.get_assert_metric(uri, "related_objects")
self.assertEqual(rv.status_code, 200)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(response["charts"]["count"], 33)
self.assertEqual(response["dashboards"]["count"], 6)
def test_get_database_related_objects_not_found(self):
"""
Database API: Test related objects not found
"""
max_id = db.session.query(func.max(Database.id)).scalar()
# id does not exist and we get 404
invalid_id = max_id + 1
uri = f"api/v1/database/{invalid_id}/related_objects/"
self.login(username="admin")
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
self.logout()
self.login(username="gamma")
database = get_example_database()
uri = f"api/v1/database/{database.id}/related_objects/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)