feat: new import commands for dataset and databases (#11670)

* feat: commands for importing databases and datasets

* Refactor code
This commit is contained in:
Beto Dealmeida 2020-11-16 17:11:20 -08:00 committed by GitHub
parent 871a98abe2
commit 7bc353f8a8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 983 additions and 7 deletions

View File

@ -304,6 +304,9 @@ def import_datasources(path: str, sync: str, recursive: bool) -> None:
from superset.datasets.commands.importers.v0 import ImportDatasetsCommand
sync_array = sync.split(",")
sync_columns = "columns" in sync_array
sync_metrics = "metrics" in sync_array
path_object = Path(path)
files: List[Path] = []
if path_object.is_file():
@ -316,7 +319,7 @@ def import_datasources(path: str, sync: str, recursive: bool) -> None:
files.extend(path_object.rglob("*.yml"))
contents = {path.name: open(path).read() for path in files}
try:
ImportDatasetsCommand(contents, sync_array).run()
ImportDatasetsCommand(contents, sync_columns, sync_metrics).run()
except Exception: # pylint: disable=broad-except
logger.exception("Error when importing dataset")

View File

@ -0,0 +1,23 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset.commands.exceptions import CommandException
class IncorrectVersionError(CommandException):
status = 422
message = "Import has incorrect version"

View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@ -0,0 +1,67 @@
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Dict
import yaml
from marshmallow import fields, Schema, validate
from marshmallow.exceptions import ValidationError
from superset.commands.importers.exceptions import IncorrectVersionError
METADATA_FILE_NAME = "metadata.yaml"
IMPORT_VERSION = "1.0.0"
logger = logging.getLogger(__name__)
class MetadataSchema(Schema):
version = fields.String(required=True, validate=validate.Equal(IMPORT_VERSION))
type = fields.String(required=True)
timestamp = fields.DateTime()
def load_yaml(file_name: str, content: str) -> Dict[str, Any]:
"""Try to load a YAML file"""
try:
return yaml.safe_load(content)
except yaml.parser.ParserError:
logger.exception("Invalid YAML in %s", METADATA_FILE_NAME)
raise ValidationError({file_name: "Not a valid YAML file"})
def load_metadata(contents: Dict[str, str]) -> Dict[str, str]:
"""Apply validation and load a metadata file"""
if METADATA_FILE_NAME not in contents:
# if the contents ahve no METADATA_FILE_NAME this is probably
# a original export without versioning that should not be
# handled by this command
raise IncorrectVersionError(f"Missing {METADATA_FILE_NAME}")
metadata = load_yaml(METADATA_FILE_NAME, contents[METADATA_FILE_NAME])
try:
MetadataSchema().load(metadata)
except ValidationError as exc:
# if the version doesn't match raise an exception so that the
# dispatcher can try a different command version
if "version" in exc.messages:
raise IncorrectVersionError(exc.messages["version"][0])
# otherwise we raise the validation error
exc.messages = {METADATA_FILE_NAME: exc.messages}
raise exc
return metadata

View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@ -0,0 +1,116 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict, List
from marshmallow import Schema, validate
from marshmallow.exceptions import ValidationError
from sqlalchemy.orm import Session
from superset import db
from superset.commands.base import BaseCommand
from superset.commands.exceptions import CommandInvalidError
from superset.commands.importers.v1.utils import (
load_metadata,
load_yaml,
METADATA_FILE_NAME,
)
from superset.databases.commands.importers.v1.utils import import_database
from superset.databases.schemas import ImportV1DatabaseSchema
from superset.datasets.commands.importers.v1.utils import import_dataset
from superset.datasets.schemas import ImportV1DatasetSchema
from superset.models.core import Database
schemas: Dict[str, Schema] = {
"databases/": ImportV1DatabaseSchema(),
"datasets/": ImportV1DatasetSchema(),
}
class ImportDatabasesCommand(BaseCommand):
"""Import databases"""
# pylint: disable=unused-argument
def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
self.contents = contents
self._configs: Dict[str, Any] = {}
def _import_bundle(self, session: Session) -> None:
# first import databases
database_ids: Dict[str, int] = {}
for file_name, config in self._configs.items():
if file_name.startswith("databases/"):
database = import_database(session, config, overwrite=True)
database_ids[str(database.uuid)] = database.id
# import related datasets
for file_name, config in self._configs.items():
if (
file_name.startswith("datasets/")
and config["database_uuid"] in database_ids
):
config["database_id"] = database_ids[config["database_uuid"]]
# overwrite=False prevents deleting any non-imported columns/metrics
import_dataset(session, config, overwrite=False)
def run(self) -> None:
self.validate()
# rollback to prevent partial imports
try:
self._import_bundle(db.session)
db.session.commit()
except Exception as exc:
db.session.rollback()
raise exc
def validate(self) -> None:
exceptions: List[ValidationError] = []
# verify that the metadata file is present and valid
try:
metadata = load_metadata(self.contents)
except ValidationError as exc:
exceptions.append(exc)
metadata = None
for file_name, content in self.contents.items():
prefix = file_name.split("/")[0]
schema = schemas.get(f"{prefix}/")
if schema:
try:
config = load_yaml(file_name, content)
schema.load(config)
self._configs[file_name] = config
except ValidationError as exc:
exc.messages = {file_name: exc.messages}
exceptions.append(exc)
# validate that the type declared in METADATA_FILE_NAME is correct
if metadata:
type_validator = validate.Equal(Database.__name__)
try:
type_validator(metadata["type"])
except ValidationError as exc:
exc.messages = {METADATA_FILE_NAME: {"type": exc.messages}}
exceptions.append(exc)
if exceptions:
exception = CommandInvalidError("Error importing database")
exception.add_list(exceptions)
raise exception

View File

@ -0,0 +1,42 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
from typing import Any, Dict
from sqlalchemy.orm import Session
from superset.models.core import Database
def import_database(
session: Session, config: Dict[str, Any], overwrite: bool = False
) -> Database:
existing = session.query(Database).filter_by(uuid=config["uuid"]).first()
if existing:
if not overwrite:
return existing
config["id"] = existing.id
# TODO (betodealmeida): move this logic to import_from_dict
config["extra"] = json.dumps(config["extra"])
database = Database.import_from_dict(session, config, recursive=False)
if database.id is None:
session.flush()
return database

View File

@ -408,3 +408,24 @@ class DatabaseRelatedDashboards(Schema):
class DatabaseRelatedObjectsResponse(Schema):
charts = fields.Nested(DatabaseRelatedCharts)
dashboards = fields.Nested(DatabaseRelatedDashboards)
class ImportV1DatabaseExtraSchema(Schema):
metadata_params = fields.Dict(keys=fields.Str(), values=fields.Raw())
engine_params = fields.Dict(keys=fields.Str(), values=fields.Raw())
metadata_cache_timeout = fields.Dict(keys=fields.Str(), values=fields.Integer())
schemas_allowed_for_csv_upload = fields.List(fields.String)
class ImportV1DatabaseSchema(Schema):
database_name = fields.String(required=True)
sqlalchemy_uri = fields.String(required=True)
cache_timeout = fields.Integer(allow_none=True)
expose_in_sqllab = fields.Boolean()
allow_run_async = fields.Boolean()
allow_ctas = fields.Boolean()
allow_cvas = fields.Boolean()
allow_csv_upload = fields.Boolean()
extra = fields.Nested(ImportV1DatabaseExtraSchema)
uuid = fields.UUID(required=True)
version = fields.String(required=True)

View File

@ -282,9 +282,19 @@ class ImportDatasetsCommand(BaseCommand):
in Superset.
"""
def __init__(self, contents: Dict[str, str], sync: Optional[List[str]] = None):
def __init__(
self,
contents: Dict[str, str],
sync_columns: bool = False,
sync_metrics: bool = False,
):
self.contents = contents
self.sync = sync
self.sync = []
if sync_columns:
self.sync.append("columns")
if sync_metrics:
self.sync.append("metrics")
def run(self) -> None:
self.validate()

View File

@ -0,0 +1,121 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict, List, Set
from marshmallow import Schema, validate
from marshmallow.exceptions import ValidationError
from sqlalchemy.orm import Session
from superset import db
from superset.commands.base import BaseCommand
from superset.commands.exceptions import CommandInvalidError
from superset.commands.importers.v1.utils import (
load_metadata,
load_yaml,
METADATA_FILE_NAME,
)
from superset.connectors.sqla.models import SqlaTable
from superset.databases.commands.importers.v1.utils import import_database
from superset.databases.schemas import ImportV1DatabaseSchema
from superset.datasets.commands.importers.v1.utils import import_dataset
from superset.datasets.schemas import ImportV1DatasetSchema
schemas: Dict[str, Schema] = {
"databases/": ImportV1DatabaseSchema(),
"datasets/": ImportV1DatasetSchema(),
}
class ImportDatasetsCommand(BaseCommand):
"""Import datasets"""
# pylint: disable=unused-argument
def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
self.contents = contents
self._configs: Dict[str, Any] = {}
def _import_bundle(self, session: Session) -> None:
# discover databases associated with datasets
database_uuids: Set[str] = set()
for file_name, config in self._configs.items():
if file_name.startswith("datasets/"):
database_uuids.add(config["database_uuid"])
# import related databases
database_ids: Dict[str, int] = {}
for file_name, config in self._configs.items():
if file_name.startswith("databases/") and config["uuid"] in database_uuids:
database = import_database(session, config, overwrite=False)
database_ids[str(database.uuid)] = database.id
# import datasets with the correct parent ref
for file_name, config in self._configs.items():
if (
file_name.startswith("datasets/")
and config["database_uuid"] in database_ids
):
config["database_id"] = database_ids[config["database_uuid"]]
import_dataset(session, config, overwrite=True)
def run(self) -> None:
self.validate()
# rollback to prevent partial imports
try:
self._import_bundle(db.session)
db.session.commit()
except Exception as exc:
db.session.rollback()
raise exc
def validate(self) -> None:
exceptions: List[ValidationError] = []
# verify that the metadata file is present and valid
try:
metadata = load_metadata(self.contents)
except ValidationError as exc:
exceptions.append(exc)
metadata = None
for file_name, content in self.contents.items():
prefix = file_name.split("/")[0]
schema = schemas.get(f"{prefix}/")
if schema:
try:
config = load_yaml(file_name, content)
schema.load(config)
self._configs[file_name] = config
except ValidationError as exc:
exc.messages = {file_name: exc.messages}
exceptions.append(exc)
# validate that the type declared in METADATA_FILE_NAME is correct
if metadata:
type_validator = validate.Equal(SqlaTable.__name__)
try:
type_validator(metadata["type"])
except ValidationError as exc:
exc.messages = {METADATA_FILE_NAME: {"type": exc.messages}}
exceptions.append(exc)
if exceptions:
exception = CommandInvalidError("Error importing dataset")
exception.add_list(exceptions)
raise exception

View File

@ -0,0 +1,42 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict
from sqlalchemy.orm import Session
from superset.connectors.sqla.models import SqlaTable
def import_dataset(
session: Session, config: Dict[str, Any], overwrite: bool = False
) -> SqlaTable:
existing = session.query(SqlaTable).filter_by(uuid=config["uuid"]).first()
if existing:
if not overwrite:
return existing
config["id"] = existing.id
# should we delete columns and metrics not present in the current import?
sync = ["columns", "metrics"] if overwrite else []
# import recursively to include columns and metrics
dataset = SqlaTable.import_from_dict(session, config, recursive=True, sync=sync)
if dataset.id is None:
session.flush()
return dataset

View File

@ -122,3 +122,48 @@ class DatasetRelatedDashboards(Schema):
class DatasetRelatedObjectsResponse(Schema):
charts = fields.Nested(DatasetRelatedCharts)
dashboards = fields.Nested(DatasetRelatedDashboards)
class ImportV1ColumnSchema(Schema):
column_name = fields.String(required=True)
verbose_name = fields.String()
is_dttm = fields.Boolean()
is_active = fields.Boolean(allow_none=True)
type = fields.String(required=True)
groupby = fields.Boolean()
filterable = fields.Boolean()
expression = fields.String()
description = fields.String(allow_none=True)
python_date_format = fields.String(allow_none=True)
class ImportV1MetricSchema(Schema):
metric_name = fields.String(required=True)
verbose_name = fields.String()
metric_type = fields.String(allow_none=True)
expression = fields.String(required=True)
description = fields.String(allow_none=True)
d3format = fields.String(allow_none=True)
extra = fields.String(allow_none=True)
warning_text = fields.String(allow_none=True)
class ImportV1DatasetSchema(Schema):
table_name = fields.String(required=True)
main_dttm_col = fields.String(allow_none=True)
description = fields.String()
default_endpoint = fields.String()
offset = fields.Integer()
cache_timeout = fields.Integer()
schema = fields.String()
sql = fields.String()
params = fields.String(allow_none=True)
template_params = fields.String(allow_none=True)
filter_select_enabled = fields.Boolean()
fetch_values_predicate = fields.String(allow_none=True)
extra = fields.String(allow_none=True)
uuid = fields.UUID(required=True)
columns = fields.List(fields.Nested(ImportV1ColumnSchema))
metrics = fields.List(fields.Nested(ImportV1MetricSchema))
version = fields.String(required=True)
database_uuid = fields.UUID(required=True)

View File

@ -163,7 +163,7 @@ class ImportExportMixin:
if sync is None:
sync = []
parent_refs = cls.parent_foreign_key_mappings()
export_fields = set(cls.export_fields) | set(parent_refs.keys())
export_fields = set(cls.export_fields) | set(parent_refs.keys()) | {"uuid"}
new_children = {c: dict_rep[c] for c in cls.export_children if c in dict_rep}
unique_constrains = cls._unique_constrains()

View File

@ -14,16 +14,29 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-self-use, invalid-name
from unittest.mock import patch
import pytest
import yaml
from superset import security_manager
from superset import db, security_manager
from superset.commands.exceptions import CommandInvalidError
from superset.commands.importers.exceptions import IncorrectVersionError
from superset.connectors.sqla.models import SqlaTable
from superset.databases.commands.exceptions import DatabaseNotFoundError
from superset.databases.commands.export import ExportDatabasesCommand
from superset.databases.commands.importers.v1 import ImportDatabasesCommand
from superset.models.core import Database
from superset.utils.core import backend, get_example_database
from tests.base_tests import SupersetTestCase
from tests.fixtures.importexport import (
database_config,
database_metadata_config,
dataset_config,
dataset_metadata_config,
)
class TestExportDatabasesCommand(SupersetTestCase):
@ -265,3 +278,197 @@ class TestExportDatabasesCommand(SupersetTestCase):
"uuid",
"version",
]
def test_import_v1_database(self):
"""Test that a database can be imported"""
contents = {
"metadata.yaml": yaml.safe_dump(database_metadata_config),
"databases/imported_database.yaml": yaml.safe_dump(database_config),
}
command = ImportDatabasesCommand(contents)
command.run()
database = (
db.session.query(Database).filter_by(uuid=database_config["uuid"]).one()
)
assert database.allow_csv_upload
assert database.allow_ctas
assert database.allow_cvas
assert not database.allow_run_async
assert database.cache_timeout is None
assert database.database_name == "imported_database"
assert database.expose_in_sqllab
assert database.extra == "{}"
assert database.sqlalchemy_uri == "sqlite:///test.db"
db.session.delete(database)
db.session.commit()
def test_import_v1_database_multiple(self):
"""Test that a database can be imported multiple times"""
num_databases = db.session.query(Database).count()
contents = {
"databases/imported_database.yaml": yaml.safe_dump(database_config),
"metadata.yaml": yaml.safe_dump(database_metadata_config),
}
command = ImportDatabasesCommand(contents)
# import twice
command.run()
command.run()
database = (
db.session.query(Database).filter_by(uuid=database_config["uuid"]).one()
)
assert database.allow_csv_upload
# update allow_csv_upload to False
new_config = database_config.copy()
new_config["allow_csv_upload"] = False
contents = {
"databases/imported_database.yaml": yaml.safe_dump(new_config),
"metadata.yaml": yaml.safe_dump(database_metadata_config),
}
command = ImportDatabasesCommand(contents)
command.run()
database = (
db.session.query(Database).filter_by(uuid=database_config["uuid"]).one()
)
assert not database.allow_csv_upload
# test that only one database was created
new_num_databases = db.session.query(Database).count()
assert new_num_databases == num_databases + 1
db.session.delete(database)
db.session.commit()
def test_import_v1_database_with_dataset(self):
"""Test that a database can be imported with datasets"""
contents = {
"databases/imported_database.yaml": yaml.safe_dump(database_config),
"datasets/imported_dataset.yaml": yaml.safe_dump(dataset_config),
"metadata.yaml": yaml.safe_dump(database_metadata_config),
}
command = ImportDatabasesCommand(contents)
command.run()
database = (
db.session.query(Database).filter_by(uuid=database_config["uuid"]).one()
)
assert len(database.tables) == 1
assert str(database.tables[0].uuid) == "10808100-158b-42c4-842e-f32b99d88dfb"
db.session.delete(database.tables[0])
db.session.delete(database)
db.session.commit()
def test_import_v1_database_with_dataset_multiple(self):
"""Test that a database can be imported multiple times w/o changing datasets"""
contents = {
"databases/imported_database.yaml": yaml.safe_dump(database_config),
"datasets/imported_dataset.yaml": yaml.safe_dump(dataset_config),
"metadata.yaml": yaml.safe_dump(database_metadata_config),
}
command = ImportDatabasesCommand(contents)
command.run()
dataset = (
db.session.query(SqlaTable).filter_by(uuid=dataset_config["uuid"]).one()
)
assert dataset.offset == 66
new_config = dataset_config.copy()
new_config["offset"] = 67
contents = {
"databases/imported_database.yaml": yaml.safe_dump(database_config),
"datasets/imported_dataset.yaml": yaml.safe_dump(new_config),
"metadata.yaml": yaml.safe_dump(database_metadata_config),
}
command = ImportDatabasesCommand(contents)
command.run()
# the underlying dataset should not be modified by the second import, since
# we're importing a database, not a dataset
dataset = (
db.session.query(SqlaTable).filter_by(uuid=dataset_config["uuid"]).one()
)
assert dataset.offset == 66
db.session.delete(dataset)
db.session.delete(dataset.database)
db.session.commit()
def test_import_v1_database_validation(self):
"""Test different validations applied when importing a database"""
# metadata.yaml must be present
contents = {
"databases/imported_database.yaml": yaml.safe_dump(database_config),
}
command = ImportDatabasesCommand(contents)
with pytest.raises(IncorrectVersionError) as excinfo:
command.run()
assert str(excinfo.value) == "Missing metadata.yaml"
# version should be 1.0.0
contents["metadata.yaml"] = yaml.safe_dump(
{
"version": "2.0.0",
"type": "Database",
"timestamp": "2020-11-04T21:27:44.423819+00:00",
}
)
command = ImportDatabasesCommand(contents)
with pytest.raises(IncorrectVersionError) as excinfo:
command.run()
assert str(excinfo.value) == "Must be equal to 1.0.0."
# type should be Database
contents["metadata.yaml"] = yaml.safe_dump(dataset_metadata_config)
command = ImportDatabasesCommand(contents)
with pytest.raises(CommandInvalidError) as excinfo:
command.run()
assert str(excinfo.value) == "Error importing database"
assert excinfo.value.normalized_messages() == {
"metadata.yaml": {"type": ["Must be equal to Database."],}
}
# must also validate datasets
broken_config = dataset_config.copy()
del broken_config["table_name"]
contents["metadata.yaml"] = yaml.safe_dump(database_metadata_config)
contents["datasets/imported_dataset.yaml"] = yaml.safe_dump(broken_config)
command = ImportDatabasesCommand(contents)
with pytest.raises(CommandInvalidError) as excinfo:
command.run()
assert str(excinfo.value) == "Error importing database"
assert excinfo.value.normalized_messages() == {
"datasets/imported_dataset.yaml": {
"table_name": ["Missing data for required field."],
}
}
@patch("superset.databases.commands.importers.v1.import_dataset")
def test_import_v1_rollback(self, mock_import_dataset):
"""Test than on an exception everything is rolled back"""
num_databases = db.session.query(Database).count()
# raise an exception when importing the dataset, after the database has
# already been imported
mock_import_dataset.side_effect = Exception("A wild exception appears!")
contents = {
"databases/imported_database.yaml": yaml.safe_dump(database_config),
"datasets/imported_dataset.yaml": yaml.safe_dump(dataset_config),
"metadata.yaml": yaml.safe_dump(database_metadata_config),
}
command = ImportDatabasesCommand(contents)
with pytest.raises(Exception) as excinfo:
command.run()
assert str(excinfo.value) == "A wild exception appears!"
# verify that the database was not added
new_num_databases = db.session.query(Database).count()
assert new_num_databases == num_databases

View File

@ -14,18 +14,29 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-self-use, invalid-name
from operator import itemgetter
from unittest.mock import patch
import pytest
import yaml
from superset import security_manager
from superset import db, security_manager
from superset.commands.exceptions import CommandInvalidError
from superset.commands.importers.exceptions import IncorrectVersionError
from superset.connectors.sqla.models import SqlaTable
from superset.datasets.commands.exceptions import DatasetNotFoundError
from superset.datasets.commands.export import ExportDatasetsCommand
from superset.utils.core import backend, get_example_database
from superset.datasets.commands.importers.v1 import ImportDatasetsCommand
from superset.utils.core import get_example_database
from tests.base_tests import SupersetTestCase
from tests.fixtures.importexport import (
database_config,
database_metadata_config,
dataset_config,
dataset_metadata_config,
)
class TestExportDatasetsCommand(SupersetTestCase):
@ -186,3 +197,149 @@ class TestExportDatasetsCommand(SupersetTestCase):
"version",
"database_uuid",
]
def test_import_v1_dataset(self):
"""Test that we can import a dataset"""
contents = {
"metadata.yaml": yaml.safe_dump(dataset_metadata_config),
"databases/imported_database.yaml": yaml.safe_dump(database_config),
"datasets/imported_dataset.yaml": yaml.safe_dump(dataset_config),
}
command = ImportDatasetsCommand(contents)
command.run()
dataset = (
db.session.query(SqlaTable).filter_by(uuid=dataset_config["uuid"]).one()
)
assert dataset.table_name == "imported_dataset"
assert dataset.main_dttm_col is None
assert dataset.description == "This is a dataset that was exported"
assert dataset.default_endpoint == ""
assert dataset.offset == 66
assert dataset.cache_timeout == 55
assert dataset.schema == ""
assert dataset.sql == ""
assert dataset.params is None
assert dataset.template_params is None
assert dataset.filter_select_enabled
assert dataset.fetch_values_predicate is None
assert dataset.extra is None
# database is also imported
assert str(dataset.database.uuid) == "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89"
assert len(dataset.metrics) == 1
metric = dataset.metrics[0]
assert metric.metric_name == "count"
assert metric.verbose_name == ""
assert metric.metric_type is None
assert metric.expression == "count(1)"
assert metric.description is None
assert metric.d3format is None
assert metric.extra is None
assert metric.warning_text is None
assert len(dataset.columns) == 1
column = dataset.columns[0]
assert column.column_name == "cnt"
assert column.verbose_name == "Count of something"
assert not column.is_dttm
assert column.is_active # imported columns are set to active
assert column.type == "NUMBER"
assert not column.groupby
assert column.filterable
assert column.expression == ""
assert column.description is None
assert column.python_date_format is None
db.session.delete(dataset)
db.session.delete(dataset.database)
db.session.commit()
def test_import_v1_dataset_multiple(self):
"""Test that a dataset can be imported multiple times"""
contents = {
"metadata.yaml": yaml.safe_dump(dataset_metadata_config),
"databases/imported_database.yaml": yaml.safe_dump(database_config),
"datasets/imported_dataset.yaml": yaml.safe_dump(dataset_config),
}
command = ImportDatasetsCommand(contents)
command.run()
command.run()
dataset = (
db.session.query(SqlaTable).filter_by(uuid=dataset_config["uuid"]).one()
)
assert dataset.table_name == "imported_dataset"
# test that columns and metrics sync, ie, old ones not the import
# are removed
new_config = dataset_config.copy()
new_config["metrics"][0]["metric_name"] = "count2"
new_config["columns"][0]["column_name"] = "cnt2"
contents = {
"metadata.yaml": yaml.safe_dump(dataset_metadata_config),
"databases/imported_database.yaml": yaml.safe_dump(database_config),
"datasets/imported_dataset.yaml": yaml.safe_dump(new_config),
}
command = ImportDatasetsCommand(contents)
command.run()
dataset = (
db.session.query(SqlaTable).filter_by(uuid=dataset_config["uuid"]).one()
)
assert len(dataset.metrics) == 1
assert dataset.metrics[0].metric_name == "count2"
assert len(dataset.columns) == 1
assert dataset.columns[0].column_name == "cnt2"
db.session.delete(dataset)
db.session.delete(dataset.database)
db.session.commit()
def test_import_v1_dataset_validation(self):
"""Test different validations applied when importing a dataset"""
# metadata.yaml must be present
contents = {
"datasets/imported_dataset.yaml": yaml.safe_dump(dataset_config),
}
command = ImportDatasetsCommand(contents)
with pytest.raises(IncorrectVersionError) as excinfo:
command.run()
assert str(excinfo.value) == "Missing metadata.yaml"
# version should be 1.0.0
contents["metadata.yaml"] = yaml.safe_dump(
{
"version": "2.0.0",
"type": "SqlaTable",
"timestamp": "2020-11-04T21:27:44.423819+00:00",
}
)
command = ImportDatasetsCommand(contents)
with pytest.raises(IncorrectVersionError) as excinfo:
command.run()
assert str(excinfo.value) == "Must be equal to 1.0.0."
# type should be SqlaTable
contents["metadata.yaml"] = yaml.safe_dump(database_metadata_config)
command = ImportDatasetsCommand(contents)
with pytest.raises(CommandInvalidError) as excinfo:
command.run()
assert str(excinfo.value) == "Error importing dataset"
assert excinfo.value.normalized_messages() == {
"metadata.yaml": {"type": ["Must be equal to SqlaTable."],}
}
# must also validate databases
broken_config = database_config.copy()
del broken_config["database_name"]
contents["metadata.yaml"] = yaml.safe_dump(dataset_metadata_config)
contents["databases/imported_database.yaml"] = yaml.safe_dump(broken_config)
command = ImportDatasetsCommand(contents)
with pytest.raises(CommandInvalidError) as excinfo:
command.run()
assert str(excinfo.value) == "Error importing dataset"
assert excinfo.value.normalized_messages() == {
"databases/imported_database.yaml": {
"database_name": ["Missing data for required field."],
}
}

90
tests/fixtures/importexport.py vendored Normal file
View File

@ -0,0 +1,90 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict
# example YAML files
database_metadata_config: Dict[str, Any] = {
"version": "1.0.0",
"type": "Database",
"timestamp": "2020-11-04T21:27:44.423819+00:00",
}
dataset_metadata_config: Dict[str, Any] = {
"version": "1.0.0",
"type": "SqlaTable",
"timestamp": "2020-11-04T21:27:44.423819+00:00",
}
database_config: Dict[str, Any] = {
"allow_csv_upload": True,
"allow_ctas": True,
"allow_cvas": True,
"allow_run_async": False,
"cache_timeout": None,
"database_name": "imported_database",
"expose_in_sqllab": True,
"extra": {},
"sqlalchemy_uri": "sqlite:///test.db",
"uuid": "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89",
"version": "1.0.0",
}
dataset_config: Dict[str, Any] = {
"table_name": "imported_dataset",
"main_dttm_col": None,
"description": "This is a dataset that was exported",
"default_endpoint": "",
"offset": 66,
"cache_timeout": 55,
"schema": "",
"sql": "",
"params": None,
"template_params": None,
"filter_select_enabled": True,
"fetch_values_predicate": None,
"extra": None,
"metrics": [
{
"metric_name": "count",
"verbose_name": "",
"metric_type": None,
"expression": "count(1)",
"description": None,
"d3format": None,
"extra": None,
"warning_text": None,
},
],
"columns": [
{
"column_name": "cnt",
"verbose_name": "Count of something",
"is_dttm": False,
"is_active": None,
"type": "NUMBER",
"groupby": False,
"filterable": True,
"expression": "",
"description": None,
"python_date_format": None,
},
],
"version": "1.0.0",
"uuid": "10808100-158b-42c4-842e-f32b99d88dfb",
"database_uuid": "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89",
}