# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # isort:skip_file """Unit tests for Superset""" import dataclasses import json from collections import defaultdict from io import BytesIO from unittest import mock from unittest.mock import patch, MagicMock from zipfile import is_zipfile, ZipFile import prison import pytest import yaml from unittest.mock import Mock from sqlalchemy.engine.url import make_url from sqlalchemy.exc import DBAPIError from sqlalchemy.sql import func from superset import db, security_manager from superset.connectors.sqla.models import SqlaTable from superset.databases.ssh_tunnel.models import SSHTunnel from superset.databases.utils import make_url_safe from superset.db_engine_specs.mysql import MySQLEngineSpec from superset.db_engine_specs.postgres import PostgresEngineSpec from superset.db_engine_specs.redshift import RedshiftEngineSpec from superset.db_engine_specs.bigquery import BigQueryEngineSpec from superset.db_engine_specs.gsheets import GSheetsEngineSpec from superset.db_engine_specs.hana import HanaEngineSpec from superset.errors import SupersetError from superset.models.core import Database, ConfigurationMethod from superset.reports.models import ReportSchedule, ReportScheduleType from superset.utils.database import get_example_database, get_main_database from tests.integration_tests.base_tests import SupersetTestCase from tests.integration_tests.fixtures.birth_names_dashboard import ( load_birth_names_dashboard_with_slices, load_birth_names_data, ) from tests.integration_tests.fixtures.energy_dashboard import ( load_energy_table_with_slice, load_energy_table_data, ) from tests.integration_tests.fixtures.world_bank_dashboard import ( load_world_bank_dashboard_with_slices, load_world_bank_data, ) from tests.integration_tests.fixtures.importexport import ( database_config, dataset_config, database_metadata_config, dataset_metadata_config, database_with_ssh_tunnel_config_password, database_with_ssh_tunnel_config_private_key, database_with_ssh_tunnel_config_mix_credentials, database_with_ssh_tunnel_config_no_credentials, database_with_ssh_tunnel_config_private_pass_only, ) from tests.integration_tests.fixtures.unicode_dashboard import ( load_unicode_dashboard_with_position, load_unicode_data, ) from tests.integration_tests.test_app import app SQL_VALIDATORS_BY_ENGINE = { "presto": "PrestoDBSQLValidator", "postgresql": "PostgreSQLValidator", } PRESTO_SQL_VALIDATORS_BY_ENGINE = { "presto": "PrestoDBSQLValidator", "sqlite": "PrestoDBSQLValidator", "postgresql": "PrestoDBSQLValidator", "mysql": "PrestoDBSQLValidator", } class TestDatabaseApi(SupersetTestCase): def insert_database( self, database_name: str, sqlalchemy_uri: str, extra: str = "", encrypted_extra: str = "", server_cert: str = "", expose_in_sqllab: bool = False, allow_file_upload: bool = False, ) -> Database: database = Database( database_name=database_name, sqlalchemy_uri=sqlalchemy_uri, extra=extra, encrypted_extra=encrypted_extra, server_cert=server_cert, expose_in_sqllab=expose_in_sqllab, allow_file_upload=allow_file_upload, ) db.session.add(database) db.session.commit() return database @pytest.fixture() def create_database_with_report(self): with self.create_app().app_context(): example_db = get_example_database() database = self.insert_database( "database_with_report", example_db.sqlalchemy_uri_decrypted, expose_in_sqllab=True, ) report_schedule = ReportSchedule( type=ReportScheduleType.ALERT, name="report_with_database", crontab="* * * * *", database=database, ) db.session.add(report_schedule) db.session.commit() yield database # rollback changes db.session.delete(report_schedule) db.session.delete(database) db.session.commit() @pytest.fixture() def create_database_with_dataset(self): with self.create_app().app_context(): example_db = get_example_database() self._database = self.insert_database( "database_with_dataset", example_db.sqlalchemy_uri_decrypted, expose_in_sqllab=True, ) table = SqlaTable( schema="main", table_name="ab_permission", database=self._database ) db.session.add(table) db.session.commit() yield self._database # rollback changes db.session.delete(table) db.session.delete(self._database) db.session.commit() self._database = None def create_database_import(self): buf = BytesIO() with ZipFile(buf, "w") as bundle: with bundle.open("database_export/metadata.yaml", "w") as fp: fp.write(yaml.safe_dump(database_metadata_config).encode()) with bundle.open( "database_export/databases/imported_database.yaml", "w" ) as fp: fp.write(yaml.safe_dump(database_config).encode()) with bundle.open( "database_export/datasets/imported_dataset.yaml", "w" ) as fp: fp.write(yaml.safe_dump(dataset_config).encode()) buf.seek(0) return buf def test_get_items(self): """ Database API: Test get items """ self.login(username="admin") uri = "api/v1/database/" rv = self.client.get(uri) self.assertEqual(rv.status_code, 200) response = json.loads(rv.data.decode("utf-8")) expected_columns = [ "allow_ctas", "allow_cvas", "allow_dml", "allow_file_upload", "allow_run_async", "allows_cost_estimate", "allows_subquery", "allows_virtual_table_explore", "backend", "changed_on", "changed_on_delta_humanized", "created_by", "database_name", "disable_data_preview", "engine_information", "explore_database_id", "expose_in_sqllab", "extra", "force_ctas_schema", "id", "uuid", ] self.assertGreater(response["count"], 0) self.assertEqual(list(response["result"][0].keys()), expected_columns) def test_get_items_filter(self): """ Database API: Test get items with filter """ example_db = get_example_database() test_database = self.insert_database( "test-database", example_db.sqlalchemy_uri_decrypted, expose_in_sqllab=True ) dbs = db.session.query(Database).filter_by(expose_in_sqllab=True).all() self.login(username="admin") arguments = { "keys": ["none"], "filters": [{"col": "expose_in_sqllab", "opr": "eq", "value": True}], "order_columns": "database_name", "order_direction": "asc", "page": 0, "page_size": -1, } uri = f"api/v1/database/?q={prison.dumps(arguments)}" rv = self.client.get(uri) response = json.loads(rv.data.decode("utf-8")) self.assertEqual(rv.status_code, 200) self.assertEqual(response["count"], len(dbs)) # Cleanup db.session.delete(test_database) db.session.commit() def test_get_items_not_allowed(self): """ Database API: Test get items not allowed """ self.login(username="gamma") uri = "api/v1/database/" rv = self.client.get(uri) self.assertEqual(rv.status_code, 200) response = json.loads(rv.data.decode("utf-8")) self.assertEqual(response["count"], 0) def test_create_database(self): """ Database API: Test create """ extra = { "metadata_params": {}, "engine_params": {}, "metadata_cache_timeout": {}, "schemas_allowed_for_file_upload": [], } self.login(username="admin") example_db = get_example_database() if example_db.backend == "sqlite": return database_data = { "database_name": "test-create-database", "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, "configuration_method": ConfigurationMethod.SQLALCHEMY_FORM, "server_cert": None, "extra": json.dumps(extra), } uri = "api/v1/database/" rv = self.client.post(uri, json=database_data) response = json.loads(rv.data.decode("utf-8")) self.assertEqual(rv.status_code, 201) # Cleanup model = db.session.query(Database).get(response.get("id")) assert model.configuration_method == ConfigurationMethod.SQLALCHEMY_FORM db.session.delete(model) db.session.commit() @mock.patch( "superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run", ) @mock.patch("superset.databases.commands.create.is_feature_enabled") @mock.patch( "superset.models.core.Database.get_all_schema_names", ) def test_create_database_with_ssh_tunnel( self, mock_test_connection_database_command_run, mock_create_is_feature_enabled, mock_get_all_schema_names, ): """ Database API: Test create with SSH Tunnel """ mock_create_is_feature_enabled.return_value = True self.login(username="admin") example_db = get_example_database() if example_db.backend == "sqlite": return ssh_tunnel_properties = { "server_address": "123.132.123.1", "server_port": 8080, "username": "foo", "password": "bar", } database_data = { "database_name": "test-db-with-ssh-tunnel", "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, "ssh_tunnel": ssh_tunnel_properties, } uri = "api/v1/database/" rv = self.client.post(uri, json=database_data) response = json.loads(rv.data.decode("utf-8")) self.assertEqual(rv.status_code, 201) model_ssh_tunnel = ( db.session.query(SSHTunnel) .filter(SSHTunnel.database_id == response.get("id")) .one() ) self.assertEqual(response.get("result")["ssh_tunnel"]["password"], "XXXXXXXXXX") self.assertEqual(model_ssh_tunnel.database_id, response.get("id")) # Cleanup model = db.session.query(Database).get(response.get("id")) db.session.delete(model) db.session.commit() @mock.patch( "superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run", ) @mock.patch("superset.databases.commands.create.is_feature_enabled") @mock.patch("superset.databases.commands.update.is_feature_enabled") @mock.patch( "superset.models.core.Database.get_all_schema_names", ) def test_update_database_with_ssh_tunnel( self, mock_test_connection_database_command_run, mock_create_is_feature_enabled, mock_update_is_feature_enabled, mock_get_all_schema_names, ): """ Database API: Test update Database with SSH Tunnel """ mock_create_is_feature_enabled.return_value = True mock_update_is_feature_enabled.return_value = True self.login(username="admin") example_db = get_example_database() if example_db.backend == "sqlite": return ssh_tunnel_properties = { "server_address": "123.132.123.1", "server_port": 8080, "username": "foo", "password": "bar", } database_data = { "database_name": "test-db-with-ssh-tunnel", "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, } database_data_with_ssh_tunnel = { "database_name": "test-db-with-ssh-tunnel", "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, "ssh_tunnel": ssh_tunnel_properties, } uri = "api/v1/database/" rv = self.client.post(uri, json=database_data) response = json.loads(rv.data.decode("utf-8")) self.assertEqual(rv.status_code, 201) uri = "api/v1/database/{}".format(response.get("id")) rv = self.client.put(uri, json=database_data_with_ssh_tunnel) response_update = json.loads(rv.data.decode("utf-8")) self.assertEqual(rv.status_code, 200) model_ssh_tunnel = ( db.session.query(SSHTunnel) .filter(SSHTunnel.database_id == response_update.get("id")) .one() ) self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id")) # Cleanup model = db.session.query(Database).get(response.get("id")) db.session.delete(model) db.session.commit() @mock.patch( "superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run", ) @mock.patch("superset.databases.commands.create.is_feature_enabled") @mock.patch("superset.databases.commands.update.is_feature_enabled") @mock.patch( "superset.models.core.Database.get_all_schema_names", ) def test_update_ssh_tunnel_via_database_api( self, mock_test_connection_database_command_run, mock_create_is_feature_enabled, mock_update_is_feature_enabled, mock_get_all_schema_names, ): """ Database API: Test update SSH Tunnel via Database API """ mock_create_is_feature_enabled.return_value = True mock_update_is_feature_enabled.return_value = True self.login(username="admin") example_db = get_example_database() if example_db.backend == "sqlite": return initial_ssh_tunnel_properties = { "server_address": "123.132.123.1", "server_port": 8080, "username": "foo", "password": "bar", } updated_ssh_tunnel_properties = { "server_address": "123.132.123.1", "server_port": 8080, "username": "Test", "password": "new_bar", } database_data_with_ssh_tunnel = { "database_name": "test-db-with-ssh-tunnel", "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, "ssh_tunnel": initial_ssh_tunnel_properties, } database_data_with_ssh_tunnel_update = { "database_name": "test-db-with-ssh-tunnel", "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, "ssh_tunnel": updated_ssh_tunnel_properties, } uri = "api/v1/database/" rv = self.client.post(uri, json=database_data_with_ssh_tunnel) response = json.loads(rv.data.decode("utf-8")) self.assertEqual(rv.status_code, 201) model_ssh_tunnel = ( db.session.query(SSHTunnel) .filter(SSHTunnel.database_id == response.get("id")) .one() ) self.assertEqual(model_ssh_tunnel.database_id, response.get("id")) self.assertEqual(model_ssh_tunnel.username, "foo") uri = "api/v1/database/{}".format(response.get("id")) rv = self.client.put(uri, json=database_data_with_ssh_tunnel_update) response_update = json.loads(rv.data.decode("utf-8")) self.assertEqual(rv.status_code, 200) model_ssh_tunnel = ( db.session.query(SSHTunnel) .filter(SSHTunnel.database_id == response_update.get("id")) .one() ) self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id")) self.assertEqual( response_update.get("result")["ssh_tunnel"]["password"], "XXXXXXXXXX" ) self.assertEqual(model_ssh_tunnel.username, "Test") self.assertEqual(model_ssh_tunnel.server_address, "123.132.123.1") self.assertEqual(model_ssh_tunnel.server_port, 8080) # Cleanup model = db.session.query(Database).get(response.get("id")) db.session.delete(model) db.session.commit() @mock.patch( "superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run", ) @mock.patch( "superset.models.core.Database.get_all_schema_names", ) @mock.patch("superset.databases.commands.create.is_feature_enabled") def test_cascade_delete_ssh_tunnel( self, mock_test_connection_database_command_run, mock_get_all_schema_names, mock_create_is_feature_enabled, ): """ Database API: SSH Tunnel gets deleted if Database gets deleted """ mock_create_is_feature_enabled.return_value = True self.login(username="admin") example_db = get_example_database() if example_db.backend == "sqlite": return ssh_tunnel_properties = { "server_address": "123.132.123.1", "server_port": 8080, "username": "foo", "password": "bar", } database_data = { "database_name": "test-db-with-ssh-tunnel", "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, "ssh_tunnel": ssh_tunnel_properties, } uri = "api/v1/database/" rv = self.client.post(uri, json=database_data) response = json.loads(rv.data.decode("utf-8")) self.assertEqual(rv.status_code, 201) model_ssh_tunnel = ( db.session.query(SSHTunnel) .filter(SSHTunnel.database_id == response.get("id")) .one() ) self.assertEqual(model_ssh_tunnel.database_id, response.get("id")) # Cleanup model = db.session.query(Database).get(response.get("id")) db.session.delete(model) db.session.commit() model_ssh_tunnel = ( db.session.query(SSHTunnel) .filter(SSHTunnel.database_id == response.get("id")) .one_or_none() ) assert model_ssh_tunnel is None @mock.patch( "superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run", ) @mock.patch("superset.databases.commands.create.is_feature_enabled") @mock.patch( "superset.models.core.Database.get_all_schema_names", ) def test_do_not_create_database_if_ssh_tunnel_creation_fails( self, mock_test_connection_database_command_run, mock_create_is_feature_enabled, mock_get_all_schema_names, ): """ Database API: Test Database is not created if SSH Tunnel creation fails """ mock_create_is_feature_enabled.return_value = True self.login(username="admin") example_db = get_example_database() if example_db.backend == "sqlite": return ssh_tunnel_properties = { "server_address": "123.132.123.1", } database_data = { "database_name": "test-db-failure-ssh-tunnel", "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, "ssh_tunnel": ssh_tunnel_properties, } fail_message = {"message": "SSH Tunnel parameters are invalid."} uri = "api/v1/database/" rv = self.client.post(uri, json=database_data) response = json.loads(rv.data.decode("utf-8")) self.assertEqual(rv.status_code, 422) model_ssh_tunnel = ( db.session.query(SSHTunnel) .filter(SSHTunnel.database_id == response.get("id")) .one_or_none() ) assert model_ssh_tunnel is None self.assertEqual(response, fail_message) # Cleanup model = ( db.session.query(Database) .filter(Database.database_name == "test-db-failure-ssh-tunnel") .one_or_none() ) # the DB should not be created assert model is None @mock.patch( "superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run", ) @mock.patch("superset.databases.commands.create.is_feature_enabled") @mock.patch( "superset.models.core.Database.get_all_schema_names", ) def test_get_database_returns_related_ssh_tunnel( self, mock_test_connection_database_command_run, mock_create_is_feature_enabled, mock_get_all_schema_names, ): """ Database API: Test GET Database returns its related SSH Tunnel """ mock_create_is_feature_enabled.return_value = True self.login(username="admin") example_db = get_example_database() if example_db.backend == "sqlite": return ssh_tunnel_properties = { "server_address": "123.132.123.1", "server_port": 8080, "username": "foo", "password": "bar", } database_data = { "database_name": "test-db-with-ssh-tunnel", "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, "ssh_tunnel": ssh_tunnel_properties, } response_ssh_tunnel = { "server_address": "123.132.123.1", "server_port": 8080, "username": "foo", "password": "XXXXXXXXXX", } uri = "api/v1/database/" rv = self.client.post(uri, json=database_data) response = json.loads(rv.data.decode("utf-8")) self.assertEqual(rv.status_code, 201) model_ssh_tunnel = ( db.session.query(SSHTunnel) .filter(SSHTunnel.database_id == response.get("id")) .one() ) self.assertEqual(model_ssh_tunnel.database_id, response.get("id")) self.assertEqual(response.get("result")["ssh_tunnel"], response_ssh_tunnel) # Cleanup model = db.session.query(Database).get(response.get("id")) db.session.delete(model) db.session.commit() @mock.patch( "superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run", ) @mock.patch( "superset.models.core.Database.get_all_schema_names", ) def test_if_ssh_tunneling_flag_is_not_active_it_raises_new_exception( self, mock_test_connection_database_command_run, mock_get_all_schema_names, ): """ Database API: Test raises SSHTunneling feature flag not enabled """ self.login(username="admin") example_db = get_example_database() if example_db.backend == "sqlite": return ssh_tunnel_properties = { "server_address": "123.132.123.1", "server_port": 8080, "username": "foo", "password": "bar", } database_data = { "database_name": "test-db-with-ssh-tunnel-7", "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, "ssh_tunnel": ssh_tunnel_properties, } uri = "api/v1/database/" rv = self.client.post(uri, json=database_data) response = json.loads(rv.data.decode("utf-8")) self.assertEqual(rv.status_code, 400) self.assertEqual(response, {"message": "SSH Tunneling is not enabled"}) model_ssh_tunnel = ( db.session.query(SSHTunnel) .filter(SSHTunnel.database_id == response.get("id")) .one_or_none() ) assert model_ssh_tunnel is None # Cleanup model = ( db.session.query(Database) .filter(Database.database_name == "test-db-with-ssh-tunnel-7") .one_or_none() ) # the DB should not be created assert model is None def test_create_database_invalid_configuration_method(self): """ Database API: Test create with an invalid configuration method. """ extra = { "metadata_params": {}, "engine_params": {}, "metadata_cache_timeout": {}, "schemas_allowed_for_file_upload": [], } self.login(username="admin") example_db = get_example_database() if example_db.backend == "sqlite": return database_data = { "database_name": "test-create-database", "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, "configuration_method": "BAD_FORM", "server_cert": None, "extra": json.dumps(extra), } uri = "api/v1/database/" rv = self.client.post(uri, json=database_data) response = json.loads(rv.data.decode("utf-8")) assert response == { "message": { "configuration_method": [ "Must be one of: sqlalchemy_form, dynamic_form." ] } } assert rv.status_code == 400 def test_create_database_no_configuration_method(self): """ Database API: Test create with no config method. """ extra = { "metadata_params": {}, "engine_params": {}, "metadata_cache_timeout": {}, "schemas_allowed_for_file_upload": [], } self.login(username="admin") example_db = get_example_database() if example_db.backend == "sqlite": return database_data = { "database_name": "test-create-database", "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, "server_cert": None, "extra": json.dumps(extra), } uri = "api/v1/database/" rv = self.client.post(uri, json=database_data) response = json.loads(rv.data.decode("utf-8")) assert rv.status_code == 201 self.assertIn("sqlalchemy_form", response["result"]["configuration_method"]) def test_create_database_server_cert_validate(self): """ Database API: Test create server cert validation """ example_db = get_example_database() if example_db.backend == "sqlite": return self.login(username="admin") database_data = { "database_name": "test-create-database-invalid-cert", "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, "configuration_method": ConfigurationMethod.SQLALCHEMY_FORM, "server_cert": "INVALID CERT", } uri = "api/v1/database/" rv = self.client.post(uri, json=database_data) response = json.loads(rv.data.decode("utf-8")) expected_response = {"message": {"server_cert": ["Invalid certificate"]}} self.assertEqual(rv.status_code, 400) self.assertEqual(response, expected_response) def test_create_database_json_validate(self): """ Database API: Test create encrypted extra and extra validation """ example_db = get_example_database() if example_db.backend == "sqlite": return self.login(username="admin") database_data = { "database_name": "test-create-database-invalid-json", "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, "configuration_method": ConfigurationMethod.SQLALCHEMY_FORM, "masked_encrypted_extra": '{"A": "a", "B", "C"}', "extra": '["A": "a", "B", "C"]', } uri = "api/v1/database/" rv = self.client.post(uri, json=database_data) response = json.loads(rv.data.decode("utf-8")) expected_response = { "message": { "masked_encrypted_extra": [ "Field cannot be decoded by JSON. Expecting ':' " "delimiter: line 1 column 15 (char 14)" ], "extra": [ "Field cannot be decoded by JSON. Expecting ','" " delimiter: line 1 column 5 (char 4)" ], } } self.assertEqual(rv.status_code, 400) self.assertEqual(response, expected_response) def test_create_database_extra_metadata_validate(self): """ Database API: Test create extra metadata_params validation """ example_db = get_example_database() if example_db.backend == "sqlite": return extra = { "metadata_params": {"wrong_param": "some_value"}, "engine_params": {}, "metadata_cache_timeout": {}, "schemas_allowed_for_file_upload": [], } self.login(username="admin") database_data = { "database_name": "test-create-database-invalid-extra", "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, "configuration_method": ConfigurationMethod.SQLALCHEMY_FORM, "extra": json.dumps(extra), } uri = "api/v1/database/" rv = self.client.post(uri, json=database_data) response = json.loads(rv.data.decode("utf-8")) expected_response = { "message": { "extra": [ "The metadata_params in Extra field is not configured correctly." " The key wrong_param is invalid." ] } } self.assertEqual(rv.status_code, 400) self.assertEqual(response, expected_response) def test_create_database_unique_validate(self): """ Database API: Test create database_name already exists """ example_db = get_example_database() if example_db.backend == "sqlite": return self.login(username="admin") database_data = { "database_name": "examples", "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, "configuration_method": ConfigurationMethod.SQLALCHEMY_FORM, } uri = "api/v1/database/" rv = self.client.post(uri, json=database_data) response = json.loads(rv.data.decode("utf-8")) expected_response = { "message": { "database_name": "A database with the same name already exists." } } self.assertEqual(rv.status_code, 422) self.assertEqual(response, expected_response) def test_create_database_uri_validate(self): """ Database API: Test create fail validate sqlalchemy uri """ self.login(username="admin") database_data = { "database_name": "test-database-invalid-uri", "sqlalchemy_uri": "wrong_uri", "configuration_method": ConfigurationMethod.SQLALCHEMY_FORM, } uri = "api/v1/database/" rv = self.client.post(uri, json=database_data) response = json.loads(rv.data.decode("utf-8")) self.assertEqual(rv.status_code, 400) self.assertIn( "Invalid connection string", response["message"]["sqlalchemy_uri"][0], ) @mock.patch( "superset.views.core.app.config", {**app.config, "PREVENT_UNSAFE_DB_CONNECTIONS": True}, ) def test_create_database_fail_sqlite(self): """ Database API: Test create fail with sqlite """ database_data = { "database_name": "test-create-sqlite-database", "sqlalchemy_uri": "sqlite:////some.db", "configuration_method": ConfigurationMethod.SQLALCHEMY_FORM, } uri = "api/v1/database/" self.login(username="admin") response = self.client.post(uri, json=database_data) response_data = json.loads(response.data.decode("utf-8")) expected_response = { "message": { "sqlalchemy_uri": [ "SQLiteDialect_pysqlite cannot be used as a data source " "for security reasons." ] } } self.assertEqual(response_data, expected_response) self.assertEqual(response.status_code, 400) def test_create_database_conn_fail(self): """ Database API: Test create fails connection """ example_db = get_example_database() if example_db.backend in ("sqlite", "hive", "presto"): return example_db.password = "wrong_password" database_data = { "database_name": "test-create-database-wrong-password", "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, "configuration_method": ConfigurationMethod.SQLALCHEMY_FORM, } uri = "api/v1/database/" self.login(username="admin") response = self.client.post(uri, json=database_data) response_data = json.loads(response.data.decode("utf-8")) superset_error_mysql = SupersetError( message='Either the username "superset" or the password is incorrect.', error_type="CONNECTION_ACCESS_DENIED_ERROR", level="error", extra={ "engine_name": "MySQL", "invalid": ["username", "password"], "issue_codes": [ { "code": 1014, "message": ( "Issue 1014 - Either the username or the password is wrong." ), }, { "code": 1015, "message": ( "Issue 1015 - Issue 1015 - Either the database is spelled incorrectly or does not exist." ), }, ], }, ) superset_error_postgres = SupersetError( message='The password provided for username "superset" is incorrect.', error_type="CONNECTION_INVALID_PASSWORD_ERROR", level="error", extra={ "engine_name": "PostgreSQL", "invalid": ["username", "password"], "issue_codes": [ { "code": 1013, "message": ( "Issue 1013 - The password provided when connecting to a database is not valid." ), } ], }, ) expected_response_mysql = {"errors": [dataclasses.asdict(superset_error_mysql)]} expected_response_postgres = { "errors": [dataclasses.asdict(superset_error_postgres)] } self.assertEqual(response.status_code, 500) if example_db.backend == "mysql": self.assertEqual(response_data, expected_response_mysql) else: self.assertEqual(response_data, expected_response_postgres) def test_update_database(self): """ Database API: Test update """ example_db = get_example_database() test_database = self.insert_database( "test-database", example_db.sqlalchemy_uri_decrypted ) self.login(username="admin") database_data = { "database_name": "test-database-updated", "configuration_method": ConfigurationMethod.SQLALCHEMY_FORM, } uri = f"api/v1/database/{test_database.id}" rv = self.client.put(uri, json=database_data) self.assertEqual(rv.status_code, 200) # Cleanup model = db.session.query(Database).get(test_database.id) db.session.delete(model) db.session.commit() def test_update_database_conn_fail(self): """ Database API: Test update fails connection """ example_db = get_example_database() if example_db.backend in ("sqlite", "hive", "presto"): return test_database = self.insert_database( "test-database1", example_db.sqlalchemy_uri_decrypted ) example_db.password = "wrong_password" database_data = { "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, } uri = f"api/v1/database/{test_database.id}" self.login(username="admin") rv = self.client.put(uri, json=database_data) response = json.loads(rv.data.decode("utf-8")) expected_response = { "message": "Connection failed, please check your connection settings" } self.assertEqual(rv.status_code, 422) self.assertEqual(response, expected_response) # Cleanup model = db.session.query(Database).get(test_database.id) db.session.delete(model) db.session.commit() def test_update_database_uniqueness(self): """ Database API: Test update uniqueness """ example_db = get_example_database() test_database1 = self.insert_database( "test-database1", example_db.sqlalchemy_uri_decrypted ) test_database2 = self.insert_database( "test-database2", example_db.sqlalchemy_uri_decrypted ) self.login(username="admin") database_data = {"database_name": "test-database2"} uri = f"api/v1/database/{test_database1.id}" rv = self.client.put(uri, json=database_data) response = json.loads(rv.data.decode("utf-8")) expected_response = { "message": { "database_name": "A database with the same name already exists." } } self.assertEqual(rv.status_code, 422) self.assertEqual(response, expected_response) # Cleanup db.session.delete(test_database1) db.session.delete(test_database2) db.session.commit() def test_update_database_invalid(self): """ Database API: Test update invalid request """ self.login(username="admin") database_data = {"database_name": "test-database-updated"} uri = "api/v1/database/invalid" rv = self.client.put(uri, json=database_data) self.assertEqual(rv.status_code, 404) def test_update_database_uri_validate(self): """ Database API: Test update sqlalchemy_uri validate """ example_db = get_example_database() test_database = self.insert_database( "test-database", example_db.sqlalchemy_uri_decrypted ) self.login(username="admin") database_data = { "database_name": "test-database-updated", "sqlalchemy_uri": "wrong_uri", } uri = f"api/v1/database/{test_database.id}" rv = self.client.put(uri, json=database_data) response = json.loads(rv.data.decode("utf-8")) self.assertEqual(rv.status_code, 400) self.assertIn( "Invalid connection string", response["message"]["sqlalchemy_uri"][0], ) db.session.delete(test_database) db.session.commit() def test_update_database_with_invalid_configuration_method(self): """ Database API: Test update """ example_db = get_example_database() test_database = self.insert_database( "test-database", example_db.sqlalchemy_uri_decrypted ) self.login(username="admin") database_data = { "database_name": "test-database-updated", "configuration_method": "BAD_FORM", } uri = f"api/v1/database/{test_database.id}" rv = self.client.put(uri, json=database_data) response = json.loads(rv.data.decode("utf-8")) assert response == { "message": { "configuration_method": [ "Must be one of: sqlalchemy_form, dynamic_form." ] } } assert rv.status_code == 400 db.session.delete(test_database) db.session.commit() def test_update_database_with_no_configuration_method(self): """ Database API: Test update """ example_db = get_example_database() test_database = self.insert_database( "test-database", example_db.sqlalchemy_uri_decrypted ) self.login(username="admin") database_data = { "database_name": "test-database-updated", } uri = f"api/v1/database/{test_database.id}" rv = self.client.put(uri, json=database_data) assert rv.status_code == 200 db.session.delete(test_database) db.session.commit() def test_delete_database(self): """ Database API: Test delete """ database_id = self.insert_database("test-database", "test_uri").id self.login(username="admin") uri = f"api/v1/database/{database_id}" rv = self.delete_assert_metric(uri, "delete") self.assertEqual(rv.status_code, 200) model = db.session.query(Database).get(database_id) self.assertEqual(model, None) def test_delete_database_not_found(self): """ Database API: Test delete not found """ max_id = db.session.query(func.max(Database.id)).scalar() self.login(username="admin") uri = f"api/v1/database/{max_id + 1}" rv = self.delete_assert_metric(uri, "delete") self.assertEqual(rv.status_code, 404) @pytest.mark.usefixtures("create_database_with_dataset") def test_delete_database_with_datasets(self): """ Database API: Test delete fails because it has depending datasets """ self.login(username="admin") uri = f"api/v1/database/{self._database.id}" rv = self.delete_assert_metric(uri, "delete") self.assertEqual(rv.status_code, 422) @pytest.mark.usefixtures("create_database_with_report") def test_delete_database_with_report(self): """ Database API: Test delete with associated report """ self.login(username="admin") database = ( db.session.query(Database) .filter(Database.database_name == "database_with_report") .one_or_none() ) uri = f"api/v1/database/{database.id}" rv = self.client.delete(uri) response = json.loads(rv.data.decode("utf-8")) self.assertEqual(rv.status_code, 422) expected_response = { "message": "There are associated alerts or reports: report_with_database" } self.assertEqual(response, expected_response) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_get_table_metadata(self): """ Database API: Test get table metadata info """ example_db = get_example_database() self.login(username="admin") uri = f"api/v1/database/{example_db.id}/table/birth_names/null/" rv = self.client.get(uri) self.assertEqual(rv.status_code, 200) response = json.loads(rv.data.decode("utf-8")) self.assertEqual(response["name"], "birth_names") self.assertIsNone(response["comment"]) self.assertTrue(len(response["columns"]) > 5) self.assertTrue(response.get("selectStar").startswith("SELECT")) def test_info_security_database(self): """ Database API: Test info security """ self.login(username="admin") params = {"keys": ["permissions"]} uri = f"api/v1/database/_info?q={prison.dumps(params)}" rv = self.get_assert_metric(uri, "info") data = json.loads(rv.data.decode("utf-8")) assert rv.status_code == 200 assert set(data["permissions"]) == {"can_read", "can_write", "can_export"} def test_get_invalid_database_table_metadata(self): """ Database API: Test get invalid database from table metadata """ database_id = 1000 self.login(username="admin") uri = f"api/v1/database/{database_id}/table/some_table/some_schema/" rv = self.client.get(uri) self.assertEqual(rv.status_code, 404) uri = "api/v1/database/some_database/table/some_table/some_schema/" rv = self.client.get(uri) self.assertEqual(rv.status_code, 404) def test_get_invalid_table_table_metadata(self): """ Database API: Test get invalid table from table metadata """ example_db = get_example_database() uri = f"api/v1/database/{example_db.id}/table/wrong_table/null/" self.login(username="admin") rv = self.client.get(uri) data = json.loads(rv.data.decode("utf-8")) if example_db.backend == "sqlite": self.assertEqual(rv.status_code, 200) self.assertEqual( data, { "columns": [], "comment": None, "foreignKeys": [], "indexes": [], "name": "wrong_table", "primaryKey": {"constrained_columns": None, "name": None}, "selectStar": "SELECT\nFROM wrong_table\nLIMIT 100\nOFFSET 0", }, ) elif example_db.backend == "mysql": self.assertEqual(rv.status_code, 422) self.assertEqual(data, {"message": "`wrong_table`"}) else: self.assertEqual(rv.status_code, 422) self.assertEqual(data, {"message": "wrong_table"}) def test_get_table_metadata_no_db_permission(self): """ Database API: Test get table metadata from not permitted db """ self.login(username="gamma") example_db = get_example_database() uri = f"api/v1/database/{example_db.id}/birth_names/null/" rv = self.client.get(uri) self.assertEqual(rv.status_code, 404) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_get_table_extra_metadata(self): """ Database API: Test get table extra metadata info """ example_db = get_example_database() self.login(username="admin") uri = f"api/v1/database/{example_db.id}/table_extra/birth_names/null/" rv = self.client.get(uri) self.assertEqual(rv.status_code, 200) response = json.loads(rv.data.decode("utf-8")) self.assertEqual(response, {}) def test_get_invalid_database_table_extra_metadata(self): """ Database API: Test get invalid database from table extra metadata """ database_id = 1000 self.login(username="admin") uri = f"api/v1/database/{database_id}/table_extra/some_table/some_schema/" rv = self.client.get(uri) self.assertEqual(rv.status_code, 404) uri = "api/v1/database/some_database/table_extra/some_table/some_schema/" rv = self.client.get(uri) self.assertEqual(rv.status_code, 404) def test_get_invalid_table_table_extra_metadata(self): """ Database API: Test get invalid table from table extra metadata """ example_db = get_example_database() uri = f"api/v1/database/{example_db.id}/table_extra/wrong_table/null/" self.login(username="admin") rv = self.client.get(uri) data = json.loads(rv.data.decode("utf-8")) self.assertEqual(rv.status_code, 200) self.assertEqual(data, {}) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_get_select_star(self): """ Database API: Test get select star """ self.login(username="admin") example_db = get_example_database() uri = f"api/v1/database/{example_db.id}/select_star/birth_names/" rv = self.client.get(uri) self.assertEqual(rv.status_code, 200) response = json.loads(rv.data.decode("utf-8")) self.assertIn("gender", response["result"]) def test_get_select_star_not_allowed(self): """ Database API: Test get select star not allowed """ self.login(username="gamma") example_db = get_example_database() uri = f"api/v1/database/{example_db.id}/select_star/birth_names/" rv = self.client.get(uri) self.assertEqual(rv.status_code, 404) def test_get_select_star_datasource_access(self): """ Database API: Test get select star with datasource access """ session = db.session table = SqlaTable( schema="main", table_name="ab_permission", database=get_main_database() ) session.add(table) session.commit() tmp_table_perm = security_manager.find_permission_view_menu( "datasource_access", table.get_perm() ) gamma_role = security_manager.find_role("Gamma") security_manager.add_permission_role(gamma_role, tmp_table_perm) self.login(username="gamma") main_db = get_main_database() uri = f"api/v1/database/{main_db.id}/select_star/ab_permission/" rv = self.client.get(uri) self.assertEqual(rv.status_code, 200) # rollback changes security_manager.del_permission_role(gamma_role, tmp_table_perm) db.session.delete(table) db.session.delete(main_db) db.session.commit() def test_get_select_star_not_found_database(self): """ Database API: Test get select star not found database """ self.login(username="admin") max_id = db.session.query(func.max(Database.id)).scalar() uri = f"api/v1/database/{max_id + 1}/select_star/birth_names/" rv = self.client.get(uri) self.assertEqual(rv.status_code, 404) def test_get_select_star_not_found_table(self): """ Database API: Test get select star not found database """ self.login(username="admin") example_db = get_example_database() # sqlite will not raise a NoSuchTableError if example_db.backend == "sqlite": return uri = f"api/v1/database/{example_db.id}/select_star/table_does_not_exist/" rv = self.client.get(uri) # TODO(bkyryliuk): investigate why presto returns 500 self.assertEqual(rv.status_code, 404 if example_db.backend != "presto" else 500) def test_get_allow_file_upload_filter(self): """ Database API: Test filter for allow file upload checks for schemas """ with self.create_app().app_context(): example_db = get_example_database() extra = { "metadata_params": {}, "engine_params": {}, "metadata_cache_timeout": {}, "schemas_allowed_for_file_upload": ["public"], } self.login(username="admin") database = self.insert_database( "database_with_upload", example_db.sqlalchemy_uri_decrypted, extra=json.dumps(extra), allow_file_upload=True, ) db.session.commit() yield database arguments = { "columns": ["allow_file_upload"], "filters": [ { "col": "allow_file_upload", "opr": "upload_is_enabled", "value": True, } ], } uri = f"api/v1/database/?q={prison.dumps(arguments)}" rv = self.client.get(uri) data = json.loads(rv.data.decode("utf-8")) assert data["count"] == 1 db.session.delete(database) db.session.commit() def test_get_allow_file_upload_filter_no_schema(self): """ Database API: Test filter for allow file upload checks for schemas. This test has allow_file_upload but no schemas. """ with self.create_app().app_context(): example_db = get_example_database() extra = { "metadata_params": {}, "engine_params": {}, "metadata_cache_timeout": {}, "schemas_allowed_for_file_upload": [], } self.login(username="admin") database = self.insert_database( "database_with_upload", example_db.sqlalchemy_uri_decrypted, extra=json.dumps(extra), allow_file_upload=True, ) db.session.commit() yield database arguments = { "columns": ["allow_file_upload"], "filters": [ { "col": "allow_file_upload", "opr": "upload_is_enabled", "value": True, } ], } uri = f"api/v1/database/?q={prison.dumps(arguments)}" rv = self.client.get(uri) data = json.loads(rv.data.decode("utf-8")) assert data["count"] == 0 db.session.delete(database) db.session.commit() def test_get_allow_file_upload_filter_allow_file_false(self): """ Database API: Test filter for allow file upload checks for schemas. This has a schema but does not allow_file_upload """ with self.create_app().app_context(): example_db = get_example_database() extra = { "metadata_params": {}, "engine_params": {}, "metadata_cache_timeout": {}, "schemas_allowed_for_file_upload": ["public"], } self.login(username="admin") database = self.insert_database( "database_with_upload", example_db.sqlalchemy_uri_decrypted, extra=json.dumps(extra), allow_file_upload=False, ) db.session.commit() yield database arguments = { "columns": ["allow_file_upload"], "filters": [ { "col": "allow_file_upload", "opr": "upload_is_enabled", "value": True, } ], } uri = f"api/v1/database/?q={prison.dumps(arguments)}" rv = self.client.get(uri) data = json.loads(rv.data.decode("utf-8")) assert data["count"] == 0 db.session.delete(database) db.session.commit() def test_get_allow_file_upload_false(self): """ Database API: Test filter for allow file upload checks for schemas. Both databases have false allow_file_upload """ with self.create_app().app_context(): example_db = get_example_database() extra = { "metadata_params": {}, "engine_params": {}, "metadata_cache_timeout": {}, "schemas_allowed_for_file_upload": [], } self.login(username="admin") database = self.insert_database( "database_with_upload", example_db.sqlalchemy_uri_decrypted, extra=json.dumps(extra), allow_file_upload=False, ) db.session.commit() yield database arguments = { "columns": ["allow_file_upload"], "filters": [ { "col": "allow_file_upload", "opr": "upload_is_enabled", "value": True, } ], } uri = f"api/v1/database/?q={prison.dumps(arguments)}" rv = self.client.get(uri) data = json.loads(rv.data.decode("utf-8")) assert data["count"] == 0 db.session.delete(database) db.session.commit() def test_get_allow_file_upload_false_no_extra(self): """ Database API: Test filter for allow file upload checks for schemas. Both databases have false allow_file_upload """ with self.create_app().app_context(): example_db = get_example_database() self.login(username="admin") database = self.insert_database( "database_with_upload", example_db.sqlalchemy_uri_decrypted, allow_file_upload=False, ) db.session.commit() yield database arguments = { "columns": ["allow_file_upload"], "filters": [ { "col": "allow_file_upload", "opr": "upload_is_enabled", "value": True, } ], } uri = f"api/v1/database/?q={prison.dumps(arguments)}" rv = self.client.get(uri) data = json.loads(rv.data.decode("utf-8")) assert data["count"] == 0 db.session.delete(database) db.session.commit() def mock_csv_function(d, user): return d.get_all_schema_names() @mock.patch( "superset.views.core.app.config", {**app.config, "ALLOWED_USER_CSV_SCHEMA_FUNC": mock_csv_function}, ) def test_get_allow_file_upload_true_csv(self): """ Database API: Test filter for allow file upload checks for schemas. Both databases have false allow_file_upload """ with self.create_app().app_context(): example_db = get_example_database() extra = { "metadata_params": {}, "engine_params": {}, "metadata_cache_timeout": {}, "schemas_allowed_for_file_upload": [], } self.login(username="admin") database = self.insert_database( "database_with_upload", example_db.sqlalchemy_uri_decrypted, extra=json.dumps(extra), allow_file_upload=True, ) db.session.commit() yield database arguments = { "columns": ["allow_file_upload"], "filters": [ { "col": "allow_file_upload", "opr": "upload_is_enabled", "value": True, } ], } uri = f"api/v1/database/?q={prison.dumps(arguments)}" rv = self.client.get(uri) data = json.loads(rv.data.decode("utf-8")) assert data["count"] == 1 db.session.delete(database) db.session.commit() def mock_empty_csv_function(d, user): return [] @mock.patch( "superset.views.core.app.config", {**app.config, "ALLOWED_USER_CSV_SCHEMA_FUNC": mock_empty_csv_function}, ) def test_get_allow_file_upload_false_csv(self): """ Database API: Test filter for allow file upload checks for schemas. Both databases have false allow_file_upload """ with self.create_app().app_context(): self.login(username="admin") arguments = { "columns": ["allow_file_upload"], "filters": [ { "col": "allow_file_upload", "opr": "upload_is_enabled", "value": True, } ], } uri = f"api/v1/database/?q={prison.dumps(arguments)}" rv = self.client.get(uri) data = json.loads(rv.data.decode("utf-8")) assert data["count"] == 1 def test_get_allow_file_upload_filter_no_permission(self): """ Database API: Test filter for allow file upload checks for schemas """ with self.create_app().app_context(): example_db = get_example_database() extra = { "metadata_params": {}, "engine_params": {}, "metadata_cache_timeout": {}, "schemas_allowed_for_file_upload": ["public"], } self.login(username="gamma") database = self.insert_database( "database_with_upload", example_db.sqlalchemy_uri_decrypted, extra=json.dumps(extra), allow_file_upload=True, ) db.session.commit() yield database arguments = { "columns": ["allow_file_upload"], "filters": [ { "col": "allow_file_upload", "opr": "upload_is_enabled", "value": True, } ], } uri = f"api/v1/database/?q={prison.dumps(arguments)}" rv = self.client.get(uri) data = json.loads(rv.data.decode("utf-8")) assert data["count"] == 0 db.session.delete(database) db.session.commit() def test_get_allow_file_upload_filter_with_permission(self): """ Database API: Test filter for allow file upload checks for schemas """ with self.create_app().app_context(): main_db = get_main_database() main_db.allow_file_upload = True session = db.session table = SqlaTable( schema="public", table_name="ab_permission", database=get_main_database(), ) session.add(table) session.commit() tmp_table_perm = security_manager.find_permission_view_menu( "datasource_access", table.get_perm() ) gamma_role = security_manager.find_role("Gamma") security_manager.add_permission_role(gamma_role, tmp_table_perm) self.login(username="gamma") arguments = { "columns": ["allow_file_upload"], "filters": [ { "col": "allow_file_upload", "opr": "upload_is_enabled", "value": True, } ], } uri = f"api/v1/database/?q={prison.dumps(arguments)}" rv = self.client.get(uri) data = json.loads(rv.data.decode("utf-8")) assert data["count"] == 1 # rollback changes security_manager.del_permission_role(gamma_role, tmp_table_perm) db.session.delete(table) db.session.delete(main_db) db.session.commit() def test_database_schemas(self): """ Database API: Test database schemas """ self.login(username="admin") database = db.session.query(Database).filter_by(database_name="examples").one() schemas = database.get_all_schema_names() rv = self.client.get(f"api/v1/database/{database.id}/schemas/") response = json.loads(rv.data.decode("utf-8")) self.assertEqual(schemas, response["result"]) rv = self.client.get( f"api/v1/database/{database.id}/schemas/?q={prison.dumps({'force': True})}" ) response = json.loads(rv.data.decode("utf-8")) self.assertEqual(schemas, response["result"]) def test_database_schemas_not_found(self): """ Database API: Test database schemas not found """ self.logout() self.login(username="gamma") example_db = get_example_database() uri = f"api/v1/database/{example_db.id}/schemas/" rv = self.client.get(uri) self.assertEqual(rv.status_code, 404) def test_database_schemas_invalid_query(self): """ Database API: Test database schemas with invalid query """ self.login("admin") database = db.session.query(Database).first() rv = self.client.get( f"api/v1/database/{database.id}/schemas/?q={prison.dumps({'force': 'nop'})}" ) self.assertEqual(rv.status_code, 400) def test_database_tables(self): """ Database API: Test database tables """ self.login(username="admin") database = db.session.query(Database).filter_by(database_name="examples").one() schema_name = self.default_schema_backend_map[database.backend] rv = self.client.get( f"api/v1/database/{database.id}/tables/?q={prison.dumps({'schema_name': schema_name})}" ) self.assertEqual(rv.status_code, 200) if database.backend == "postgresql": response = json.loads(rv.data.decode("utf-8")) schemas = [ s[0] for s in database.get_all_table_names_in_schema(schema_name) ] self.assertEqual(response["count"], len(schemas)) for option in response["result"]: self.assertEqual(option["extra"], None) self.assertEqual(option["type"], "table") self.assertTrue(option["value"] in schemas) def test_database_tables_not_found(self): """ Database API: Test database tables not found """ self.logout() self.login(username="gamma") example_db = get_example_database() uri = f"api/v1/database/{example_db.id}/tables/?q={prison.dumps({'schema_name': 'non_existent'})}" rv = self.client.get(uri) self.assertEqual(rv.status_code, 404) def test_database_tables_invalid_query(self): """ Database API: Test database tables with invalid query """ self.login("admin") database = db.session.query(Database).first() rv = self.client.get( f"api/v1/database/{database.id}/tables/?q={prison.dumps({'force': 'nop'})}" ) self.assertEqual(rv.status_code, 400) @mock.patch("superset.security.manager.SupersetSecurityManager.can_access_database") def test_database_tables_unexpected_error(self, mock_can_access_database): """ Database API: Test database tables with unexpected error """ self.login(username="admin") database = db.session.query(Database).filter_by(database_name="examples").one() mock_can_access_database.side_effect = Exception("Test Error") rv = self.client.get( f"api/v1/database/{database.id}/tables/?q={prison.dumps({'schema_name': 'main'})}" ) self.assertEqual(rv.status_code, 422) def test_test_connection(self): """ Database API: Test test connection """ extra = { "metadata_params": {}, "engine_params": {}, "metadata_cache_timeout": {}, "schemas_allowed_for_file_upload": [], } # need to temporarily allow sqlite dbs, teardown will undo this app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = False self.login("admin") example_db = get_example_database() # validate that the endpoint works with the password-masked sqlalchemy uri data = { "database_name": "examples", "masked_encrypted_extra": "{}", "extra": json.dumps(extra), "impersonate_user": False, "sqlalchemy_uri": example_db.safe_sqlalchemy_uri(), "server_cert": None, } url = "api/v1/database/test_connection/" rv = self.post_assert_metric(url, data, "test_connection") self.assertEqual(rv.status_code, 200) self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8") # validate that the endpoint works with the decrypted sqlalchemy uri data = { "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, "database_name": "examples", "impersonate_user": False, "extra": json.dumps(extra), "server_cert": None, } rv = self.post_assert_metric(url, data, "test_connection") self.assertEqual(rv.status_code, 200) self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8") def test_test_connection_failed(self): """ Database API: Test test connection failed """ self.login("admin") data = { "sqlalchemy_uri": "broken://url", "database_name": "examples", "impersonate_user": False, "server_cert": None, } url = "api/v1/database/test_connection/" rv = self.post_assert_metric(url, data, "test_connection") self.assertEqual(rv.status_code, 422) self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8") response = json.loads(rv.data.decode("utf-8")) expected_response = { "errors": [ { "message": "Could not load database driver: BaseEngineSpec", "error_type": "GENERIC_COMMAND_ERROR", "level": "warning", "extra": { "issue_codes": [ { "code": 1010, "message": "Issue 1010 - Superset encountered an error while running a command.", } ] }, } ] } self.assertEqual(response, expected_response) data = { "sqlalchemy_uri": "mssql+pymssql://url", "database_name": "examples", "impersonate_user": False, "server_cert": None, } rv = self.post_assert_metric(url, data, "test_connection") self.assertEqual(rv.status_code, 422) self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8") response = json.loads(rv.data.decode("utf-8")) expected_response = { "errors": [ { "message": "Could not load database driver: MssqlEngineSpec", "error_type": "GENERIC_COMMAND_ERROR", "level": "warning", "extra": { "issue_codes": [ { "code": 1010, "message": "Issue 1010 - Superset encountered an error while running a command.", } ] }, } ] } self.assertEqual(response, expected_response) def test_test_connection_unsafe_uri(self): """ Database API: Test test connection with unsafe uri """ self.login("admin") app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = True data = { "sqlalchemy_uri": "sqlite:///home/superset/unsafe.db", "database_name": "unsafe", "impersonate_user": False, "server_cert": None, } url = "api/v1/database/test_connection/" rv = self.post_assert_metric(url, data, "test_connection") self.assertEqual(rv.status_code, 400) response = json.loads(rv.data.decode("utf-8")) expected_response = { "message": { "sqlalchemy_uri": [ "SQLiteDialect_pysqlite cannot be used as a data source for security reasons." ] } } self.assertEqual(response, expected_response) app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = False @mock.patch( "superset.databases.commands.test_connection.DatabaseDAO.build_db_for_connection_test", ) @mock.patch( "superset.databases.commands.test_connection.event_logger", ) def test_test_connection_failed_invalid_hostname( self, mock_event_logger, mock_build_db ): """ Database API: Test test connection failed due to invalid hostname """ msg = 'psql: error: could not translate host name "locahost" to address: nodename nor servname provided, or not known' mock_build_db.return_value.set_sqlalchemy_uri.side_effect = DBAPIError( msg, None, None ) mock_build_db.return_value.db_engine_spec.__name__ = "Some name" superset_error = SupersetError( message='Unable to resolve hostname "locahost".', error_type="CONNECTION_INVALID_HOSTNAME_ERROR", level="error", extra={ "hostname": "locahost", "issue_codes": [ { "code": 1007, "message": ( "Issue 1007 - The hostname provided can't be resolved." ), } ], }, ) mock_build_db.return_value.db_engine_spec.extract_errors.return_value = [ superset_error ] self.login("admin") data = { "sqlalchemy_uri": "postgres://username:password@locahost:12345/db", "database_name": "examples", "impersonate_user": False, "server_cert": None, } url = "api/v1/database/test_connection/" rv = self.post_assert_metric(url, data, "test_connection") assert rv.status_code == 500 assert rv.headers["Content-Type"] == "application/json; charset=utf-8" response = json.loads(rv.data.decode("utf-8")) expected_response = {"errors": [dataclasses.asdict(superset_error)]} assert response == expected_response @pytest.mark.usefixtures( "load_unicode_dashboard_with_position", "load_energy_table_with_slice", "load_world_bank_dashboard_with_slices", "load_birth_names_dashboard_with_slices", ) def test_get_database_related_objects(self): """ Database API: Test get chart and dashboard count related to a database :return: """ self.login(username="admin") database = get_example_database() uri = f"api/v1/database/{database.id}/related_objects/" rv = self.get_assert_metric(uri, "related_objects") self.assertEqual(rv.status_code, 200) response = json.loads(rv.data.decode("utf-8")) self.assertEqual(response["charts"]["count"], 34) self.assertEqual(response["dashboards"]["count"], 3) def test_get_database_related_objects_not_found(self): """ Database API: Test related objects not found """ max_id = db.session.query(func.max(Database.id)).scalar() # id does not exist and we get 404 invalid_id = max_id + 1 uri = f"api/v1/database/{invalid_id}/related_objects/" self.login(username="admin") rv = self.get_assert_metric(uri, "related_objects") self.assertEqual(rv.status_code, 404) self.logout() self.login(username="gamma") database = get_example_database() uri = f"api/v1/database/{database.id}/related_objects/" rv = self.get_assert_metric(uri, "related_objects") self.assertEqual(rv.status_code, 404) def test_export_database(self): """ Database API: Test export database """ self.login(username="admin") database = get_example_database() argument = [database.id] uri = f"api/v1/database/export/?q={prison.dumps(argument)}" rv = self.get_assert_metric(uri, "export") assert rv.status_code == 200 buf = BytesIO(rv.data) assert is_zipfile(buf) def test_export_database_not_allowed(self): """ Database API: Test export database not allowed """ self.login(username="gamma") database = get_example_database() argument = [database.id] uri = f"api/v1/database/export/?q={prison.dumps(argument)}" rv = self.client.get(uri) assert rv.status_code == 403 def test_export_database_non_existing(self): """ Database API: Test export database not allowed """ max_id = db.session.query(func.max(Database.id)).scalar() # id does not exist and we get 404 invalid_id = max_id + 1 self.login(username="admin") argument = [invalid_id] uri = f"api/v1/database/export/?q={prison.dumps(argument)}" rv = self.get_assert_metric(uri, "export") assert rv.status_code == 404 def test_import_database(self): """ Database API: Test import database """ self.login(username="admin") uri = "api/v1/database/import/" buf = self.create_database_import() form_data = { "formData": (buf, "database_export.zip"), } rv = self.client.post(uri, data=form_data, content_type="multipart/form-data") response = json.loads(rv.data.decode("utf-8")) assert rv.status_code == 200 assert response == {"message": "OK"} database = ( db.session.query(Database).filter_by(uuid=database_config["uuid"]).one() ) assert database.database_name == "imported_database" assert len(database.tables) == 1 dataset = database.tables[0] assert dataset.table_name == "imported_dataset" assert str(dataset.uuid) == dataset_config["uuid"] db.session.delete(dataset) db.session.commit() db.session.delete(database) db.session.commit() def test_import_database_overwrite(self): """ Database API: Test import existing database """ self.login(username="admin") uri = "api/v1/database/import/" buf = self.create_database_import() form_data = { "formData": (buf, "database_export.zip"), } rv = self.client.post(uri, data=form_data, content_type="multipart/form-data") response = json.loads(rv.data.decode("utf-8")) assert rv.status_code == 200 assert response == {"message": "OK"} # import again without overwrite flag buf = self.create_database_import() form_data = { "formData": (buf, "database_export.zip"), } rv = self.client.post(uri, data=form_data, content_type="multipart/form-data") response = json.loads(rv.data.decode("utf-8")) assert rv.status_code == 422 assert response == { "errors": [ { "message": "Error importing database", "error_type": "GENERIC_COMMAND_ERROR", "level": "warning", "extra": { "databases/imported_database.yaml": "Database already exists and `overwrite=true` was not passed", "issue_codes": [ { "code": 1010, "message": ( "Issue 1010 - Superset encountered an " "error while running a command." ), } ], }, } ] } # import with overwrite flag buf = self.create_database_import() form_data = { "formData": (buf, "database_export.zip"), "overwrite": "true", } rv = self.client.post(uri, data=form_data, content_type="multipart/form-data") response = json.loads(rv.data.decode("utf-8")) assert rv.status_code == 200 assert response == {"message": "OK"} # clean up database = ( db.session.query(Database).filter_by(uuid=database_config["uuid"]).one() ) dataset = database.tables[0] db.session.delete(dataset) db.session.commit() db.session.delete(database) db.session.commit() def test_import_database_invalid(self): """ Database API: Test import invalid database """ self.login(username="admin") uri = "api/v1/database/import/" buf = BytesIO() with ZipFile(buf, "w") as bundle: with bundle.open("database_export/metadata.yaml", "w") as fp: fp.write(yaml.safe_dump(dataset_metadata_config).encode()) with bundle.open( "database_export/databases/imported_database.yaml", "w" ) as fp: fp.write(yaml.safe_dump(database_config).encode()) with bundle.open( "database_export/datasets/imported_dataset.yaml", "w" ) as fp: fp.write(yaml.safe_dump(dataset_config).encode()) buf.seek(0) form_data = { "formData": (buf, "database_export.zip"), } rv = self.client.post(uri, data=form_data, content_type="multipart/form-data") response = json.loads(rv.data.decode("utf-8")) assert rv.status_code == 422 assert response == { "errors": [ { "message": "Error importing database", "error_type": "GENERIC_COMMAND_ERROR", "level": "warning", "extra": { "metadata.yaml": {"type": ["Must be equal to Database."]}, "issue_codes": [ { "code": 1010, "message": ( "Issue 1010 - Superset encountered an " "error while running a command." ), } ], }, } ] } def test_import_database_masked_password(self): """ Database API: Test import database with masked password """ self.login(username="admin") uri = "api/v1/database/import/" masked_database_config = database_config.copy() masked_database_config[ "sqlalchemy_uri" ] = "postgresql://username:XXXXXXXXXX@host:12345/db" buf = BytesIO() with ZipFile(buf, "w") as bundle: with bundle.open("database_export/metadata.yaml", "w") as fp: fp.write(yaml.safe_dump(database_metadata_config).encode()) with bundle.open( "database_export/databases/imported_database.yaml", "w" ) as fp: fp.write(yaml.safe_dump(masked_database_config).encode()) with bundle.open( "database_export/datasets/imported_dataset.yaml", "w" ) as fp: fp.write(yaml.safe_dump(dataset_config).encode()) buf.seek(0) form_data = { "formData": (buf, "database_export.zip"), } rv = self.client.post(uri, data=form_data, content_type="multipart/form-data") response = json.loads(rv.data.decode("utf-8")) assert rv.status_code == 422 assert response == { "errors": [ { "message": "Error importing database", "error_type": "GENERIC_COMMAND_ERROR", "level": "warning", "extra": { "databases/imported_database.yaml": { "_schema": ["Must provide a password for the database"] }, "issue_codes": [ { "code": 1010, "message": ( "Issue 1010 - Superset encountered an " "error while running a command." ), } ], }, } ] } def test_import_database_masked_password_provided(self): """ Database API: Test import database with masked password provided """ self.login(username="admin") uri = "api/v1/database/import/" masked_database_config = database_config.copy() masked_database_config[ "sqlalchemy_uri" ] = "vertica+vertica_python://hackathon:XXXXXXXXXX@host:5433/dbname?ssl=1" buf = BytesIO() with ZipFile(buf, "w") as bundle: with bundle.open("database_export/metadata.yaml", "w") as fp: fp.write(yaml.safe_dump(database_metadata_config).encode()) with bundle.open( "database_export/databases/imported_database.yaml", "w" ) as fp: fp.write(yaml.safe_dump(masked_database_config).encode()) buf.seek(0) form_data = { "formData": (buf, "database_export.zip"), "passwords": json.dumps({"databases/imported_database.yaml": "SECRET"}), } rv = self.client.post(uri, data=form_data, content_type="multipart/form-data") response = json.loads(rv.data.decode("utf-8")) assert rv.status_code == 200 assert response == {"message": "OK"} database = ( db.session.query(Database).filter_by(uuid=database_config["uuid"]).one() ) assert database.database_name == "imported_database" assert ( database.sqlalchemy_uri == "vertica+vertica_python://hackathon:XXXXXXXXXX@host:5433/dbname?ssl=1" ) assert database.password == "SECRET" db.session.delete(database) db.session.commit() @mock.patch("superset.databases.schemas.is_feature_enabled") def test_import_database_masked_ssh_tunnel_password( self, mock_schema_is_feature_enabled ): """ Database API: Test import database with masked password """ self.login(username="admin") uri = "api/v1/database/import/" mock_schema_is_feature_enabled.return_value = True masked_database_config = database_with_ssh_tunnel_config_password.copy() buf = BytesIO() with ZipFile(buf, "w") as bundle: with bundle.open("database_export/metadata.yaml", "w") as fp: fp.write(yaml.safe_dump(database_metadata_config).encode()) with bundle.open( "database_export/databases/imported_database.yaml", "w" ) as fp: fp.write(yaml.safe_dump(masked_database_config).encode()) with bundle.open( "database_export/datasets/imported_dataset.yaml", "w" ) as fp: fp.write(yaml.safe_dump(dataset_config).encode()) buf.seek(0) form_data = { "formData": (buf, "database_export.zip"), } rv = self.client.post(uri, data=form_data, content_type="multipart/form-data") response = json.loads(rv.data.decode("utf-8")) assert rv.status_code == 422 assert response == { "errors": [ { "message": "Error importing database", "error_type": "GENERIC_COMMAND_ERROR", "level": "warning", "extra": { "databases/imported_database.yaml": { "_schema": ["Must provide a password for the ssh tunnel"] }, "issue_codes": [ { "code": 1010, "message": ( "Issue 1010 - Superset encountered an " "error while running a command." ), } ], }, } ] } @mock.patch("superset.databases.schemas.is_feature_enabled") def test_import_database_masked_ssh_tunnel_password_provided( self, mock_schema_is_feature_enabled ): """ Database API: Test import database with masked password provided """ self.login(username="admin") uri = "api/v1/database/import/" mock_schema_is_feature_enabled.return_value = True masked_database_config = database_with_ssh_tunnel_config_password.copy() buf = BytesIO() with ZipFile(buf, "w") as bundle: with bundle.open("database_export/metadata.yaml", "w") as fp: fp.write(yaml.safe_dump(database_metadata_config).encode()) with bundle.open( "database_export/databases/imported_database.yaml", "w" ) as fp: fp.write(yaml.safe_dump(masked_database_config).encode()) buf.seek(0) form_data = { "formData": (buf, "database_export.zip"), "ssh_tunnel_passwords": json.dumps( {"databases/imported_database.yaml": "TEST"} ), } rv = self.client.post(uri, data=form_data, content_type="multipart/form-data") response = json.loads(rv.data.decode("utf-8")) assert rv.status_code == 200 assert response == {"message": "OK"} database = ( db.session.query(Database).filter_by(uuid=database_config["uuid"]).one() ) assert database.database_name == "imported_database" model_ssh_tunnel = ( db.session.query(SSHTunnel) .filter(SSHTunnel.database_id == database.id) .one() ) self.assertEqual(model_ssh_tunnel.password, "TEST") db.session.delete(database) db.session.commit() @mock.patch("superset.databases.schemas.is_feature_enabled") def test_import_database_masked_ssh_tunnel_private_key_and_password( self, mock_schema_is_feature_enabled ): """ Database API: Test import database with masked private_key """ self.login(username="admin") uri = "api/v1/database/import/" mock_schema_is_feature_enabled.return_value = True masked_database_config = database_with_ssh_tunnel_config_private_key.copy() buf = BytesIO() with ZipFile(buf, "w") as bundle: with bundle.open("database_export/metadata.yaml", "w") as fp: fp.write(yaml.safe_dump(database_metadata_config).encode()) with bundle.open( "database_export/databases/imported_database.yaml", "w" ) as fp: fp.write(yaml.safe_dump(masked_database_config).encode()) with bundle.open( "database_export/datasets/imported_dataset.yaml", "w" ) as fp: fp.write(yaml.safe_dump(dataset_config).encode()) buf.seek(0) form_data = { "formData": (buf, "database_export.zip"), } rv = self.client.post(uri, data=form_data, content_type="multipart/form-data") response = json.loads(rv.data.decode("utf-8")) assert rv.status_code == 422 assert response == { "errors": [ { "message": "Error importing database", "error_type": "GENERIC_COMMAND_ERROR", "level": "warning", "extra": { "databases/imported_database.yaml": { "_schema": [ "Must provide a private key for the ssh tunnel", "Must provide a private key password for the ssh tunnel", ] }, "issue_codes": [ { "code": 1010, "message": ( "Issue 1010 - Superset encountered an " "error while running a command." ), } ], }, } ] } @mock.patch("superset.databases.schemas.is_feature_enabled") def test_import_database_masked_ssh_tunnel_private_key_and_password_provided( self, mock_schema_is_feature_enabled ): """ Database API: Test import database with masked password provided """ self.login(username="admin") uri = "api/v1/database/import/" mock_schema_is_feature_enabled.return_value = True masked_database_config = database_with_ssh_tunnel_config_private_key.copy() buf = BytesIO() with ZipFile(buf, "w") as bundle: with bundle.open("database_export/metadata.yaml", "w") as fp: fp.write(yaml.safe_dump(database_metadata_config).encode()) with bundle.open( "database_export/databases/imported_database.yaml", "w" ) as fp: fp.write(yaml.safe_dump(masked_database_config).encode()) buf.seek(0) form_data = { "formData": (buf, "database_export.zip"), "ssh_tunnel_private_keys": json.dumps( {"databases/imported_database.yaml": "TestPrivateKey"} ), "ssh_tunnel_private_key_passwords": json.dumps( {"databases/imported_database.yaml": "TEST"} ), } rv = self.client.post(uri, data=form_data, content_type="multipart/form-data") response = json.loads(rv.data.decode("utf-8")) assert rv.status_code == 200 assert response == {"message": "OK"} database = ( db.session.query(Database).filter_by(uuid=database_config["uuid"]).one() ) assert database.database_name == "imported_database" model_ssh_tunnel = ( db.session.query(SSHTunnel) .filter(SSHTunnel.database_id == database.id) .one() ) self.assertEqual(model_ssh_tunnel.private_key, "TestPrivateKey") self.assertEqual(model_ssh_tunnel.private_key_password, "TEST") db.session.delete(database) db.session.commit() def test_import_database_masked_ssh_tunnel_feature_flag_disabled(self): """ Database API: Test import database with ssh_tunnel and feature flag disabled """ self.login(username="admin") uri = "api/v1/database/import/" masked_database_config = database_with_ssh_tunnel_config_private_key.copy() buf = BytesIO() with ZipFile(buf, "w") as bundle: with bundle.open("database_export/metadata.yaml", "w") as fp: fp.write(yaml.safe_dump(database_metadata_config).encode()) with bundle.open( "database_export/databases/imported_database.yaml", "w" ) as fp: fp.write(yaml.safe_dump(masked_database_config).encode()) with bundle.open( "database_export/datasets/imported_dataset.yaml", "w" ) as fp: fp.write(yaml.safe_dump(dataset_config).encode()) buf.seek(0) form_data = { "formData": (buf, "database_export.zip"), } rv = self.client.post(uri, data=form_data, content_type="multipart/form-data") response = json.loads(rv.data.decode("utf-8")) assert rv.status_code == 400 assert response == { "errors": [ { "message": "SSH Tunneling is not enabled", "error_type": "GENERIC_COMMAND_ERROR", "level": "warning", "extra": { "issue_codes": [ { "code": 1010, "message": ( "Issue 1010 - Superset encountered an " "error while running a command." ), } ], }, } ] } @mock.patch("superset.databases.schemas.is_feature_enabled") def test_import_database_masked_ssh_tunnel_feature_no_credentials( self, mock_schema_is_feature_enabled ): """ Database API: Test import database with ssh_tunnel that has no credentials """ self.login(username="admin") uri = "api/v1/database/import/" mock_schema_is_feature_enabled.return_value = True masked_database_config = database_with_ssh_tunnel_config_no_credentials.copy() buf = BytesIO() with ZipFile(buf, "w") as bundle: with bundle.open("database_export/metadata.yaml", "w") as fp: fp.write(yaml.safe_dump(database_metadata_config).encode()) with bundle.open( "database_export/databases/imported_database.yaml", "w" ) as fp: fp.write(yaml.safe_dump(masked_database_config).encode()) with bundle.open( "database_export/datasets/imported_dataset.yaml", "w" ) as fp: fp.write(yaml.safe_dump(dataset_config).encode()) buf.seek(0) form_data = { "formData": (buf, "database_export.zip"), } rv = self.client.post(uri, data=form_data, content_type="multipart/form-data") response = json.loads(rv.data.decode("utf-8")) assert rv.status_code == 422 assert response == { "errors": [ { "message": "Must provide credentials for the SSH Tunnel", "error_type": "GENERIC_COMMAND_ERROR", "level": "warning", "extra": { "issue_codes": [ { "code": 1010, "message": ( "Issue 1010 - Superset encountered an " "error while running a command." ), } ], }, } ] } @mock.patch("superset.databases.schemas.is_feature_enabled") def test_import_database_masked_ssh_tunnel_feature_mix_credentials( self, mock_schema_is_feature_enabled ): """ Database API: Test import database with ssh_tunnel that has no credentials """ self.login(username="admin") uri = "api/v1/database/import/" mock_schema_is_feature_enabled.return_value = True masked_database_config = database_with_ssh_tunnel_config_mix_credentials.copy() buf = BytesIO() with ZipFile(buf, "w") as bundle: with bundle.open("database_export/metadata.yaml", "w") as fp: fp.write(yaml.safe_dump(database_metadata_config).encode()) with bundle.open( "database_export/databases/imported_database.yaml", "w" ) as fp: fp.write(yaml.safe_dump(masked_database_config).encode()) with bundle.open( "database_export/datasets/imported_dataset.yaml", "w" ) as fp: fp.write(yaml.safe_dump(dataset_config).encode()) buf.seek(0) form_data = { "formData": (buf, "database_export.zip"), } rv = self.client.post(uri, data=form_data, content_type="multipart/form-data") response = json.loads(rv.data.decode("utf-8")) assert rv.status_code == 422 assert response == { "errors": [ { "message": "Cannot have multiple credentials for the SSH Tunnel", "error_type": "GENERIC_COMMAND_ERROR", "level": "warning", "extra": { "issue_codes": [ { "code": 1010, "message": ( "Issue 1010 - Superset encountered an " "error while running a command." ), } ], }, } ] } @mock.patch("superset.databases.schemas.is_feature_enabled") def test_import_database_masked_ssh_tunnel_feature_only_pk_passwd( self, mock_schema_is_feature_enabled ): """ Database API: Test import database with ssh_tunnel that has no credentials """ self.login(username="admin") uri = "api/v1/database/import/" mock_schema_is_feature_enabled.return_value = True masked_database_config = ( database_with_ssh_tunnel_config_private_pass_only.copy() ) buf = BytesIO() with ZipFile(buf, "w") as bundle: with bundle.open("database_export/metadata.yaml", "w") as fp: fp.write(yaml.safe_dump(database_metadata_config).encode()) with bundle.open( "database_export/databases/imported_database.yaml", "w" ) as fp: fp.write(yaml.safe_dump(masked_database_config).encode()) with bundle.open( "database_export/datasets/imported_dataset.yaml", "w" ) as fp: fp.write(yaml.safe_dump(dataset_config).encode()) buf.seek(0) form_data = { "formData": (buf, "database_export.zip"), } rv = self.client.post(uri, data=form_data, content_type="multipart/form-data") response = json.loads(rv.data.decode("utf-8")) assert rv.status_code == 422 assert response == { "errors": [ { "message": "Error importing database", "error_type": "GENERIC_COMMAND_ERROR", "level": "warning", "extra": { "databases/imported_database.yaml": { "_schema": [ "Must provide a private key for the ssh tunnel", "Must provide a private key password for the ssh tunnel", ] }, "issue_codes": [ { "code": 1010, "message": ( "Issue 1010 - Superset encountered an " "error while running a command." ), } ], }, } ] } @mock.patch( "superset.db_engine_specs.base.BaseEngineSpec.get_function_names", ) def test_function_names(self, mock_get_function_names): example_db = get_example_database() if example_db.backend in {"hive", "presto"}: return mock_get_function_names.return_value = ["AVG", "MAX", "SUM"] self.login(username="admin") uri = "api/v1/database/1/function_names/" rv = self.client.get(uri) response = json.loads(rv.data.decode("utf-8")) assert rv.status_code == 200 assert response == {"function_names": ["AVG", "MAX", "SUM"]} @mock.patch("superset.databases.api.get_available_engine_specs") @mock.patch("superset.databases.api.app") def test_available(self, app, get_available_engine_specs): app.config = {"PREFERRED_DATABASES": ["PostgreSQL", "Google BigQuery"]} get_available_engine_specs.return_value = { PostgresEngineSpec: {"psycopg2"}, BigQueryEngineSpec: {"bigquery"}, MySQLEngineSpec: {"mysqlconnector", "mysqldb"}, GSheetsEngineSpec: {"apsw"}, RedshiftEngineSpec: {"psycopg2"}, HanaEngineSpec: {""}, } self.login(username="admin") uri = "api/v1/database/available/" rv = self.client.get(uri) response = json.loads(rv.data.decode("utf-8")) assert rv.status_code == 200 assert response == { "databases": [ { "available_drivers": ["psycopg2"], "default_driver": "psycopg2", "engine": "postgresql", "name": "PostgreSQL", "parameters": { "properties": { "database": { "description": "Database name", "type": "string", }, "encryption": { "description": "Use an encrypted connection to the database", "type": "boolean", }, "host": { "description": "Hostname or IP address", "type": "string", }, "password": { "description": "Password", "nullable": True, "type": "string", }, "port": { "description": "Database port", "maximum": 65536, "minimum": 0, "type": "integer", }, "query": { "additionalProperties": {}, "description": "Additional parameters", "type": "object", }, "ssh": { "description": "Use an ssh tunnel connection to the database", "type": "boolean", }, "username": { "description": "Username", "nullable": True, "type": "string", }, }, "required": ["database", "host", "port", "username"], "type": "object", }, "preferred": True, "sqlalchemy_uri_placeholder": "postgresql://user:password@host:port/dbname[?key=value&key=value...]", "engine_information": { "supports_file_upload": True, "disable_ssh_tunneling": False, }, }, { "available_drivers": ["bigquery"], "default_driver": "bigquery", "engine": "bigquery", "name": "Google BigQuery", "parameters": { "properties": { "credentials_info": { "description": "Contents of BigQuery JSON credentials.", "type": "string", "x-encrypted-extra": True, }, "query": {"type": "object"}, }, "type": "object", }, "preferred": True, "sqlalchemy_uri_placeholder": "bigquery://{project_id}", "engine_information": { "supports_file_upload": True, "disable_ssh_tunneling": True, }, }, { "available_drivers": ["psycopg2"], "default_driver": "psycopg2", "engine": "redshift", "name": "Amazon Redshift", "parameters": { "properties": { "database": { "description": "Database name", "type": "string", }, "encryption": { "description": "Use an encrypted connection to the database", "type": "boolean", }, "host": { "description": "Hostname or IP address", "type": "string", }, "password": { "description": "Password", "nullable": True, "type": "string", }, "port": { "description": "Database port", "maximum": 65536, "minimum": 0, "type": "integer", }, "query": { "additionalProperties": {}, "description": "Additional parameters", "type": "object", }, "ssh": { "description": "Use an ssh tunnel connection to the database", "type": "boolean", }, "username": { "description": "Username", "nullable": True, "type": "string", }, }, "required": ["database", "host", "port", "username"], "type": "object", }, "preferred": False, "sqlalchemy_uri_placeholder": "redshift+psycopg2://user:password@host:port/dbname[?key=value&key=value...]", "engine_information": { "supports_file_upload": True, "disable_ssh_tunneling": False, }, }, { "available_drivers": ["apsw"], "default_driver": "apsw", "engine": "gsheets", "name": "Google Sheets", "parameters": { "properties": { "catalog": {"type": "object"}, "service_account_info": { "description": "Contents of GSheets JSON credentials.", "type": "string", "x-encrypted-extra": True, }, }, "type": "object", }, "preferred": False, "sqlalchemy_uri_placeholder": "gsheets://", "engine_information": { "supports_file_upload": False, "disable_ssh_tunneling": True, }, }, { "available_drivers": ["mysqlconnector", "mysqldb"], "default_driver": "mysqldb", "engine": "mysql", "name": "MySQL", "parameters": { "properties": { "database": { "description": "Database name", "type": "string", }, "encryption": { "description": "Use an encrypted connection to the database", "type": "boolean", }, "host": { "description": "Hostname or IP address", "type": "string", }, "password": { "description": "Password", "nullable": True, "type": "string", }, "port": { "description": "Database port", "maximum": 65536, "minimum": 0, "type": "integer", }, "query": { "additionalProperties": {}, "description": "Additional parameters", "type": "object", }, "ssh": { "description": "Use an ssh tunnel connection to the database", "type": "boolean", }, "username": { "description": "Username", "nullable": True, "type": "string", }, }, "required": ["database", "host", "port", "username"], "type": "object", }, "preferred": False, "sqlalchemy_uri_placeholder": "mysql://user:password@host:port/dbname[?key=value&key=value...]", "engine_information": { "supports_file_upload": True, "disable_ssh_tunneling": False, }, }, { "available_drivers": [""], "engine": "hana", "name": "SAP HANA", "preferred": False, "engine_information": { "supports_file_upload": True, "disable_ssh_tunneling": False, }, }, ] } @mock.patch("superset.databases.api.get_available_engine_specs") @mock.patch("superset.databases.api.app") def test_available_no_default(self, app, get_available_engine_specs): app.config = {"PREFERRED_DATABASES": ["MySQL"]} get_available_engine_specs.return_value = { MySQLEngineSpec: {"mysqlconnector"}, HanaEngineSpec: {""}, } self.login(username="admin") uri = "api/v1/database/available/" rv = self.client.get(uri) response = json.loads(rv.data.decode("utf-8")) assert rv.status_code == 200 assert response == { "databases": [ { "available_drivers": ["mysqlconnector"], "default_driver": "mysqldb", "engine": "mysql", "name": "MySQL", "preferred": True, "engine_information": { "supports_file_upload": True, "disable_ssh_tunneling": False, }, }, { "available_drivers": [""], "engine": "hana", "name": "SAP HANA", "preferred": False, "engine_information": { "supports_file_upload": True, "disable_ssh_tunneling": False, }, }, ] } def test_validate_parameters_invalid_payload_format(self): self.login(username="admin") url = "api/v1/database/validate_parameters/" rv = self.client.post(url, data="INVALID", content_type="text/plain") response = json.loads(rv.data.decode("utf-8")) assert rv.status_code == 400 assert response == { "errors": [ { "message": "Request is not JSON", "error_type": "INVALID_PAYLOAD_FORMAT_ERROR", "level": "error", "extra": { "issue_codes": [ { "code": 1019, "message": "Issue 1019 - The submitted payload has the incorrect format.", } ] }, } ] } def test_validate_parameters_invalid_payload_schema(self): self.login(username="admin") url = "api/v1/database/validate_parameters/" payload = {"foo": "bar"} rv = self.client.post(url, json=payload) response = json.loads(rv.data.decode("utf-8")) assert rv.status_code == 422 response["errors"].sort(key=lambda error: error["extra"]["invalid"][0]) assert response == { "errors": [ { "message": "Missing data for required field.", "error_type": "INVALID_PAYLOAD_SCHEMA_ERROR", "level": "error", "extra": { "invalid": ["configuration_method"], "issue_codes": [ { "code": 1020, "message": "Issue 1020 - The submitted payload" " has the incorrect schema.", } ], }, }, { "message": "Missing data for required field.", "error_type": "INVALID_PAYLOAD_SCHEMA_ERROR", "level": "error", "extra": { "invalid": ["engine"], "issue_codes": [ { "code": 1020, "message": "Issue 1020 - The submitted payload " "has the incorrect schema.", } ], }, }, ] } def test_validate_parameters_missing_fields(self): self.login(username="admin") url = "api/v1/database/validate_parameters/" payload = { "configuration_method": ConfigurationMethod.SQLALCHEMY_FORM, "engine": "postgresql", "parameters": defaultdict(dict), } payload["parameters"].update( { "host": "", "port": 5432, "username": "", "password": "", "database": "", "query": {}, } ) rv = self.client.post(url, json=payload) response = json.loads(rv.data.decode("utf-8")) assert rv.status_code == 422 assert response == { "errors": [ { "message": "One or more parameters are missing: database, host," " username", "error_type": "CONNECTION_MISSING_PARAMETERS_ERROR", "level": "warning", "extra": { "missing": ["database", "host", "username"], "issue_codes": [ { "code": 1018, "message": "Issue 1018 - One or more parameters " "needed to configure a database are missing.", } ], }, } ] } @mock.patch("superset.db_engine_specs.base.is_hostname_valid") @mock.patch("superset.db_engine_specs.base.is_port_open") @mock.patch("superset.databases.api.ValidateDatabaseParametersCommand") def test_validate_parameters_valid_payload( self, ValidateDatabaseParametersCommand, is_port_open, is_hostname_valid ): is_hostname_valid.return_value = True is_port_open.return_value = True self.login(username="admin") url = "api/v1/database/validate_parameters/" payload = { "engine": "postgresql", "parameters": defaultdict(dict), "configuration_method": ConfigurationMethod.SQLALCHEMY_FORM, } payload["parameters"].update( { "host": "localhost", "port": 6789, "username": "superset", "password": "XXX", "database": "test", "query": {}, } ) rv = self.client.post(url, json=payload) response = json.loads(rv.data.decode("utf-8")) assert rv.status_code == 200 assert response == {"message": "OK"} def test_validate_parameters_invalid_port(self): self.login(username="admin") url = "api/v1/database/validate_parameters/" payload = { "engine": "postgresql", "parameters": defaultdict(dict), "configuration_method": ConfigurationMethod.SQLALCHEMY_FORM, } payload["parameters"].update( { "host": "localhost", "port": "string", "username": "superset", "password": "XXX", "database": "test", "query": {}, } ) rv = self.client.post(url, json=payload) response = json.loads(rv.data.decode("utf-8")) assert rv.status_code == 422 assert response == { "errors": [ { "message": "Port must be a valid integer.", "error_type": "CONNECTION_INVALID_PORT_ERROR", "level": "error", "extra": { "invalid": ["port"], "issue_codes": [ { "code": 1034, "message": "Issue 1034 - The port number is invalid.", } ], }, }, { "message": "The port must be an integer between " "0 and 65535 (inclusive).", "error_type": "CONNECTION_INVALID_PORT_ERROR", "level": "error", "extra": { "invalid": ["port"], "issue_codes": [ { "code": 1034, "message": "Issue 1034 - The port number is invalid.", } ], }, }, ] } @mock.patch("superset.db_engine_specs.base.is_hostname_valid") def test_validate_parameters_invalid_host(self, is_hostname_valid): is_hostname_valid.return_value = False self.login(username="admin") url = "api/v1/database/validate_parameters/" payload = { "engine": "postgresql", "parameters": defaultdict(dict), "configuration_method": ConfigurationMethod.SQLALCHEMY_FORM, } payload["parameters"].update( { "host": "localhost", "port": 5432, "username": "", "password": "", "database": "", "query": {}, } ) rv = self.client.post(url, json=payload) response = json.loads(rv.data.decode("utf-8")) assert rv.status_code == 422 assert response == { "errors": [ { "message": "One or more parameters are missing: database, username", "error_type": "CONNECTION_MISSING_PARAMETERS_ERROR", "level": "warning", "extra": { "missing": ["database", "username"], "issue_codes": [ { "code": 1018, "message": "Issue 1018 - One or more parameters" " needed to configure a database are missing.", } ], }, }, { "message": "The hostname provided can't be resolved.", "error_type": "CONNECTION_INVALID_HOSTNAME_ERROR", "level": "error", "extra": { "invalid": ["host"], "issue_codes": [ { "code": 1007, "message": "Issue 1007 - The hostname " "provided can't be resolved.", } ], }, }, ] } @mock.patch("superset.db_engine_specs.base.is_hostname_valid") def test_validate_parameters_invalid_port_range(self, is_hostname_valid): is_hostname_valid.return_value = True self.login(username="admin") url = "api/v1/database/validate_parameters/" payload = { "engine": "postgresql", "parameters": defaultdict(dict), "configuration_method": ConfigurationMethod.SQLALCHEMY_FORM, } payload["parameters"].update( { "host": "localhost", "port": 65536, "username": "", "password": "", "database": "", "query": {}, } ) rv = self.client.post(url, json=payload) response = json.loads(rv.data.decode("utf-8")) assert rv.status_code == 422 assert response == { "errors": [ { "message": "One or more parameters are missing: database, username", "error_type": "CONNECTION_MISSING_PARAMETERS_ERROR", "level": "warning", "extra": { "missing": ["database", "username"], "issue_codes": [ { "code": 1018, "message": "Issue 1018 - One or more parameters needed to configure a database are missing.", } ], }, }, { "message": "The port must be an integer between 0 and 65535 (inclusive).", "error_type": "CONNECTION_INVALID_PORT_ERROR", "level": "error", "extra": { "invalid": ["port"], "issue_codes": [ { "code": 1034, "message": "Issue 1034 - The port number is invalid.", } ], }, }, ] } def test_get_related_objects(self): example_db = get_example_database() self.login(username="admin") uri = f"api/v1/database/{example_db.id}/related_objects/" rv = self.client.get(uri) assert rv.status_code == 200 assert "charts" in rv.json assert "dashboards" in rv.json assert "sqllab_tab_states" in rv.json @patch.dict( "superset.config.SQL_VALIDATORS_BY_ENGINE", SQL_VALIDATORS_BY_ENGINE, clear=True, ) def test_validate_sql(self): """ Database API: validate SQL success """ request_payload = { "sql": "SELECT * from birth_names", "schema": None, "template_params": None, } example_db = get_example_database() if example_db.backend not in ("presto", "postgresql"): pytest.skip("Only presto and PG are implemented") self.login(username="admin") uri = f"api/v1/database/{example_db.id}/validate_sql/" rv = self.client.post(uri, json=request_payload) response = json.loads(rv.data.decode("utf-8")) self.assertEqual(rv.status_code, 200) self.assertEqual(response["result"], []) @patch.dict( "superset.config.SQL_VALIDATORS_BY_ENGINE", SQL_VALIDATORS_BY_ENGINE, clear=True, ) def test_validate_sql_errors(self): """ Database API: validate SQL with errors """ request_payload = { "sql": "SELECT col1 froma table1", "schema": None, "template_params": None, } example_db = get_example_database() if example_db.backend not in ("presto", "postgresql"): pytest.skip("Only presto and PG are implemented") self.login(username="admin") uri = f"api/v1/database/{example_db.id}/validate_sql/" rv = self.client.post(uri, json=request_payload) response = json.loads(rv.data.decode("utf-8")) self.assertEqual(rv.status_code, 200) self.assertEqual( response["result"], [ { "end_column": None, "line_number": 1, "message": 'ERROR: syntax error at or near "table1"', "start_column": None, } ], ) @patch.dict( "superset.config.SQL_VALIDATORS_BY_ENGINE", SQL_VALIDATORS_BY_ENGINE, clear=True, ) def test_validate_sql_not_found(self): """ Database API: validate SQL database not found """ request_payload = { "sql": "SELECT * from birth_names", "schema": None, "template_params": None, } self.login(username="admin") uri = ( f"api/v1/database/{self.get_nonexistent_numeric_id(Database)}/validate_sql/" ) rv = self.client.post(uri, json=request_payload) self.assertEqual(rv.status_code, 404) @patch.dict( "superset.config.SQL_VALIDATORS_BY_ENGINE", SQL_VALIDATORS_BY_ENGINE, clear=True, ) def test_validate_sql_validation_fails(self): """ Database API: validate SQL database payload validation fails """ request_payload = { "sql": None, "schema": None, "template_params": None, } self.login(username="admin") uri = ( f"api/v1/database/{self.get_nonexistent_numeric_id(Database)}/validate_sql/" ) rv = self.client.post(uri, json=request_payload) response = json.loads(rv.data.decode("utf-8")) self.assertEqual(rv.status_code, 400) self.assertEqual(response, {"message": {"sql": ["Field may not be null."]}}) @patch.dict( "superset.config.SQL_VALIDATORS_BY_ENGINE", {}, clear=True, ) def test_validate_sql_endpoint_noconfig(self): """Assert that validate_sql_json errors out when no validators are configured for any db""" request_payload = { "sql": "SELECT col1 from table1", "schema": None, "template_params": None, } self.login("admin") example_db = get_example_database() uri = f"api/v1/database/{example_db.id}/validate_sql/" rv = self.client.post(uri, json=request_payload) response = json.loads(rv.data.decode("utf-8")) self.assertEqual(rv.status_code, 422) self.assertEqual( response, { "errors": [ { "message": f"no SQL validator is configured for " f"{example_db.backend}", "error_type": "GENERIC_DB_ENGINE_ERROR", "level": "error", "extra": { "issue_codes": [ { "code": 1002, "message": "Issue 1002 - The database returned an " "unexpected error.", } ] }, } ] }, ) @patch("superset.databases.commands.validate_sql.get_validator_by_name") @patch.dict( "superset.config.SQL_VALIDATORS_BY_ENGINE", PRESTO_SQL_VALIDATORS_BY_ENGINE, clear=True, ) def test_validate_sql_endpoint_failure(self, get_validator_by_name): """Assert that validate_sql_json errors out when the selected validator raises an unexpected exception""" request_payload = { "sql": "SELECT * FROM birth_names", "schema": None, "template_params": None, } self.login("admin") validator = MagicMock() get_validator_by_name.return_value = validator validator.validate.side_effect = Exception("Kaboom!") self.login("admin") example_db = get_example_database() uri = f"api/v1/database/{example_db.id}/validate_sql/" rv = self.client.post(uri, json=request_payload) response = json.loads(rv.data.decode("utf-8")) # TODO(bkyryliuk): properly handle hive error if get_example_database().backend == "hive": return self.assertEqual(rv.status_code, 422) self.assertIn("Kaboom!", response["errors"][0]["message"]) def test_get_databases_with_extra_filters(self): """ API: Test get database with extra query filter. Here we are testing our default where all databases must be returned if nothing is being set in the config. Then, we're adding the patch for the config to add the filter function and testing it's being applied. """ self.login(username="admin") extra = { "metadata_params": {}, "engine_params": {}, "metadata_cache_timeout": {}, "schemas_allowed_for_file_upload": [], } example_db = get_example_database() if example_db.backend == "sqlite": return # Create our two databases database_data = { "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, "configuration_method": ConfigurationMethod.SQLALCHEMY_FORM, "server_cert": None, "extra": json.dumps(extra), } uri = "api/v1/database/" rv = self.client.post( uri, json={**database_data, "database_name": "dyntest-create-database-1"} ) first_response = json.loads(rv.data.decode("utf-8")) self.assertEqual(rv.status_code, 201) uri = "api/v1/database/" rv = self.client.post( uri, json={**database_data, "database_name": "create-database-2"} ) second_response = json.loads(rv.data.decode("utf-8")) self.assertEqual(rv.status_code, 201) # The filter function def _base_filter(query): from superset.models.core import Database return query.filter(Database.database_name.startswith("dyntest")) # Create the Mock base_filter_mock = Mock(side_effect=_base_filter) dbs = db.session.query(Database).all() expected_names = [db.database_name for db in dbs] expected_names.sort() uri = f"api/v1/database/" # Get the list of databases without filter in the config rv = self.client.get(uri) data = json.loads(rv.data.decode("utf-8")) # All databases must be returned if no filter is present self.assertEqual(data["count"], len(dbs)) database_names = [item["database_name"] for item in data["result"]] database_names.sort() # All Databases because we are an admin self.assertEqual(database_names, expected_names) assert rv.status_code == 200 # Our filter function wasn't get called base_filter_mock.assert_not_called() # Now we patch the config to include our filter function with patch.dict( "superset.views.filters.current_app.config", {"EXTRA_DYNAMIC_QUERY_FILTERS": {"databases": base_filter_mock}}, ): uri = f"api/v1/database/" rv = self.client.get(uri) data = json.loads(rv.data.decode("utf-8")) # Only one database start with dyntest self.assertEqual(data["count"], 1) database_names = [item["database_name"] for item in data["result"]] # Only the database that starts with tests, even if we are an admin self.assertEqual(database_names, ["dyntest-create-database-1"]) assert rv.status_code == 200 # The filter function is called now that it's defined in our config base_filter_mock.assert_called() # Cleanup first_model = db.session.query(Database).get(first_response.get("id")) second_model = db.session.query(Database).get(second_response.get("id")) db.session.delete(first_model) db.session.delete(second_model) db.session.commit()