From fbcfaacda308c7a89527caefbf06e5635b8c0ab1 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Fri, 30 Oct 2020 11:52:11 -0700 Subject: [PATCH] feat: create base class for export commands (#11463) * Add UUID to saved_query * Reuse function from previous migration * Point to new head * feat: add backend to export saved queries using new format * Rename ImportMixin to ImportExportMixin * Create base class for exports * Add saved queries as well * Add constant, small fixes * Fix wrong import * Fix lint --- superset/charts/commands/export.py | 39 ++++------- superset/dashboards/commands/export.py | 31 +++------ superset/databases/commands/export.py | 35 +++------- superset/datasets/commands/export.py | 40 +++-------- superset/importexport/commands/base.py | 69 +++++++++++++++++++ .../queries/saved_queries/commands/export.py | 39 ++++------- tests/charts/commands_tests.py | 9 +-- tests/dashboards/commands_tests.py | 9 +-- tests/databases/commands_tests.py | 9 +-- tests/datasets/commands_tests.py | 9 +-- tests/importexport/commands_tests.py | 48 +++++++++++++ tests/queries/saved_queries/commands_tests.py | 9 +-- 12 files changed, 199 insertions(+), 147 deletions(-) create mode 100644 superset/importexport/commands/base.py create mode 100644 tests/importexport/commands_tests.py diff --git a/superset/charts/commands/export.py b/superset/charts/commands/export.py index db90e742da..23bdb55b24 100644 --- a/superset/charts/commands/export.py +++ b/superset/charts/commands/export.py @@ -18,16 +18,16 @@ import json import logging -from typing import Iterator, List, Tuple +from typing import Iterator, Tuple import yaml -from superset.commands.base import BaseCommand from superset.charts.commands.exceptions import ChartNotFoundError from superset.charts.dao import ChartDAO from superset.datasets.commands.export import ExportDatasetsCommand -from superset.utils.dict_import_export import IMPORT_EXPORT_VERSION, sanitize +from superset.importexport.commands.base import ExportModelsCommand from superset.models.slice import Slice +from superset.utils.dict_import_export import IMPORT_EXPORT_VERSION, sanitize logger = logging.getLogger(__name__) @@ -36,19 +36,17 @@ logger = logging.getLogger(__name__) REMOVE_KEYS = ["datasource_type", "datasource_name"] -class ExportChartsCommand(BaseCommand): - def __init__(self, chart_ids: List[int]): - self.chart_ids = chart_ids +class ExportChartsCommand(ExportModelsCommand): - # this will be set when calling validate() - self._models: List[Slice] = [] + dao = ChartDAO + not_found = ChartNotFoundError @staticmethod - def export_chart(chart: Slice) -> Iterator[Tuple[str, str]]: - chart_slug = sanitize(chart.slice_name) + def export(model: Slice) -> Iterator[Tuple[str, str]]: + chart_slug = sanitize(model.slice_name) file_name = f"charts/{chart_slug}.yaml" - payload = chart.export_to_dict( + payload = model.export_to_dict( recursive=False, include_parent_ref=False, include_defaults=True, @@ -65,22 +63,11 @@ class ExportChartsCommand(BaseCommand): logger.info("Unable to decode `params` field: %s", payload["params"]) payload["version"] = IMPORT_EXPORT_VERSION - if chart.table: - payload["dataset_uuid"] = str(chart.table.uuid) + if model.table: + payload["dataset_uuid"] = str(model.table.uuid) file_content = yaml.safe_dump(payload, sort_keys=False) yield file_name, file_content - if chart.table: - yield from ExportDatasetsCommand([chart.table.id]).run() - - def run(self) -> Iterator[Tuple[str, str]]: - self.validate() - - for chart in self._models: - yield from self.export_chart(chart) - - def validate(self) -> None: - self._models = ChartDAO.find_by_ids(self.chart_ids) - if len(self._models) != len(self.chart_ids): - raise ChartNotFoundError() + if model.table: + yield from ExportDatasetsCommand([model.table.id]).run() diff --git a/superset/dashboards/commands/export.py b/superset/dashboards/commands/export.py index f769a67480..ba55b64be4 100644 --- a/superset/dashboards/commands/export.py +++ b/superset/dashboards/commands/export.py @@ -18,14 +18,14 @@ import json import logging -from typing import Iterator, List, Tuple +from typing import Iterator, Tuple import yaml -from superset.commands.base import BaseCommand from superset.charts.commands.export import ExportChartsCommand from superset.dashboards.commands.exceptions import DashboardNotFoundError from superset.dashboards.dao import DashboardDAO +from superset.importexport.commands.base import ExportModelsCommand from superset.models.dashboard import Dashboard from superset.utils.dict_import_export import IMPORT_EXPORT_VERSION, sanitize @@ -36,19 +36,17 @@ logger = logging.getLogger(__name__) JSON_KEYS = {"position_json": "position", "json_metadata": "metadata"} -class ExportDashboardsCommand(BaseCommand): - def __init__(self, dashboard_ids: List[int]): - self.dashboard_ids = dashboard_ids +class ExportDashboardsCommand(ExportModelsCommand): - # this will be set when calling validate() - self._models: List[Dashboard] = [] + dao = DashboardDAO + not_found = DashboardNotFoundError @staticmethod - def export_dashboard(dashboard: Dashboard) -> Iterator[Tuple[str, str]]: - dashboard_slug = sanitize(dashboard.dashboard_title) + def export(model: Dashboard) -> Iterator[Tuple[str, str]]: + dashboard_slug = sanitize(model.dashboard_title) file_name = f"dashboards/{dashboard_slug}.yaml" - payload = dashboard.export_to_dict( + payload = model.export_to_dict( recursive=False, include_parent_ref=False, include_defaults=True, @@ -69,16 +67,5 @@ class ExportDashboardsCommand(BaseCommand): file_content = yaml.safe_dump(payload, sort_keys=False) yield file_name, file_content - chart_ids = [chart.id for chart in dashboard.slices] + chart_ids = [chart.id for chart in model.slices] yield from ExportChartsCommand(chart_ids).run() - - def run(self) -> Iterator[Tuple[str, str]]: - self.validate() - - for dashboard in self._models: - yield from self.export_dashboard(dashboard) - - def validate(self) -> None: - self._models = DashboardDAO.find_by_ids(self.dashboard_ids) - if len(self._models) != len(self.dashboard_ids): - raise DashboardNotFoundError() diff --git a/superset/databases/commands/export.py b/superset/databases/commands/export.py index 9009d200fa..a7715f5c0e 100644 --- a/superset/databases/commands/export.py +++ b/superset/databases/commands/export.py @@ -18,32 +18,30 @@ import json import logging -from typing import Iterator, List, Tuple +from typing import Iterator, Tuple import yaml -from superset.commands.base import BaseCommand from superset.databases.commands.exceptions import DatabaseNotFoundError from superset.databases.dao import DatabaseDAO -from superset.utils.dict_import_export import IMPORT_EXPORT_VERSION, sanitize +from superset.importexport.commands.base import ExportModelsCommand from superset.models.core import Database +from superset.utils.dict_import_export import IMPORT_EXPORT_VERSION, sanitize logger = logging.getLogger(__name__) -class ExportDatabasesCommand(BaseCommand): - def __init__(self, database_ids: List[int]): - self.database_ids = database_ids +class ExportDatabasesCommand(ExportModelsCommand): - # this will be set when calling validate() - self._models: List[Database] = [] + dao = DatabaseDAO + not_found = DatabaseNotFoundError @staticmethod - def export_database(database: Database) -> Iterator[Tuple[str, str]]: - database_slug = sanitize(database.database_name) + def export(model: Database) -> Iterator[Tuple[str, str]]: + database_slug = sanitize(model.database_name) file_name = f"databases/{database_slug}.yaml" - payload = database.export_to_dict( + payload = model.export_to_dict( recursive=False, include_parent_ref=False, include_defaults=True, @@ -62,7 +60,7 @@ class ExportDatabasesCommand(BaseCommand): file_content = yaml.safe_dump(payload, sort_keys=False) yield file_name, file_content - for dataset in database.tables: + for dataset in model.tables: dataset_slug = sanitize(dataset.table_name) file_name = f"datasets/{database_slug}/{dataset_slug}.yaml" @@ -73,18 +71,7 @@ class ExportDatabasesCommand(BaseCommand): export_uuids=True, ) payload["version"] = IMPORT_EXPORT_VERSION - payload["database_uuid"] = str(database.uuid) + payload["database_uuid"] = str(model.uuid) file_content = yaml.safe_dump(payload, sort_keys=False) yield file_name, file_content - - def run(self) -> Iterator[Tuple[str, str]]: - self.validate() - - for database in self._models: - yield from self.export_database(database) - - def validate(self) -> None: - self._models = DatabaseDAO.find_by_ids(self.database_ids) - if len(self._models) != len(self.database_ids): - raise DatabaseNotFoundError() diff --git a/superset/datasets/commands/export.py b/superset/datasets/commands/export.py index cc8be2eef0..a14cdcd67e 100644 --- a/superset/datasets/commands/export.py +++ b/superset/datasets/commands/export.py @@ -18,33 +18,31 @@ import json import logging -from typing import Iterator, List, Tuple +from typing import Iterator, Tuple import yaml -from superset.commands.base import BaseCommand from superset.connectors.sqla.models import SqlaTable from superset.datasets.commands.exceptions import DatasetNotFoundError from superset.datasets.dao import DatasetDAO +from superset.importexport.commands.base import ExportModelsCommand from superset.utils.dict_import_export import IMPORT_EXPORT_VERSION, sanitize logger = logging.getLogger(__name__) -class ExportDatasetsCommand(BaseCommand): - def __init__(self, dataset_ids: List[int]): - self.dataset_ids = dataset_ids +class ExportDatasetsCommand(ExportModelsCommand): - # this will be set when calling validate() - self._models: List[SqlaTable] = [] + dao = DatasetDAO + not_found = DatasetNotFoundError @staticmethod - def export_dataset(dataset: SqlaTable) -> Iterator[Tuple[str, str]]: - database_slug = sanitize(dataset.database.database_name) - dataset_slug = sanitize(dataset.table_name) + def export(model: SqlaTable) -> Iterator[Tuple[str, str]]: + database_slug = sanitize(model.database.database_name) + dataset_slug = sanitize(model.table_name) file_name = f"datasets/{database_slug}/{dataset_slug}.yaml" - payload = dataset.export_to_dict( + payload = model.export_to_dict( recursive=True, include_parent_ref=False, include_defaults=True, @@ -52,7 +50,7 @@ class ExportDatasetsCommand(BaseCommand): ) payload["version"] = IMPORT_EXPORT_VERSION - payload["database_uuid"] = str(dataset.database.uuid) + payload["database_uuid"] = str(model.database.uuid) file_content = yaml.safe_dump(payload, sort_keys=False) yield file_name, file_content @@ -60,7 +58,7 @@ class ExportDatasetsCommand(BaseCommand): # include database as well file_name = f"databases/{database_slug}.yaml" - payload = dataset.database.export_to_dict( + payload = model.database.export_to_dict( recursive=False, include_parent_ref=False, include_defaults=True, @@ -78,19 +76,3 @@ class ExportDatasetsCommand(BaseCommand): file_content = yaml.safe_dump(payload, sort_keys=False) yield file_name, file_content - - def run(self) -> Iterator[Tuple[str, str]]: - self.validate() - - seen = set() - for dataset in self._models: - for file_name, file_content in self.export_dataset(dataset): - # ignore repeated databases - if file_name not in seen: - yield file_name, file_content - seen.add(file_name) - - def validate(self) -> None: - self._models = DatasetDAO.find_by_ids(self.dataset_ids) - if len(self._models) != len(self.dataset_ids): - raise DatasetNotFoundError() diff --git a/superset/importexport/commands/base.py b/superset/importexport/commands/base.py new file mode 100644 index 0000000000..1c687fbea2 --- /dev/null +++ b/superset/importexport/commands/base.py @@ -0,0 +1,69 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# isort:skip_file + +from datetime import datetime +from datetime import timezone +from typing import Iterator, List, Tuple + +import yaml +from flask_appbuilder import Model + +from superset.commands.base import BaseCommand +from superset.commands.exceptions import CommandException +from superset.dao.base import BaseDAO +from superset.utils.dict_import_export import IMPORT_EXPORT_VERSION + +METADATA_FILE_NAME = "metadata.yaml" + + +class ExportModelsCommand(BaseCommand): + + dao = BaseDAO + not_found = CommandException + + def __init__(self, model_ids: List[int]): + self.model_ids = model_ids + + # this will be set when calling validate() + self._models: List[Model] = [] + + @staticmethod + def export(model: Model) -> Iterator[Tuple[str, str]]: + raise NotImplementedError("Subclasses MUST implement export") + + def run(self) -> Iterator[Tuple[str, str]]: + self.validate() + + metadata = { + "version": IMPORT_EXPORT_VERSION, + "type": self.dao.model_cls.__name__, # type: ignore + "timestamp": datetime.now(tz=timezone.utc).isoformat(), + } + yield METADATA_FILE_NAME, yaml.safe_dump(metadata, sort_keys=False) + + seen = set() + for model in self._models: + for file_name, file_content in self.export(model): + if file_name not in seen: + yield file_name, file_content + seen.add(file_name) + + def validate(self) -> None: + self._models = self.dao.find_by_ids(self.model_ids) + if len(self._models) != len(self.model_ids): + raise self.not_found() diff --git a/superset/queries/saved_queries/commands/export.py b/superset/queries/saved_queries/commands/export.py index 44c0c1f604..33dfffc86e 100644 --- a/superset/queries/saved_queries/commands/export.py +++ b/superset/queries/saved_queries/commands/export.py @@ -18,42 +18,40 @@ import json import logging -from typing import Iterator, List, Tuple +from typing import Iterator, Tuple import yaml -from superset.commands.base import BaseCommand +from superset.importexport.commands.base import ExportModelsCommand +from superset.models.sql_lab import SavedQuery from superset.queries.saved_queries.commands.exceptions import SavedQueryNotFoundError from superset.queries.saved_queries.dao import SavedQueryDAO from superset.utils.dict_import_export import IMPORT_EXPORT_VERSION, sanitize -from superset.models.sql_lab import SavedQuery logger = logging.getLogger(__name__) -class ExportSavedQueriesCommand(BaseCommand): - def __init__(self, query_ids: List[int]): - self.query_ids = query_ids +class ExportSavedQueriesCommand(ExportModelsCommand): - # this will be set when calling validate() - self._models: List[SavedQuery] = [] + dao = SavedQueryDAO + not_found = SavedQueryNotFoundError @staticmethod - def export_saved_query(query: SavedQuery) -> Iterator[Tuple[str, str]]: + def export(model: SavedQuery) -> Iterator[Tuple[str, str]]: # build filename based on database, optional schema, and label - database_slug = sanitize(query.database.database_name) - schema_slug = sanitize(query.schema) - query_slug = sanitize(query.label) or str(query.uuid) + database_slug = sanitize(model.database.database_name) + schema_slug = sanitize(model.schema) + query_slug = sanitize(model.label) or str(model.uuid) file_name = f"queries/{database_slug}/{schema_slug}/{query_slug}.yaml" - payload = query.export_to_dict( + payload = model.export_to_dict( recursive=False, include_parent_ref=False, include_defaults=True, export_uuids=True, ) payload["version"] = IMPORT_EXPORT_VERSION - payload["database_uuid"] = str(query.database.uuid) + payload["database_uuid"] = str(model.database.uuid) file_content = yaml.safe_dump(payload, sort_keys=False) yield file_name, file_content @@ -61,7 +59,7 @@ class ExportSavedQueriesCommand(BaseCommand): # include database as well file_name = f"databases/{database_slug}.yaml" - payload = query.database.export_to_dict( + payload = model.database.export_to_dict( recursive=False, include_parent_ref=False, include_defaults=True, @@ -79,14 +77,3 @@ class ExportSavedQueriesCommand(BaseCommand): file_content = yaml.safe_dump(payload, sort_keys=False) yield file_name, file_content - - def run(self) -> Iterator[Tuple[str, str]]: - self.validate() - - for query in self._models: - yield from self.export_saved_query(query) - - def validate(self) -> None: - self._models = SavedQueryDAO.find_by_ids(self.query_ids) - if len(self._models) != len(self.query_ids): - raise SavedQueryNotFoundError() diff --git a/tests/charts/commands_tests.py b/tests/charts/commands_tests.py index 6923757cb6..13dec7bd30 100644 --- a/tests/charts/commands_tests.py +++ b/tests/charts/commands_tests.py @@ -32,10 +32,11 @@ class TestExportChartsCommand(SupersetTestCase): mock_g.user = security_manager.find_user("admin") example_chart = db.session.query(Slice).all()[0] - command = ExportChartsCommand(chart_ids=[example_chart.id]) + command = ExportChartsCommand([example_chart.id]) contents = dict(command.run()) expected = [ + "metadata.yaml", "charts/energy_sankey.yaml", "datasets/examples/energy_usage.yaml", "databases/examples.yaml", @@ -66,7 +67,7 @@ class TestExportChartsCommand(SupersetTestCase): mock_g.user = security_manager.find_user("gamma") example_chart = db.session.query(Slice).all()[0] - command = ExportChartsCommand(chart_ids=[example_chart.id]) + command = ExportChartsCommand([example_chart.id]) contents = command.run() with self.assertRaises(ChartNotFoundError): next(contents) @@ -75,7 +76,7 @@ class TestExportChartsCommand(SupersetTestCase): def test_export_chart_command_invalid_dataset(self, mock_g): """Test that an error is raised when exporting an invalid dataset""" mock_g.user = security_manager.find_user("admin") - command = ExportChartsCommand(chart_ids=[-1]) + command = ExportChartsCommand([-1]) contents = command.run() with self.assertRaises(ChartNotFoundError): next(contents) @@ -86,7 +87,7 @@ class TestExportChartsCommand(SupersetTestCase): mock_g.user = security_manager.find_user("admin") example_chart = db.session.query(Slice).all()[0] - command = ExportChartsCommand(chart_ids=[example_chart.id]) + command = ExportChartsCommand([example_chart.id]) contents = dict(command.run()) metadata = yaml.safe_load(contents["charts/energy_sankey.yaml"]) diff --git a/tests/dashboards/commands_tests.py b/tests/dashboards/commands_tests.py index 10acf16e81..075d2b825c 100644 --- a/tests/dashboards/commands_tests.py +++ b/tests/dashboards/commands_tests.py @@ -34,10 +34,11 @@ class TestExportDashboardsCommand(SupersetTestCase): mock_g2.user = security_manager.find_user("admin") example_dashboard = db.session.query(Dashboard).filter_by(id=1).one() - command = ExportDashboardsCommand(dashboard_ids=[example_dashboard.id]) + command = ExportDashboardsCommand([example_dashboard.id]) contents = dict(command.run()) expected_paths = { + "metadata.yaml", "dashboards/world_banks_data.yaml", "charts/box_plot.yaml", "datasets/examples/wb_health_population.yaml", @@ -150,7 +151,7 @@ class TestExportDashboardsCommand(SupersetTestCase): mock_g2.user = security_manager.find_user("gamma") example_dashboard = db.session.query(Dashboard).filter_by(id=1).one() - command = ExportDashboardsCommand(dashboard_ids=[example_dashboard.id]) + command = ExportDashboardsCommand([example_dashboard.id]) contents = command.run() with self.assertRaises(DashboardNotFoundError): next(contents) @@ -161,7 +162,7 @@ class TestExportDashboardsCommand(SupersetTestCase): """Test that an error is raised when exporting an invalid dataset""" mock_g1.user = security_manager.find_user("admin") mock_g2.user = security_manager.find_user("admin") - command = ExportDashboardsCommand(dashboard_ids=[-1]) + command = ExportDashboardsCommand([-1]) contents = command.run() with self.assertRaises(DashboardNotFoundError): next(contents) @@ -174,7 +175,7 @@ class TestExportDashboardsCommand(SupersetTestCase): mock_g2.user = security_manager.find_user("admin") example_dashboard = db.session.query(Dashboard).filter_by(id=1).one() - command = ExportDashboardsCommand(dashboard_ids=[example_dashboard.id]) + command = ExportDashboardsCommand([example_dashboard.id]) contents = dict(command.run()) metadata = yaml.safe_load(contents["dashboards/world_banks_data.yaml"]) diff --git a/tests/databases/commands_tests.py b/tests/databases/commands_tests.py index 0ad2ad2044..bd4d3438d5 100644 --- a/tests/databases/commands_tests.py +++ b/tests/databases/commands_tests.py @@ -32,12 +32,13 @@ class TestExportDatabasesCommand(SupersetTestCase): mock_g.user = security_manager.find_user("admin") example_db = get_example_database() - command = ExportDatabasesCommand(database_ids=[example_db.id]) + command = ExportDatabasesCommand([example_db.id]) contents = dict(command.run()) # TODO: this list shouldn't depend on the order in which unit tests are run # or on the backend; for now use a stable subset core_files = { + "metadata.yaml", "databases/examples.yaml", "datasets/examples/energy_usage.yaml", "datasets/examples/wb_health_population.yaml", @@ -227,7 +228,7 @@ class TestExportDatabasesCommand(SupersetTestCase): mock_g.user = security_manager.find_user("gamma") example_db = get_example_database() - command = ExportDatabasesCommand(database_ids=[example_db.id]) + command = ExportDatabasesCommand([example_db.id]) contents = command.run() with self.assertRaises(DatabaseNotFoundError): next(contents) @@ -236,7 +237,7 @@ class TestExportDatabasesCommand(SupersetTestCase): def test_export_database_command_invalid_database(self, mock_g): """Test that an error is raised when exporting an invalid database""" mock_g.user = security_manager.find_user("admin") - command = ExportDatabasesCommand(database_ids=[-1]) + command = ExportDatabasesCommand([-1]) contents = command.run() with self.assertRaises(DatabaseNotFoundError): next(contents) @@ -247,7 +248,7 @@ class TestExportDatabasesCommand(SupersetTestCase): mock_g.user = security_manager.find_user("admin") example_db = get_example_database() - command = ExportDatabasesCommand(database_ids=[example_db.id]) + command = ExportDatabasesCommand([example_db.id]) contents = dict(command.run()) metadata = yaml.safe_load(contents["databases/examples.yaml"]) diff --git a/tests/datasets/commands_tests.py b/tests/datasets/commands_tests.py index bd038a4aff..17afe12662 100644 --- a/tests/datasets/commands_tests.py +++ b/tests/datasets/commands_tests.py @@ -35,10 +35,11 @@ class TestExportDatasetsCommand(SupersetTestCase): example_db = get_example_database() example_dataset = example_db.tables[0] - command = ExportDatasetsCommand(dataset_ids=[example_dataset.id]) + command = ExportDatasetsCommand([example_dataset.id]) contents = dict(command.run()) assert list(contents.keys()) == [ + "metadata.yaml", "datasets/examples/energy_usage.yaml", "databases/examples.yaml", ] @@ -140,7 +141,7 @@ class TestExportDatasetsCommand(SupersetTestCase): example_db = get_example_database() example_dataset = example_db.tables[0] - command = ExportDatasetsCommand(dataset_ids=[example_dataset.id]) + command = ExportDatasetsCommand([example_dataset.id]) contents = command.run() with self.assertRaises(DatasetNotFoundError): next(contents) @@ -149,7 +150,7 @@ class TestExportDatasetsCommand(SupersetTestCase): def test_export_dataset_command_invalid_dataset(self, mock_g): """Test that an error is raised when exporting an invalid dataset""" mock_g.user = security_manager.find_user("admin") - command = ExportDatasetsCommand(dataset_ids=[-1]) + command = ExportDatasetsCommand([-1]) contents = command.run() with self.assertRaises(DatasetNotFoundError): next(contents) @@ -161,7 +162,7 @@ class TestExportDatasetsCommand(SupersetTestCase): example_db = get_example_database() example_dataset = example_db.tables[0] - command = ExportDatasetsCommand(dataset_ids=[example_dataset.id]) + command = ExportDatasetsCommand([example_dataset.id]) contents = dict(command.run()) metadata = yaml.safe_load(contents["datasets/examples/energy_usage.yaml"]) diff --git a/tests/importexport/commands_tests.py b/tests/importexport/commands_tests.py new file mode 100644 index 0000000000..a8055c1b90 --- /dev/null +++ b/tests/importexport/commands_tests.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from unittest.mock import patch + +import yaml +from freezegun import freeze_time + +from superset import security_manager +from superset.databases.commands.export import ExportDatabasesCommand +from superset.utils.core import get_example_database +from tests.base_tests import SupersetTestCase + + +class TestExportModelsCommand(SupersetTestCase): + @patch("superset.security.manager.g") + def test_export_models_command(self, mock_g): + """Make sure metadata.yaml has the correct content.""" + mock_g.user = security_manager.find_user("admin") + + example_db = get_example_database() + + with freeze_time("2020-01-01T00:00:00Z"): + command = ExportDatabasesCommand([example_db.id]) + contents = dict(command.run()) + + metadata = yaml.safe_load(contents["metadata.yaml"]) + assert metadata == ( + { + "version": "1.0.0", + "type": "Database", + "timestamp": "2020-01-01T00:00:00+00:00", + } + ) diff --git a/tests/queries/saved_queries/commands_tests.py b/tests/queries/saved_queries/commands_tests.py index acd81af142..34f4dbe8e4 100644 --- a/tests/queries/saved_queries/commands_tests.py +++ b/tests/queries/saved_queries/commands_tests.py @@ -49,10 +49,11 @@ class TestExportSavedQueriesCommand(SupersetTestCase): def test_export_query_command(self, mock_g): mock_g.user = security_manager.find_user("admin") - command = ExportSavedQueriesCommand(query_ids=[self.example_query.id]) + command = ExportSavedQueriesCommand([self.example_query.id]) contents = dict(command.run()) expected = [ + "metadata.yaml", "queries/examples/schema1/the_answer.yaml", "databases/examples.yaml", ] @@ -74,7 +75,7 @@ class TestExportSavedQueriesCommand(SupersetTestCase): """Test that users can't export datasets they don't have access to""" mock_g.user = security_manager.find_user("gamma") - command = ExportSavedQueriesCommand(query_ids=[self.example_query.id]) + command = ExportSavedQueriesCommand([self.example_query.id]) contents = command.run() with self.assertRaises(SavedQueryNotFoundError): next(contents) @@ -84,7 +85,7 @@ class TestExportSavedQueriesCommand(SupersetTestCase): """Test that an error is raised when exporting an invalid dataset""" mock_g.user = security_manager.find_user("admin") - command = ExportSavedQueriesCommand(query_ids=[-1]) + command = ExportSavedQueriesCommand([-1]) contents = command.run() with self.assertRaises(SavedQueryNotFoundError): next(contents) @@ -94,7 +95,7 @@ class TestExportSavedQueriesCommand(SupersetTestCase): """Test that they keys in the YAML have the same order as export_fields""" mock_g.user = security_manager.find_user("admin") - command = ExportSavedQueriesCommand(query_ids=[self.example_query.id]) + command = ExportSavedQueriesCommand([self.example_query.id]) contents = dict(command.run()) metadata = yaml.safe_load(contents["queries/examples/schema1/the_answer.yaml"])