superset/tests/integration_tests/databases/api_tests.py

3028 lines
110 KiB
Python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
"""Unit tests for Superset"""
import dataclasses
import json
from collections import defaultdict
from io import BytesIO
from unittest import mock
from unittest.mock import patch, MagicMock
from zipfile import is_zipfile, ZipFile
from operator import itemgetter
import prison
import pytest
import yaml
from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import DBAPIError
from sqlalchemy.sql import func
from superset import db, security_manager
from superset.connectors.sqla.models import SqlaTable
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.databases.utils import make_url_safe
from superset.db_engine_specs.mysql import MySQLEngineSpec
from superset.db_engine_specs.postgres import PostgresEngineSpec
from superset.db_engine_specs.redshift import RedshiftEngineSpec
from superset.db_engine_specs.bigquery import BigQueryEngineSpec
from superset.db_engine_specs.gsheets import GSheetsEngineSpec
from superset.db_engine_specs.hana import HanaEngineSpec
from superset.errors import SupersetError
from superset.models.core import Database, ConfigurationMethod
from superset.reports.models import ReportSchedule, ReportScheduleType
from superset.utils.database import get_example_database, get_main_database
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
load_birth_names_data,
)
from tests.integration_tests.fixtures.certificates import ssl_certificate
from tests.integration_tests.fixtures.energy_dashboard import (
load_energy_table_with_slice,
load_energy_table_data,
)
from tests.integration_tests.fixtures.world_bank_dashboard import (
load_world_bank_dashboard_with_slices,
load_world_bank_data,
)
from tests.integration_tests.fixtures.importexport import (
database_config,
dataset_config,
database_metadata_config,
dataset_metadata_config,
)
from tests.integration_tests.fixtures.unicode_dashboard import (
load_unicode_dashboard_with_position,
load_unicode_data,
)
from tests.integration_tests.test_app import app
SQL_VALIDATORS_BY_ENGINE = {
"presto": "PrestoDBSQLValidator",
"postgresql": "PostgreSQLValidator",
}
PRESTO_SQL_VALIDATORS_BY_ENGINE = {
"presto": "PrestoDBSQLValidator",
"sqlite": "PrestoDBSQLValidator",
"postgresql": "PrestoDBSQLValidator",
"mysql": "PrestoDBSQLValidator",
}
class TestDatabaseApi(SupersetTestCase):
def insert_database(
self,
database_name: str,
sqlalchemy_uri: str,
extra: str = "",
encrypted_extra: str = "",
server_cert: str = "",
expose_in_sqllab: bool = False,
allow_file_upload: bool = False,
) -> Database:
database = Database(
database_name=database_name,
sqlalchemy_uri=sqlalchemy_uri,
extra=extra,
encrypted_extra=encrypted_extra,
server_cert=server_cert,
expose_in_sqllab=expose_in_sqllab,
allow_file_upload=allow_file_upload,
)
db.session.add(database)
db.session.commit()
return database
@pytest.fixture()
def create_database_with_report(self):
with self.create_app().app_context():
example_db = get_example_database()
database = self.insert_database(
"database_with_report",
example_db.sqlalchemy_uri_decrypted,
expose_in_sqllab=True,
)
report_schedule = ReportSchedule(
type=ReportScheduleType.ALERT,
name="report_with_database",
crontab="* * * * *",
database=database,
)
db.session.add(report_schedule)
db.session.commit()
yield database
# rollback changes
db.session.delete(report_schedule)
db.session.delete(database)
db.session.commit()
@pytest.fixture()
def create_database_with_dataset(self):
with self.create_app().app_context():
example_db = get_example_database()
self._database = self.insert_database(
"database_with_dataset",
example_db.sqlalchemy_uri_decrypted,
expose_in_sqllab=True,
)
table = SqlaTable(
schema="main", table_name="ab_permission", database=self._database
)
db.session.add(table)
db.session.commit()
yield self._database
# rollback changes
db.session.delete(table)
db.session.delete(self._database)
db.session.commit()
self._database = None
def create_database_import(self):
buf = BytesIO()
with ZipFile(buf, "w") as bundle:
with bundle.open("database_export/metadata.yaml", "w") as fp:
fp.write(yaml.safe_dump(database_metadata_config).encode())
with bundle.open(
"database_export/databases/imported_database.yaml", "w"
) as fp:
fp.write(yaml.safe_dump(database_config).encode())
with bundle.open(
"database_export/datasets/imported_dataset.yaml", "w"
) as fp:
fp.write(yaml.safe_dump(dataset_config).encode())
buf.seek(0)
return buf
def test_get_items(self):
"""
Database API: Test get items
"""
self.login(username="admin")
uri = "api/v1/database/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
response = json.loads(rv.data.decode("utf-8"))
expected_columns = [
"allow_ctas",
"allow_cvas",
"allow_dml",
"allow_file_upload",
"allow_run_async",
"allows_cost_estimate",
"allows_subquery",
"allows_virtual_table_explore",
"backend",
"changed_on",
"changed_on_delta_humanized",
"created_by",
"database_name",
"disable_data_preview",
"engine_information",
"explore_database_id",
"expose_in_sqllab",
"extra",
"force_ctas_schema",
"id",
"uuid",
]
self.assertGreater(response["count"], 0)
self.assertEqual(list(response["result"][0].keys()), expected_columns)
def test_get_items_filter(self):
"""
Database API: Test get items with filter
"""
example_db = get_example_database()
test_database = self.insert_database(
"test-database", example_db.sqlalchemy_uri_decrypted, expose_in_sqllab=True
)
dbs = db.session.query(Database).filter_by(expose_in_sqllab=True).all()
self.login(username="admin")
arguments = {
"keys": ["none"],
"filters": [{"col": "expose_in_sqllab", "opr": "eq", "value": True}],
"order_columns": "database_name",
"order_direction": "asc",
"page": 0,
"page_size": -1,
}
uri = f"api/v1/database/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
self.assertEqual(response["count"], len(dbs))
# Cleanup
db.session.delete(test_database)
db.session.commit()
def test_get_items_not_allowed(self):
"""
Database API: Test get items not allowed
"""
self.login(username="gamma")
uri = "api/v1/database/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(response["count"], 0)
def test_create_database(self):
"""
Database API: Test create
"""
extra = {
"metadata_params": {},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_file_upload": [],
}
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
database_data = {
"database_name": "test-create-database",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"configuration_method": ConfigurationMethod.SQLALCHEMY_FORM,
"server_cert": None,
"extra": json.dumps(extra),
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
# Cleanup
model = db.session.query(Database).get(response.get("id"))
assert model.configuration_method == ConfigurationMethod.SQLALCHEMY_FORM
db.session.delete(model)
db.session.commit()
@mock.patch(
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
def test_create_database_with_ssh_tunnel(
self, mock_test_connection_database_command_run, mock_get_all_schema_names
):
"""
Database API: Test create with SSH Tunnel
"""
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
database_data = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": ssh_tunnel_properties,
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one()
)
self.assertEqual(response.get("result")["ssh_tunnel"]["password"], "XXXXXXXXXX")
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
db.session.commit()
@mock.patch(
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
def test_update_database_with_ssh_tunnel(
self, mock_test_connection_database_command_run, mock_get_all_schema_names
):
"""
Database API: Test update with SSH Tunnel
"""
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
database_data = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
}
database_data_with_ssh_tunnel = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": ssh_tunnel_properties,
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
uri = "api/v1/database/{}".format(response.get("id"))
rv = self.client.put(uri, json=database_data_with_ssh_tunnel)
response_update = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response_update.get("id"))
.one()
)
self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id"))
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
db.session.commit()
@mock.patch(
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
def test_update_ssh_tunnel_via_database_api(
self, mock_test_connection_database_command_run, mock_get_all_schema_names
):
"""
Database API: Test update with SSH Tunnel
"""
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
initial_ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
updated_ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "Test",
"password": "new_bar",
}
database_data_with_ssh_tunnel = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": initial_ssh_tunnel_properties,
}
database_data_with_ssh_tunnel_update = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": updated_ssh_tunnel_properties,
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data_with_ssh_tunnel)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one()
)
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
self.assertEqual(model_ssh_tunnel.username, "foo")
uri = "api/v1/database/{}".format(response.get("id"))
rv = self.client.put(uri, json=database_data_with_ssh_tunnel_update)
response_update = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response_update.get("id"))
.one()
)
self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id"))
self.assertEqual(
response_update.get("result")["ssh_tunnel"]["password"], "XXXXXXXXXX"
)
self.assertEqual(model_ssh_tunnel.username, "Test")
self.assertEqual(model_ssh_tunnel.server_address, "123.132.123.1")
self.assertEqual(model_ssh_tunnel.server_port, 8080)
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
db.session.commit()
@mock.patch(
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
def test_cascade_delete_ssh_tunnel(
self, mock_test_connection_database_command_run, mock_get_all_schema_names
):
"""
Database API: Test create with SSH Tunnel
"""
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
database_data = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": ssh_tunnel_properties,
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one()
)
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
db.session.commit()
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one_or_none()
)
assert model_ssh_tunnel is None
@mock.patch(
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
def test_do_not_create_database_if_ssh_tunnel_creation_fails(
self, mock_test_connection_database_command_run, mock_get_all_schema_names
):
"""
Database API: Test create with SSH Tunnel
"""
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
ssh_tunnel_properties = {
"server_address": "123.132.123.1",
}
database_data = {
"database_name": "test-db-failure-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": ssh_tunnel_properties,
}
fail_message = {"message": "SSH Tunnel parameters are invalid."}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 422)
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one_or_none()
)
assert model_ssh_tunnel is None
self.assertEqual(response, fail_message)
# Cleanup
model = (
db.session.query(Database)
.filter(Database.database_name == "test-db-failure-ssh-tunnel")
.one_or_none()
)
# the DB should not be created
assert model is None
@mock.patch(
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
def test_get_database_returns_related_ssh_tunnel(
self, mock_test_connection_database_command_run, mock_get_all_schema_names
):
"""
Database API: Test GET Database returns its related SSH Tunnel
"""
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
database_data = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": ssh_tunnel_properties,
}
response_ssh_tunnel = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "XXXXXXXXXX",
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one()
)
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
self.assertEqual(response.get("result")["ssh_tunnel"], response_ssh_tunnel)
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
db.session.commit()
def test_create_database_invalid_configuration_method(self):
"""
Database API: Test create with an invalid configuration method.
"""
extra = {
"metadata_params": {},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_file_upload": [],
}
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
database_data = {
"database_name": "test-create-database",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"configuration_method": "BAD_FORM",
"server_cert": None,
"extra": json.dumps(extra),
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
assert response == {
"message": {"configuration_method": ["Invalid enum value BAD_FORM"]}
}
assert rv.status_code == 400
def test_create_database_no_configuration_method(self):
"""
Database API: Test create with no config method.
"""
extra = {
"metadata_params": {},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_file_upload": [],
}
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
database_data = {
"database_name": "test-create-database",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"server_cert": None,
"extra": json.dumps(extra),
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 201
self.assertIn("sqlalchemy_form", response["result"]["configuration_method"])
def test_create_database_server_cert_validate(self):
"""
Database API: Test create server cert validation
"""
example_db = get_example_database()
if example_db.backend == "sqlite":
return
self.login(username="admin")
database_data = {
"database_name": "test-create-database-invalid-cert",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"configuration_method": ConfigurationMethod.SQLALCHEMY_FORM,
"server_cert": "INVALID CERT",
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {"message": {"server_cert": ["Invalid certificate"]}}
self.assertEqual(rv.status_code, 400)
self.assertEqual(response, expected_response)
def test_create_database_json_validate(self):
"""
Database API: Test create encrypted extra and extra validation
"""
example_db = get_example_database()
if example_db.backend == "sqlite":
return
self.login(username="admin")
database_data = {
"database_name": "test-create-database-invalid-json",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"configuration_method": ConfigurationMethod.SQLALCHEMY_FORM,
"masked_encrypted_extra": '{"A": "a", "B", "C"}',
"extra": '["A": "a", "B", "C"]',
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"message": {
"masked_encrypted_extra": [
"Field cannot be decoded by JSON. Expecting ':' "
"delimiter: line 1 column 15 (char 14)"
],
"extra": [
"Field cannot be decoded by JSON. Expecting ','"
" delimiter: line 1 column 5 (char 4)"
],
}
}
self.assertEqual(rv.status_code, 400)
self.assertEqual(response, expected_response)
def test_create_database_extra_metadata_validate(self):
"""
Database API: Test create extra metadata_params validation
"""
example_db = get_example_database()
if example_db.backend == "sqlite":
return
extra = {
"metadata_params": {"wrong_param": "some_value"},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_file_upload": [],
}
self.login(username="admin")
database_data = {
"database_name": "test-create-database-invalid-extra",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"configuration_method": ConfigurationMethod.SQLALCHEMY_FORM,
"extra": json.dumps(extra),
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"message": {
"extra": [
"The metadata_params in Extra field is not configured correctly."
" The key wrong_param is invalid."
]
}
}
self.assertEqual(rv.status_code, 400)
self.assertEqual(response, expected_response)
def test_create_database_unique_validate(self):
"""
Database API: Test create database_name already exists
"""
example_db = get_example_database()
if example_db.backend == "sqlite":
return
self.login(username="admin")
database_data = {
"database_name": "examples",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"configuration_method": ConfigurationMethod.SQLALCHEMY_FORM,
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"message": {
"database_name": "A database with the same name already exists."
}
}
self.assertEqual(rv.status_code, 422)
self.assertEqual(response, expected_response)
def test_create_database_uri_validate(self):
"""
Database API: Test create fail validate sqlalchemy uri
"""
self.login(username="admin")
database_data = {
"database_name": "test-database-invalid-uri",
"sqlalchemy_uri": "wrong_uri",
"configuration_method": ConfigurationMethod.SQLALCHEMY_FORM,
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 400)
self.assertIn(
"Invalid connection string",
response["message"]["sqlalchemy_uri"][0],
)
@mock.patch(
"superset.views.core.app.config",
{**app.config, "PREVENT_UNSAFE_DB_CONNECTIONS": True},
)
def test_create_database_fail_sqllite(self):
"""
Database API: Test create fail with sqllite
"""
database_data = {
"database_name": "test-create-sqlite-database",
"sqlalchemy_uri": "sqlite:////some.db",
"configuration_method": ConfigurationMethod.SQLALCHEMY_FORM,
}
uri = "api/v1/database/"
self.login(username="admin")
response = self.client.post(uri, json=database_data)
response_data = json.loads(response.data.decode("utf-8"))
expected_response = {
"message": {
"sqlalchemy_uri": [
"SQLiteDialect_pysqlite cannot be used as a data source "
"for security reasons."
]
}
}
self.assertEqual(response_data, expected_response)
self.assertEqual(response.status_code, 400)
def test_create_database_conn_fail(self):
"""
Database API: Test create fails connection
"""
example_db = get_example_database()
if example_db.backend in ("sqlite", "hive", "presto"):
return
example_db.password = "wrong_password"
database_data = {
"database_name": "test-create-database-wrong-password",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"configuration_method": ConfigurationMethod.SQLALCHEMY_FORM,
}
uri = "api/v1/database/"
self.login(username="admin")
response = self.client.post(uri, json=database_data)
response_data = json.loads(response.data.decode("utf-8"))
superset_error_mysql = SupersetError(
message='Either the username "superset" or the password is incorrect.',
error_type="CONNECTION_ACCESS_DENIED_ERROR",
level="error",
extra={
"engine_name": "MySQL",
"invalid": ["username", "password"],
"issue_codes": [
{
"code": 1014,
"message": (
"Issue 1014 - Either the username or the password is wrong."
),
},
{
"code": 1015,
"message": (
"Issue 1015 - Issue 1015 - Either the database is spelled incorrectly or does not exist."
),
},
],
},
)
superset_error_postgres = SupersetError(
message='The password provided for username "superset" is incorrect.',
error_type="CONNECTION_INVALID_PASSWORD_ERROR",
level="error",
extra={
"engine_name": "PostgreSQL",
"invalid": ["username", "password"],
"issue_codes": [
{
"code": 1013,
"message": (
"Issue 1013 - The password provided when connecting to a database is not valid."
),
}
],
},
)
expected_response_mysql = {"errors": [dataclasses.asdict(superset_error_mysql)]}
expected_response_postgres = {
"errors": [dataclasses.asdict(superset_error_postgres)]
}
self.assertEqual(response.status_code, 500)
if example_db.backend == "mysql":
self.assertEqual(response_data, expected_response_mysql)
else:
self.assertEqual(response_data, expected_response_postgres)
def test_update_database(self):
"""
Database API: Test update
"""
example_db = get_example_database()
test_database = self.insert_database(
"test-database", example_db.sqlalchemy_uri_decrypted
)
self.login(username="admin")
database_data = {
"database_name": "test-database-updated",
"configuration_method": ConfigurationMethod.SQLALCHEMY_FORM,
}
uri = f"api/v1/database/{test_database.id}"
rv = self.client.put(uri, json=database_data)
self.assertEqual(rv.status_code, 200)
# Cleanup
model = db.session.query(Database).get(test_database.id)
db.session.delete(model)
db.session.commit()
def test_update_database_conn_fail(self):
"""
Database API: Test update fails connection
"""
example_db = get_example_database()
if example_db.backend in ("sqlite", "hive", "presto"):
return
test_database = self.insert_database(
"test-database1", example_db.sqlalchemy_uri_decrypted
)
example_db.password = "wrong_password"
database_data = {
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
}
uri = f"api/v1/database/{test_database.id}"
self.login(username="admin")
rv = self.client.put(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"message": "Connection failed, please check your connection settings"
}
self.assertEqual(rv.status_code, 422)
self.assertEqual(response, expected_response)
# Cleanup
model = db.session.query(Database).get(test_database.id)
db.session.delete(model)
db.session.commit()
def test_update_database_uniqueness(self):
"""
Database API: Test update uniqueness
"""
example_db = get_example_database()
test_database1 = self.insert_database(
"test-database1", example_db.sqlalchemy_uri_decrypted
)
test_database2 = self.insert_database(
"test-database2", example_db.sqlalchemy_uri_decrypted
)
self.login(username="admin")
database_data = {"database_name": "test-database2"}
uri = f"api/v1/database/{test_database1.id}"
rv = self.client.put(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"message": {
"database_name": "A database with the same name already exists."
}
}
self.assertEqual(rv.status_code, 422)
self.assertEqual(response, expected_response)
# Cleanup
db.session.delete(test_database1)
db.session.delete(test_database2)
db.session.commit()
def test_update_database_invalid(self):
"""
Database API: Test update invalid request
"""
self.login(username="admin")
database_data = {"database_name": "test-database-updated"}
uri = "api/v1/database/invalid"
rv = self.client.put(uri, json=database_data)
self.assertEqual(rv.status_code, 404)
def test_update_database_uri_validate(self):
"""
Database API: Test update sqlalchemy_uri validate
"""
example_db = get_example_database()
test_database = self.insert_database(
"test-database", example_db.sqlalchemy_uri_decrypted
)
self.login(username="admin")
database_data = {
"database_name": "test-database-updated",
"sqlalchemy_uri": "wrong_uri",
}
uri = f"api/v1/database/{test_database.id}"
rv = self.client.put(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 400)
self.assertIn(
"Invalid connection string",
response["message"]["sqlalchemy_uri"][0],
)
db.session.delete(test_database)
db.session.commit()
def test_update_database_with_invalid_configuration_method(self):
"""
Database API: Test update
"""
example_db = get_example_database()
test_database = self.insert_database(
"test-database", example_db.sqlalchemy_uri_decrypted
)
self.login(username="admin")
database_data = {
"database_name": "test-database-updated",
"configuration_method": "BAD_FORM",
}
uri = f"api/v1/database/{test_database.id}"
rv = self.client.put(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
assert response == {
"message": {"configuration_method": ["Invalid enum value BAD_FORM"]}
}
assert rv.status_code == 400
db.session.delete(test_database)
db.session.commit()
def test_update_database_with_no_configuration_method(self):
"""
Database API: Test update
"""
example_db = get_example_database()
test_database = self.insert_database(
"test-database", example_db.sqlalchemy_uri_decrypted
)
self.login(username="admin")
database_data = {
"database_name": "test-database-updated",
}
uri = f"api/v1/database/{test_database.id}"
rv = self.client.put(uri, json=database_data)
assert rv.status_code == 200
db.session.delete(test_database)
db.session.commit()
def test_delete_database(self):
"""
Database API: Test delete
"""
database_id = self.insert_database("test-database", "test_uri").id
self.login(username="admin")
uri = f"api/v1/database/{database_id}"
rv = self.delete_assert_metric(uri, "delete")
self.assertEqual(rv.status_code, 200)
model = db.session.query(Database).get(database_id)
self.assertEqual(model, None)
def test_delete_database_not_found(self):
"""
Database API: Test delete not found
"""
max_id = db.session.query(func.max(Database.id)).scalar()
self.login(username="admin")
uri = f"api/v1/database/{max_id + 1}"
rv = self.delete_assert_metric(uri, "delete")
self.assertEqual(rv.status_code, 404)
@pytest.mark.usefixtures("create_database_with_dataset")
def test_delete_database_with_datasets(self):
"""
Database API: Test delete fails because it has depending datasets
"""
self.login(username="admin")
uri = f"api/v1/database/{self._database.id}"
rv = self.delete_assert_metric(uri, "delete")
self.assertEqual(rv.status_code, 422)
@pytest.mark.usefixtures("create_database_with_report")
def test_delete_database_with_report(self):
"""
Database API: Test delete with associated report
"""
self.login(username="admin")
database = (
db.session.query(Database)
.filter(Database.database_name == "database_with_report")
.one_or_none()
)
uri = f"api/v1/database/{database.id}"
rv = self.client.delete(uri)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 422)
expected_response = {
"message": "There are associated alerts or reports: report_with_database"
}
self.assertEqual(response, expected_response)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_get_table_metadata(self):
"""
Database API: Test get table metadata info
"""
example_db = get_example_database()
self.login(username="admin")
uri = f"api/v1/database/{example_db.id}/table/birth_names/null/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(response["name"], "birth_names")
self.assertIsNone(response["comment"])
self.assertTrue(len(response["columns"]) > 5)
self.assertTrue(response.get("selectStar").startswith("SELECT"))
def test_info_security_database(self):
"""
Database API: Test info security
"""
self.login(username="admin")
params = {"keys": ["permissions"]}
uri = f"api/v1/database/_info?q={prison.dumps(params)}"
rv = self.get_assert_metric(uri, "info")
data = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
assert set(data["permissions"]) == {"can_read", "can_write", "can_export"}
def test_get_invalid_database_table_metadata(self):
"""
Database API: Test get invalid database from table metadata
"""
database_id = 1000
self.login(username="admin")
uri = f"api/v1/database/{database_id}/table/some_table/some_schema/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
uri = "api/v1/database/some_database/table/some_table/some_schema/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
def test_get_invalid_table_table_metadata(self):
"""
Database API: Test get invalid table from table metadata
"""
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/table/wrong_table/null/"
self.login(username="admin")
rv = self.client.get(uri)
data = json.loads(rv.data.decode("utf-8"))
if example_db.backend == "sqlite":
self.assertEqual(rv.status_code, 200)
self.assertEqual(
data,
{
"columns": [],
"comment": None,
"foreignKeys": [],
"indexes": [],
"name": "wrong_table",
"primaryKey": {"constrained_columns": None, "name": None},
"selectStar": "SELECT\nFROM wrong_table\nLIMIT 100\nOFFSET 0",
},
)
elif example_db.backend == "mysql":
self.assertEqual(rv.status_code, 422)
self.assertEqual(data, {"message": "`wrong_table`"})
else:
self.assertEqual(rv.status_code, 422)
self.assertEqual(data, {"message": "wrong_table"})
def test_get_table_metadata_no_db_permission(self):
"""
Database API: Test get table metadata from not permitted db
"""
self.login(username="gamma")
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/birth_names/null/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_get_table_extra_metadata(self):
"""
Database API: Test get table extra metadata info
"""
example_db = get_example_database()
self.login(username="admin")
uri = f"api/v1/database/{example_db.id}/table_extra/birth_names/null/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(response, {})
def test_get_invalid_database_table_extra_metadata(self):
"""
Database API: Test get invalid database from table extra metadata
"""
database_id = 1000
self.login(username="admin")
uri = f"api/v1/database/{database_id}/table_extra/some_table/some_schema/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
uri = "api/v1/database/some_database/table_extra/some_table/some_schema/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
def test_get_invalid_table_table_extra_metadata(self):
"""
Database API: Test get invalid table from table extra metadata
"""
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/table_extra/wrong_table/null/"
self.login(username="admin")
rv = self.client.get(uri)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
self.assertEqual(data, {})
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_get_select_star(self):
"""
Database API: Test get select star
"""
self.login(username="admin")
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/select_star/birth_names/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
response = json.loads(rv.data.decode("utf-8"))
self.assertIn("gender", response["result"])
def test_get_select_star_not_allowed(self):
"""
Database API: Test get select star not allowed
"""
self.login(username="gamma")
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/select_star/birth_names/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
def test_get_select_star_datasource_access(self):
"""
Database API: Test get select star with datasource access
"""
session = db.session
table = SqlaTable(
schema="main", table_name="ab_permission", database=get_main_database()
)
session.add(table)
session.commit()
tmp_table_perm = security_manager.find_permission_view_menu(
"datasource_access", table.get_perm()
)
gamma_role = security_manager.find_role("Gamma")
security_manager.add_permission_role(gamma_role, tmp_table_perm)
self.login(username="gamma")
main_db = get_main_database()
uri = f"api/v1/database/{main_db.id}/select_star/ab_permission/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
# rollback changes
security_manager.del_permission_role(gamma_role, tmp_table_perm)
db.session.delete(table)
db.session.delete(main_db)
db.session.commit()
def test_get_select_star_not_found_database(self):
"""
Database API: Test get select star not found database
"""
self.login(username="admin")
max_id = db.session.query(func.max(Database.id)).scalar()
uri = f"api/v1/database/{max_id + 1}/select_star/birth_names/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
def test_get_select_star_not_found_table(self):
"""
Database API: Test get select star not found database
"""
self.login(username="admin")
example_db = get_example_database()
# sqllite will not raise a NoSuchTableError
if example_db.backend == "sqlite":
return
uri = f"api/v1/database/{example_db.id}/select_star/table_does_not_exist/"
rv = self.client.get(uri)
# TODO(bkyryliuk): investigate why presto returns 500
self.assertEqual(rv.status_code, 404 if example_db.backend != "presto" else 500)
def test_get_allow_file_upload_filter(self):
"""
Database API: Test filter for allow file upload checks for schemas
"""
with self.create_app().app_context():
example_db = get_example_database()
extra = {
"metadata_params": {},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_file_upload": ["public"],
}
self.login(username="admin")
database = self.insert_database(
"database_with_upload",
example_db.sqlalchemy_uri_decrypted,
extra=json.dumps(extra),
allow_file_upload=True,
)
db.session.commit()
yield database
arguments = {
"columns": ["allow_file_upload"],
"filters": [
{
"col": "allow_file_upload",
"opr": "upload_is_enabled",
"value": True,
}
],
}
uri = f"api/v1/database/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
data = json.loads(rv.data.decode("utf-8"))
assert data["count"] == 1
db.session.delete(database)
db.session.commit()
def test_get_allow_file_upload_filter_no_schema(self):
"""
Database API: Test filter for allow file upload checks for schemas.
This test has allow_file_upload but no schemas.
"""
with self.create_app().app_context():
example_db = get_example_database()
extra = {
"metadata_params": {},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_file_upload": [],
}
self.login(username="admin")
database = self.insert_database(
"database_with_upload",
example_db.sqlalchemy_uri_decrypted,
extra=json.dumps(extra),
allow_file_upload=True,
)
db.session.commit()
yield database
arguments = {
"columns": ["allow_file_upload"],
"filters": [
{
"col": "allow_file_upload",
"opr": "upload_is_enabled",
"value": True,
}
],
}
uri = f"api/v1/database/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
data = json.loads(rv.data.decode("utf-8"))
assert data["count"] == 0
db.session.delete(database)
db.session.commit()
def test_get_allow_file_upload_filter_allow_file_false(self):
"""
Database API: Test filter for allow file upload checks for schemas.
This has a schema but does not allow_file_upload
"""
with self.create_app().app_context():
example_db = get_example_database()
extra = {
"metadata_params": {},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_file_upload": ["public"],
}
self.login(username="admin")
database = self.insert_database(
"database_with_upload",
example_db.sqlalchemy_uri_decrypted,
extra=json.dumps(extra),
allow_file_upload=False,
)
db.session.commit()
yield database
arguments = {
"columns": ["allow_file_upload"],
"filters": [
{
"col": "allow_file_upload",
"opr": "upload_is_enabled",
"value": True,
}
],
}
uri = f"api/v1/database/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
data = json.loads(rv.data.decode("utf-8"))
assert data["count"] == 0
db.session.delete(database)
db.session.commit()
def test_get_allow_file_upload_false(self):
"""
Database API: Test filter for allow file upload checks for schemas.
Both databases have false allow_file_upload
"""
with self.create_app().app_context():
example_db = get_example_database()
extra = {
"metadata_params": {},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_file_upload": [],
}
self.login(username="admin")
database = self.insert_database(
"database_with_upload",
example_db.sqlalchemy_uri_decrypted,
extra=json.dumps(extra),
allow_file_upload=False,
)
db.session.commit()
yield database
arguments = {
"columns": ["allow_file_upload"],
"filters": [
{
"col": "allow_file_upload",
"opr": "upload_is_enabled",
"value": True,
}
],
}
uri = f"api/v1/database/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
data = json.loads(rv.data.decode("utf-8"))
assert data["count"] == 0
db.session.delete(database)
db.session.commit()
def test_get_allow_file_upload_false_no_extra(self):
"""
Database API: Test filter for allow file upload checks for schemas.
Both databases have false allow_file_upload
"""
with self.create_app().app_context():
example_db = get_example_database()
self.login(username="admin")
database = self.insert_database(
"database_with_upload",
example_db.sqlalchemy_uri_decrypted,
allow_file_upload=False,
)
db.session.commit()
yield database
arguments = {
"columns": ["allow_file_upload"],
"filters": [
{
"col": "allow_file_upload",
"opr": "upload_is_enabled",
"value": True,
}
],
}
uri = f"api/v1/database/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
data = json.loads(rv.data.decode("utf-8"))
assert data["count"] == 0
db.session.delete(database)
db.session.commit()
def mock_csv_function(d, user):
return d.get_all_schema_names()
@mock.patch(
"superset.views.core.app.config",
{**app.config, "ALLOWED_USER_CSV_SCHEMA_FUNC": mock_csv_function},
)
def test_get_allow_file_upload_true_csv(self):
"""
Database API: Test filter for allow file upload checks for schemas.
Both databases have false allow_file_upload
"""
with self.create_app().app_context():
example_db = get_example_database()
extra = {
"metadata_params": {},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_file_upload": [],
}
self.login(username="admin")
database = self.insert_database(
"database_with_upload",
example_db.sqlalchemy_uri_decrypted,
extra=json.dumps(extra),
allow_file_upload=True,
)
db.session.commit()
yield database
arguments = {
"columns": ["allow_file_upload"],
"filters": [
{
"col": "allow_file_upload",
"opr": "upload_is_enabled",
"value": True,
}
],
}
uri = f"api/v1/database/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
data = json.loads(rv.data.decode("utf-8"))
assert data["count"] == 1
db.session.delete(database)
db.session.commit()
def mock_empty_csv_function(d, user):
return []
@mock.patch(
"superset.views.core.app.config",
{**app.config, "ALLOWED_USER_CSV_SCHEMA_FUNC": mock_empty_csv_function},
)
def test_get_allow_file_upload_false_csv(self):
"""
Database API: Test filter for allow file upload checks for schemas.
Both databases have false allow_file_upload
"""
with self.create_app().app_context():
self.login(username="admin")
arguments = {
"columns": ["allow_file_upload"],
"filters": [
{
"col": "allow_file_upload",
"opr": "upload_is_enabled",
"value": True,
}
],
}
uri = f"api/v1/database/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
data = json.loads(rv.data.decode("utf-8"))
assert data["count"] == 1
def test_get_allow_file_upload_filter_no_permission(self):
"""
Database API: Test filter for allow file upload checks for schemas
"""
with self.create_app().app_context():
example_db = get_example_database()
extra = {
"metadata_params": {},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_file_upload": ["public"],
}
self.login(username="gamma")
database = self.insert_database(
"database_with_upload",
example_db.sqlalchemy_uri_decrypted,
extra=json.dumps(extra),
allow_file_upload=True,
)
db.session.commit()
yield database
arguments = {
"columns": ["allow_file_upload"],
"filters": [
{
"col": "allow_file_upload",
"opr": "upload_is_enabled",
"value": True,
}
],
}
uri = f"api/v1/database/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
data = json.loads(rv.data.decode("utf-8"))
assert data["count"] == 0
db.session.delete(database)
db.session.commit()
def test_get_allow_file_upload_filter_with_permission(self):
"""
Database API: Test filter for allow file upload checks for schemas
"""
with self.create_app().app_context():
main_db = get_main_database()
main_db.allow_file_upload = True
session = db.session
table = SqlaTable(
schema="public",
table_name="ab_permission",
database=get_main_database(),
)
session.add(table)
session.commit()
tmp_table_perm = security_manager.find_permission_view_menu(
"datasource_access", table.get_perm()
)
gamma_role = security_manager.find_role("Gamma")
security_manager.add_permission_role(gamma_role, tmp_table_perm)
self.login(username="gamma")
arguments = {
"columns": ["allow_file_upload"],
"filters": [
{
"col": "allow_file_upload",
"opr": "upload_is_enabled",
"value": True,
}
],
}
uri = f"api/v1/database/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
data = json.loads(rv.data.decode("utf-8"))
assert data["count"] == 1
# rollback changes
security_manager.del_permission_role(gamma_role, tmp_table_perm)
db.session.delete(table)
db.session.delete(main_db)
db.session.commit()
def test_database_schemas(self):
"""
Database API: Test database schemas
"""
self.login(username="admin")
database = db.session.query(Database).filter_by(database_name="examples").one()
schemas = database.get_all_schema_names()
rv = self.client.get(f"api/v1/database/{database.id}/schemas/")
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(schemas, response["result"])
rv = self.client.get(
f"api/v1/database/{database.id}/schemas/?q={prison.dumps({'force': True})}"
)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(schemas, response["result"])
def test_database_schemas_not_found(self):
"""
Database API: Test database schemas not found
"""
self.logout()
self.login(username="gamma")
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/schemas/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
def test_database_schemas_invalid_query(self):
"""
Database API: Test database schemas with invalid query
"""
self.login("admin")
database = db.session.query(Database).first()
rv = self.client.get(
f"api/v1/database/{database.id}/schemas/?q={prison.dumps({'force': 'nop'})}"
)
self.assertEqual(rv.status_code, 400)
def test_test_connection(self):
"""
Database API: Test test connection
"""
extra = {
"metadata_params": {},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_file_upload": [],
}
# need to temporarily allow sqlite dbs, teardown will undo this
app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = False
self.login("admin")
example_db = get_example_database()
# validate that the endpoint works with the password-masked sqlalchemy uri
data = {
"database_name": "examples",
"masked_encrypted_extra": "{}",
"extra": json.dumps(extra),
"impersonate_user": False,
"sqlalchemy_uri": example_db.safe_sqlalchemy_uri(),
"server_cert": None,
}
url = "api/v1/database/test_connection/"
rv = self.post_assert_metric(url, data, "test_connection")
self.assertEqual(rv.status_code, 200)
self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8")
# validate that the endpoint works with the decrypted sqlalchemy uri
data = {
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"database_name": "examples",
"impersonate_user": False,
"extra": json.dumps(extra),
"server_cert": None,
}
rv = self.post_assert_metric(url, data, "test_connection")
self.assertEqual(rv.status_code, 200)
self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8")
def test_test_connection_failed(self):
"""
Database API: Test test connection failed
"""
self.login("admin")
data = {
"sqlalchemy_uri": "broken://url",
"database_name": "examples",
"impersonate_user": False,
"server_cert": None,
}
url = "api/v1/database/test_connection/"
rv = self.post_assert_metric(url, data, "test_connection")
self.assertEqual(rv.status_code, 422)
self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8")
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"errors": [
{
"message": "Could not load database driver: BaseEngineSpec",
"error_type": "GENERIC_COMMAND_ERROR",
"level": "warning",
"extra": {
"issue_codes": [
{
"code": 1010,
"message": "Issue 1010 - Superset encountered an error while running a command.",
}
]
},
}
]
}
self.assertEqual(response, expected_response)
data = {
"sqlalchemy_uri": "mssql+pymssql://url",
"database_name": "examples",
"impersonate_user": False,
"server_cert": None,
}
rv = self.post_assert_metric(url, data, "test_connection")
self.assertEqual(rv.status_code, 422)
self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8")
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"errors": [
{
"message": "Could not load database driver: MssqlEngineSpec",
"error_type": "GENERIC_COMMAND_ERROR",
"level": "warning",
"extra": {
"issue_codes": [
{
"code": 1010,
"message": "Issue 1010 - Superset encountered an error while running a command.",
}
]
},
}
]
}
self.assertEqual(response, expected_response)
def test_test_connection_unsafe_uri(self):
"""
Database API: Test test connection with unsafe uri
"""
self.login("admin")
app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = True
data = {
"sqlalchemy_uri": "sqlite:///home/superset/unsafe.db",
"database_name": "unsafe",
"impersonate_user": False,
"server_cert": None,
}
url = "api/v1/database/test_connection/"
rv = self.post_assert_metric(url, data, "test_connection")
self.assertEqual(rv.status_code, 400)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"message": {
"sqlalchemy_uri": [
"SQLiteDialect_pysqlite cannot be used as a data source for security reasons."
]
}
}
self.assertEqual(response, expected_response)
app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = False
@mock.patch(
"superset.databases.commands.test_connection.DatabaseDAO.build_db_for_connection_test",
)
@mock.patch(
"superset.databases.commands.test_connection.event_logger",
)
def test_test_connection_failed_invalid_hostname(
self, mock_event_logger, mock_build_db
):
"""
Database API: Test test connection failed due to invalid hostname
"""
msg = 'psql: error: could not translate host name "locahost" to address: nodename nor servname provided, or not known'
mock_build_db.return_value.set_sqlalchemy_uri.side_effect = DBAPIError(
msg, None, None
)
mock_build_db.return_value.db_engine_spec.__name__ = "Some name"
superset_error = SupersetError(
message='Unable to resolve hostname "locahost".',
error_type="CONNECTION_INVALID_HOSTNAME_ERROR",
level="error",
extra={
"hostname": "locahost",
"issue_codes": [
{
"code": 1007,
"message": (
"Issue 1007 - The hostname provided can't be resolved."
),
}
],
},
)
mock_build_db.return_value.db_engine_spec.extract_errors.return_value = [
superset_error
]
self.login("admin")
data = {
"sqlalchemy_uri": "postgres://username:password@locahost:12345/db",
"database_name": "examples",
"impersonate_user": False,
"server_cert": None,
}
url = "api/v1/database/test_connection/"
rv = self.post_assert_metric(url, data, "test_connection")
assert rv.status_code == 500
assert rv.headers["Content-Type"] == "application/json; charset=utf-8"
response = json.loads(rv.data.decode("utf-8"))
expected_response = {"errors": [dataclasses.asdict(superset_error)]}
assert response == expected_response
@pytest.mark.usefixtures(
"load_unicode_dashboard_with_position",
"load_energy_table_with_slice",
"load_world_bank_dashboard_with_slices",
"load_birth_names_dashboard_with_slices",
)
def test_get_database_related_objects(self):
"""
Database API: Test get chart and dashboard count related to a database
:return:
"""
self.login(username="admin")
database = get_example_database()
uri = f"api/v1/database/{database.id}/related_objects/"
rv = self.get_assert_metric(uri, "related_objects")
self.assertEqual(rv.status_code, 200)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(response["charts"]["count"], 34)
self.assertEqual(response["dashboards"]["count"], 3)
def test_get_database_related_objects_not_found(self):
"""
Database API: Test related objects not found
"""
max_id = db.session.query(func.max(Database.id)).scalar()
# id does not exist and we get 404
invalid_id = max_id + 1
uri = f"api/v1/database/{invalid_id}/related_objects/"
self.login(username="admin")
rv = self.get_assert_metric(uri, "related_objects")
self.assertEqual(rv.status_code, 404)
self.logout()
self.login(username="gamma")
database = get_example_database()
uri = f"api/v1/database/{database.id}/related_objects/"
rv = self.get_assert_metric(uri, "related_objects")
self.assertEqual(rv.status_code, 404)
def test_export_database(self):
"""
Database API: Test export database
"""
self.login(username="admin")
database = get_example_database()
argument = [database.id]
uri = f"api/v1/database/export/?q={prison.dumps(argument)}"
rv = self.get_assert_metric(uri, "export")
assert rv.status_code == 200
buf = BytesIO(rv.data)
assert is_zipfile(buf)
def test_export_database_not_allowed(self):
"""
Database API: Test export database not allowed
"""
self.login(username="gamma")
database = get_example_database()
argument = [database.id]
uri = f"api/v1/database/export/?q={prison.dumps(argument)}"
rv = self.client.get(uri)
assert rv.status_code == 403
def test_export_database_non_existing(self):
"""
Database API: Test export database not allowed
"""
max_id = db.session.query(func.max(Database.id)).scalar()
# id does not exist and we get 404
invalid_id = max_id + 1
self.login(username="admin")
argument = [invalid_id]
uri = f"api/v1/database/export/?q={prison.dumps(argument)}"
rv = self.get_assert_metric(uri, "export")
assert rv.status_code == 404
def test_import_database(self):
"""
Database API: Test import database
"""
self.login(username="admin")
uri = "api/v1/database/import/"
buf = self.create_database_import()
form_data = {
"formData": (buf, "database_export.zip"),
}
rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
assert response == {"message": "OK"}
database = (
db.session.query(Database).filter_by(uuid=database_config["uuid"]).one()
)
assert database.database_name == "imported_database"
assert len(database.tables) == 1
dataset = database.tables[0]
assert dataset.table_name == "imported_dataset"
assert str(dataset.uuid) == dataset_config["uuid"]
dataset.owners = []
db.session.delete(dataset)
db.session.commit()
db.session.delete(database)
db.session.commit()
def test_import_database_overwrite(self):
"""
Database API: Test import existing database
"""
self.login(username="admin")
uri = "api/v1/database/import/"
buf = self.create_database_import()
form_data = {
"formData": (buf, "database_export.zip"),
}
rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
assert response == {"message": "OK"}
# import again without overwrite flag
buf = self.create_database_import()
form_data = {
"formData": (buf, "database_export.zip"),
}
rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 422
assert response == {
"errors": [
{
"message": "Error importing database",
"error_type": "GENERIC_COMMAND_ERROR",
"level": "warning",
"extra": {
"databases/imported_database.yaml": "Database already exists and `overwrite=true` was not passed",
"issue_codes": [
{
"code": 1010,
"message": (
"Issue 1010 - Superset encountered an "
"error while running a command."
),
}
],
},
}
]
}
# import with overwrite flag
buf = self.create_database_import()
form_data = {
"formData": (buf, "database_export.zip"),
"overwrite": "true",
}
rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
assert response == {"message": "OK"}
# clean up
database = (
db.session.query(Database).filter_by(uuid=database_config["uuid"]).one()
)
dataset = database.tables[0]
dataset.owners = []
db.session.delete(dataset)
db.session.commit()
db.session.delete(database)
db.session.commit()
def test_import_database_invalid(self):
"""
Database API: Test import invalid database
"""
self.login(username="admin")
uri = "api/v1/database/import/"
buf = BytesIO()
with ZipFile(buf, "w") as bundle:
with bundle.open("database_export/metadata.yaml", "w") as fp:
fp.write(yaml.safe_dump(dataset_metadata_config).encode())
with bundle.open(
"database_export/databases/imported_database.yaml", "w"
) as fp:
fp.write(yaml.safe_dump(database_config).encode())
with bundle.open(
"database_export/datasets/imported_dataset.yaml", "w"
) as fp:
fp.write(yaml.safe_dump(dataset_config).encode())
buf.seek(0)
form_data = {
"formData": (buf, "database_export.zip"),
}
rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 422
assert response == {
"errors": [
{
"message": "Error importing database",
"error_type": "GENERIC_COMMAND_ERROR",
"level": "warning",
"extra": {
"metadata.yaml": {"type": ["Must be equal to Database."]},
"issue_codes": [
{
"code": 1010,
"message": (
"Issue 1010 - Superset encountered an "
"error while running a command."
),
}
],
},
}
]
}
def test_import_database_masked_password(self):
"""
Database API: Test import database with masked password
"""
self.login(username="admin")
uri = "api/v1/database/import/"
masked_database_config = database_config.copy()
masked_database_config[
"sqlalchemy_uri"
] = "postgresql://username:XXXXXXXXXX@host:12345/db"
buf = BytesIO()
with ZipFile(buf, "w") as bundle:
with bundle.open("database_export/metadata.yaml", "w") as fp:
fp.write(yaml.safe_dump(database_metadata_config).encode())
with bundle.open(
"database_export/databases/imported_database.yaml", "w"
) as fp:
fp.write(yaml.safe_dump(masked_database_config).encode())
with bundle.open(
"database_export/datasets/imported_dataset.yaml", "w"
) as fp:
fp.write(yaml.safe_dump(dataset_config).encode())
buf.seek(0)
form_data = {
"formData": (buf, "database_export.zip"),
}
rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 422
assert response == {
"errors": [
{
"message": "Error importing database",
"error_type": "GENERIC_COMMAND_ERROR",
"level": "warning",
"extra": {
"databases/imported_database.yaml": {
"_schema": ["Must provide a password for the database"]
},
"issue_codes": [
{
"code": 1010,
"message": (
"Issue 1010 - Superset encountered an "
"error while running a command."
),
}
],
},
}
]
}
def test_import_database_masked_password_provided(self):
"""
Database API: Test import database with masked password provided
"""
self.login(username="admin")
uri = "api/v1/database/import/"
masked_database_config = database_config.copy()
masked_database_config[
"sqlalchemy_uri"
] = "vertica+vertica_python://hackathon:XXXXXXXXXX@host:5433/dbname?ssl=1"
buf = BytesIO()
with ZipFile(buf, "w") as bundle:
with bundle.open("database_export/metadata.yaml", "w") as fp:
fp.write(yaml.safe_dump(database_metadata_config).encode())
with bundle.open(
"database_export/databases/imported_database.yaml", "w"
) as fp:
fp.write(yaml.safe_dump(masked_database_config).encode())
buf.seek(0)
form_data = {
"formData": (buf, "database_export.zip"),
"passwords": json.dumps({"databases/imported_database.yaml": "SECRET"}),
}
rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
assert response == {"message": "OK"}
database = (
db.session.query(Database).filter_by(uuid=database_config["uuid"]).one()
)
assert database.database_name == "imported_database"
assert (
database.sqlalchemy_uri
== "vertica+vertica_python://hackathon:XXXXXXXXXX@host:5433/dbname?ssl=1"
)
assert database.password == "SECRET"
db.session.delete(database)
db.session.commit()
@mock.patch(
"superset.db_engine_specs.base.BaseEngineSpec.get_function_names",
)
def test_function_names(self, mock_get_function_names):
example_db = get_example_database()
if example_db.backend in {"hive", "presto"}:
return
mock_get_function_names.return_value = ["AVG", "MAX", "SUM"]
self.login(username="admin")
uri = "api/v1/database/1/function_names/"
rv = self.client.get(uri)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
assert response == {"function_names": ["AVG", "MAX", "SUM"]}
@mock.patch("superset.databases.api.get_available_engine_specs")
@mock.patch("superset.databases.api.app")
def test_available(self, app, get_available_engine_specs):
app.config = {"PREFERRED_DATABASES": ["PostgreSQL", "Google BigQuery"]}
get_available_engine_specs.return_value = {
PostgresEngineSpec: {"psycopg2"},
BigQueryEngineSpec: {"bigquery"},
MySQLEngineSpec: {"mysqlconnector", "mysqldb"},
GSheetsEngineSpec: {"apsw"},
RedshiftEngineSpec: {"psycopg2"},
HanaEngineSpec: {""},
}
self.login(username="admin")
uri = "api/v1/database/available/"
rv = self.client.get(uri)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
assert response == {
"databases": [
{
"available_drivers": ["psycopg2"],
"default_driver": "psycopg2",
"engine": "postgresql",
"name": "PostgreSQL",
"parameters": {
"properties": {
"database": {
"description": "Database name",
"type": "string",
},
"encryption": {
"description": "Use an encrypted connection to the database",
"type": "boolean",
},
"host": {
"description": "Hostname or IP address",
"type": "string",
},
"password": {
"description": "Password",
"nullable": True,
"type": "string",
},
"port": {
"description": "Database port",
"format": "int32",
"maximum": 65536,
"minimum": 0,
"type": "integer",
},
"query": {
"additionalProperties": {},
"description": "Additional parameters",
"type": "object",
},
"username": {
"description": "Username",
"nullable": True,
"type": "string",
},
},
"required": ["database", "host", "port", "username"],
"type": "object",
},
"preferred": True,
"sqlalchemy_uri_placeholder": "postgresql://user:password@host:port/dbname[?key=value&key=value...]",
"engine_information": {
"supports_file_upload": True,
"disable_ssh_tunneling": False,
},
},
{
"available_drivers": ["bigquery"],
"default_driver": "bigquery",
"engine": "bigquery",
"name": "Google BigQuery",
"parameters": {
"properties": {
"credentials_info": {
"description": "Contents of BigQuery JSON credentials.",
"type": "string",
"x-encrypted-extra": True,
},
"query": {"type": "object"},
},
"type": "object",
},
"preferred": True,
"sqlalchemy_uri_placeholder": "bigquery://{project_id}",
"engine_information": {
"supports_file_upload": True,
"disable_ssh_tunneling": True,
},
},
{
"available_drivers": ["psycopg2"],
"default_driver": "psycopg2",
"engine": "redshift",
"name": "Amazon Redshift",
"parameters": {
"properties": {
"database": {
"description": "Database name",
"type": "string",
},
"encryption": {
"description": "Use an encrypted connection to the database",
"type": "boolean",
},
"host": {
"description": "Hostname or IP address",
"type": "string",
},
"password": {
"description": "Password",
"nullable": True,
"type": "string",
},
"port": {
"description": "Database port",
"format": "int32",
"maximum": 65536,
"minimum": 0,
"type": "integer",
},
"query": {
"additionalProperties": {},
"description": "Additional parameters",
"type": "object",
},
"username": {
"description": "Username",
"nullable": True,
"type": "string",
},
},
"required": ["database", "host", "port", "username"],
"type": "object",
},
"preferred": False,
"sqlalchemy_uri_placeholder": "redshift+psycopg2://user:password@host:port/dbname[?key=value&key=value...]",
"engine_information": {
"supports_file_upload": True,
"disable_ssh_tunneling": False,
},
},
{
"available_drivers": ["apsw"],
"default_driver": "apsw",
"engine": "gsheets",
"name": "Google Sheets",
"parameters": {
"properties": {
"catalog": {"type": "object"},
"service_account_info": {
"description": "Contents of GSheets JSON credentials.",
"type": "string",
"x-encrypted-extra": True,
},
},
"type": "object",
},
"preferred": False,
"sqlalchemy_uri_placeholder": "gsheets://",
"engine_information": {
"supports_file_upload": False,
"disable_ssh_tunneling": True,
},
},
{
"available_drivers": ["mysqlconnector", "mysqldb"],
"default_driver": "mysqldb",
"engine": "mysql",
"name": "MySQL",
"parameters": {
"properties": {
"database": {
"description": "Database name",
"type": "string",
},
"encryption": {
"description": "Use an encrypted connection to the database",
"type": "boolean",
},
"host": {
"description": "Hostname or IP address",
"type": "string",
},
"password": {
"description": "Password",
"nullable": True,
"type": "string",
},
"port": {
"description": "Database port",
"format": "int32",
"maximum": 65536,
"minimum": 0,
"type": "integer",
},
"query": {
"additionalProperties": {},
"description": "Additional parameters",
"type": "object",
},
"username": {
"description": "Username",
"nullable": True,
"type": "string",
},
},
"required": ["database", "host", "port", "username"],
"type": "object",
},
"preferred": False,
"sqlalchemy_uri_placeholder": "mysql://user:password@host:port/dbname[?key=value&key=value...]",
"engine_information": {
"supports_file_upload": True,
"disable_ssh_tunneling": False,
},
},
{
"available_drivers": [""],
"engine": "hana",
"name": "SAP HANA",
"preferred": False,
"engine_information": {
"supports_file_upload": True,
"disable_ssh_tunneling": False,
},
},
]
}
@mock.patch("superset.databases.api.get_available_engine_specs")
@mock.patch("superset.databases.api.app")
def test_available_no_default(self, app, get_available_engine_specs):
app.config = {"PREFERRED_DATABASES": ["MySQL"]}
get_available_engine_specs.return_value = {
MySQLEngineSpec: {"mysqlconnector"},
HanaEngineSpec: {""},
}
self.login(username="admin")
uri = "api/v1/database/available/"
rv = self.client.get(uri)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
assert response == {
"databases": [
{
"available_drivers": ["mysqlconnector"],
"default_driver": "mysqldb",
"engine": "mysql",
"name": "MySQL",
"preferred": True,
"engine_information": {
"supports_file_upload": True,
"disable_ssh_tunneling": False,
},
},
{
"available_drivers": [""],
"engine": "hana",
"name": "SAP HANA",
"preferred": False,
"engine_information": {
"supports_file_upload": True,
"disable_ssh_tunneling": False,
},
},
]
}
def test_validate_parameters_invalid_payload_format(self):
self.login(username="admin")
url = "api/v1/database/validate_parameters/"
rv = self.client.post(url, data="INVALID", content_type="text/plain")
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 400
assert response == {
"errors": [
{
"message": "Request is not JSON",
"error_type": "INVALID_PAYLOAD_FORMAT_ERROR",
"level": "error",
"extra": {
"issue_codes": [
{
"code": 1019,
"message": "Issue 1019 - The submitted payload has the incorrect format.",
}
]
},
}
]
}
def test_validate_parameters_invalid_payload_schema(self):
self.login(username="admin")
url = "api/v1/database/validate_parameters/"
payload = {"foo": "bar"}
rv = self.client.post(url, json=payload)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 422
response["errors"].sort(key=lambda error: error["extra"]["invalid"][0])
assert response == {
"errors": [
{
"message": "Missing data for required field.",
"error_type": "INVALID_PAYLOAD_SCHEMA_ERROR",
"level": "error",
"extra": {
"invalid": ["configuration_method"],
"issue_codes": [
{
"code": 1020,
"message": "Issue 1020 - The submitted payload"
" has the incorrect schema.",
}
],
},
},
{
"message": "Missing data for required field.",
"error_type": "INVALID_PAYLOAD_SCHEMA_ERROR",
"level": "error",
"extra": {
"invalid": ["engine"],
"issue_codes": [
{
"code": 1020,
"message": "Issue 1020 - The submitted payload "
"has the incorrect schema.",
}
],
},
},
]
}
def test_validate_parameters_missing_fields(self):
self.login(username="admin")
url = "api/v1/database/validate_parameters/"
payload = {
"configuration_method": ConfigurationMethod.SQLALCHEMY_FORM,
"engine": "postgresql",
"parameters": defaultdict(dict),
}
payload["parameters"].update(
{
"host": "",
"port": 5432,
"username": "",
"password": "",
"database": "",
"query": {},
}
)
rv = self.client.post(url, json=payload)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 422
assert response == {
"errors": [
{
"message": "One or more parameters are missing: database, host,"
" username",
"error_type": "CONNECTION_MISSING_PARAMETERS_ERROR",
"level": "warning",
"extra": {
"missing": ["database", "host", "username"],
"issue_codes": [
{
"code": 1018,
"message": "Issue 1018 - One or more parameters "
"needed to configure a database are missing.",
}
],
},
}
]
}
@mock.patch("superset.db_engine_specs.base.is_hostname_valid")
@mock.patch("superset.db_engine_specs.base.is_port_open")
@mock.patch("superset.databases.api.ValidateDatabaseParametersCommand")
def test_validate_parameters_valid_payload(
self, ValidateDatabaseParametersCommand, is_port_open, is_hostname_valid
):
is_hostname_valid.return_value = True
is_port_open.return_value = True
self.login(username="admin")
url = "api/v1/database/validate_parameters/"
payload = {
"engine": "postgresql",
"parameters": defaultdict(dict),
"configuration_method": ConfigurationMethod.SQLALCHEMY_FORM,
}
payload["parameters"].update(
{
"host": "localhost",
"port": 6789,
"username": "superset",
"password": "XXX",
"database": "test",
"query": {},
}
)
rv = self.client.post(url, json=payload)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
assert response == {"message": "OK"}
def test_validate_parameters_invalid_port(self):
self.login(username="admin")
url = "api/v1/database/validate_parameters/"
payload = {
"engine": "postgresql",
"parameters": defaultdict(dict),
"configuration_method": ConfigurationMethod.SQLALCHEMY_FORM,
}
payload["parameters"].update(
{
"host": "localhost",
"port": "string",
"username": "superset",
"password": "XXX",
"database": "test",
"query": {},
}
)
rv = self.client.post(url, json=payload)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 422
assert response == {
"errors": [
{
"message": "Port must be a valid integer.",
"error_type": "CONNECTION_INVALID_PORT_ERROR",
"level": "error",
"extra": {
"invalid": ["port"],
"issue_codes": [
{
"code": 1034,
"message": "Issue 1034 - The port number is invalid.",
}
],
},
},
{
"message": "The port must be an integer between "
"0 and 65535 (inclusive).",
"error_type": "CONNECTION_INVALID_PORT_ERROR",
"level": "error",
"extra": {
"invalid": ["port"],
"issue_codes": [
{
"code": 1034,
"message": "Issue 1034 - The port number is invalid.",
}
],
},
},
]
}
@mock.patch("superset.db_engine_specs.base.is_hostname_valid")
def test_validate_parameters_invalid_host(self, is_hostname_valid):
is_hostname_valid.return_value = False
self.login(username="admin")
url = "api/v1/database/validate_parameters/"
payload = {
"engine": "postgresql",
"parameters": defaultdict(dict),
"configuration_method": ConfigurationMethod.SQLALCHEMY_FORM,
}
payload["parameters"].update(
{
"host": "localhost",
"port": 5432,
"username": "",
"password": "",
"database": "",
"query": {},
}
)
rv = self.client.post(url, json=payload)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 422
assert response == {
"errors": [
{
"message": "One or more parameters are missing: database, username",
"error_type": "CONNECTION_MISSING_PARAMETERS_ERROR",
"level": "warning",
"extra": {
"missing": ["database", "username"],
"issue_codes": [
{
"code": 1018,
"message": "Issue 1018 - One or more parameters"
" needed to configure a database are missing.",
}
],
},
},
{
"message": "The hostname provided can't be resolved.",
"error_type": "CONNECTION_INVALID_HOSTNAME_ERROR",
"level": "error",
"extra": {
"invalid": ["host"],
"issue_codes": [
{
"code": 1007,
"message": "Issue 1007 - The hostname "
"provided can't be resolved.",
}
],
},
},
]
}
@mock.patch("superset.db_engine_specs.base.is_hostname_valid")
def test_validate_parameters_invalid_port_range(self, is_hostname_valid):
is_hostname_valid.return_value = True
self.login(username="admin")
url = "api/v1/database/validate_parameters/"
payload = {
"engine": "postgresql",
"parameters": defaultdict(dict),
"configuration_method": ConfigurationMethod.SQLALCHEMY_FORM,
}
payload["parameters"].update(
{
"host": "localhost",
"port": 65536,
"username": "",
"password": "",
"database": "",
"query": {},
}
)
rv = self.client.post(url, json=payload)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 422
assert response == {
"errors": [
{
"message": "One or more parameters are missing: database, username",
"error_type": "CONNECTION_MISSING_PARAMETERS_ERROR",
"level": "warning",
"extra": {
"missing": ["database", "username"],
"issue_codes": [
{
"code": 1018,
"message": "Issue 1018 - One or more parameters needed to configure a database are missing.",
}
],
},
},
{
"message": "The port must be an integer between 0 and 65535 (inclusive).",
"error_type": "CONNECTION_INVALID_PORT_ERROR",
"level": "error",
"extra": {
"invalid": ["port"],
"issue_codes": [
{
"code": 1034,
"message": "Issue 1034 - The port number is invalid.",
}
],
},
},
]
}
def test_get_related_objects(self):
example_db = get_example_database()
self.login(username="admin")
uri = f"api/v1/database/{example_db.id}/related_objects/"
rv = self.client.get(uri)
assert rv.status_code == 200
assert "charts" in rv.json
assert "dashboards" in rv.json
assert "sqllab_tab_states" in rv.json
@patch.dict(
"superset.config.SQL_VALIDATORS_BY_ENGINE",
SQL_VALIDATORS_BY_ENGINE,
clear=True,
)
def test_validate_sql(self):
"""
Database API: validate SQL success
"""
request_payload = {
"sql": "SELECT * from birth_names",
"schema": None,
"template_params": None,
}
example_db = get_example_database()
if example_db.backend not in ("presto", "postgresql"):
pytest.skip("Only presto and PG are implemented")
self.login(username="admin")
uri = f"api/v1/database/{example_db.id}/validate_sql/"
rv = self.client.post(uri, json=request_payload)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
self.assertEqual(response["result"], [])
@patch.dict(
"superset.config.SQL_VALIDATORS_BY_ENGINE",
SQL_VALIDATORS_BY_ENGINE,
clear=True,
)
def test_validate_sql_errors(self):
"""
Database API: validate SQL with errors
"""
request_payload = {
"sql": "SELECT col1 froma table1",
"schema": None,
"template_params": None,
}
example_db = get_example_database()
if example_db.backend not in ("presto", "postgresql"):
pytest.skip("Only presto and PG are implemented")
self.login(username="admin")
uri = f"api/v1/database/{example_db.id}/validate_sql/"
rv = self.client.post(uri, json=request_payload)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
self.assertEqual(
response["result"],
[
{
"end_column": None,
"line_number": 1,
"message": 'ERROR: syntax error at or near "table1"',
"start_column": None,
}
],
)
@patch.dict(
"superset.config.SQL_VALIDATORS_BY_ENGINE",
SQL_VALIDATORS_BY_ENGINE,
clear=True,
)
def test_validate_sql_not_found(self):
"""
Database API: validate SQL database not found
"""
request_payload = {
"sql": "SELECT * from birth_names",
"schema": None,
"template_params": None,
}
self.login(username="admin")
uri = (
f"api/v1/database/{self.get_nonexistent_numeric_id(Database)}/validate_sql/"
)
rv = self.client.post(uri, json=request_payload)
self.assertEqual(rv.status_code, 404)
@patch.dict(
"superset.config.SQL_VALIDATORS_BY_ENGINE",
SQL_VALIDATORS_BY_ENGINE,
clear=True,
)
def test_validate_sql_validation_fails(self):
"""
Database API: validate SQL database payload validation fails
"""
request_payload = {
"sql": None,
"schema": None,
"template_params": None,
}
self.login(username="admin")
uri = (
f"api/v1/database/{self.get_nonexistent_numeric_id(Database)}/validate_sql/"
)
rv = self.client.post(uri, json=request_payload)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 400)
self.assertEqual(response, {"message": {"sql": ["Field may not be null."]}})
@patch.dict(
"superset.config.SQL_VALIDATORS_BY_ENGINE",
{},
clear=True,
)
def test_validate_sql_endpoint_noconfig(self):
"""Assert that validate_sql_json errors out when no validators are
configured for any db"""
request_payload = {
"sql": "SELECT col1 from table1",
"schema": None,
"template_params": None,
}
self.login("admin")
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/validate_sql/"
rv = self.client.post(uri, json=request_payload)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 422)
self.assertEqual(
response,
{
"errors": [
{
"message": f"no SQL validator is configured for "
f"{example_db.backend}",
"error_type": "GENERIC_DB_ENGINE_ERROR",
"level": "error",
"extra": {
"issue_codes": [
{
"code": 1002,
"message": "Issue 1002 - The database returned an "
"unexpected error.",
}
]
},
}
]
},
)
@patch("superset.databases.commands.validate_sql.get_validator_by_name")
@patch.dict(
"superset.config.SQL_VALIDATORS_BY_ENGINE",
PRESTO_SQL_VALIDATORS_BY_ENGINE,
clear=True,
)
def test_validate_sql_endpoint_failure(self, get_validator_by_name):
"""Assert that validate_sql_json errors out when the selected validator
raises an unexpected exception"""
request_payload = {
"sql": "SELECT * FROM birth_names",
"schema": None,
"template_params": None,
}
self.login("admin")
validator = MagicMock()
get_validator_by_name.return_value = validator
validator.validate.side_effect = Exception("Kaboom!")
self.login("admin")
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/validate_sql/"
rv = self.client.post(uri, json=request_payload)
response = json.loads(rv.data.decode("utf-8"))
# TODO(bkyryliuk): properly handle hive error
if get_example_database().backend == "hive":
return
self.assertEqual(rv.status_code, 422)
self.assertIn("Kaboom!", response["errors"][0]["message"])