revisions

This commit is contained in:
Arash 2021-04-07 12:32:13 -04:00
parent a8416b00b6
commit b2ec820adb
9 changed files with 71 additions and 74 deletions

View File

@ -21,14 +21,16 @@ from io import BytesIO
from typing import Any
from zipfile import ZipFile
from flask import g, Response, send_file, request
from flask import g, request, Response, send_file
from flask_appbuilder.api import expose, protect, rison, safe
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_babel import ngettext
from marshmallow import ValidationError
from superset.commands.exceptions import CommandInvalidError
from superset.commands.importers.v1.utils import get_contents_from_bundle
from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod
from superset.databases.filters import DatabaseFilter
from superset.extensions import event_logger
from superset.models.sql_lab import SavedQuery
from superset.queries.saved_queries.commands.bulk_delete import (
BulkDeleteSavedQueryCommand,
@ -36,12 +38,11 @@ from superset.queries.saved_queries.commands.bulk_delete import (
from superset.queries.saved_queries.commands.exceptions import (
SavedQueryBulkDeleteFailedError,
SavedQueryNotFoundError,
SavedQueryImportError,
SavedQueryImportError,
SavedQueryInvalidError,
)
from superset.queries.saved_queries.commands.export import ExportSavedQueriesCommand
from superset.queries.saved_queries.commands.importers.dispatcher import ImportSavedQueriesCommand
from superset.queries.saved_queries.commands.importers.dispatcher import (
ImportSavedQueriesCommand,
)
from superset.queries.saved_queries.filters import (
SavedQueryAllTextFilter,
SavedQueryFavoriteFilter,
@ -52,9 +53,6 @@ from superset.queries.saved_queries.schemas import (
get_export_ids_schema,
openapi_spec_methods_override,
)
from superset.commands.exceptions import CommandInvalidError
from superset.commands.importers.v1.utils import get_contents_from_bundle
from superset.extensions import event_logger
from superset.views.base_api import BaseSupersetModelRestApi, statsd_metrics
logger = logging.getLogger(__name__)
@ -262,6 +260,7 @@ class SavedQueryRestApi(BaseSupersetModelRestApi):
as_attachment=True,
attachment_filename=filename,
)
@expose("/import/", methods=["POST"])
@protect()
@safe
@ -271,7 +270,7 @@ class SavedQueryRestApi(BaseSupersetModelRestApi):
log_to_statsd=False,
)
def import_(self) -> Response:
"""Import Saved Queries with associated datasets and databases
"""Import Saved Queries with associated databases
---
post:
requestBody:
@ -289,7 +288,7 @@ class SavedQueryRestApi(BaseSupersetModelRestApi):
description: JSON map of passwords for each file
type: string
overwrite:
description: overwrite existing databases?
description: overwrite existing saved queries?
type: bool
responses:
200:

View File

@ -17,10 +17,10 @@
from flask_babel import lazy_gettext as _
from superset.commands.exceptions import (
CommandException,
CommandException,
CommandInvalidError,
DeleteFailedError,
ImportFailedError
ImportFailedError,
)
@ -31,9 +31,10 @@ class SavedQueryBulkDeleteFailedError(DeleteFailedError):
class SavedQueryNotFoundError(CommandException):
message = _("Saved query not found.")
class SavedQueryImportError(ImportFailedError):
message = _("Import saved query failed for an unknown reason.")
class SavedQueryInvalidError(CommandInvalidError):
message = _("Saved query parameters are invalid.")

View File

@ -30,6 +30,8 @@ logger = logging.getLogger(__name__)
command_versions = [
v1.ImportSavedQueriesCommand,
]
class ImportSavedQueriesCommand(BaseCommand):
"""
Import Saved Queries
@ -54,7 +56,7 @@ class ImportSavedQueriesCommand(BaseCommand):
return
except IncorrectVersionError:
logger.debug("File not handled by command, skipping")
except(CommandInvalidError, ValidationError) as exc:
except (CommandInvalidError, ValidationError) as exc:
# found right version, but file is invalid
logger.exception("Error running import command")
raise exc

View File

@ -20,22 +20,25 @@ from typing import Any, Dict, Set
from marshmallow import Schema
from sqlalchemy.orm import Session
from superset.queries.saved_queries.commands.exceptions import SavedQueryImportError
from superset.commands.importers.v1 import ImportModelsCommand
from superset.connectors.sqla.models import SqlaTable
from superset.databases.commands.importers.v1.utils import import_database
from superset.datasets.commands.importers.v1.utils import import_dataset
from superset.datasets.schemas import ImportV1DatasetSchema
from superset.queries.saved_queries.commands.exceptions import SavedQueryImportError
from superset.queries.saved_queries.commands.importers.v1.utils import (
import_saved_query,
)
from superset.queries.saved_queries.dao import SavedQueryDAO
from superset.queries.saved_queries.commands.importers.v1.utils import import_saved_query
from superset.queries.saved_queries.schemas import ImportV1SavedQuerySchema
class ImportSavedQueriesCommand(ImportModelsCommand):
"""Import Saved Queries"""
dao = SavedQueryDAO
model_name= "saved_queries"
prefix ="saved_queries/"
model_name = "saved_queries"
prefix = "queries/"
schemas: Dict[str, Schema] = {
"datasets/": ImportV1DatasetSchema(),
"queries/": ImportV1SavedQuerySchema(),
@ -59,17 +62,11 @@ class ImportSavedQueriesCommand(ImportModelsCommand):
database = import_database(session, config, overwrite=False)
database_ids[str(database.uuid)] = database.id
# import saved queries with the correct parent ref
for file_name, config in configs.items():
if file_name.startswith("queries/") and config["database_uuid"] in database:
# update datasource id, type, and name
database = database[config["dataset_uuid"]]
config.update(
{
"datasource_id": database.id,
"datasource_name": database.table_name,
}
)
config["params"].update({"datasource": database.uid})
if (
file_name.startswith("queries/")
and config["database_uuid"] in database_ids
):
config["db_id"] = database_ids[config["database_uuid"]]
import_saved_query(session, config, overwrite=overwrite)

View File

@ -21,10 +21,11 @@ from sqlalchemy.orm import Session
from superset.models.sql_lab import SavedQuery
def import_saved_query(
session: Session, config: Dict[str, Any], overwrite: bool = False
) -> SavedQuery:
existing = session.query(SavedQuery).filter_by(uuid= config["uuid"]).first()
existing = session.query(SavedQuery).filter_by(uuid=config["uuid"]).first()
if existing:
if not overwrite:
return existing

View File

@ -14,8 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from flask_babel import lazy_gettext as _
from marshmallow import fields, Schema, ValidationError
from marshmallow import fields, Schema
from marshmallow.validate import Length
openapi_spec_methods_override = {
@ -36,11 +35,12 @@ openapi_spec_methods_override = {
get_delete_ids_schema = {"type": "array", "items": {"type": "integer"}}
get_export_ids_schema = {"type": "array", "items": {"type": "integer"}}
class ImportV1SavedQuerySchema(Schema):
schema = fields.String(allow_none=True, validate= Length(0, 128))
label = fields.String(allow_none=True, validate= Length(0,256))
description = fields.String(allow_none = True)
sql = fields.String(required= True)
schema = fields.String(allow_none=True, validate=Length(0, 128))
label = fields.String(allow_none=True, validate=Length(0, 256))
description = fields.String(allow_none=True)
sql = fields.String(required=True)
uuid = fields.UUID(required=True)
version = fields.String(required=True)
database_uuid = fields.UUID(required=True)

View File

@ -346,7 +346,7 @@ dashboard_metadata_config: Dict[str, Any] = {
saved_queries_metadata_config: Dict[str, Any] = {
"version": "1.0.0",
"type": "SavedQuery",
"timestamp": "2021-03-30T20:37:54.791187+00:00"
"timestamp": "2021-03-30T20:37:54.791187+00:00",
}
database_config: Dict[str, Any] = {
"allow_csv_upload": True,
@ -510,5 +510,5 @@ saved_queries_config = {
"sql": "-- Note: Unless you save your query, these tabs will NOT persist if you clear\nyour cookies or change browsers.\n\n\nSELECT * from birth_names",
"uuid": "05b679b5-8eaf-452c-b874-a7a774cfa4e9",
"version": "1.0.0",
"database_uuid": "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89"
"database_uuid": "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89",
}

View File

@ -761,12 +761,14 @@ class TestSavedQueryApi(SupersetTestCase):
"saved_query_export/databases/imported_database.yaml", "w"
) as fp:
fp.write(yaml.safe_dump(database_config).encode())
with bundle.open("saved_query_export/queries/imported_database/public/imported_saved_query.yaml", "w") as fp:
with bundle.open(
"saved_query_export/queries/imported_database/public/imported_saved_query.yaml",
"w",
) as fp:
fp.write(yaml.safe_dump(saved_queries_config).encode())
buf.seek(0)
return buf
@pytest.mark.usefixtures("create_saved_queries")
def test_import_saved_queries(self):
"""
Saved Query API: Test import
@ -791,8 +793,8 @@ class TestSavedQueryApi(SupersetTestCase):
assert len(database.tables) == 1
saved_query = (
db.session
.query(SavedQuery)
.filter_by(uuid=saved_queries_config["uuid"]).one()
db.session.query(SavedQuery)
.filter_by(uuid=saved_queries_config["uuid"])
.one()
)
assert saved_query.database == database

View File

@ -21,15 +21,15 @@ import pytest
import yaml
from superset import db, security_manager
from superset.queries.saved_queries.commands.importers.v1 import (
ImportSavedQueriesCommand
)
from superset.commands.exceptions import CommandInvalidError
from superset.commands.importers.exceptions import IncorrectVersionError
from superset.models.sql_lab import SavedQuery
from superset.models.core import Database
from superset.models.sql_lab import SavedQuery
from superset.queries.saved_queries.commands.exceptions import SavedQueryNotFoundError
from superset.queries.saved_queries.commands.export import ExportSavedQueriesCommand
from superset.queries.saved_queries.commands.importers.v1 import (
ImportSavedQueriesCommand,
)
from superset.utils.core import get_example_database
from tests.base_tests import SupersetTestCase
from tests.fixtures.importexport import (
@ -39,6 +39,7 @@ from tests.fixtures.importexport import (
saved_queries_metadata_config,
)
class TestExportSavedQueriesCommand(SupersetTestCase):
def setUp(self):
self.example_database = get_example_database()
@ -120,33 +121,26 @@ class TestExportSavedQueriesCommand(SupersetTestCase):
"version",
"database_uuid",
]
class TestImportSavedQueriesCommand(SupersetTestCase):
def test_import_v1_saved_queries(self):
"""Test that we can import a saved query"""
contents = {
"metadata.yaml": yaml.safe_dump(saved_queries_metadata_config),
"databases/imported_database.yaml": yaml.safe_dump(database_config),
"queries/imported_query.yaml": yaml.safe_dump(saved_queries_config)
"queries/imported_query.yaml": yaml.safe_dump(saved_queries_config),
}
command = ImportSavedQueriesCommand(contents)
command.run()
saved_query = db.session.query(SavedQuery).filter_by(
uuid=saved_queries_config["uuid"]
).one()
assert saved_query.schema == "public"
assert saved_query.sql == (
"""
-- Note: Unless you save your query,
these tabs will NOT persist if you clear
your cookies or change browsers.
SELECT * from birth_names
"""
saved_query = (
db.session.query(SavedQuery)
.filter_by(uuid=saved_queries_config["uuid"])
.one()
)
assert saved_query.uuid == "05b679b5-8eaf-452c-b874-a7a774cfa4e9"
assert saved_query.database_uuid == "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89"
assert saved_query.schema == "public"
database = (
db.session.query(Database).filter_by(uuid=database_config["uuid"]).one()
@ -155,12 +149,13 @@ class TestImportSavedQueriesCommand(SupersetTestCase):
db.session.delete(saved_query)
db.session.delete(database)
db.session.commit()
def test_import_v1_saved_queries_multiple(self):
"""Test that a saved query can be imported multiple times"""
contents = {
"metadata.yaml": yaml.safe_dump(saved_queries_metadata_config),
"databases/imported_database.yaml": yaml.safe_dump(database_config),
"queries/imported_query.yaml": yaml.safe_dump(saved_queries_config)
"queries/imported_query.yaml": yaml.safe_dump(saved_queries_config),
}
command = ImportSavedQueriesCommand(contents, overwrite=True)
command.run()
@ -168,27 +163,28 @@ class TestImportSavedQueriesCommand(SupersetTestCase):
database = (
db.session.query(SavedQuery).filter_by(uuid=database_config["uuid"]).one()
)
saved_query = db.session.query(SavedQuery).filter_by(datasource_id=database.id).all()
saved_query = (
db.session.query(SavedQuery).filter_by(datasource_id=database.id).all()
)
assert len(saved_query) == 1
db.session.delete(saved_query[0])
db.session.delete(database)
db.session.commit()
def test_import_v1_saved_queries_validation(self):
"""Test different validations applied when importing a chart"""
"""Test different validations applied when importing a saved query"""
# metadata.yaml must be present
contents = {
"metadata.yaml": yaml.safe_dump(saved_queries_metadata_config),
"databases/imported_database.yaml": yaml.safe_dump(database_config),
"queries/imported_query.yaml": yaml.safe_dump(saved_queries_config)
"queries/imported_query.yaml": yaml.safe_dump(saved_queries_config),
}
command = ImportSavedQueriesCommand(contents)
with pytest.raises(IncorrectVersionError) as excinfo:
command.run()
assert str(excinfo.value) == "Missing metadata.yaml"
#version should be 1.0.0
# version should be 1.0.0
contents["metadata.yaml"] = yaml.safe_dump(
{
"version": "2.0.0",
@ -201,7 +197,7 @@ class TestImportSavedQueriesCommand(SupersetTestCase):
command.run()
assert str(excinfo.value) == "Must be equal to 1.0.0"
#type should be a SavedQuery
# type should be a SavedQuery
contents["metadata.yaml"] = yaml.safe_dump(database_metadata_config)
command = ImportSavedQueriesCommand(contents)
with pytest.raises(CommandInvalidError) as excinfo:
@ -211,7 +207,7 @@ class TestImportSavedQueriesCommand(SupersetTestCase):
"metadata.yaml": {"type": ["Must be equal to SavedQuery."]}
}
# must also validate databases
# must also validate databases
broken_config = database_config.copy()
del broken_config["database_name"]
contents["metadata.yaml"] = yaml.safe_dump(saved_queries_metadata_config)
@ -219,10 +215,9 @@ class TestImportSavedQueriesCommand(SupersetTestCase):
command = ImportSavedQueriesCommand(contents)
with pytest.raises(CommandInvalidError) as excinfo:
command.run()
assert str(excinfo.value) = "Error importing saved query."
assert str(excinfo.value) == "Error importing saved query."
assert excinfo.value.normalized_messages() == {
"databases/imported_database.yaml": {
"database_name": ["Missing data for required field."],
}
}