# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # pylint: disable=no-self-use, invalid-name, line-too-long from operator import itemgetter from typing import Any, List from unittest.mock import patch import pytest import yaml from superset import db, security_manager from superset.commands.exceptions import CommandInvalidError from superset.commands.importers.exceptions import IncorrectVersionError from superset.connectors.sqla.models import SqlaTable from superset.databases.commands.importers.v1 import ImportDatabasesCommand from superset.datasets.commands.exceptions import DatasetNotFoundError from superset.datasets.commands.export import ExportDatasetsCommand from superset.datasets.commands.importers import v0, v1 from superset.models.core import Database from superset.utils.core import get_example_database from tests.base_tests import SupersetTestCase from tests.fixtures.energy_dashboard import load_energy_table_with_slice from tests.fixtures.importexport import ( database_config, database_metadata_config, dataset_cli_export, dataset_config, dataset_metadata_config, dataset_ui_export, ) from tests.fixtures.world_bank_dashboard import load_world_bank_dashboard_with_slices class TestExportDatasetsCommand(SupersetTestCase): @patch("superset.security.manager.g") @pytest.mark.usefixtures("load_energy_table_with_slice") def test_export_dataset_command(self, mock_g): mock_g.user = security_manager.find_user("admin") example_db = get_example_database() example_dataset = _get_table_from_list_by_name( "energy_usage", example_db.tables ) command = ExportDatasetsCommand([example_dataset.id]) contents = dict(command.run()) assert list(contents.keys()) == [ "metadata.yaml", "datasets/examples/energy_usage.yaml", "databases/examples.yaml", ] metadata = yaml.safe_load(contents["datasets/examples/energy_usage.yaml"]) # sort columns for deterministc comparison metadata["columns"] = sorted(metadata["columns"], key=itemgetter("column_name")) metadata["metrics"] = sorted(metadata["metrics"], key=itemgetter("metric_name")) # types are different depending on the backend type_map = { column.column_name: str(column.type) for column in example_dataset.columns } assert metadata == { "cache_timeout": None, "columns": [ { "column_name": "source", "description": None, "expression": "", "filterable": True, "groupby": True, "is_active": True, "is_dttm": False, "python_date_format": None, "type": type_map["source"], "verbose_name": None, }, { "column_name": "target", "description": None, "expression": "", "filterable": True, "groupby": True, "is_active": True, "is_dttm": False, "python_date_format": None, "type": type_map["target"], "verbose_name": None, }, { "column_name": "value", "description": None, "expression": "", "filterable": True, "groupby": True, "is_active": True, "is_dttm": False, "python_date_format": None, "type": type_map["value"], "verbose_name": None, }, ], "database_uuid": str(example_db.uuid), "default_endpoint": None, "description": "Energy consumption", "extra": None, "fetch_values_predicate": None, "filter_select_enabled": False, "main_dttm_col": None, "metrics": [ { "d3format": None, "description": None, "expression": "COUNT(*)", "extra": None, "metric_name": "count", "metric_type": "count", "verbose_name": "COUNT(*)", "warning_text": None, }, { "d3format": None, "description": None, "expression": "SUM(value)", "extra": None, "metric_name": "sum__value", "metric_type": None, "verbose_name": None, "warning_text": None, }, ], "offset": 0, "params": None, "schema": None, "sql": None, "table_name": "energy_usage", "template_params": None, "uuid": str(example_dataset.uuid), "version": "1.0.0", } @patch("superset.security.manager.g") def test_export_dataset_command_no_access(self, mock_g): """Test that users can't export datasets they don't have access to""" mock_g.user = security_manager.find_user("gamma") example_db = get_example_database() example_dataset = example_db.tables[0] command = ExportDatasetsCommand([example_dataset.id]) contents = command.run() with self.assertRaises(DatasetNotFoundError): next(contents) @patch("superset.security.manager.g") def test_export_dataset_command_invalid_dataset(self, mock_g): """Test that an error is raised when exporting an invalid dataset""" mock_g.user = security_manager.find_user("admin") command = ExportDatasetsCommand([-1]) contents = command.run() with self.assertRaises(DatasetNotFoundError): next(contents) @patch("superset.security.manager.g") @pytest.mark.usefixtures("load_energy_table_with_slice") def test_export_dataset_command_key_order(self, mock_g): """Test that they keys in the YAML have the same order as export_fields""" mock_g.user = security_manager.find_user("admin") example_db = get_example_database() example_dataset = _get_table_from_list_by_name( "energy_usage", example_db.tables ) command = ExportDatasetsCommand([example_dataset.id]) contents = dict(command.run()) metadata = yaml.safe_load(contents["datasets/examples/energy_usage.yaml"]) assert list(metadata.keys()) == [ "table_name", "main_dttm_col", "description", "default_endpoint", "offset", "cache_timeout", "schema", "sql", "params", "template_params", "filter_select_enabled", "fetch_values_predicate", "extra", "uuid", "metrics", "columns", "version", "database_uuid", ] class TestImportDatasetsCommand(SupersetTestCase): @pytest.mark.usefixtures("load_world_bank_dashboard_with_slices") def test_import_v0_dataset_cli_export(self): num_datasets = db.session.query(SqlaTable).count() contents = { "20201119_181105.yaml": yaml.safe_dump(dataset_cli_export), } command = v0.ImportDatasetsCommand(contents) command.run() new_num_datasets = db.session.query(SqlaTable).count() assert new_num_datasets == num_datasets + 1 dataset = ( db.session.query(SqlaTable).filter_by(table_name="birth_names_2").one() ) assert ( dataset.params == '{"remote_id": 3, "database_name": "examples", "import_time": 1604342885}' ) assert len(dataset.metrics) == 2 assert dataset.main_dttm_col == "ds" assert dataset.filter_select_enabled dataset.columns.sort(key=lambda obj: obj.column_name) expected_columns = [ "num_california", "ds", "state", "gender", "name", "num_boys", "num_girls", "num", ] expected_columns.sort() assert [col.column_name for col in dataset.columns] == expected_columns db.session.delete(dataset) db.session.commit() @pytest.mark.usefixtures("load_world_bank_dashboard_with_slices") def test_import_v0_dataset_ui_export(self): num_datasets = db.session.query(SqlaTable).count() contents = { "20201119_181105.yaml": yaml.safe_dump(dataset_ui_export), } command = v0.ImportDatasetsCommand(contents) command.run() new_num_datasets = db.session.query(SqlaTable).count() assert new_num_datasets == num_datasets + 1 dataset = ( db.session.query(SqlaTable).filter_by(table_name="birth_names_2").one() ) assert ( dataset.params == '{"remote_id": 3, "database_name": "examples", "import_time": 1604342885}' ) assert len(dataset.metrics) == 2 assert dataset.main_dttm_col == "ds" assert dataset.filter_select_enabled assert set(col.column_name for col in dataset.columns) == { "num_california", "ds", "state", "gender", "name", "num_boys", "num_girls", "num", } db.session.delete(dataset) db.session.commit() def test_import_v1_dataset(self): """Test that we can import a dataset""" contents = { "metadata.yaml": yaml.safe_dump(dataset_metadata_config), "databases/imported_database.yaml": yaml.safe_dump(database_config), "datasets/imported_dataset.yaml": yaml.safe_dump(dataset_config), } command = v1.ImportDatasetsCommand(contents) command.run() dataset = ( db.session.query(SqlaTable).filter_by(uuid=dataset_config["uuid"]).one() ) assert dataset.table_name == "imported_dataset" assert dataset.main_dttm_col is None assert dataset.description == "This is a dataset that was exported" assert dataset.default_endpoint == "" assert dataset.offset == 66 assert dataset.cache_timeout == 55 assert dataset.schema == "" assert dataset.sql == "" assert dataset.params is None assert dataset.template_params == "{}" assert dataset.filter_select_enabled assert dataset.fetch_values_predicate is None assert dataset.extra is None # database is also imported assert str(dataset.database.uuid) == "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89" assert len(dataset.metrics) == 1 metric = dataset.metrics[0] assert metric.metric_name == "count" assert metric.verbose_name == "" assert metric.metric_type is None assert metric.expression == "count(1)" assert metric.description is None assert metric.d3format is None assert metric.extra is None assert metric.warning_text is None assert len(dataset.columns) == 1 column = dataset.columns[0] assert column.column_name == "cnt" assert column.verbose_name == "Count of something" assert not column.is_dttm assert column.is_active # imported columns are set to active assert column.type == "NUMBER" assert not column.groupby assert column.filterable assert column.expression == "" assert column.description is None assert column.python_date_format is None db.session.delete(dataset) db.session.delete(dataset.database) db.session.commit() def test_import_v1_dataset_multiple(self): """Test that a dataset can be imported multiple times""" contents = { "metadata.yaml": yaml.safe_dump(dataset_metadata_config), "databases/imported_database.yaml": yaml.safe_dump(database_config), "datasets/imported_dataset.yaml": yaml.safe_dump(dataset_config), } command = v1.ImportDatasetsCommand(contents, overwrite=True) command.run() command.run() dataset = ( db.session.query(SqlaTable).filter_by(uuid=dataset_config["uuid"]).one() ) assert dataset.table_name == "imported_dataset" # test that columns and metrics sync, ie, old ones not the import # are removed new_config = dataset_config.copy() new_config["metrics"][0]["metric_name"] = "count2" new_config["columns"][0]["column_name"] = "cnt2" contents = { "metadata.yaml": yaml.safe_dump(dataset_metadata_config), "databases/imported_database.yaml": yaml.safe_dump(database_config), "datasets/imported_dataset.yaml": yaml.safe_dump(new_config), } command = v1.ImportDatasetsCommand(contents, overwrite=True) command.run() dataset = ( db.session.query(SqlaTable).filter_by(uuid=dataset_config["uuid"]).one() ) assert len(dataset.metrics) == 1 assert dataset.metrics[0].metric_name == "count2" assert len(dataset.columns) == 1 assert dataset.columns[0].column_name == "cnt2" db.session.delete(dataset) db.session.delete(dataset.database) db.session.commit() def test_import_v1_dataset_validation(self): """Test different validations applied when importing a dataset""" # metadata.yaml must be present contents = { "datasets/imported_dataset.yaml": yaml.safe_dump(dataset_config), } command = v1.ImportDatasetsCommand(contents) with pytest.raises(IncorrectVersionError) as excinfo: command.run() assert str(excinfo.value) == "Missing metadata.yaml" # version should be 1.0.0 contents["metadata.yaml"] = yaml.safe_dump( { "version": "2.0.0", "type": "SqlaTable", "timestamp": "2020-11-04T21:27:44.423819+00:00", } ) command = v1.ImportDatasetsCommand(contents) with pytest.raises(IncorrectVersionError) as excinfo: command.run() assert str(excinfo.value) == "Must be equal to 1.0.0." # type should be SqlaTable contents["metadata.yaml"] = yaml.safe_dump(database_metadata_config) command = v1.ImportDatasetsCommand(contents) with pytest.raises(CommandInvalidError) as excinfo: command.run() assert str(excinfo.value) == "Error importing dataset" assert excinfo.value.normalized_messages() == { "metadata.yaml": {"type": ["Must be equal to SqlaTable."]} } # must also validate databases broken_config = database_config.copy() del broken_config["database_name"] contents["metadata.yaml"] = yaml.safe_dump(dataset_metadata_config) contents["databases/imported_database.yaml"] = yaml.safe_dump(broken_config) command = v1.ImportDatasetsCommand(contents) with pytest.raises(CommandInvalidError) as excinfo: command.run() assert str(excinfo.value) == "Error importing dataset" assert excinfo.value.normalized_messages() == { "databases/imported_database.yaml": { "database_name": ["Missing data for required field."], } } def test_import_v1_dataset_existing_database(self): """Test that a dataset can be imported when the database already exists""" # first import database... contents = { "metadata.yaml": yaml.safe_dump(database_metadata_config), "databases/imported_database.yaml": yaml.safe_dump(database_config), } command = ImportDatabasesCommand(contents) command.run() database = ( db.session.query(Database).filter_by(uuid=database_config["uuid"]).one() ) assert len(database.tables) == 0 # ...then dataset contents = { "metadata.yaml": yaml.safe_dump(dataset_metadata_config), "datasets/imported_dataset.yaml": yaml.safe_dump(dataset_config), "databases/imported_database.yaml": yaml.safe_dump(database_config), } command = v1.ImportDatasetsCommand(contents, overwrite=True) command.run() database = ( db.session.query(Database).filter_by(uuid=database_config["uuid"]).one() ) assert len(database.tables) == 1 db.session.delete(database.tables[0]) db.session.delete(database) db.session.commit() def _get_table_from_list_by_name(name: str, tables: List[Any]): for table in tables: if table.table_name == name: return table raise ValueError(f"Table {name} does not exists in database")