feat: API endpoint to import charts (#11744)

* ImportChartsCommand

* feat: API endpoint to import charts

* Add dispatcher

* Fix docstring
This commit is contained in:
Beto Dealmeida 2020-11-20 14:40:27 -08:00 committed by GitHub
parent 2f4f87795d
commit a3a2a68f01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 223 additions and 13 deletions

View File

@ -44,6 +44,7 @@ from superset.charts.commands.exceptions import (
ChartUpdateFailedError,
)
from superset.charts.commands.export import ExportChartsCommand
from superset.charts.commands.importers.dispatcher import ImportChartsCommand
from superset.charts.commands.update import UpdateChartCommand
from superset.charts.dao import ChartDAO
from superset.charts.filters import ChartAllTextFilter, ChartFavoriteFilter, ChartFilter
@ -59,6 +60,7 @@ from superset.charts.schemas import (
screenshot_query_schema,
thumbnail_query_schema,
)
from superset.commands.exceptions import CommandInvalidError
from superset.constants import RouteMethod
from superset.exceptions import SupersetSecurityException
from superset.extensions import event_logger
@ -86,6 +88,7 @@ class ChartRestApi(BaseSupersetModelRestApi):
include_route_methods = RouteMethod.REST_MODEL_VIEW_CRUD_SET | {
RouteMethod.EXPORT,
RouteMethod.IMPORT,
RouteMethod.RELATED,
"bulk_delete", # not using RouteMethod since locally defined
"data",
@ -823,3 +826,56 @@ class ChartRestApi(BaseSupersetModelRestApi):
for request_id in requested_ids
]
return self.response(200, result=res)
@expose("/import/", methods=["POST"])
@protect()
@safe
@statsd_metrics
def import_(self) -> Response:
"""Import chart(s) with associated datasets and databases
---
post:
requestBody:
content:
application/zip:
schema:
type: string
format: binary
responses:
200:
description: Chart import result
content:
application/json:
schema:
type: object
properties:
message:
type: string
400:
$ref: '#/components/responses/400'
401:
$ref: '#/components/responses/401'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
upload = request.files.get("file")
if not upload:
return self.response_400()
with ZipFile(upload) as bundle:
contents = {
file_name: bundle.read(file_name).decode()
for file_name in bundle.namelist()
}
command = ImportChartsCommand(contents)
try:
command.run()
return self.response(200, message="OK")
except CommandInvalidError as exc:
logger.warning("Import chart failed")
return self.response_422(message=exc.normalized_messages())
except Exception as exc: # pylint: disable=broad-except
logger.exception("Import chart failed")
return self.response_500(message=str(exc))

View File

@ -0,0 +1,70 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Dict
from marshmallow.exceptions import ValidationError
from superset.charts.commands.importers import v1
from superset.commands.base import BaseCommand
from superset.commands.exceptions import CommandInvalidError
from superset.commands.importers.exceptions import IncorrectVersionError
logger = logging.getLogger(__name__)
command_versions = [
v1.ImportChartsCommand,
]
class ImportChartsCommand(BaseCommand):
"""
Import charts.
This command dispatches the import to different versions of the command
until it finds one that matches.
"""
# pylint: disable=unused-argument
def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
self.contents = contents
def run(self) -> None:
# iterate over all commands until we find a version that can
# handle the contents
for version in command_versions:
command = version(self.contents)
try:
command.run()
return
except IncorrectVersionError:
# file is not handled by command, skip
pass
except (CommandInvalidError, ValidationError) as exc:
# found right version, but file is invalid
logger.info("Command failed validation")
raise exc
except Exception as exc:
# validation succeeded but something went wrong
logger.exception("Error running import command")
raise exc
raise CommandInvalidError("Could not find a valid command to import file")
def validate(self) -> None:
pass

View File

@ -21,11 +21,12 @@ from typing import List, Optional
from datetime import datetime
from io import BytesIO
from unittest import mock
from zipfile import is_zipfile
from zipfile import is_zipfile, ZipFile
import humanize
import prison
import pytest
import yaml
from sqlalchemy import and_
from sqlalchemy.sql import func
@ -35,12 +36,19 @@ from tests.fixtures.unicode_dashboard import load_unicode_dashboard_with_slice
from tests.test_app import app
from superset.connectors.connector_registry import ConnectorRegistry
from superset.extensions import db, security_manager
from superset.models.core import FavStar, FavStarClassName
from superset.models.core import Database, FavStar, FavStarClassName
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.utils import core as utils
from tests.base_api_tests import ApiOwnersTestCaseMixin
from tests.base_tests import SupersetTestCase
from tests.fixtures.importexport import (
chart_config,
chart_metadata_config,
database_config,
dataset_config,
dataset_metadata_config,
)
from tests.fixtures.query_context import get_query_context
CHART_DATA_URI = "api/v1/chart/data"
@ -1131,7 +1139,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
def test_export_chart(self):
"""
Chart API: Test export dataset
Chart API: Test export chart
"""
example_chart = db.session.query(Slice).all()[0]
argument = [example_chart.id]
@ -1147,7 +1155,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
def test_export_chart_not_found(self):
"""
Dataset API: Test export dataset not found
Chart API: Test export chart not found
"""
# Just one does not exist and we get 404
argument = [-1, 1]
@ -1159,7 +1167,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
def test_export_chart_gamma(self):
"""
Dataset API: Test export dataset has gamma
Chart API: Test export chart has gamma
"""
example_chart = db.session.query(Slice).all()[0]
argument = [example_chart.id]
@ -1169,3 +1177,79 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
rv = self.client.get(uri)
assert rv.status_code == 404
def test_import_chart(self):
"""
Chart API: Test import chart
"""
self.login(username="admin")
uri = "api/v1/chart/import/"
buf = BytesIO()
with ZipFile(buf, "w") as bundle:
with bundle.open("metadata.yaml", "w") as fp:
fp.write(yaml.safe_dump(chart_metadata_config).encode())
with bundle.open("databases/imported_database.yaml", "w") as fp:
fp.write(yaml.safe_dump(database_config).encode())
with bundle.open("datasets/imported_dataset.yaml", "w") as fp:
fp.write(yaml.safe_dump(dataset_config).encode())
with bundle.open("charts/imported_chart.yaml", "w") as fp:
fp.write(yaml.safe_dump(chart_config).encode())
buf.seek(0)
form_data = {
"file": (buf, "chart_export.zip"),
}
rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
assert response == {"message": "OK"}
database = (
db.session.query(Database).filter_by(uuid=database_config["uuid"]).one()
)
assert database.database_name == "imported_database"
assert len(database.tables) == 1
dataset = database.tables[0]
assert dataset.table_name == "imported_dataset"
assert str(dataset.uuid) == dataset_config["uuid"]
chart = db.session.query(Slice).filter_by(uuid=chart_config["uuid"]).one()
assert chart.table == dataset
db.session.delete(chart)
db.session.delete(dataset)
db.session.delete(database)
db.session.commit()
def test_import_chart_invalid(self):
"""
Chart API: Test import invalid chart
"""
self.login(username="admin")
uri = "api/v1/chart/import/"
buf = BytesIO()
with ZipFile(buf, "w") as bundle:
with bundle.open("metadata.yaml", "w") as fp:
fp.write(yaml.safe_dump(dataset_metadata_config).encode())
with bundle.open("databases/imported_database.yaml", "w") as fp:
fp.write(yaml.safe_dump(database_config).encode())
with bundle.open("datasets/imported_dataset.yaml", "w") as fp:
fp.write(yaml.safe_dump(dataset_config).encode())
with bundle.open("charts/imported_chart.yaml", "w") as fp:
fp.write(yaml.safe_dump(chart_config).encode())
buf.seek(0)
form_data = {
"file": (buf, "chart_export.zip"),
}
rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 422
assert response == {
"message": {"metadata.yaml": {"type": ["Must be equal to Slice."]}}
}

View File

@ -840,7 +840,7 @@ class TestDatabaseApi(SupersetTestCase):
fp.write(yaml.safe_dump(database_metadata_config).encode())
with bundle.open("databases/imported_database.yaml", "w") as fp:
fp.write(yaml.safe_dump(database_config).encode())
with bundle.open("datasets/import_dataset.yaml", "w") as fp:
with bundle.open("datasets/imported_dataset.yaml", "w") as fp:
fp.write(yaml.safe_dump(dataset_config).encode())
buf.seek(0)
@ -880,7 +880,7 @@ class TestDatabaseApi(SupersetTestCase):
fp.write(yaml.safe_dump(dataset_metadata_config).encode())
with bundle.open("databases/imported_database.yaml", "w") as fp:
fp.write(yaml.safe_dump(database_config).encode())
with bundle.open("datasets/import_dataset.yaml", "w") as fp:
with bundle.open("datasets/imported_dataset.yaml", "w") as fp:
fp.write(yaml.safe_dump(dataset_config).encode())
buf.seek(0)

View File

@ -1176,7 +1176,7 @@ class TestDatasetApi(SupersetTestCase):
for table_name in self.fixture_tables_names:
assert table_name in [ds["table_name"] for ds in data["result"]]
def test_import_dataset(self):
def test_imported_dataset(self):
"""
Dataset API: Test import dataset
"""
@ -1189,7 +1189,7 @@ class TestDatasetApi(SupersetTestCase):
fp.write(yaml.safe_dump(dataset_metadata_config).encode())
with bundle.open("databases/imported_database.yaml", "w") as fp:
fp.write(yaml.safe_dump(database_config).encode())
with bundle.open("datasets/import_dataset.yaml", "w") as fp:
with bundle.open("datasets/imported_dataset.yaml", "w") as fp:
fp.write(yaml.safe_dump(dataset_config).encode())
buf.seek(0)
@ -1216,7 +1216,7 @@ class TestDatasetApi(SupersetTestCase):
db.session.delete(database)
db.session.commit()
def test_import_dataset_invalid(self):
def test_imported_dataset_invalid(self):
"""
Dataset API: Test import invalid dataset
"""
@ -1229,7 +1229,7 @@ class TestDatasetApi(SupersetTestCase):
fp.write(yaml.safe_dump(database_metadata_config).encode())
with bundle.open("databases/imported_database.yaml", "w") as fp:
fp.write(yaml.safe_dump(database_config).encode())
with bundle.open("datasets/import_dataset.yaml", "w") as fp:
with bundle.open("datasets/imported_dataset.yaml", "w") as fp:
fp.write(yaml.safe_dump(dataset_config).encode())
buf.seek(0)
@ -1244,7 +1244,7 @@ class TestDatasetApi(SupersetTestCase):
"message": {"metadata.yaml": {"type": ["Must be equal to SqlaTable."]}}
}
def test_import_dataset_invalid_v0_validation(self):
def test_imported_dataset_invalid_v0_validation(self):
"""
Dataset API: Test import invalid dataset
"""
@ -1255,7 +1255,7 @@ class TestDatasetApi(SupersetTestCase):
with ZipFile(buf, "w") as bundle:
with bundle.open("databases/imported_database.yaml", "w") as fp:
fp.write(yaml.safe_dump(database_config).encode())
with bundle.open("datasets/import_dataset.yaml", "w") as fp:
with bundle.open("datasets/imported_dataset.yaml", "w") as fp:
fp.write(yaml.safe_dump(dataset_config).encode())
buf.seek(0)