mirror of https://github.com/apache/superset.git
feat: new import commands for dataset and databases (#11670)
* feat: commands for importing databases and datasets * Refactor code
This commit is contained in:
parent
871a98abe2
commit
7bc353f8a8
|
@ -304,6 +304,9 @@ def import_datasources(path: str, sync: str, recursive: bool) -> None:
|
|||
from superset.datasets.commands.importers.v0 import ImportDatasetsCommand
|
||||
|
||||
sync_array = sync.split(",")
|
||||
sync_columns = "columns" in sync_array
|
||||
sync_metrics = "metrics" in sync_array
|
||||
|
||||
path_object = Path(path)
|
||||
files: List[Path] = []
|
||||
if path_object.is_file():
|
||||
|
@ -316,7 +319,7 @@ def import_datasources(path: str, sync: str, recursive: bool) -> None:
|
|||
files.extend(path_object.rglob("*.yml"))
|
||||
contents = {path.name: open(path).read() for path in files}
|
||||
try:
|
||||
ImportDatasetsCommand(contents, sync_array).run()
|
||||
ImportDatasetsCommand(contents, sync_columns, sync_metrics).run()
|
||||
except Exception: # pylint: disable=broad-except
|
||||
logger.exception("Error when importing dataset")
|
||||
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from superset.commands.exceptions import CommandException
|
||||
|
||||
|
||||
class IncorrectVersionError(CommandException):
|
||||
status = 422
|
||||
message = "Import has incorrect version"
|
|
@ -0,0 +1,16 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
|
@ -0,0 +1,67 @@
|
|||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
import yaml
|
||||
from marshmallow import fields, Schema, validate
|
||||
from marshmallow.exceptions import ValidationError
|
||||
|
||||
from superset.commands.importers.exceptions import IncorrectVersionError
|
||||
|
||||
METADATA_FILE_NAME = "metadata.yaml"
|
||||
IMPORT_VERSION = "1.0.0"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MetadataSchema(Schema):
|
||||
version = fields.String(required=True, validate=validate.Equal(IMPORT_VERSION))
|
||||
type = fields.String(required=True)
|
||||
timestamp = fields.DateTime()
|
||||
|
||||
|
||||
def load_yaml(file_name: str, content: str) -> Dict[str, Any]:
|
||||
"""Try to load a YAML file"""
|
||||
try:
|
||||
return yaml.safe_load(content)
|
||||
except yaml.parser.ParserError:
|
||||
logger.exception("Invalid YAML in %s", METADATA_FILE_NAME)
|
||||
raise ValidationError({file_name: "Not a valid YAML file"})
|
||||
|
||||
|
||||
def load_metadata(contents: Dict[str, str]) -> Dict[str, str]:
|
||||
"""Apply validation and load a metadata file"""
|
||||
if METADATA_FILE_NAME not in contents:
|
||||
# if the contents ahve no METADATA_FILE_NAME this is probably
|
||||
# a original export without versioning that should not be
|
||||
# handled by this command
|
||||
raise IncorrectVersionError(f"Missing {METADATA_FILE_NAME}")
|
||||
|
||||
metadata = load_yaml(METADATA_FILE_NAME, contents[METADATA_FILE_NAME])
|
||||
try:
|
||||
MetadataSchema().load(metadata)
|
||||
except ValidationError as exc:
|
||||
# if the version doesn't match raise an exception so that the
|
||||
# dispatcher can try a different command version
|
||||
if "version" in exc.messages:
|
||||
raise IncorrectVersionError(exc.messages["version"][0])
|
||||
|
||||
# otherwise we raise the validation error
|
||||
exc.messages = {METADATA_FILE_NAME: exc.messages}
|
||||
raise exc
|
||||
|
||||
return metadata
|
|
@ -0,0 +1,16 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
|
@ -0,0 +1,116 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from marshmallow import Schema, validate
|
||||
from marshmallow.exceptions import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from superset import db
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.commands.exceptions import CommandInvalidError
|
||||
from superset.commands.importers.v1.utils import (
|
||||
load_metadata,
|
||||
load_yaml,
|
||||
METADATA_FILE_NAME,
|
||||
)
|
||||
from superset.databases.commands.importers.v1.utils import import_database
|
||||
from superset.databases.schemas import ImportV1DatabaseSchema
|
||||
from superset.datasets.commands.importers.v1.utils import import_dataset
|
||||
from superset.datasets.schemas import ImportV1DatasetSchema
|
||||
from superset.models.core import Database
|
||||
|
||||
schemas: Dict[str, Schema] = {
|
||||
"databases/": ImportV1DatabaseSchema(),
|
||||
"datasets/": ImportV1DatasetSchema(),
|
||||
}
|
||||
|
||||
|
||||
class ImportDatabasesCommand(BaseCommand):
|
||||
|
||||
"""Import databases"""
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
|
||||
self.contents = contents
|
||||
self._configs: Dict[str, Any] = {}
|
||||
|
||||
def _import_bundle(self, session: Session) -> None:
|
||||
# first import databases
|
||||
database_ids: Dict[str, int] = {}
|
||||
for file_name, config in self._configs.items():
|
||||
if file_name.startswith("databases/"):
|
||||
database = import_database(session, config, overwrite=True)
|
||||
database_ids[str(database.uuid)] = database.id
|
||||
|
||||
# import related datasets
|
||||
for file_name, config in self._configs.items():
|
||||
if (
|
||||
file_name.startswith("datasets/")
|
||||
and config["database_uuid"] in database_ids
|
||||
):
|
||||
config["database_id"] = database_ids[config["database_uuid"]]
|
||||
# overwrite=False prevents deleting any non-imported columns/metrics
|
||||
import_dataset(session, config, overwrite=False)
|
||||
|
||||
def run(self) -> None:
|
||||
self.validate()
|
||||
|
||||
# rollback to prevent partial imports
|
||||
try:
|
||||
self._import_bundle(db.session)
|
||||
db.session.commit()
|
||||
except Exception as exc:
|
||||
db.session.rollback()
|
||||
raise exc
|
||||
|
||||
def validate(self) -> None:
|
||||
exceptions: List[ValidationError] = []
|
||||
|
||||
# verify that the metadata file is present and valid
|
||||
try:
|
||||
metadata = load_metadata(self.contents)
|
||||
except ValidationError as exc:
|
||||
exceptions.append(exc)
|
||||
metadata = None
|
||||
|
||||
for file_name, content in self.contents.items():
|
||||
prefix = file_name.split("/")[0]
|
||||
schema = schemas.get(f"{prefix}/")
|
||||
if schema:
|
||||
try:
|
||||
config = load_yaml(file_name, content)
|
||||
schema.load(config)
|
||||
self._configs[file_name] = config
|
||||
except ValidationError as exc:
|
||||
exc.messages = {file_name: exc.messages}
|
||||
exceptions.append(exc)
|
||||
|
||||
# validate that the type declared in METADATA_FILE_NAME is correct
|
||||
if metadata:
|
||||
type_validator = validate.Equal(Database.__name__)
|
||||
try:
|
||||
type_validator(metadata["type"])
|
||||
except ValidationError as exc:
|
||||
exc.messages = {METADATA_FILE_NAME: {"type": exc.messages}}
|
||||
exceptions.append(exc)
|
||||
|
||||
if exceptions:
|
||||
exception = CommandInvalidError("Error importing database")
|
||||
exception.add_list(exceptions)
|
||||
raise exception
|
|
@ -0,0 +1,42 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import json
|
||||
from typing import Any, Dict
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from superset.models.core import Database
|
||||
|
||||
|
||||
def import_database(
|
||||
session: Session, config: Dict[str, Any], overwrite: bool = False
|
||||
) -> Database:
|
||||
existing = session.query(Database).filter_by(uuid=config["uuid"]).first()
|
||||
if existing:
|
||||
if not overwrite:
|
||||
return existing
|
||||
config["id"] = existing.id
|
||||
|
||||
# TODO (betodealmeida): move this logic to import_from_dict
|
||||
config["extra"] = json.dumps(config["extra"])
|
||||
|
||||
database = Database.import_from_dict(session, config, recursive=False)
|
||||
if database.id is None:
|
||||
session.flush()
|
||||
|
||||
return database
|
|
@ -408,3 +408,24 @@ class DatabaseRelatedDashboards(Schema):
|
|||
class DatabaseRelatedObjectsResponse(Schema):
|
||||
charts = fields.Nested(DatabaseRelatedCharts)
|
||||
dashboards = fields.Nested(DatabaseRelatedDashboards)
|
||||
|
||||
|
||||
class ImportV1DatabaseExtraSchema(Schema):
|
||||
metadata_params = fields.Dict(keys=fields.Str(), values=fields.Raw())
|
||||
engine_params = fields.Dict(keys=fields.Str(), values=fields.Raw())
|
||||
metadata_cache_timeout = fields.Dict(keys=fields.Str(), values=fields.Integer())
|
||||
schemas_allowed_for_csv_upload = fields.List(fields.String)
|
||||
|
||||
|
||||
class ImportV1DatabaseSchema(Schema):
|
||||
database_name = fields.String(required=True)
|
||||
sqlalchemy_uri = fields.String(required=True)
|
||||
cache_timeout = fields.Integer(allow_none=True)
|
||||
expose_in_sqllab = fields.Boolean()
|
||||
allow_run_async = fields.Boolean()
|
||||
allow_ctas = fields.Boolean()
|
||||
allow_cvas = fields.Boolean()
|
||||
allow_csv_upload = fields.Boolean()
|
||||
extra = fields.Nested(ImportV1DatabaseExtraSchema)
|
||||
uuid = fields.UUID(required=True)
|
||||
version = fields.String(required=True)
|
||||
|
|
|
@ -282,9 +282,19 @@ class ImportDatasetsCommand(BaseCommand):
|
|||
in Superset.
|
||||
"""
|
||||
|
||||
def __init__(self, contents: Dict[str, str], sync: Optional[List[str]] = None):
|
||||
def __init__(
|
||||
self,
|
||||
contents: Dict[str, str],
|
||||
sync_columns: bool = False,
|
||||
sync_metrics: bool = False,
|
||||
):
|
||||
self.contents = contents
|
||||
self.sync = sync
|
||||
|
||||
self.sync = []
|
||||
if sync_columns:
|
||||
self.sync.append("columns")
|
||||
if sync_metrics:
|
||||
self.sync.append("metrics")
|
||||
|
||||
def run(self) -> None:
|
||||
self.validate()
|
||||
|
|
|
@ -0,0 +1,121 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from typing import Any, Dict, List, Set
|
||||
|
||||
from marshmallow import Schema, validate
|
||||
from marshmallow.exceptions import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from superset import db
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.commands.exceptions import CommandInvalidError
|
||||
from superset.commands.importers.v1.utils import (
|
||||
load_metadata,
|
||||
load_yaml,
|
||||
METADATA_FILE_NAME,
|
||||
)
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.databases.commands.importers.v1.utils import import_database
|
||||
from superset.databases.schemas import ImportV1DatabaseSchema
|
||||
from superset.datasets.commands.importers.v1.utils import import_dataset
|
||||
from superset.datasets.schemas import ImportV1DatasetSchema
|
||||
|
||||
schemas: Dict[str, Schema] = {
|
||||
"databases/": ImportV1DatabaseSchema(),
|
||||
"datasets/": ImportV1DatasetSchema(),
|
||||
}
|
||||
|
||||
|
||||
class ImportDatasetsCommand(BaseCommand):
|
||||
|
||||
"""Import datasets"""
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
|
||||
self.contents = contents
|
||||
self._configs: Dict[str, Any] = {}
|
||||
|
||||
def _import_bundle(self, session: Session) -> None:
|
||||
# discover databases associated with datasets
|
||||
database_uuids: Set[str] = set()
|
||||
for file_name, config in self._configs.items():
|
||||
if file_name.startswith("datasets/"):
|
||||
database_uuids.add(config["database_uuid"])
|
||||
|
||||
# import related databases
|
||||
database_ids: Dict[str, int] = {}
|
||||
for file_name, config in self._configs.items():
|
||||
if file_name.startswith("databases/") and config["uuid"] in database_uuids:
|
||||
database = import_database(session, config, overwrite=False)
|
||||
database_ids[str(database.uuid)] = database.id
|
||||
|
||||
# import datasets with the correct parent ref
|
||||
for file_name, config in self._configs.items():
|
||||
if (
|
||||
file_name.startswith("datasets/")
|
||||
and config["database_uuid"] in database_ids
|
||||
):
|
||||
config["database_id"] = database_ids[config["database_uuid"]]
|
||||
import_dataset(session, config, overwrite=True)
|
||||
|
||||
def run(self) -> None:
|
||||
self.validate()
|
||||
|
||||
# rollback to prevent partial imports
|
||||
try:
|
||||
self._import_bundle(db.session)
|
||||
db.session.commit()
|
||||
except Exception as exc:
|
||||
db.session.rollback()
|
||||
raise exc
|
||||
|
||||
def validate(self) -> None:
|
||||
exceptions: List[ValidationError] = []
|
||||
|
||||
# verify that the metadata file is present and valid
|
||||
try:
|
||||
metadata = load_metadata(self.contents)
|
||||
except ValidationError as exc:
|
||||
exceptions.append(exc)
|
||||
metadata = None
|
||||
|
||||
for file_name, content in self.contents.items():
|
||||
prefix = file_name.split("/")[0]
|
||||
schema = schemas.get(f"{prefix}/")
|
||||
if schema:
|
||||
try:
|
||||
config = load_yaml(file_name, content)
|
||||
schema.load(config)
|
||||
self._configs[file_name] = config
|
||||
except ValidationError as exc:
|
||||
exc.messages = {file_name: exc.messages}
|
||||
exceptions.append(exc)
|
||||
|
||||
# validate that the type declared in METADATA_FILE_NAME is correct
|
||||
if metadata:
|
||||
type_validator = validate.Equal(SqlaTable.__name__)
|
||||
try:
|
||||
type_validator(metadata["type"])
|
||||
except ValidationError as exc:
|
||||
exc.messages = {METADATA_FILE_NAME: {"type": exc.messages}}
|
||||
exceptions.append(exc)
|
||||
|
||||
if exceptions:
|
||||
exception = CommandInvalidError("Error importing dataset")
|
||||
exception.add_list(exceptions)
|
||||
raise exception
|
|
@ -0,0 +1,42 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
|
||||
|
||||
def import_dataset(
|
||||
session: Session, config: Dict[str, Any], overwrite: bool = False
|
||||
) -> SqlaTable:
|
||||
existing = session.query(SqlaTable).filter_by(uuid=config["uuid"]).first()
|
||||
if existing:
|
||||
if not overwrite:
|
||||
return existing
|
||||
config["id"] = existing.id
|
||||
|
||||
# should we delete columns and metrics not present in the current import?
|
||||
sync = ["columns", "metrics"] if overwrite else []
|
||||
|
||||
# import recursively to include columns and metrics
|
||||
dataset = SqlaTable.import_from_dict(session, config, recursive=True, sync=sync)
|
||||
if dataset.id is None:
|
||||
session.flush()
|
||||
|
||||
return dataset
|
|
@ -122,3 +122,48 @@ class DatasetRelatedDashboards(Schema):
|
|||
class DatasetRelatedObjectsResponse(Schema):
|
||||
charts = fields.Nested(DatasetRelatedCharts)
|
||||
dashboards = fields.Nested(DatasetRelatedDashboards)
|
||||
|
||||
|
||||
class ImportV1ColumnSchema(Schema):
|
||||
column_name = fields.String(required=True)
|
||||
verbose_name = fields.String()
|
||||
is_dttm = fields.Boolean()
|
||||
is_active = fields.Boolean(allow_none=True)
|
||||
type = fields.String(required=True)
|
||||
groupby = fields.Boolean()
|
||||
filterable = fields.Boolean()
|
||||
expression = fields.String()
|
||||
description = fields.String(allow_none=True)
|
||||
python_date_format = fields.String(allow_none=True)
|
||||
|
||||
|
||||
class ImportV1MetricSchema(Schema):
|
||||
metric_name = fields.String(required=True)
|
||||
verbose_name = fields.String()
|
||||
metric_type = fields.String(allow_none=True)
|
||||
expression = fields.String(required=True)
|
||||
description = fields.String(allow_none=True)
|
||||
d3format = fields.String(allow_none=True)
|
||||
extra = fields.String(allow_none=True)
|
||||
warning_text = fields.String(allow_none=True)
|
||||
|
||||
|
||||
class ImportV1DatasetSchema(Schema):
|
||||
table_name = fields.String(required=True)
|
||||
main_dttm_col = fields.String(allow_none=True)
|
||||
description = fields.String()
|
||||
default_endpoint = fields.String()
|
||||
offset = fields.Integer()
|
||||
cache_timeout = fields.Integer()
|
||||
schema = fields.String()
|
||||
sql = fields.String()
|
||||
params = fields.String(allow_none=True)
|
||||
template_params = fields.String(allow_none=True)
|
||||
filter_select_enabled = fields.Boolean()
|
||||
fetch_values_predicate = fields.String(allow_none=True)
|
||||
extra = fields.String(allow_none=True)
|
||||
uuid = fields.UUID(required=True)
|
||||
columns = fields.List(fields.Nested(ImportV1ColumnSchema))
|
||||
metrics = fields.List(fields.Nested(ImportV1MetricSchema))
|
||||
version = fields.String(required=True)
|
||||
database_uuid = fields.UUID(required=True)
|
||||
|
|
|
@ -163,7 +163,7 @@ class ImportExportMixin:
|
|||
if sync is None:
|
||||
sync = []
|
||||
parent_refs = cls.parent_foreign_key_mappings()
|
||||
export_fields = set(cls.export_fields) | set(parent_refs.keys())
|
||||
export_fields = set(cls.export_fields) | set(parent_refs.keys()) | {"uuid"}
|
||||
new_children = {c: dict_rep[c] for c in cls.export_children if c in dict_rep}
|
||||
unique_constrains = cls._unique_constrains()
|
||||
|
||||
|
|
|
@ -14,16 +14,29 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=no-self-use, invalid-name
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from superset import security_manager
|
||||
from superset import db, security_manager
|
||||
from superset.commands.exceptions import CommandInvalidError
|
||||
from superset.commands.importers.exceptions import IncorrectVersionError
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.databases.commands.exceptions import DatabaseNotFoundError
|
||||
from superset.databases.commands.export import ExportDatabasesCommand
|
||||
from superset.databases.commands.importers.v1 import ImportDatabasesCommand
|
||||
from superset.models.core import Database
|
||||
from superset.utils.core import backend, get_example_database
|
||||
from tests.base_tests import SupersetTestCase
|
||||
from tests.fixtures.importexport import (
|
||||
database_config,
|
||||
database_metadata_config,
|
||||
dataset_config,
|
||||
dataset_metadata_config,
|
||||
)
|
||||
|
||||
|
||||
class TestExportDatabasesCommand(SupersetTestCase):
|
||||
|
@ -265,3 +278,197 @@ class TestExportDatabasesCommand(SupersetTestCase):
|
|||
"uuid",
|
||||
"version",
|
||||
]
|
||||
|
||||
def test_import_v1_database(self):
|
||||
"""Test that a database can be imported"""
|
||||
contents = {
|
||||
"metadata.yaml": yaml.safe_dump(database_metadata_config),
|
||||
"databases/imported_database.yaml": yaml.safe_dump(database_config),
|
||||
}
|
||||
command = ImportDatabasesCommand(contents)
|
||||
command.run()
|
||||
|
||||
database = (
|
||||
db.session.query(Database).filter_by(uuid=database_config["uuid"]).one()
|
||||
)
|
||||
assert database.allow_csv_upload
|
||||
assert database.allow_ctas
|
||||
assert database.allow_cvas
|
||||
assert not database.allow_run_async
|
||||
assert database.cache_timeout is None
|
||||
assert database.database_name == "imported_database"
|
||||
assert database.expose_in_sqllab
|
||||
assert database.extra == "{}"
|
||||
assert database.sqlalchemy_uri == "sqlite:///test.db"
|
||||
|
||||
db.session.delete(database)
|
||||
db.session.commit()
|
||||
|
||||
def test_import_v1_database_multiple(self):
|
||||
"""Test that a database can be imported multiple times"""
|
||||
num_databases = db.session.query(Database).count()
|
||||
|
||||
contents = {
|
||||
"databases/imported_database.yaml": yaml.safe_dump(database_config),
|
||||
"metadata.yaml": yaml.safe_dump(database_metadata_config),
|
||||
}
|
||||
command = ImportDatabasesCommand(contents)
|
||||
|
||||
# import twice
|
||||
command.run()
|
||||
command.run()
|
||||
|
||||
database = (
|
||||
db.session.query(Database).filter_by(uuid=database_config["uuid"]).one()
|
||||
)
|
||||
assert database.allow_csv_upload
|
||||
|
||||
# update allow_csv_upload to False
|
||||
new_config = database_config.copy()
|
||||
new_config["allow_csv_upload"] = False
|
||||
contents = {
|
||||
"databases/imported_database.yaml": yaml.safe_dump(new_config),
|
||||
"metadata.yaml": yaml.safe_dump(database_metadata_config),
|
||||
}
|
||||
command = ImportDatabasesCommand(contents)
|
||||
command.run()
|
||||
|
||||
database = (
|
||||
db.session.query(Database).filter_by(uuid=database_config["uuid"]).one()
|
||||
)
|
||||
assert not database.allow_csv_upload
|
||||
|
||||
# test that only one database was created
|
||||
new_num_databases = db.session.query(Database).count()
|
||||
assert new_num_databases == num_databases + 1
|
||||
|
||||
db.session.delete(database)
|
||||
db.session.commit()
|
||||
|
||||
def test_import_v1_database_with_dataset(self):
|
||||
"""Test that a database can be imported with datasets"""
|
||||
contents = {
|
||||
"databases/imported_database.yaml": yaml.safe_dump(database_config),
|
||||
"datasets/imported_dataset.yaml": yaml.safe_dump(dataset_config),
|
||||
"metadata.yaml": yaml.safe_dump(database_metadata_config),
|
||||
}
|
||||
command = ImportDatabasesCommand(contents)
|
||||
command.run()
|
||||
|
||||
database = (
|
||||
db.session.query(Database).filter_by(uuid=database_config["uuid"]).one()
|
||||
)
|
||||
assert len(database.tables) == 1
|
||||
assert str(database.tables[0].uuid) == "10808100-158b-42c4-842e-f32b99d88dfb"
|
||||
|
||||
db.session.delete(database.tables[0])
|
||||
db.session.delete(database)
|
||||
db.session.commit()
|
||||
|
||||
def test_import_v1_database_with_dataset_multiple(self):
|
||||
"""Test that a database can be imported multiple times w/o changing datasets"""
|
||||
contents = {
|
||||
"databases/imported_database.yaml": yaml.safe_dump(database_config),
|
||||
"datasets/imported_dataset.yaml": yaml.safe_dump(dataset_config),
|
||||
"metadata.yaml": yaml.safe_dump(database_metadata_config),
|
||||
}
|
||||
command = ImportDatabasesCommand(contents)
|
||||
command.run()
|
||||
|
||||
dataset = (
|
||||
db.session.query(SqlaTable).filter_by(uuid=dataset_config["uuid"]).one()
|
||||
)
|
||||
assert dataset.offset == 66
|
||||
|
||||
new_config = dataset_config.copy()
|
||||
new_config["offset"] = 67
|
||||
contents = {
|
||||
"databases/imported_database.yaml": yaml.safe_dump(database_config),
|
||||
"datasets/imported_dataset.yaml": yaml.safe_dump(new_config),
|
||||
"metadata.yaml": yaml.safe_dump(database_metadata_config),
|
||||
}
|
||||
command = ImportDatabasesCommand(contents)
|
||||
command.run()
|
||||
|
||||
# the underlying dataset should not be modified by the second import, since
|
||||
# we're importing a database, not a dataset
|
||||
dataset = (
|
||||
db.session.query(SqlaTable).filter_by(uuid=dataset_config["uuid"]).one()
|
||||
)
|
||||
assert dataset.offset == 66
|
||||
|
||||
db.session.delete(dataset)
|
||||
db.session.delete(dataset.database)
|
||||
db.session.commit()
|
||||
|
||||
def test_import_v1_database_validation(self):
|
||||
"""Test different validations applied when importing a database"""
|
||||
# metadata.yaml must be present
|
||||
contents = {
|
||||
"databases/imported_database.yaml": yaml.safe_dump(database_config),
|
||||
}
|
||||
command = ImportDatabasesCommand(contents)
|
||||
with pytest.raises(IncorrectVersionError) as excinfo:
|
||||
command.run()
|
||||
assert str(excinfo.value) == "Missing metadata.yaml"
|
||||
|
||||
# version should be 1.0.0
|
||||
contents["metadata.yaml"] = yaml.safe_dump(
|
||||
{
|
||||
"version": "2.0.0",
|
||||
"type": "Database",
|
||||
"timestamp": "2020-11-04T21:27:44.423819+00:00",
|
||||
}
|
||||
)
|
||||
command = ImportDatabasesCommand(contents)
|
||||
with pytest.raises(IncorrectVersionError) as excinfo:
|
||||
command.run()
|
||||
assert str(excinfo.value) == "Must be equal to 1.0.0."
|
||||
|
||||
# type should be Database
|
||||
contents["metadata.yaml"] = yaml.safe_dump(dataset_metadata_config)
|
||||
command = ImportDatabasesCommand(contents)
|
||||
with pytest.raises(CommandInvalidError) as excinfo:
|
||||
command.run()
|
||||
assert str(excinfo.value) == "Error importing database"
|
||||
assert excinfo.value.normalized_messages() == {
|
||||
"metadata.yaml": {"type": ["Must be equal to Database."],}
|
||||
}
|
||||
|
||||
# must also validate datasets
|
||||
broken_config = dataset_config.copy()
|
||||
del broken_config["table_name"]
|
||||
contents["metadata.yaml"] = yaml.safe_dump(database_metadata_config)
|
||||
contents["datasets/imported_dataset.yaml"] = yaml.safe_dump(broken_config)
|
||||
command = ImportDatabasesCommand(contents)
|
||||
with pytest.raises(CommandInvalidError) as excinfo:
|
||||
command.run()
|
||||
assert str(excinfo.value) == "Error importing database"
|
||||
assert excinfo.value.normalized_messages() == {
|
||||
"datasets/imported_dataset.yaml": {
|
||||
"table_name": ["Missing data for required field."],
|
||||
}
|
||||
}
|
||||
|
||||
@patch("superset.databases.commands.importers.v1.import_dataset")
|
||||
def test_import_v1_rollback(self, mock_import_dataset):
|
||||
"""Test than on an exception everything is rolled back"""
|
||||
num_databases = db.session.query(Database).count()
|
||||
|
||||
# raise an exception when importing the dataset, after the database has
|
||||
# already been imported
|
||||
mock_import_dataset.side_effect = Exception("A wild exception appears!")
|
||||
|
||||
contents = {
|
||||
"databases/imported_database.yaml": yaml.safe_dump(database_config),
|
||||
"datasets/imported_dataset.yaml": yaml.safe_dump(dataset_config),
|
||||
"metadata.yaml": yaml.safe_dump(database_metadata_config),
|
||||
}
|
||||
command = ImportDatabasesCommand(contents)
|
||||
with pytest.raises(Exception) as excinfo:
|
||||
command.run()
|
||||
assert str(excinfo.value) == "A wild exception appears!"
|
||||
|
||||
# verify that the database was not added
|
||||
new_num_databases = db.session.query(Database).count()
|
||||
assert new_num_databases == num_databases
|
||||
|
|
|
@ -14,18 +14,29 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=no-self-use, invalid-name
|
||||
|
||||
from operator import itemgetter
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from superset import security_manager
|
||||
from superset import db, security_manager
|
||||
from superset.commands.exceptions import CommandInvalidError
|
||||
from superset.commands.importers.exceptions import IncorrectVersionError
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.datasets.commands.exceptions import DatasetNotFoundError
|
||||
from superset.datasets.commands.export import ExportDatasetsCommand
|
||||
from superset.utils.core import backend, get_example_database
|
||||
from superset.datasets.commands.importers.v1 import ImportDatasetsCommand
|
||||
from superset.utils.core import get_example_database
|
||||
from tests.base_tests import SupersetTestCase
|
||||
from tests.fixtures.importexport import (
|
||||
database_config,
|
||||
database_metadata_config,
|
||||
dataset_config,
|
||||
dataset_metadata_config,
|
||||
)
|
||||
|
||||
|
||||
class TestExportDatasetsCommand(SupersetTestCase):
|
||||
|
@ -186,3 +197,149 @@ class TestExportDatasetsCommand(SupersetTestCase):
|
|||
"version",
|
||||
"database_uuid",
|
||||
]
|
||||
|
||||
def test_import_v1_dataset(self):
|
||||
"""Test that we can import a dataset"""
|
||||
contents = {
|
||||
"metadata.yaml": yaml.safe_dump(dataset_metadata_config),
|
||||
"databases/imported_database.yaml": yaml.safe_dump(database_config),
|
||||
"datasets/imported_dataset.yaml": yaml.safe_dump(dataset_config),
|
||||
}
|
||||
command = ImportDatasetsCommand(contents)
|
||||
command.run()
|
||||
|
||||
dataset = (
|
||||
db.session.query(SqlaTable).filter_by(uuid=dataset_config["uuid"]).one()
|
||||
)
|
||||
assert dataset.table_name == "imported_dataset"
|
||||
assert dataset.main_dttm_col is None
|
||||
assert dataset.description == "This is a dataset that was exported"
|
||||
assert dataset.default_endpoint == ""
|
||||
assert dataset.offset == 66
|
||||
assert dataset.cache_timeout == 55
|
||||
assert dataset.schema == ""
|
||||
assert dataset.sql == ""
|
||||
assert dataset.params is None
|
||||
assert dataset.template_params is None
|
||||
assert dataset.filter_select_enabled
|
||||
assert dataset.fetch_values_predicate is None
|
||||
assert dataset.extra is None
|
||||
|
||||
# database is also imported
|
||||
assert str(dataset.database.uuid) == "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89"
|
||||
|
||||
assert len(dataset.metrics) == 1
|
||||
metric = dataset.metrics[0]
|
||||
assert metric.metric_name == "count"
|
||||
assert metric.verbose_name == ""
|
||||
assert metric.metric_type is None
|
||||
assert metric.expression == "count(1)"
|
||||
assert metric.description is None
|
||||
assert metric.d3format is None
|
||||
assert metric.extra is None
|
||||
assert metric.warning_text is None
|
||||
|
||||
assert len(dataset.columns) == 1
|
||||
column = dataset.columns[0]
|
||||
assert column.column_name == "cnt"
|
||||
assert column.verbose_name == "Count of something"
|
||||
assert not column.is_dttm
|
||||
assert column.is_active # imported columns are set to active
|
||||
assert column.type == "NUMBER"
|
||||
assert not column.groupby
|
||||
assert column.filterable
|
||||
assert column.expression == ""
|
||||
assert column.description is None
|
||||
assert column.python_date_format is None
|
||||
|
||||
db.session.delete(dataset)
|
||||
db.session.delete(dataset.database)
|
||||
db.session.commit()
|
||||
|
||||
def test_import_v1_dataset_multiple(self):
|
||||
"""Test that a dataset can be imported multiple times"""
|
||||
contents = {
|
||||
"metadata.yaml": yaml.safe_dump(dataset_metadata_config),
|
||||
"databases/imported_database.yaml": yaml.safe_dump(database_config),
|
||||
"datasets/imported_dataset.yaml": yaml.safe_dump(dataset_config),
|
||||
}
|
||||
command = ImportDatasetsCommand(contents)
|
||||
command.run()
|
||||
command.run()
|
||||
dataset = (
|
||||
db.session.query(SqlaTable).filter_by(uuid=dataset_config["uuid"]).one()
|
||||
)
|
||||
assert dataset.table_name == "imported_dataset"
|
||||
|
||||
# test that columns and metrics sync, ie, old ones not the import
|
||||
# are removed
|
||||
new_config = dataset_config.copy()
|
||||
new_config["metrics"][0]["metric_name"] = "count2"
|
||||
new_config["columns"][0]["column_name"] = "cnt2"
|
||||
contents = {
|
||||
"metadata.yaml": yaml.safe_dump(dataset_metadata_config),
|
||||
"databases/imported_database.yaml": yaml.safe_dump(database_config),
|
||||
"datasets/imported_dataset.yaml": yaml.safe_dump(new_config),
|
||||
}
|
||||
command = ImportDatasetsCommand(contents)
|
||||
command.run()
|
||||
dataset = (
|
||||
db.session.query(SqlaTable).filter_by(uuid=dataset_config["uuid"]).one()
|
||||
)
|
||||
assert len(dataset.metrics) == 1
|
||||
assert dataset.metrics[0].metric_name == "count2"
|
||||
assert len(dataset.columns) == 1
|
||||
assert dataset.columns[0].column_name == "cnt2"
|
||||
|
||||
db.session.delete(dataset)
|
||||
db.session.delete(dataset.database)
|
||||
db.session.commit()
|
||||
|
||||
def test_import_v1_dataset_validation(self):
|
||||
"""Test different validations applied when importing a dataset"""
|
||||
# metadata.yaml must be present
|
||||
contents = {
|
||||
"datasets/imported_dataset.yaml": yaml.safe_dump(dataset_config),
|
||||
}
|
||||
command = ImportDatasetsCommand(contents)
|
||||
with pytest.raises(IncorrectVersionError) as excinfo:
|
||||
command.run()
|
||||
assert str(excinfo.value) == "Missing metadata.yaml"
|
||||
|
||||
# version should be 1.0.0
|
||||
contents["metadata.yaml"] = yaml.safe_dump(
|
||||
{
|
||||
"version": "2.0.0",
|
||||
"type": "SqlaTable",
|
||||
"timestamp": "2020-11-04T21:27:44.423819+00:00",
|
||||
}
|
||||
)
|
||||
command = ImportDatasetsCommand(contents)
|
||||
with pytest.raises(IncorrectVersionError) as excinfo:
|
||||
command.run()
|
||||
assert str(excinfo.value) == "Must be equal to 1.0.0."
|
||||
|
||||
# type should be SqlaTable
|
||||
contents["metadata.yaml"] = yaml.safe_dump(database_metadata_config)
|
||||
command = ImportDatasetsCommand(contents)
|
||||
with pytest.raises(CommandInvalidError) as excinfo:
|
||||
command.run()
|
||||
assert str(excinfo.value) == "Error importing dataset"
|
||||
assert excinfo.value.normalized_messages() == {
|
||||
"metadata.yaml": {"type": ["Must be equal to SqlaTable."],}
|
||||
}
|
||||
|
||||
# must also validate databases
|
||||
broken_config = database_config.copy()
|
||||
del broken_config["database_name"]
|
||||
contents["metadata.yaml"] = yaml.safe_dump(dataset_metadata_config)
|
||||
contents["databases/imported_database.yaml"] = yaml.safe_dump(broken_config)
|
||||
command = ImportDatasetsCommand(contents)
|
||||
with pytest.raises(CommandInvalidError) as excinfo:
|
||||
command.run()
|
||||
assert str(excinfo.value) == "Error importing dataset"
|
||||
assert excinfo.value.normalized_messages() == {
|
||||
"databases/imported_database.yaml": {
|
||||
"database_name": ["Missing data for required field."],
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
# example YAML files
|
||||
database_metadata_config: Dict[str, Any] = {
|
||||
"version": "1.0.0",
|
||||
"type": "Database",
|
||||
"timestamp": "2020-11-04T21:27:44.423819+00:00",
|
||||
}
|
||||
|
||||
dataset_metadata_config: Dict[str, Any] = {
|
||||
"version": "1.0.0",
|
||||
"type": "SqlaTable",
|
||||
"timestamp": "2020-11-04T21:27:44.423819+00:00",
|
||||
}
|
||||
|
||||
database_config: Dict[str, Any] = {
|
||||
"allow_csv_upload": True,
|
||||
"allow_ctas": True,
|
||||
"allow_cvas": True,
|
||||
"allow_run_async": False,
|
||||
"cache_timeout": None,
|
||||
"database_name": "imported_database",
|
||||
"expose_in_sqllab": True,
|
||||
"extra": {},
|
||||
"sqlalchemy_uri": "sqlite:///test.db",
|
||||
"uuid": "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89",
|
||||
"version": "1.0.0",
|
||||
}
|
||||
|
||||
dataset_config: Dict[str, Any] = {
|
||||
"table_name": "imported_dataset",
|
||||
"main_dttm_col": None,
|
||||
"description": "This is a dataset that was exported",
|
||||
"default_endpoint": "",
|
||||
"offset": 66,
|
||||
"cache_timeout": 55,
|
||||
"schema": "",
|
||||
"sql": "",
|
||||
"params": None,
|
||||
"template_params": None,
|
||||
"filter_select_enabled": True,
|
||||
"fetch_values_predicate": None,
|
||||
"extra": None,
|
||||
"metrics": [
|
||||
{
|
||||
"metric_name": "count",
|
||||
"verbose_name": "",
|
||||
"metric_type": None,
|
||||
"expression": "count(1)",
|
||||
"description": None,
|
||||
"d3format": None,
|
||||
"extra": None,
|
||||
"warning_text": None,
|
||||
},
|
||||
],
|
||||
"columns": [
|
||||
{
|
||||
"column_name": "cnt",
|
||||
"verbose_name": "Count of something",
|
||||
"is_dttm": False,
|
||||
"is_active": None,
|
||||
"type": "NUMBER",
|
||||
"groupby": False,
|
||||
"filterable": True,
|
||||
"expression": "",
|
||||
"description": None,
|
||||
"python_date_format": None,
|
||||
},
|
||||
],
|
||||
"version": "1.0.0",
|
||||
"uuid": "10808100-158b-42c4-842e-f32b99d88dfb",
|
||||
"database_uuid": "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89",
|
||||
}
|
Loading…
Reference in New Issue