feat: create base class for export commands (#11463)

* Add UUID to saved_query

* Reuse function from previous migration

* Point to new head

* feat: add backend to export saved queries using new format

* Rename ImportMixin to ImportExportMixin

* Create base class for exports

* Add saved queries as well

* Add constant, small fixes

* Fix wrong import

* Fix lint
This commit is contained in:
Beto Dealmeida 2020-10-30 11:52:11 -07:00 committed by GitHub
parent ca40877640
commit fbcfaacda3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 199 additions and 147 deletions

View File

@ -18,16 +18,16 @@
import json import json
import logging import logging
from typing import Iterator, List, Tuple from typing import Iterator, Tuple
import yaml import yaml
from superset.commands.base import BaseCommand
from superset.charts.commands.exceptions import ChartNotFoundError from superset.charts.commands.exceptions import ChartNotFoundError
from superset.charts.dao import ChartDAO from superset.charts.dao import ChartDAO
from superset.datasets.commands.export import ExportDatasetsCommand from superset.datasets.commands.export import ExportDatasetsCommand
from superset.utils.dict_import_export import IMPORT_EXPORT_VERSION, sanitize from superset.importexport.commands.base import ExportModelsCommand
from superset.models.slice import Slice from superset.models.slice import Slice
from superset.utils.dict_import_export import IMPORT_EXPORT_VERSION, sanitize
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -36,19 +36,17 @@ logger = logging.getLogger(__name__)
REMOVE_KEYS = ["datasource_type", "datasource_name"] REMOVE_KEYS = ["datasource_type", "datasource_name"]
class ExportChartsCommand(BaseCommand): class ExportChartsCommand(ExportModelsCommand):
def __init__(self, chart_ids: List[int]):
self.chart_ids = chart_ids
# this will be set when calling validate() dao = ChartDAO
self._models: List[Slice] = [] not_found = ChartNotFoundError
@staticmethod @staticmethod
def export_chart(chart: Slice) -> Iterator[Tuple[str, str]]: def export(model: Slice) -> Iterator[Tuple[str, str]]:
chart_slug = sanitize(chart.slice_name) chart_slug = sanitize(model.slice_name)
file_name = f"charts/{chart_slug}.yaml" file_name = f"charts/{chart_slug}.yaml"
payload = chart.export_to_dict( payload = model.export_to_dict(
recursive=False, recursive=False,
include_parent_ref=False, include_parent_ref=False,
include_defaults=True, include_defaults=True,
@ -65,22 +63,11 @@ class ExportChartsCommand(BaseCommand):
logger.info("Unable to decode `params` field: %s", payload["params"]) logger.info("Unable to decode `params` field: %s", payload["params"])
payload["version"] = IMPORT_EXPORT_VERSION payload["version"] = IMPORT_EXPORT_VERSION
if chart.table: if model.table:
payload["dataset_uuid"] = str(chart.table.uuid) payload["dataset_uuid"] = str(model.table.uuid)
file_content = yaml.safe_dump(payload, sort_keys=False) file_content = yaml.safe_dump(payload, sort_keys=False)
yield file_name, file_content yield file_name, file_content
if chart.table: if model.table:
yield from ExportDatasetsCommand([chart.table.id]).run() yield from ExportDatasetsCommand([model.table.id]).run()
def run(self) -> Iterator[Tuple[str, str]]:
self.validate()
for chart in self._models:
yield from self.export_chart(chart)
def validate(self) -> None:
self._models = ChartDAO.find_by_ids(self.chart_ids)
if len(self._models) != len(self.chart_ids):
raise ChartNotFoundError()

View File

@ -18,14 +18,14 @@
import json import json
import logging import logging
from typing import Iterator, List, Tuple from typing import Iterator, Tuple
import yaml import yaml
from superset.commands.base import BaseCommand
from superset.charts.commands.export import ExportChartsCommand from superset.charts.commands.export import ExportChartsCommand
from superset.dashboards.commands.exceptions import DashboardNotFoundError from superset.dashboards.commands.exceptions import DashboardNotFoundError
from superset.dashboards.dao import DashboardDAO from superset.dashboards.dao import DashboardDAO
from superset.importexport.commands.base import ExportModelsCommand
from superset.models.dashboard import Dashboard from superset.models.dashboard import Dashboard
from superset.utils.dict_import_export import IMPORT_EXPORT_VERSION, sanitize from superset.utils.dict_import_export import IMPORT_EXPORT_VERSION, sanitize
@ -36,19 +36,17 @@ logger = logging.getLogger(__name__)
JSON_KEYS = {"position_json": "position", "json_metadata": "metadata"} JSON_KEYS = {"position_json": "position", "json_metadata": "metadata"}
class ExportDashboardsCommand(BaseCommand): class ExportDashboardsCommand(ExportModelsCommand):
def __init__(self, dashboard_ids: List[int]):
self.dashboard_ids = dashboard_ids
# this will be set when calling validate() dao = DashboardDAO
self._models: List[Dashboard] = [] not_found = DashboardNotFoundError
@staticmethod @staticmethod
def export_dashboard(dashboard: Dashboard) -> Iterator[Tuple[str, str]]: def export(model: Dashboard) -> Iterator[Tuple[str, str]]:
dashboard_slug = sanitize(dashboard.dashboard_title) dashboard_slug = sanitize(model.dashboard_title)
file_name = f"dashboards/{dashboard_slug}.yaml" file_name = f"dashboards/{dashboard_slug}.yaml"
payload = dashboard.export_to_dict( payload = model.export_to_dict(
recursive=False, recursive=False,
include_parent_ref=False, include_parent_ref=False,
include_defaults=True, include_defaults=True,
@ -69,16 +67,5 @@ class ExportDashboardsCommand(BaseCommand):
file_content = yaml.safe_dump(payload, sort_keys=False) file_content = yaml.safe_dump(payload, sort_keys=False)
yield file_name, file_content yield file_name, file_content
chart_ids = [chart.id for chart in dashboard.slices] chart_ids = [chart.id for chart in model.slices]
yield from ExportChartsCommand(chart_ids).run() yield from ExportChartsCommand(chart_ids).run()
def run(self) -> Iterator[Tuple[str, str]]:
self.validate()
for dashboard in self._models:
yield from self.export_dashboard(dashboard)
def validate(self) -> None:
self._models = DashboardDAO.find_by_ids(self.dashboard_ids)
if len(self._models) != len(self.dashboard_ids):
raise DashboardNotFoundError()

View File

@ -18,32 +18,30 @@
import json import json
import logging import logging
from typing import Iterator, List, Tuple from typing import Iterator, Tuple
import yaml import yaml
from superset.commands.base import BaseCommand
from superset.databases.commands.exceptions import DatabaseNotFoundError from superset.databases.commands.exceptions import DatabaseNotFoundError
from superset.databases.dao import DatabaseDAO from superset.databases.dao import DatabaseDAO
from superset.utils.dict_import_export import IMPORT_EXPORT_VERSION, sanitize from superset.importexport.commands.base import ExportModelsCommand
from superset.models.core import Database from superset.models.core import Database
from superset.utils.dict_import_export import IMPORT_EXPORT_VERSION, sanitize
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ExportDatabasesCommand(BaseCommand): class ExportDatabasesCommand(ExportModelsCommand):
def __init__(self, database_ids: List[int]):
self.database_ids = database_ids
# this will be set when calling validate() dao = DatabaseDAO
self._models: List[Database] = [] not_found = DatabaseNotFoundError
@staticmethod @staticmethod
def export_database(database: Database) -> Iterator[Tuple[str, str]]: def export(model: Database) -> Iterator[Tuple[str, str]]:
database_slug = sanitize(database.database_name) database_slug = sanitize(model.database_name)
file_name = f"databases/{database_slug}.yaml" file_name = f"databases/{database_slug}.yaml"
payload = database.export_to_dict( payload = model.export_to_dict(
recursive=False, recursive=False,
include_parent_ref=False, include_parent_ref=False,
include_defaults=True, include_defaults=True,
@ -62,7 +60,7 @@ class ExportDatabasesCommand(BaseCommand):
file_content = yaml.safe_dump(payload, sort_keys=False) file_content = yaml.safe_dump(payload, sort_keys=False)
yield file_name, file_content yield file_name, file_content
for dataset in database.tables: for dataset in model.tables:
dataset_slug = sanitize(dataset.table_name) dataset_slug = sanitize(dataset.table_name)
file_name = f"datasets/{database_slug}/{dataset_slug}.yaml" file_name = f"datasets/{database_slug}/{dataset_slug}.yaml"
@ -73,18 +71,7 @@ class ExportDatabasesCommand(BaseCommand):
export_uuids=True, export_uuids=True,
) )
payload["version"] = IMPORT_EXPORT_VERSION payload["version"] = IMPORT_EXPORT_VERSION
payload["database_uuid"] = str(database.uuid) payload["database_uuid"] = str(model.uuid)
file_content = yaml.safe_dump(payload, sort_keys=False) file_content = yaml.safe_dump(payload, sort_keys=False)
yield file_name, file_content yield file_name, file_content
def run(self) -> Iterator[Tuple[str, str]]:
self.validate()
for database in self._models:
yield from self.export_database(database)
def validate(self) -> None:
self._models = DatabaseDAO.find_by_ids(self.database_ids)
if len(self._models) != len(self.database_ids):
raise DatabaseNotFoundError()

View File

@ -18,33 +18,31 @@
import json import json
import logging import logging
from typing import Iterator, List, Tuple from typing import Iterator, Tuple
import yaml import yaml
from superset.commands.base import BaseCommand
from superset.connectors.sqla.models import SqlaTable from superset.connectors.sqla.models import SqlaTable
from superset.datasets.commands.exceptions import DatasetNotFoundError from superset.datasets.commands.exceptions import DatasetNotFoundError
from superset.datasets.dao import DatasetDAO from superset.datasets.dao import DatasetDAO
from superset.importexport.commands.base import ExportModelsCommand
from superset.utils.dict_import_export import IMPORT_EXPORT_VERSION, sanitize from superset.utils.dict_import_export import IMPORT_EXPORT_VERSION, sanitize
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ExportDatasetsCommand(BaseCommand): class ExportDatasetsCommand(ExportModelsCommand):
def __init__(self, dataset_ids: List[int]):
self.dataset_ids = dataset_ids
# this will be set when calling validate() dao = DatasetDAO
self._models: List[SqlaTable] = [] not_found = DatasetNotFoundError
@staticmethod @staticmethod
def export_dataset(dataset: SqlaTable) -> Iterator[Tuple[str, str]]: def export(model: SqlaTable) -> Iterator[Tuple[str, str]]:
database_slug = sanitize(dataset.database.database_name) database_slug = sanitize(model.database.database_name)
dataset_slug = sanitize(dataset.table_name) dataset_slug = sanitize(model.table_name)
file_name = f"datasets/{database_slug}/{dataset_slug}.yaml" file_name = f"datasets/{database_slug}/{dataset_slug}.yaml"
payload = dataset.export_to_dict( payload = model.export_to_dict(
recursive=True, recursive=True,
include_parent_ref=False, include_parent_ref=False,
include_defaults=True, include_defaults=True,
@ -52,7 +50,7 @@ class ExportDatasetsCommand(BaseCommand):
) )
payload["version"] = IMPORT_EXPORT_VERSION payload["version"] = IMPORT_EXPORT_VERSION
payload["database_uuid"] = str(dataset.database.uuid) payload["database_uuid"] = str(model.database.uuid)
file_content = yaml.safe_dump(payload, sort_keys=False) file_content = yaml.safe_dump(payload, sort_keys=False)
yield file_name, file_content yield file_name, file_content
@ -60,7 +58,7 @@ class ExportDatasetsCommand(BaseCommand):
# include database as well # include database as well
file_name = f"databases/{database_slug}.yaml" file_name = f"databases/{database_slug}.yaml"
payload = dataset.database.export_to_dict( payload = model.database.export_to_dict(
recursive=False, recursive=False,
include_parent_ref=False, include_parent_ref=False,
include_defaults=True, include_defaults=True,
@ -78,19 +76,3 @@ class ExportDatasetsCommand(BaseCommand):
file_content = yaml.safe_dump(payload, sort_keys=False) file_content = yaml.safe_dump(payload, sort_keys=False)
yield file_name, file_content yield file_name, file_content
def run(self) -> Iterator[Tuple[str, str]]:
self.validate()
seen = set()
for dataset in self._models:
for file_name, file_content in self.export_dataset(dataset):
# ignore repeated databases
if file_name not in seen:
yield file_name, file_content
seen.add(file_name)
def validate(self) -> None:
self._models = DatasetDAO.find_by_ids(self.dataset_ids)
if len(self._models) != len(self.dataset_ids):
raise DatasetNotFoundError()

View File

@ -0,0 +1,69 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
from datetime import datetime
from datetime import timezone
from typing import Iterator, List, Tuple
import yaml
from flask_appbuilder import Model
from superset.commands.base import BaseCommand
from superset.commands.exceptions import CommandException
from superset.dao.base import BaseDAO
from superset.utils.dict_import_export import IMPORT_EXPORT_VERSION
METADATA_FILE_NAME = "metadata.yaml"
class ExportModelsCommand(BaseCommand):
dao = BaseDAO
not_found = CommandException
def __init__(self, model_ids: List[int]):
self.model_ids = model_ids
# this will be set when calling validate()
self._models: List[Model] = []
@staticmethod
def export(model: Model) -> Iterator[Tuple[str, str]]:
raise NotImplementedError("Subclasses MUST implement export")
def run(self) -> Iterator[Tuple[str, str]]:
self.validate()
metadata = {
"version": IMPORT_EXPORT_VERSION,
"type": self.dao.model_cls.__name__, # type: ignore
"timestamp": datetime.now(tz=timezone.utc).isoformat(),
}
yield METADATA_FILE_NAME, yaml.safe_dump(metadata, sort_keys=False)
seen = set()
for model in self._models:
for file_name, file_content in self.export(model):
if file_name not in seen:
yield file_name, file_content
seen.add(file_name)
def validate(self) -> None:
self._models = self.dao.find_by_ids(self.model_ids)
if len(self._models) != len(self.model_ids):
raise self.not_found()

View File

@ -18,42 +18,40 @@
import json import json
import logging import logging
from typing import Iterator, List, Tuple from typing import Iterator, Tuple
import yaml import yaml
from superset.commands.base import BaseCommand from superset.importexport.commands.base import ExportModelsCommand
from superset.models.sql_lab import SavedQuery
from superset.queries.saved_queries.commands.exceptions import SavedQueryNotFoundError from superset.queries.saved_queries.commands.exceptions import SavedQueryNotFoundError
from superset.queries.saved_queries.dao import SavedQueryDAO from superset.queries.saved_queries.dao import SavedQueryDAO
from superset.utils.dict_import_export import IMPORT_EXPORT_VERSION, sanitize from superset.utils.dict_import_export import IMPORT_EXPORT_VERSION, sanitize
from superset.models.sql_lab import SavedQuery
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ExportSavedQueriesCommand(BaseCommand): class ExportSavedQueriesCommand(ExportModelsCommand):
def __init__(self, query_ids: List[int]):
self.query_ids = query_ids
# this will be set when calling validate() dao = SavedQueryDAO
self._models: List[SavedQuery] = [] not_found = SavedQueryNotFoundError
@staticmethod @staticmethod
def export_saved_query(query: SavedQuery) -> Iterator[Tuple[str, str]]: def export(model: SavedQuery) -> Iterator[Tuple[str, str]]:
# build filename based on database, optional schema, and label # build filename based on database, optional schema, and label
database_slug = sanitize(query.database.database_name) database_slug = sanitize(model.database.database_name)
schema_slug = sanitize(query.schema) schema_slug = sanitize(model.schema)
query_slug = sanitize(query.label) or str(query.uuid) query_slug = sanitize(model.label) or str(model.uuid)
file_name = f"queries/{database_slug}/{schema_slug}/{query_slug}.yaml" file_name = f"queries/{database_slug}/{schema_slug}/{query_slug}.yaml"
payload = query.export_to_dict( payload = model.export_to_dict(
recursive=False, recursive=False,
include_parent_ref=False, include_parent_ref=False,
include_defaults=True, include_defaults=True,
export_uuids=True, export_uuids=True,
) )
payload["version"] = IMPORT_EXPORT_VERSION payload["version"] = IMPORT_EXPORT_VERSION
payload["database_uuid"] = str(query.database.uuid) payload["database_uuid"] = str(model.database.uuid)
file_content = yaml.safe_dump(payload, sort_keys=False) file_content = yaml.safe_dump(payload, sort_keys=False)
yield file_name, file_content yield file_name, file_content
@ -61,7 +59,7 @@ class ExportSavedQueriesCommand(BaseCommand):
# include database as well # include database as well
file_name = f"databases/{database_slug}.yaml" file_name = f"databases/{database_slug}.yaml"
payload = query.database.export_to_dict( payload = model.database.export_to_dict(
recursive=False, recursive=False,
include_parent_ref=False, include_parent_ref=False,
include_defaults=True, include_defaults=True,
@ -79,14 +77,3 @@ class ExportSavedQueriesCommand(BaseCommand):
file_content = yaml.safe_dump(payload, sort_keys=False) file_content = yaml.safe_dump(payload, sort_keys=False)
yield file_name, file_content yield file_name, file_content
def run(self) -> Iterator[Tuple[str, str]]:
self.validate()
for query in self._models:
yield from self.export_saved_query(query)
def validate(self) -> None:
self._models = SavedQueryDAO.find_by_ids(self.query_ids)
if len(self._models) != len(self.query_ids):
raise SavedQueryNotFoundError()

View File

@ -32,10 +32,11 @@ class TestExportChartsCommand(SupersetTestCase):
mock_g.user = security_manager.find_user("admin") mock_g.user = security_manager.find_user("admin")
example_chart = db.session.query(Slice).all()[0] example_chart = db.session.query(Slice).all()[0]
command = ExportChartsCommand(chart_ids=[example_chart.id]) command = ExportChartsCommand([example_chart.id])
contents = dict(command.run()) contents = dict(command.run())
expected = [ expected = [
"metadata.yaml",
"charts/energy_sankey.yaml", "charts/energy_sankey.yaml",
"datasets/examples/energy_usage.yaml", "datasets/examples/energy_usage.yaml",
"databases/examples.yaml", "databases/examples.yaml",
@ -66,7 +67,7 @@ class TestExportChartsCommand(SupersetTestCase):
mock_g.user = security_manager.find_user("gamma") mock_g.user = security_manager.find_user("gamma")
example_chart = db.session.query(Slice).all()[0] example_chart = db.session.query(Slice).all()[0]
command = ExportChartsCommand(chart_ids=[example_chart.id]) command = ExportChartsCommand([example_chart.id])
contents = command.run() contents = command.run()
with self.assertRaises(ChartNotFoundError): with self.assertRaises(ChartNotFoundError):
next(contents) next(contents)
@ -75,7 +76,7 @@ class TestExportChartsCommand(SupersetTestCase):
def test_export_chart_command_invalid_dataset(self, mock_g): def test_export_chart_command_invalid_dataset(self, mock_g):
"""Test that an error is raised when exporting an invalid dataset""" """Test that an error is raised when exporting an invalid dataset"""
mock_g.user = security_manager.find_user("admin") mock_g.user = security_manager.find_user("admin")
command = ExportChartsCommand(chart_ids=[-1]) command = ExportChartsCommand([-1])
contents = command.run() contents = command.run()
with self.assertRaises(ChartNotFoundError): with self.assertRaises(ChartNotFoundError):
next(contents) next(contents)
@ -86,7 +87,7 @@ class TestExportChartsCommand(SupersetTestCase):
mock_g.user = security_manager.find_user("admin") mock_g.user = security_manager.find_user("admin")
example_chart = db.session.query(Slice).all()[0] example_chart = db.session.query(Slice).all()[0]
command = ExportChartsCommand(chart_ids=[example_chart.id]) command = ExportChartsCommand([example_chart.id])
contents = dict(command.run()) contents = dict(command.run())
metadata = yaml.safe_load(contents["charts/energy_sankey.yaml"]) metadata = yaml.safe_load(contents["charts/energy_sankey.yaml"])

View File

@ -34,10 +34,11 @@ class TestExportDashboardsCommand(SupersetTestCase):
mock_g2.user = security_manager.find_user("admin") mock_g2.user = security_manager.find_user("admin")
example_dashboard = db.session.query(Dashboard).filter_by(id=1).one() example_dashboard = db.session.query(Dashboard).filter_by(id=1).one()
command = ExportDashboardsCommand(dashboard_ids=[example_dashboard.id]) command = ExportDashboardsCommand([example_dashboard.id])
contents = dict(command.run()) contents = dict(command.run())
expected_paths = { expected_paths = {
"metadata.yaml",
"dashboards/world_banks_data.yaml", "dashboards/world_banks_data.yaml",
"charts/box_plot.yaml", "charts/box_plot.yaml",
"datasets/examples/wb_health_population.yaml", "datasets/examples/wb_health_population.yaml",
@ -150,7 +151,7 @@ class TestExportDashboardsCommand(SupersetTestCase):
mock_g2.user = security_manager.find_user("gamma") mock_g2.user = security_manager.find_user("gamma")
example_dashboard = db.session.query(Dashboard).filter_by(id=1).one() example_dashboard = db.session.query(Dashboard).filter_by(id=1).one()
command = ExportDashboardsCommand(dashboard_ids=[example_dashboard.id]) command = ExportDashboardsCommand([example_dashboard.id])
contents = command.run() contents = command.run()
with self.assertRaises(DashboardNotFoundError): with self.assertRaises(DashboardNotFoundError):
next(contents) next(contents)
@ -161,7 +162,7 @@ class TestExportDashboardsCommand(SupersetTestCase):
"""Test that an error is raised when exporting an invalid dataset""" """Test that an error is raised when exporting an invalid dataset"""
mock_g1.user = security_manager.find_user("admin") mock_g1.user = security_manager.find_user("admin")
mock_g2.user = security_manager.find_user("admin") mock_g2.user = security_manager.find_user("admin")
command = ExportDashboardsCommand(dashboard_ids=[-1]) command = ExportDashboardsCommand([-1])
contents = command.run() contents = command.run()
with self.assertRaises(DashboardNotFoundError): with self.assertRaises(DashboardNotFoundError):
next(contents) next(contents)
@ -174,7 +175,7 @@ class TestExportDashboardsCommand(SupersetTestCase):
mock_g2.user = security_manager.find_user("admin") mock_g2.user = security_manager.find_user("admin")
example_dashboard = db.session.query(Dashboard).filter_by(id=1).one() example_dashboard = db.session.query(Dashboard).filter_by(id=1).one()
command = ExportDashboardsCommand(dashboard_ids=[example_dashboard.id]) command = ExportDashboardsCommand([example_dashboard.id])
contents = dict(command.run()) contents = dict(command.run())
metadata = yaml.safe_load(contents["dashboards/world_banks_data.yaml"]) metadata = yaml.safe_load(contents["dashboards/world_banks_data.yaml"])

View File

@ -32,12 +32,13 @@ class TestExportDatabasesCommand(SupersetTestCase):
mock_g.user = security_manager.find_user("admin") mock_g.user = security_manager.find_user("admin")
example_db = get_example_database() example_db = get_example_database()
command = ExportDatabasesCommand(database_ids=[example_db.id]) command = ExportDatabasesCommand([example_db.id])
contents = dict(command.run()) contents = dict(command.run())
# TODO: this list shouldn't depend on the order in which unit tests are run # TODO: this list shouldn't depend on the order in which unit tests are run
# or on the backend; for now use a stable subset # or on the backend; for now use a stable subset
core_files = { core_files = {
"metadata.yaml",
"databases/examples.yaml", "databases/examples.yaml",
"datasets/examples/energy_usage.yaml", "datasets/examples/energy_usage.yaml",
"datasets/examples/wb_health_population.yaml", "datasets/examples/wb_health_population.yaml",
@ -227,7 +228,7 @@ class TestExportDatabasesCommand(SupersetTestCase):
mock_g.user = security_manager.find_user("gamma") mock_g.user = security_manager.find_user("gamma")
example_db = get_example_database() example_db = get_example_database()
command = ExportDatabasesCommand(database_ids=[example_db.id]) command = ExportDatabasesCommand([example_db.id])
contents = command.run() contents = command.run()
with self.assertRaises(DatabaseNotFoundError): with self.assertRaises(DatabaseNotFoundError):
next(contents) next(contents)
@ -236,7 +237,7 @@ class TestExportDatabasesCommand(SupersetTestCase):
def test_export_database_command_invalid_database(self, mock_g): def test_export_database_command_invalid_database(self, mock_g):
"""Test that an error is raised when exporting an invalid database""" """Test that an error is raised when exporting an invalid database"""
mock_g.user = security_manager.find_user("admin") mock_g.user = security_manager.find_user("admin")
command = ExportDatabasesCommand(database_ids=[-1]) command = ExportDatabasesCommand([-1])
contents = command.run() contents = command.run()
with self.assertRaises(DatabaseNotFoundError): with self.assertRaises(DatabaseNotFoundError):
next(contents) next(contents)
@ -247,7 +248,7 @@ class TestExportDatabasesCommand(SupersetTestCase):
mock_g.user = security_manager.find_user("admin") mock_g.user = security_manager.find_user("admin")
example_db = get_example_database() example_db = get_example_database()
command = ExportDatabasesCommand(database_ids=[example_db.id]) command = ExportDatabasesCommand([example_db.id])
contents = dict(command.run()) contents = dict(command.run())
metadata = yaml.safe_load(contents["databases/examples.yaml"]) metadata = yaml.safe_load(contents["databases/examples.yaml"])

View File

@ -35,10 +35,11 @@ class TestExportDatasetsCommand(SupersetTestCase):
example_db = get_example_database() example_db = get_example_database()
example_dataset = example_db.tables[0] example_dataset = example_db.tables[0]
command = ExportDatasetsCommand(dataset_ids=[example_dataset.id]) command = ExportDatasetsCommand([example_dataset.id])
contents = dict(command.run()) contents = dict(command.run())
assert list(contents.keys()) == [ assert list(contents.keys()) == [
"metadata.yaml",
"datasets/examples/energy_usage.yaml", "datasets/examples/energy_usage.yaml",
"databases/examples.yaml", "databases/examples.yaml",
] ]
@ -140,7 +141,7 @@ class TestExportDatasetsCommand(SupersetTestCase):
example_db = get_example_database() example_db = get_example_database()
example_dataset = example_db.tables[0] example_dataset = example_db.tables[0]
command = ExportDatasetsCommand(dataset_ids=[example_dataset.id]) command = ExportDatasetsCommand([example_dataset.id])
contents = command.run() contents = command.run()
with self.assertRaises(DatasetNotFoundError): with self.assertRaises(DatasetNotFoundError):
next(contents) next(contents)
@ -149,7 +150,7 @@ class TestExportDatasetsCommand(SupersetTestCase):
def test_export_dataset_command_invalid_dataset(self, mock_g): def test_export_dataset_command_invalid_dataset(self, mock_g):
"""Test that an error is raised when exporting an invalid dataset""" """Test that an error is raised when exporting an invalid dataset"""
mock_g.user = security_manager.find_user("admin") mock_g.user = security_manager.find_user("admin")
command = ExportDatasetsCommand(dataset_ids=[-1]) command = ExportDatasetsCommand([-1])
contents = command.run() contents = command.run()
with self.assertRaises(DatasetNotFoundError): with self.assertRaises(DatasetNotFoundError):
next(contents) next(contents)
@ -161,7 +162,7 @@ class TestExportDatasetsCommand(SupersetTestCase):
example_db = get_example_database() example_db = get_example_database()
example_dataset = example_db.tables[0] example_dataset = example_db.tables[0]
command = ExportDatasetsCommand(dataset_ids=[example_dataset.id]) command = ExportDatasetsCommand([example_dataset.id])
contents = dict(command.run()) contents = dict(command.run())
metadata = yaml.safe_load(contents["datasets/examples/energy_usage.yaml"]) metadata = yaml.safe_load(contents["datasets/examples/energy_usage.yaml"])

View File

@ -0,0 +1,48 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest.mock import patch
import yaml
from freezegun import freeze_time
from superset import security_manager
from superset.databases.commands.export import ExportDatabasesCommand
from superset.utils.core import get_example_database
from tests.base_tests import SupersetTestCase
class TestExportModelsCommand(SupersetTestCase):
@patch("superset.security.manager.g")
def test_export_models_command(self, mock_g):
"""Make sure metadata.yaml has the correct content."""
mock_g.user = security_manager.find_user("admin")
example_db = get_example_database()
with freeze_time("2020-01-01T00:00:00Z"):
command = ExportDatabasesCommand([example_db.id])
contents = dict(command.run())
metadata = yaml.safe_load(contents["metadata.yaml"])
assert metadata == (
{
"version": "1.0.0",
"type": "Database",
"timestamp": "2020-01-01T00:00:00+00:00",
}
)

View File

@ -49,10 +49,11 @@ class TestExportSavedQueriesCommand(SupersetTestCase):
def test_export_query_command(self, mock_g): def test_export_query_command(self, mock_g):
mock_g.user = security_manager.find_user("admin") mock_g.user = security_manager.find_user("admin")
command = ExportSavedQueriesCommand(query_ids=[self.example_query.id]) command = ExportSavedQueriesCommand([self.example_query.id])
contents = dict(command.run()) contents = dict(command.run())
expected = [ expected = [
"metadata.yaml",
"queries/examples/schema1/the_answer.yaml", "queries/examples/schema1/the_answer.yaml",
"databases/examples.yaml", "databases/examples.yaml",
] ]
@ -74,7 +75,7 @@ class TestExportSavedQueriesCommand(SupersetTestCase):
"""Test that users can't export datasets they don't have access to""" """Test that users can't export datasets they don't have access to"""
mock_g.user = security_manager.find_user("gamma") mock_g.user = security_manager.find_user("gamma")
command = ExportSavedQueriesCommand(query_ids=[self.example_query.id]) command = ExportSavedQueriesCommand([self.example_query.id])
contents = command.run() contents = command.run()
with self.assertRaises(SavedQueryNotFoundError): with self.assertRaises(SavedQueryNotFoundError):
next(contents) next(contents)
@ -84,7 +85,7 @@ class TestExportSavedQueriesCommand(SupersetTestCase):
"""Test that an error is raised when exporting an invalid dataset""" """Test that an error is raised when exporting an invalid dataset"""
mock_g.user = security_manager.find_user("admin") mock_g.user = security_manager.find_user("admin")
command = ExportSavedQueriesCommand(query_ids=[-1]) command = ExportSavedQueriesCommand([-1])
contents = command.run() contents = command.run()
with self.assertRaises(SavedQueryNotFoundError): with self.assertRaises(SavedQueryNotFoundError):
next(contents) next(contents)
@ -94,7 +95,7 @@ class TestExportSavedQueriesCommand(SupersetTestCase):
"""Test that they keys in the YAML have the same order as export_fields""" """Test that they keys in the YAML have the same order as export_fields"""
mock_g.user = security_manager.find_user("admin") mock_g.user = security_manager.find_user("admin")
command = ExportSavedQueriesCommand(query_ids=[self.example_query.id]) command = ExportSavedQueriesCommand([self.example_query.id])
contents = dict(command.run()) contents = dict(command.run())
metadata = yaml.safe_load(contents["queries/examples/schema1/the_answer.yaml"]) metadata = yaml.safe_load(contents["queries/examples/schema1/the_answer.yaml"])