# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """Unit tests for Superset""" import json from typing import List from unittest.mock import patch import prison import yaml from sqlalchemy.sql import func from superset import db, security_manager from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.dao.exceptions import ( DAOCreateFailedError, DAODeleteFailedError, DAOUpdateFailedError, ) from superset.models.core import Database from superset.utils.core import get_example_database from superset.utils.dict_import_export import export_to_dict from superset.views.base import generate_download_headers from tests.base_tests import SupersetTestCase class DatasetApiTests(SupersetTestCase): @staticmethod def insert_dataset( table_name: str, schema: str, owners: List[int], database: Database ) -> SqlaTable: obj_owners = list() for owner in owners: user = db.session.query(security_manager.user_model).get(owner) obj_owners.append(user) table = SqlaTable( table_name=table_name, schema=schema, owners=obj_owners, database=database ) db.session.add(table) db.session.commit() table.fetch_metadata() return table def insert_default_dataset(self): return self.insert_dataset( "ab_permission", "", [self.get_user("admin").id], get_example_database() ) @staticmethod def get_birth_names_dataset(): example_db = get_example_database() return ( db.session.query(SqlaTable) .filter_by(database=example_db, table_name="birth_names") .one() ) def test_get_dataset_list(self): """ Dataset API: Test get dataset list """ example_db = get_example_database() self.login(username="admin") arguments = { "filters": [ {"col": "database", "opr": "rel_o_m", "value": f"{example_db.id}"}, {"col": "table_name", "opr": "eq", "value": f"birth_names"}, ] } uri = f"api/v1/dataset/?q={prison.dumps(arguments)}" rv = self.client.get(uri) self.assertEqual(rv.status_code, 200) response = json.loads(rv.data.decode("utf-8")) self.assertEqual(response["count"], 1) expected_columns = [ "changed_by", "changed_by_name", "changed_by_url", "changed_on", "database_name", "explore_url", "id", "schema", "table_name", ] self.assertEqual(sorted(list(response["result"][0].keys())), expected_columns) def test_get_dataset_list_gamma(self): """ Dataset API: Test get dataset list gamma """ example_db = get_example_database() self.login(username="gamma") uri = "api/v1/dataset/" rv = self.client.get(uri) self.assertEqual(rv.status_code, 200) response = json.loads(rv.data.decode("utf-8")) self.assertEqual(response["result"], []) def test_get_dataset_related_database_gamma(self): """ Dataset API: Test get dataset related databases gamma """ example_db = get_example_database() self.login(username="gamma") uri = "api/v1/dataset/related/database" rv = self.client.get(uri) self.assertEqual(rv.status_code, 200) response = json.loads(rv.data.decode("utf-8")) self.assertEqual(response["count"], 0) self.assertEqual(response["result"], []) def test_get_dataset_item(self): """ Dataset API: Test get dataset item """ table = self.get_birth_names_dataset() self.login(username="admin") uri = f"api/v1/dataset/{table.id}" rv = self.client.get(uri) self.assertEqual(rv.status_code, 200) response = json.loads(rv.data.decode("utf-8")) expected_result = { "cache_timeout": None, "database": {"database_name": "examples", "id": 1}, "default_endpoint": None, "description": None, "fetch_values_predicate": None, "filter_select_enabled": True, "is_sqllab_view": False, "main_dttm_col": "ds", "offset": 0, "owners": [], "schema": None, "sql": None, "table_name": "birth_names", "template_params": None, } for key, value in expected_result.items(): self.assertEqual(response["result"][key], expected_result[key]) self.assertEqual(len(response["result"]["columns"]), 8) self.assertEqual(len(response["result"]["metrics"]), 2) def test_get_dataset_info(self): """ Dataset API: Test get dataset info """ self.login(username="admin") uri = "api/v1/dataset/_info" rv = self.client.get(uri) self.assertEqual(rv.status_code, 200) def test_create_dataset_item(self): """ Dataset API: Test create dataset item """ example_db = get_example_database() self.login(username="admin") table_data = { "database": example_db.id, "schema": "", "table_name": "ab_permission", } uri = "api/v1/dataset/" rv = self.client.post(uri, json=table_data) self.assertEqual(rv.status_code, 201) data = json.loads(rv.data.decode("utf-8")) table_id = data.get("id") model = db.session.query(SqlaTable).get(table_id) self.assertEqual(model.table_name, table_data["table_name"]) self.assertEqual(model.database_id, table_data["database"]) # Assert that columns were created columns = ( db.session.query(TableColumn) .filter_by(table_id=table_id) .order_by("column_name") .all() ) self.assertEqual(columns[0].column_name, "id") self.assertEqual(columns[1].column_name, "name") # Assert that metrics were created columns = ( db.session.query(SqlMetric) .filter_by(table_id=table_id) .order_by("metric_name") .all() ) self.assertEqual(columns[0].expression, "COUNT(*)") db.session.delete(model) db.session.commit() def test_create_dataset_item_gamma(self): """ Dataset API: Test create dataset item gamma """ self.login(username="gamma") example_db = get_example_database() table_data = { "database": example_db.id, "schema": "", "table_name": "ab_permission", } uri = "api/v1/dataset/" rv = self.client.post(uri, json=table_data) self.assertEqual(rv.status_code, 401) def test_create_dataset_item_owner(self): """ Dataset API: Test create item owner """ example_db = get_example_database() self.login(username="alpha") admin = self.get_user("admin") alpha = self.get_user("alpha") table_data = { "database": example_db.id, "schema": "", "table_name": "ab_permission", "owners": [admin.id], } uri = "api/v1/dataset/" rv = self.client.post(uri, json=table_data) self.assertEqual(rv.status_code, 201) data = json.loads(rv.data.decode("utf-8")) model = db.session.query(SqlaTable).get(data.get("id")) self.assertIn(admin, model.owners) self.assertIn(alpha, model.owners) db.session.delete(model) db.session.commit() def test_create_dataset_item_owners_invalid(self): """ Dataset API: Test create dataset item owner invalid """ admin = self.get_user("admin") example_db = get_example_database() self.login(username="admin") table_data = { "database": example_db.id, "schema": "", "table_name": "ab_permission", "owners": [admin.id, 1000], } uri = f"api/v1/dataset/" rv = self.client.post(uri, json=table_data) self.assertEqual(rv.status_code, 422) data = json.loads(rv.data.decode("utf-8")) expected_result = {"message": {"owners": ["Owners are invalid"]}} self.assertEqual(data, expected_result) def test_create_dataset_validate_uniqueness(self): """ Dataset API: Test create dataset validate table uniqueness """ example_db = get_example_database() self.login(username="admin") table_data = { "database": example_db.id, "schema": "", "table_name": "birth_names", } uri = "api/v1/dataset/" rv = self.client.post(uri, json=table_data) self.assertEqual(rv.status_code, 422) data = json.loads(rv.data.decode("utf-8")) self.assertEqual( data, {"message": {"table_name": ["Datasource birth_names already exists"]}} ) def test_create_dataset_validate_database(self): """ Dataset API: Test create dataset validate database exists """ self.login(username="admin") dataset_data = {"database": 1000, "schema": "", "table_name": "birth_names"} uri = "api/v1/dataset/" rv = self.client.post(uri, json=dataset_data) self.assertEqual(rv.status_code, 422) data = json.loads(rv.data.decode("utf-8")) self.assertEqual(data, {"message": {"database": ["Database does not exist"]}}) def test_create_dataset_validate_tables_exists(self): """ Dataset API: Test create dataset validate table exists """ example_db = get_example_database() self.login(username="admin") table_data = { "database": example_db.id, "schema": "", "table_name": "does_not_exist", } uri = "api/v1/dataset/" rv = self.client.post(uri, json=table_data) self.assertEqual(rv.status_code, 422) @patch("superset.datasets.dao.DatasetDAO.create") def test_create_dataset_sqlalchemy_error(self, mock_dao_create): """ Dataset API: Test create dataset sqlalchemy error """ mock_dao_create.side_effect = DAOCreateFailedError() self.login(username="admin") example_db = get_example_database() dataset_data = { "database": example_db.id, "schema": "", "table_name": "ab_permission", } uri = "api/v1/dataset/" rv = self.client.post(uri, json=dataset_data) data = json.loads(rv.data.decode("utf-8")) self.assertEqual(rv.status_code, 422) self.assertEqual(data, {"message": "Dataset could not be created."}) def test_update_dataset_item(self): """ Dataset API: Test update dataset item """ dataset = self.insert_default_dataset() self.login(username="admin") dataset_data = {"description": "changed_description"} uri = f"api/v1/dataset/{dataset.id}" rv = self.client.put(uri, json=dataset_data) self.assertEqual(rv.status_code, 200) model = db.session.query(SqlaTable).get(dataset.id) self.assertEqual(model.description, dataset_data["description"]) db.session.delete(dataset) db.session.commit() def test_update_dataset_create_column(self): """ Dataset API: Test update dataset create column """ # create example dataset by Command dataset = self.insert_default_dataset() new_column_data = { "column_name": "new_col", "description": "description", "expression": "expression", "type": "INTEGER", "verbose_name": "New Col", } uri = f"api/v1/dataset/{dataset.id}" # Get current cols and append the new column self.login(username="admin") rv = self.client.get(uri) data = json.loads(rv.data.decode("utf-8")) data["result"]["columns"].append(new_column_data) rv = self.client.put(uri, json={"columns": data["result"]["columns"]}) self.assertEqual(rv.status_code, 200) columns = ( db.session.query(TableColumn) .filter_by(table_id=dataset.id) .order_by("column_name") .all() ) self.assertEqual(columns[0].column_name, "id") self.assertEqual(columns[1].column_name, "name") self.assertEqual(columns[2].column_name, new_column_data["column_name"]) self.assertEqual(columns[2].description, new_column_data["description"]) self.assertEqual(columns[2].expression, new_column_data["expression"]) self.assertEqual(columns[2].type, new_column_data["type"]) self.assertEqual(columns[2].verbose_name, new_column_data["verbose_name"]) db.session.delete(dataset) db.session.commit() def test_update_dataset_update_column(self): """ Dataset API: Test update dataset columns """ dataset = self.insert_default_dataset() self.login(username="admin") uri = f"api/v1/dataset/{dataset.id}" # Get current cols and alter one rv = self.client.get(uri) resp_columns = json.loads(rv.data.decode("utf-8"))["result"]["columns"] resp_columns[0]["groupby"] = False resp_columns[0]["filterable"] = False v = self.client.put(uri, json={"columns": resp_columns}) self.assertEqual(rv.status_code, 200) columns = ( db.session.query(TableColumn) .filter_by(table_id=dataset.id) .order_by("column_name") .all() ) self.assertEqual(columns[0].column_name, "id") self.assertEqual(columns[1].column_name, "name") self.assertEqual(columns[0].groupby, False) self.assertEqual(columns[0].filterable, False) db.session.delete(dataset) db.session.commit() def test_update_dataset_update_column_uniqueness(self): """ Dataset API: Test update dataset columns uniqueness """ dataset = self.insert_default_dataset() self.login(username="admin") uri = f"api/v1/dataset/{dataset.id}" # try to insert a new column ID that already exists data = {"columns": [{"column_name": "id", "type": "INTEGER"}]} rv = self.client.put(uri, json=data) self.assertEqual(rv.status_code, 422) data = json.loads(rv.data.decode("utf-8")) expected_result = { "message": {"columns": ["One or more columns already exist"]} } self.assertEqual(data, expected_result) db.session.delete(dataset) db.session.commit() def test_update_dataset_update_metric_uniqueness(self): """ Dataset API: Test update dataset metric uniqueness """ dataset = self.insert_default_dataset() self.login(username="admin") uri = f"api/v1/dataset/{dataset.id}" # try to insert a new column ID that already exists data = {"metrics": [{"metric_name": "count", "expression": "COUNT(*)"}]} rv = self.client.put(uri, json=data) self.assertEqual(rv.status_code, 422) data = json.loads(rv.data.decode("utf-8")) expected_result = { "message": {"metrics": ["One or more metrics already exist"]} } self.assertEqual(data, expected_result) db.session.delete(dataset) db.session.commit() def test_update_dataset_update_column_duplicate(self): """ Dataset API: Test update dataset columns duplicate """ dataset = self.insert_default_dataset() self.login(username="admin") uri = f"api/v1/dataset/{dataset.id}" # try to insert a new column ID that already exists data = { "columns": [ {"column_name": "id", "type": "INTEGER"}, {"column_name": "id", "type": "VARCHAR"}, ] } rv = self.client.put(uri, json=data) self.assertEqual(rv.status_code, 422) data = json.loads(rv.data.decode("utf-8")) expected_result = { "message": {"columns": ["One or more columns are duplicated"]} } self.assertEqual(data, expected_result) db.session.delete(dataset) db.session.commit() def test_update_dataset_update_metric_duplicate(self): """ Dataset API: Test update dataset metric duplicate """ dataset = self.insert_default_dataset() self.login(username="admin") uri = f"api/v1/dataset/{dataset.id}" # try to insert a new column ID that already exists data = { "metrics": [ {"metric_name": "dup", "expression": "COUNT(*)"}, {"metric_name": "dup", "expression": "DIFF_COUNT(*)"}, ] } rv = self.client.put(uri, json=data) self.assertEqual(rv.status_code, 422) data = json.loads(rv.data.decode("utf-8")) expected_result = { "message": {"metrics": ["One or more metrics are duplicated"]} } self.assertEqual(data, expected_result) db.session.delete(dataset) db.session.commit() def test_update_dataset_item_gamma(self): """ Dataset API: Test update dataset item gamma """ dataset = self.insert_default_dataset() self.login(username="gamma") table_data = {"description": "changed_description"} uri = f"api/v1/dataset/{dataset.id}" rv = self.client.put(uri, json=table_data) self.assertEqual(rv.status_code, 401) db.session.delete(dataset) db.session.commit() def test_update_dataset_item_not_owned(self): """ Dataset API: Test update dataset item not owned """ dataset = self.insert_default_dataset() self.login(username="alpha") table_data = {"description": "changed_description"} uri = f"api/v1/dataset/{dataset.id}" rv = self.client.put(uri, json=table_data) self.assertEqual(rv.status_code, 403) db.session.delete(dataset) db.session.commit() def test_update_dataset_item_owners_invalid(self): """ Dataset API: Test update dataset item owner invalid """ dataset = self.insert_default_dataset() self.login(username="admin") table_data = {"description": "changed_description", "owners": [1000]} uri = f"api/v1/dataset/{dataset.id}" rv = self.client.put(uri, json=table_data) self.assertEqual(rv.status_code, 422) db.session.delete(dataset) db.session.commit() def test_update_dataset_item_uniqueness(self): """ Dataset API: Test update dataset uniqueness """ dataset = self.insert_default_dataset() self.login(username="admin") table_data = {"table_name": "birth_names"} uri = f"api/v1/dataset/{dataset.id}" rv = self.client.put(uri, json=table_data) data = json.loads(rv.data.decode("utf-8")) self.assertEqual(rv.status_code, 422) expected_response = { "message": {"table_name": ["Datasource birth_names already exists"]} } self.assertEqual(data, expected_response) db.session.delete(dataset) db.session.commit() @patch("superset.datasets.dao.DatasetDAO.update") def test_update_dataset_sqlalchemy_error(self, mock_dao_update): """ Dataset API: Test update dataset sqlalchemy error """ mock_dao_update.side_effect = DAOUpdateFailedError() dataset = self.insert_default_dataset() self.login(username="admin") table_data = {"description": "changed_description"} uri = f"api/v1/dataset/{dataset.id}" rv = self.client.put(uri, json=table_data) data = json.loads(rv.data.decode("utf-8")) self.assertEqual(rv.status_code, 422) self.assertEqual(data, {"message": "Dataset could not be updated."}) db.session.delete(dataset) db.session.commit() def test_delete_dataset_item(self): """ Dataset API: Test delete dataset item """ dataset = self.insert_default_dataset() self.login(username="admin") uri = f"api/v1/dataset/{dataset.id}" rv = self.client.delete(uri) self.assertEqual(rv.status_code, 200) def test_delete_item_dataset_not_owned(self): """ Dataset API: Test delete item not owned """ dataset = self.insert_default_dataset() self.login(username="alpha") uri = f"api/v1/dataset/{dataset.id}" rv = self.client.delete(uri) self.assertEqual(rv.status_code, 403) db.session.delete(dataset) db.session.commit() def test_delete_dataset_item_not_authorized(self): """ Dataset API: Test delete item not authorized """ dataset = self.insert_default_dataset() self.login(username="gamma") uri = f"api/v1/dataset/{dataset.id}" rv = self.client.delete(uri) self.assertEqual(rv.status_code, 401) db.session.delete(dataset) db.session.commit() @patch("superset.datasets.dao.DatasetDAO.delete") def test_delete_dataset_sqlalchemy_error(self, mock_dao_delete): """ Dataset API: Test delete dataset sqlalchemy error """ mock_dao_delete.side_effect = DAODeleteFailedError() dataset = self.insert_default_dataset() self.login(username="admin") uri = f"api/v1/dataset/{dataset.id}" rv = self.client.delete(uri) data = json.loads(rv.data.decode("utf-8")) self.assertEqual(rv.status_code, 422) self.assertEqual(data, {"message": "Dataset could not be deleted."}) db.session.delete(dataset) db.session.commit() def test_dataset_item_refresh(self): """ Dataset API: Test item refresh """ dataset = self.insert_default_dataset() # delete a column id_column = ( db.session.query(TableColumn) .filter_by(table_id=dataset.id, column_name="id") .one() ) db.session.delete(id_column) db.session.commit() self.login(username="admin") uri = f"api/v1/dataset/{dataset.id}/refresh" rv = self.client.put(uri) self.assertEqual(rv.status_code, 200) # Assert the column is restored on refresh id_column = ( db.session.query(TableColumn) .filter_by(table_id=dataset.id, column_name="id") .one() ) self.assertIsNotNone(id_column) db.session.delete(dataset) db.session.commit() def test_dataset_item_refresh_not_found(self): """ Dataset API: Test item refresh not found dataset """ max_id = db.session.query(func.max(SqlaTable.id)).scalar() self.login(username="admin") uri = f"api/v1/dataset/{max_id + 1}/refresh" rv = self.client.put(uri) self.assertEqual(rv.status_code, 404) def test_dataset_item_refresh_not_owned(self): """ Dataset API: Test item refresh not owned dataset """ dataset = self.insert_default_dataset() self.login(username="alpha") uri = f"api/v1/dataset/{dataset.id}/refresh" rv = self.client.put(uri) self.assertEqual(rv.status_code, 403) db.session.delete(dataset) db.session.commit() def test_export_dataset(self): """ Dataset API: Test export dataset :return: """ birth_names_dataset = self.get_birth_names_dataset() argument = [birth_names_dataset.id] uri = f"api/v1/dataset/export/?q={prison.dumps(argument)}" self.login(username="admin") rv = self.client.get(uri) self.assertEqual(rv.status_code, 200) self.assertEqual( rv.headers["Content-Disposition"], generate_download_headers("yaml")["Content-Disposition"], ) cli_export = export_to_dict( session=db.session, recursive=True, back_references=False, include_defaults=False, ) cli_export_tables = cli_export["databases"][0]["tables"] expected_response = [] for export_table in cli_export_tables: if export_table["table_name"] == "birth_names": expected_response = export_table break ui_export = yaml.safe_load(rv.data.decode("utf-8")) self.assertEqual(ui_export[0], expected_response) def test_export_dataset_not_found(self): """ Dataset API: Test export dataset not found :return: """ max_id = db.session.query(func.max(SqlaTable.id)).scalar() # Just one does not exist and we get 404 argument = [max_id + 1, 1] uri = f"api/v1/dataset/export/?q={prison.dumps(argument)}" self.login(username="admin") rv = self.client.get(uri) self.assertEqual(rv.status_code, 404) def test_export_dataset_gamma(self): """ Dataset API: Test export dataset has gamma :return: """ birth_names_dataset = self.get_birth_names_dataset() argument = [birth_names_dataset.id] uri = f"api/v1/dataset/export/?q={prison.dumps(argument)}" self.login(username="gamma") rv = self.client.get(uri) self.assertEqual(rv.status_code, 401)