chore: refactor import command (#19216)

This commit is contained in:
Beto Dealmeida 2022-03-16 18:01:31 -07:00 committed by GitHub
parent b7a0559aaf
commit a4848a2f46
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 63 additions and 46 deletions

View File

@ -24,9 +24,11 @@ from superset import db
from superset.commands.base import BaseCommand
from superset.commands.exceptions import CommandException, CommandInvalidError
from superset.commands.importers.v1.utils import (
load_configs,
load_metadata,
load_yaml,
METADATA_FILE_NAME,
validate_metadata_type,
)
from superset.dao.base import BaseDAO
from superset.models.core import Database
@ -78,9 +80,13 @@ class ImportModelsCommand(BaseCommand):
except ValidationError as exc:
exceptions.append(exc)
metadata = None
if self.dao.model_cls:
validate_metadata_type(metadata, self.dao.model_cls.__name__, exceptions)
self._validate_metadata_type(metadata, exceptions)
self._load__configs(exceptions)
# load the configs and make sure we have confirmation to overwrite existing models
self._configs = load_configs(
self.contents, self.schemas, self.passwords, exceptions
)
self._prevent_overwrite_existing_model(exceptions)
if exceptions:
@ -88,49 +94,6 @@ class ImportModelsCommand(BaseCommand):
exception.add_list(exceptions)
raise exception
def _validate_metadata_type(
self, metadata: Optional[Dict[str, str]], exceptions: List[ValidationError]
) -> None:
"""Validate that the type declared in METADATA_FILE_NAME is correct"""
if metadata and "type" in metadata:
type_validator = validate.Equal(self.dao.model_cls.__name__) # type: ignore
try:
type_validator(metadata["type"])
except ValidationError as exc:
exc.messages = {METADATA_FILE_NAME: {"type": exc.messages}}
exceptions.append(exc)
def _load__configs(self, exceptions: List[ValidationError]) -> None:
# load existing databases so we can apply the password validation
db_passwords: Dict[str, str] = {
str(uuid): password
for uuid, password in db.session.query(
Database.uuid, Database.password
).all()
}
for file_name, content in self.contents.items():
# skip directories
if not content:
continue
prefix = file_name.split("/")[0]
schema = self.schemas.get(f"{prefix}/")
if schema:
try:
config = load_yaml(file_name, content)
# populate passwords from the request or from existing DBs
if file_name in self.passwords:
config["password"] = self.passwords[file_name]
elif prefix == "databases" and config["uuid"] in db_passwords:
config["password"] = db_passwords[config["uuid"]]
schema.load(config)
self._configs[file_name] = config
except ValidationError as exc:
exc.messages = {file_name: exc.messages}
exceptions.append(exc)
def _prevent_overwrite_existing_model( # pylint: disable=invalid-name
self, exceptions: List[ValidationError]
) -> None:

View File

@ -15,14 +15,16 @@
import logging
from pathlib import Path
from typing import Any, Dict
from typing import Any, Dict, List, Optional
from zipfile import ZipFile
import yaml
from marshmallow import fields, Schema, validate
from marshmallow.exceptions import ValidationError
from superset import db
from superset.commands.importers.exceptions import IncorrectVersionError
from superset.models.core import Database
METADATA_FILE_NAME = "metadata.yaml"
IMPORT_VERSION = "1.0.0"
@ -76,6 +78,58 @@ def load_metadata(contents: Dict[str, str]) -> Dict[str, str]:
return metadata
def validate_metadata_type(
metadata: Optional[Dict[str, str]], type_: str, exceptions: List[ValidationError],
) -> None:
"""Validate that the type declared in METADATA_FILE_NAME is correct"""
if metadata and "type" in metadata:
type_validator = validate.Equal(type_)
try:
type_validator(metadata["type"])
except ValidationError as exc:
exc.messages = {METADATA_FILE_NAME: {"type": exc.messages}}
exceptions.append(exc)
def load_configs(
contents: Dict[str, str],
schemas: Dict[str, Schema],
passwords: Dict[str, str],
exceptions: List[ValidationError],
) -> Dict[str, Any]:
configs: Dict[str, Any] = {}
# load existing databases so we can apply the password validation
db_passwords: Dict[str, str] = {
str(uuid): password
for uuid, password in db.session.query(Database.uuid, Database.password).all()
}
for file_name, content in contents.items():
# skip directories
if not content:
continue
prefix = file_name.split("/")[0]
schema = schemas.get(f"{prefix}/")
if schema:
try:
config = load_yaml(file_name, content)
# populate passwords from the request or from existing DBs
if file_name in passwords:
config["password"] = passwords[file_name]
elif prefix == "databases" and config["uuid"] in db_passwords:
config["password"] = db_passwords[config["uuid"]]
schema.load(config)
configs[file_name] = config
except ValidationError as exc:
exc.messages = {file_name: exc.messages}
exceptions.append(exc)
return configs
def is_valid_config(file_name: str) -> bool:
path = Path(file_name)