initial commit

This commit is contained in:
Arash 2021-03-26 10:06:17 -04:00
parent 3c4591ef15
commit a4a7bf9bc9
7 changed files with 423 additions and 2 deletions

View File

@ -14,16 +14,18 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
import logging
from datetime import datetime
from io import BytesIO
from typing import Any
from zipfile import ZipFile
from flask import g, Response, send_file
from flask import g, Response, send_file, request
from flask_appbuilder.api import expose, protect, rison, safe
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_babel import ngettext
from marshmallow import ValidationError
from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod
from superset.databases.filters import DatabaseFilter
@ -31,11 +33,19 @@ from superset.models.sql_lab import SavedQuery
from superset.queries.saved_queries.commands.bulk_delete import (
BulkDeleteSavedQueryCommand,
)
from superset.queries.saved_queries.commands.create import (
CreateSavedQueryCommand
)
from superset.queries.saved_queries.commands.exceptions import (
SavedQueryBulkDeleteFailedError,
SavedQueryNotFoundError,
SavedQueryImportError,
SavedQueryImportError,
SavedQueryCreateError,
SavedQueryInvalidError,
)
from superset.queries.saved_queries.commands.export import ExportSavedQueriesCommand
from superset.queries.saved_queries.commands.importers.dispatcher import ImportSavedQueriesCommand
from superset.queries.saved_queries.filters import (
SavedQueryAllTextFilter,
SavedQueryFavoriteFilter,
@ -46,6 +56,9 @@ from superset.queries.saved_queries.schemas import (
get_export_ids_schema,
openapi_spec_methods_override,
)
from superset.commands.exceptions import CommandInvalidError
from superset.commands.importers.v1.utils import get_contents_from_bundle
from superset.extensions import event_logger
from superset.views.base_api import BaseSupersetModelRestApi, statsd_metrics
logger = logging.getLogger(__name__)
@ -252,3 +265,135 @@ class SavedQueryRestApi(BaseSupersetModelRestApi):
as_attachment=True,
attachment_filename=filename,
)
@expose("/import/", methods=["POST"])
@protect()
@safe
@statsd_metrics
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.import_",
log_to_statsd=False,
)
def import_(self) -> Response:
"""Import chart(s) with associated datasets and databases
---
post:
requestBody:
required: true
content:
multipart/form-data:
schema:
type: object
properties:
formData:
description: upload file (ZIP)
type: string
format: binary
passwords:
description: JSON map of passwords for each file
type: string
overwrite:
description: overwrite existing databases?
type: bool
responses:
200:
description: Chart import result
content:
application/json:
schema:
type: object
properties:
message:
type: string
400:
$ref: '#/components/responses/400'
401:
$ref: '#/components/responses/401'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
upload = request.files.get("formData")
if not upload:
return self.response_400()
with ZipFile(upload) as bundle:
contents = get_contents_from_bundle(bundle)
passwords = (
json.loads(request.form["passwords"])
if "passwords" in request.form
else None
)
overwrite = request.form.get("overwrite") == "true"
command = ImportSavedQueriesCommand(
contents, passwords=passwords, overwrite=overwrite
)
try:
command.run()
return self.response(200, message="OK")
except CommandInvalidError as exc:
logger.warning("Import Saved Query failed")
return self.response_422(message=exc.normalized_messages())
except Exception as exc: # pylint: disable=broad-except
logger.exception("Import Saved Query failed")
return self.response_500(message=str(exc))
@expose("/", methods=["POST"])
@protect()
@safe
@statsd_metrics
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.post",
log_to_statsd=False,
)
def post(self) -> Response:
"""Creates a new Saved Query
---
post:
description: >-
Create a new Saved Query.
requestBody:
description: Saved Query schema
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/{{self.__class__.__name__}}.post'
responses:
201:
description: Chart added
content:
application/json:
schema:
type: object
properties:
id:
type: number
result:
$ref: '#/components/schemas/{{self.__class__.__name__}}.post'
400:
$ref: '#/components/responses/400'
401:
$ref: '#/components/responses/401'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
if not request.is_json:
return self.response_400(message="Request is not JSON")
try:
item = self.add_model_schema.load(request.json)
# This validates custom Schema with custom validations
except ValidationError as error:
return self.response_400(message=error.messages)
try:
new_model = CreateSavedQueryCommand(g.user, item).run()
return self.response(201, id=new_model.id, result=item)
except SavedQueryInvalidError as ex:
return self.response_422(message=ex.normalized_messages())
except SavedQueryCreateFailedError as ex:
logger.error(
"Error creating model %s: %s", self.__class__.__name__, str(ex)
)
return self.response_422(message=str(ex))

View File

@ -0,0 +1,56 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Dict, List, Optional
from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
from marshmallow import ValidationError
from supereset.queries.saved_queries.commands.exceptions import SavedQueryCreateError, SavedQueryInvalidError
from superset.queries.saved_queries.dao import SavedQueryDAO
from superset.commands.base import BaseCommand
from superset.commands.utils import get_datasource_by_id
from superset.dao.exceptions import DAOCreateFailedError
logger = logging.getLogger(__name__)
class CreateSavedQueryCommand(BaseCommand):
def __init__(self, user: User, data: Dict[str, Any]):
self._actor = user
self._properties = data.copy()
def run(self) -> Model:
self.validate()
try:
saved_query = SavedQueryDAO.create(self._properties)
except DAOCreateFailedError as ex:
logger.exception(ex._exception)
raise SavedQueryCreateError
return saved_query
def validate(self) -> None:
exceptions = list()
datasource_type = self._properties["datasource_type"]
datasource_id = self._properties["datasource_id"]
owner_ids: Optional[List[int]] = self._properties.get("owners")
try:
datasource = get_datasource_by_id(datasource_id, datasource_type)
self._properties["datasource_name"] = datasource.name
except ValidationError as ex:
exceptions.append(ex)

View File

@ -16,7 +16,13 @@
# under the License.
from flask_babel import lazy_gettext as _
from superset.commands.exceptions import CommandException, DeleteFailedError
from superset.commands.exceptions import (
CommandException,
CommandInvalidError,
DeleteFailedError,
CreateFailedError,
ImportFailedError
)
class SavedQueryBulkDeleteFailedError(DeleteFailedError):
@ -25,3 +31,13 @@ class SavedQueryBulkDeleteFailedError(DeleteFailedError):
class SavedQueryNotFoundError(CommandException):
message = _("Saved query not found.")
class SavedQueryImportError(ImportFailedError):
message = _("Import chart failed for an unknown reason")
class SavedQueryCreateError(CreateFailedError):
message = _("Saved Query could not be created")
class SavedQueryInvalidError(CommandInvalidError):
message = _("Saved Query Parameters are invalid")

View File

@ -0,0 +1,65 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Dict
from marshmallow.exceptions import ValidationError
from superset.commands.base import BaseCommand
from superset.commands.exceptions import CommandInvalidError
from superset.commands.importers.exceptions import IncorrectVersionError
from superset.queries.saved_queries.importers import v1
logger = logging.getLogger(__name__)
command_versions = [
v1.ImportSavedQueriesCommand,
]
class ImportSavedQueriesCommand(BaseCommand):
"""
Import Saved Queries
This command dispatches the import to different versions of the command
until it finds one that matches.
"""
# pylint: disable=unused-argument
def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
self.contents = contents
self.args = args
self.kwargs = kwargs
def run(self) -> None:
# iterate over all commands until we find a version that can
# handle the contents
for version in command_versions:
command = version(self.contents, *self.args, **self.kwargs)
try:
command.run()
return
except IncorrectVersionError:
logger.debug("File not handled by command, skipping")
except(CommandInvalidError, ValidationError) as exc:
# found right version, but file is invalid
logger.exception("Error running import command")
raise exc
raise CommandInvalidError("Could not find a valid command to import file")
def validate(self) -> None:
pass

View File

@ -0,0 +1,90 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict, Set
from marshmallow import Schema
from sqlalchemy.orm import Session
from supereset.queries.saved_queries.commands.exceptions import SavedQueryImportError
from superset.commands.importers.v1 import ImportModelsCommand
from superset.connectors.sqla.models import SqlaTable
from superset.databases.commands.importers.v1.utils import import_database
from superset.datasets.commands.importers.v1.utils import import_dataset
from superset.dataset.schemas import ImportV1DatasetSchema
from superset.queries.dao import SavedQueryDAO
from superset.queries.saved_queries.commands.importers.v1.utils import import_saved_query
from superset.queries.saved_queries.schemas import ImportV1SavedQuerySchema
class ImportSavedQueriesCommand(ImportModelsCommand):
"""Import Saved Queries"""
dao = SavedQueryDAO
model_name= "saved_queries"
prefix ="saved_queries/"
schemas: Dict[str, Schema] = {
"datasets/": ImportV1DatasetSchema(),
"saved_query": ImportV1SavedQuerySchema(),
}
import_error = SavedQueryImportError
@staticmethod
def _import(
session: Session, configs: Dict[str, Any], overwrite: bool = False
) -> None:
# discover datasets associated with saved queries
dataset_uuids: Set[str] = set()
for file_name, config in configs.items():
if file_name.startswith("saved_queries/"):
dataset_uuids.add(config["dataset_uuid"])
# discover databases associated with datasets
database_uuids: Set[str] = set()
for file_name, config in configs.items():
if file_name.startswith("datasets/") and config["uuid"] in dataset_uuids:
database_uuids.add(config["database_uuid"])
# import related databases
database_ids: Dict[str, int] = {}
for file_name, config in configs.items():
if file_name.startswith("databases/") and config["uuid"] in database_uuids:
database = import_database(session, config, overwrite=False)
database_ids[str(database.uuid)] = database.id
# import datasets with the correct parent ref
datasets: Dict[str, SqlaTable] = {}
for file_name, config in configs.items():
if (
file_name.startswith("datasets/")
and config["database_uuid"] in database_ids
):
config["database_id"] = database_ids[config["database_uuid"]]
dataset = import_dataset(session, config, overwrite=False)
datasets[str(dataset.uuid)] = dataset
# import saved queries with the correct parent ref
for file_name, config in configs.items():
if file_name.startswith("saved_queries/") and config["dataset_uuid"] in datasets:
# update datasource id, type, and name
dataset = datasets[config["dataset_uuid"]]
config.update(
{
"datasource_id": dataset.id,
"datasource_name": dataset.table_name,
}
)
config["params"].update({"datasource": dataset.uid})
import_saved_query(session, config, overwrite=overwrite)

View File

@ -0,0 +1,37 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict
from sqlalchemy.orm import Session
from superset.models.sql_lab import SavedQuery
def import_saved_query(
session: Session, config: Dict[str, Any], overwrite: bool = False
) -> SavedQuery:
existing = session.query(SavedQuery).filter_by(uuid= config["uuid"]).first()
if existing:
if not overwrite:
return existing
config["id"] = existing.id
saved_query = SavedQuery.import_from_dict(session, config, recursive=False)
if saved_query.id is None:
session.flush()
return saved_query

View File

@ -14,6 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from flask_babel import lazy_gettext as _
from marshmallow import fields, Schema, ValidationError
from marshmallow.validate import Length
openapi_spec_methods_override = {
"get": {"get": {"description": "Get a saved query",}},
@ -32,3 +35,12 @@ openapi_spec_methods_override = {
get_delete_ids_schema = {"type": "array", "items": {"type": "integer"}}
get_export_ids_schema = {"type": "array", "items": {"type": "integer"}}
class ImportV1SavedQuerySchema(Schema):
schema = fields.String(allow_none=True, validate= Length(0, 128))
label = fields.String(allow_none=True, validate= Length(0,256))
description = fields.String(allow_none = True)
sql = fields.String(required= True)
uuid = fields.UUID(required=True)
version = fields.String(required=True)
database_uuid = fields.UUID(required=True)