[charts] New, REST API (#8917)

* [charts] New REST API

* [charts] Small improvements

* [charts] Fix, lint

* [charts] Tests and datasource validation

* [charts] Fix, lint

* [charts] DRY post schemas

* [charts] lint and improve type declarations

* [charts] DRY owned REST APIs

* [charts] Small fixes

* [charts] More tests

* [charts] Tests and DRY

* [charts] Tests for update

* [charts] More tests

* [charts] Fix, isort

* [charts] DRY and improve quality

* [charts] DRY and more tests

* [charts] Refactor base for api and schemas

* [charts] Fix bug on partial updates for dashboards

* [charts] Fix missing apache license

* black app.py after merge

* [charts] Fix, missing imports and black

* [api] Log on sqlalchemy error

* [api] isort
This commit is contained in:
Daniel Vaz Gaspar 2020-01-21 18:04:52 +00:00 committed by Maxime Beauchemin
parent 2fc5fd4f29
commit 74158694c5
16 changed files with 1425 additions and 586 deletions

View File

@ -142,15 +142,14 @@ class SupersetAppInitializer:
from superset.views.api import Api
from superset.views.core import (
AccessRequestsModelView,
SliceModelView,
SliceAsync,
SliceAddView,
KV,
R,
Superset,
CssTemplateModelView,
CssTemplateAsyncModelView,
)
from superset.views.chart.api import ChartRestApi
from superset.views.chart.views import SliceModelView, SliceAsync, SliceAddView
from superset.views.dashboard.api import DashboardRestApi
from superset.views.dashboard.views import (
DashboardModelView,
@ -185,6 +184,7 @@ class SupersetAppInitializer:
#
# Setup API views
#
appbuilder.add_api(ChartRestApi)
appbuilder.add_api(DashboardRestApi)
appbuilder.add_api(DatabaseRestApi)

View File

@ -25,10 +25,12 @@ from superset.common.query_context import QueryContext
from superset.legacy import update_time_range
from superset.models.slice import Slice
from superset.utils import core as utils
from .base import api, BaseSupersetView, handle_api_exception
from .dashboard import api as dashboard_api # pylint: disable=unused-import
from .database import api as database_api # pylint: disable=unused-import
from superset.views.base import api, BaseSupersetView, handle_api_exception
from superset.views.chart import api as chart_api # pylint: disable=unused-import
from superset.views.dashboard import ( # pylint: disable=unused-import
api as dashboard_api,
)
from superset.views.database import api as database_api # pylint: disable=unused-import
class Api(BaseSupersetView):

View File

@ -18,21 +18,18 @@ import functools
import logging
import traceback
from datetime import datetime
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional
import simplejson as json
import yaml
from flask import abort, flash, g, get_flashed_messages, redirect, Response, session
from flask_appbuilder import BaseView, Model, ModelRestApi, ModelView
from flask_appbuilder import BaseView, ModelView
from flask_appbuilder.actions import action
from flask_appbuilder.api import expose, protect, rison, safe
from flask_appbuilder.forms import DynamicForm
from flask_appbuilder.models.filters import Filters
from flask_appbuilder.models.sqla.filters import BaseFilter
from flask_appbuilder.widgets import ListWidget
from flask_babel import get_locale, gettext as __, lazy_gettext as _
from flask_wtf.form import FlaskForm
from marshmallow import Schema
from sqlalchemy import or_
from werkzeug.exceptions import HTTPException
from wtforms.fields.core import Field, UnboundField
@ -155,26 +152,6 @@ def handle_api_exception(f):
return functools.update_wrapper(wraps, f)
def check_ownership_and_item_exists(f):
"""
A Decorator that checks if an object exists and is owned by the current user
"""
def wraps(self, pk): # pylint: disable=invalid-name
item = self.datamodel.get(
pk, self._base_filters # pylint: disable=protected-access
)
if not item:
return self.response_404()
try:
check_ownership(item)
except SupersetSecurityException as e:
return self.response(403, message=str(e))
return f(self, item)
return functools.update_wrapper(wraps, f)
def get_datasource_exist_error_msg(full_name):
return __("Datasource %(name)s already exists", name=full_name)
@ -378,148 +355,6 @@ class DatasourceFilter(BaseFilter): # pylint: disable=too-few-public-methods
)
class BaseSupersetSchema(Schema):
"""
Extends Marshmallow schema so that we can pass a Model to load
(following marshamallow-sqlalchemy pattern). This is useful
to perform partial model merges on HTTP PUT
"""
def __init__(self, **kwargs):
self.instance = None
super().__init__(**kwargs)
def load(
self, data, many=None, partial=None, instance: Model = None, **kwargs
): # pylint: disable=arguments-differ
self.instance = instance
return super().load(data, many=many, partial=partial, **kwargs)
get_related_schema = {
"type": "object",
"properties": {
"page_size": {"type": "integer"},
"page": {"type": "integer"},
"filter": {"type": "string"},
},
}
class BaseSupersetModelRestApi(ModelRestApi):
"""
Extends FAB's ModelResApi to implement specific superset generic functionality
"""
order_rel_fields: Dict[str, Tuple[str, str]] = {}
"""
Impose ordering on related fields query::
order_rel_fields = {
"<RELATED_FIELD>": ("<RELATED_FIELD_FIELD>", "<asc|desc>"),
...
}
""" # pylint: disable=pointless-string-statement
filter_rel_fields_field: Dict[str, str] = {}
"""
Declare the related field field for filtering::
filter_rel_fields_field = {
"<RELATED_FIELD>": "<RELATED_FIELD_FIELD>", "<asc|desc>")
}
""" # pylint: disable=pointless-string-statement
def _get_related_filter(self, datamodel, column_name: str, value: str) -> Filters:
filter_field = self.filter_rel_fields_field.get(column_name)
filters = datamodel.get_filters([filter_field])
if value:
filters.rest_add_filters(
[{"opr": "sw", "col": filter_field, "value": value}]
)
return filters
@expose("/related/<column_name>", methods=["GET"])
@protect()
@safe
@rison(get_related_schema)
def related(self, column_name: str, **kwargs):
"""Get related fields data
---
get:
parameters:
- in: path
schema:
type: string
name: column_name
- in: query
name: q
content:
application/json:
schema:
type: object
properties:
page_size:
type: integer
page:
type: integer
filter:
type: string
responses:
200:
description: Related column data
content:
application/json:
schema:
type: object
properties:
count:
type: integer
result:
type: object
properties:
value:
type: integer
text:
type: string
400:
$ref: '#/components/responses/400'
401:
$ref: '#/components/responses/401'
404:
$ref: '#/components/responses/404'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
args = kwargs.get("rison", {})
# handle pagination
page, page_size = self._handle_page_args(args)
try:
datamodel = self.datamodel.get_related_interface(column_name)
except KeyError:
return self.response_404()
page, page_size = self._sanitize_page_args(page, page_size)
# handle ordering
order_field = self.order_rel_fields.get(column_name)
if order_field:
order_column, order_direction = order_field
else:
order_column, order_direction = "", ""
# handle filters
filters = self._get_related_filter(datamodel, column_name, args.get("filter"))
# Make the query
count, values = datamodel.query(
filters, order_column, order_direction, page=page, page_size=page_size
)
# produce response
result = [
{"value": datamodel.get_pk_value(value), "text": str(value)}
for value in values
]
return self.response(200, count=count, result=result)
class CsvResponse(Response): # pylint: disable=too-many-ancestors
"""
Override Response to take into account csv encoding from config.py

325
superset/views/base_api.py Normal file
View File

@ -0,0 +1,325 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import functools
import logging
from typing import Dict, Tuple
from flask import request
from flask_appbuilder import ModelRestApi
from flask_appbuilder.api import expose, protect, rison, safe
from flask_appbuilder.models.filters import Filters
from sqlalchemy.exc import SQLAlchemyError
from superset.exceptions import SupersetSecurityException
from superset.views.base import check_ownership
get_related_schema = {
"type": "object",
"properties": {
"page_size": {"type": "integer"},
"page": {"type": "integer"},
"filter": {"type": "string"},
},
}
def check_ownership_and_item_exists(f):
"""
A Decorator that checks if an object exists and is owned by the current user
"""
def wraps(self, pk): # pylint: disable=invalid-name
item = self.datamodel.get(
pk, self._base_filters # pylint: disable=protected-access
)
if not item:
return self.response_404()
try:
check_ownership(item)
except SupersetSecurityException as e:
return self.response(403, message=str(e))
return f(self, item)
return functools.update_wrapper(wraps, f)
class BaseSupersetModelRestApi(ModelRestApi):
"""
Extends FAB's ModelResApi to implement specific superset generic functionality
"""
logger = logging.getLogger(__name__)
order_rel_fields: Dict[str, Tuple[str, str]] = {}
"""
Impose ordering on related fields query::
order_rel_fields = {
"<RELATED_FIELD>": ("<RELATED_FIELD_FIELD>", "<asc|desc>"),
...
}
""" # pylint: disable=pointless-string-statement
filter_rel_fields_field: Dict[str, str] = {}
"""
Declare the related field field for filtering::
filter_rel_fields_field = {
"<RELATED_FIELD>": "<RELATED_FIELD_FIELD>", "<asc|desc>")
}
""" # pylint: disable=pointless-string-statement
def _get_related_filter(self, datamodel, column_name: str, value: str) -> Filters:
filter_field = self.filter_rel_fields_field.get(column_name)
filters = datamodel.get_filters([filter_field])
if value:
filters.rest_add_filters(
[{"opr": "sw", "col": filter_field, "value": value}]
)
return filters
@expose("/related/<column_name>", methods=["GET"])
@protect()
@safe
@rison(get_related_schema)
def related(self, column_name: str, **kwargs):
"""Get related fields data
---
get:
parameters:
- in: path
schema:
type: string
name: column_name
- in: query
name: q
content:
application/json:
schema:
type: object
properties:
page_size:
type: integer
page:
type: integer
filter:
type: string
responses:
200:
description: Related column data
content:
application/json:
schema:
type: object
properties:
count:
type: integer
result:
type: object
properties:
value:
type: integer
text:
type: string
400:
$ref: '#/components/responses/400'
401:
$ref: '#/components/responses/401'
404:
$ref: '#/components/responses/404'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
args = kwargs.get("rison", {})
# handle pagination
page, page_size = self._handle_page_args(args)
try:
datamodel = self.datamodel.get_related_interface(column_name)
except KeyError:
return self.response_404()
page, page_size = self._sanitize_page_args(page, page_size)
# handle ordering
order_field = self.order_rel_fields.get(column_name)
if order_field:
order_column, order_direction = order_field
else:
order_column, order_direction = "", ""
# handle filters
filters = self._get_related_filter(datamodel, column_name, args.get("filter"))
# Make the query
count, values = datamodel.query(
filters, order_column, order_direction, page=page, page_size=page_size
)
# produce response
result = [
{"value": datamodel.get_pk_value(value), "text": str(value)}
for value in values
]
return self.response(200, count=count, result=result)
class BaseOwnedModelRestApi(BaseSupersetModelRestApi):
@expose("/<pk>", methods=["PUT"])
@protect()
@check_ownership_and_item_exists
@safe
def put(self, item): # pylint: disable=arguments-differ
"""Changes a owned Model
---
put:
parameters:
- in: path
schema:
type: integer
name: pk
requestBody:
description: Model schema
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/{{self.__class__.__name__}}.put'
responses:
200:
description: Item changed
content:
application/json:
schema:
type: object
properties:
result:
$ref: '#/components/schemas/{{self.__class__.__name__}}.put'
400:
$ref: '#/components/responses/400'
401:
$ref: '#/components/responses/401'
403:
$ref: '#/components/responses/401'
404:
$ref: '#/components/responses/404'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
if not request.is_json:
self.response_400(message="Request is not JSON")
item = self.edit_model_schema.load(request.json, instance=item)
if item.errors:
return self.response_422(message=item.errors)
try:
self.datamodel.edit(item.data, raise_exception=True)
return self.response(
200, result=self.edit_model_schema.dump(item.data, many=False).data
)
except SQLAlchemyError as e:
self.logger.error(f"Error updating model {self.__class__.__name__}: {e}")
return self.response_422(message=str(e))
@expose("/", methods=["POST"])
@protect()
@safe
def post(self):
"""Creates a new owned Model
---
post:
requestBody:
description: Model schema
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/{{self.__class__.__name__}}.post'
responses:
201:
description: Model added
content:
application/json:
schema:
type: object
properties:
id:
type: string
result:
$ref: '#/components/schemas/{{self.__class__.__name__}}.post'
400:
$ref: '#/components/responses/400'
401:
$ref: '#/components/responses/401'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
if not request.is_json:
return self.response_400(message="Request is not JSON")
item = self.add_model_schema.load(request.json)
# This validates custom Schema with custom validations
if item.errors:
return self.response_422(message=item.errors)
try:
self.datamodel.add(item.data, raise_exception=True)
return self.response(
201,
result=self.add_model_schema.dump(item.data, many=False).data,
id=item.data.id,
)
except SQLAlchemyError as e:
self.logger.error(f"Error creating model {self.__class__.__name__}: {e}")
return self.response_422(message=str(e))
@expose("/<pk>", methods=["DELETE"])
@protect()
@check_ownership_and_item_exists
@safe
def delete(self, item): # pylint: disable=arguments-differ
"""Deletes owned Model
---
delete:
parameters:
- in: path
schema:
type: integer
name: pk
responses:
200:
description: Model delete
content:
application/json:
schema:
type: object
properties:
message:
type: string
401:
$ref: '#/components/responses/401'
403:
$ref: '#/components/responses/401'
404:
$ref: '#/components/responses/404'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
try:
self.datamodel.delete(item, raise_exception=True)
return self.response(200, message="OK")
except SQLAlchemyError as e:
self.logger.error(f"Error deleting model {self.__class__.__name__}: {e}")
return self.response_422(message=str(e))

View File

@ -0,0 +1,110 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Dict, List, Optional
from flask import current_app, g
from flask_appbuilder import Model
from marshmallow import post_load, pre_load, Schema, ValidationError
from sqlalchemy.orm.exc import NoResultFound
def validate_owner(value):
try:
(
current_app.appbuilder.get_session.query(
current_app.appbuilder.sm.user_model.id
)
.filter_by(id=value)
.one()
)
except NoResultFound:
raise ValidationError(f"User {value} does not exist")
class BaseSupersetSchema(Schema):
"""
Extends Marshmallow schema so that we can pass a Model to load
(following marshamallow-sqlalchemy pattern). This is useful
to perform partial model merges on HTTP PUT
"""
__class_model__: Model = None
def __init__(self, **kwargs):
self.instance: Optional[Model] = None
super().__init__(**kwargs)
def load(
self, data, many=None, partial=None, instance: Model = None, **kwargs
): # pylint: disable=arguments-differ
self.instance = instance
return super().load(data, many=many, partial=partial, **kwargs)
@post_load
def make_object(self, data: Dict, discard: List[str] = None) -> Model:
"""
Creates a Model object from POST or PUT requests. PUT will use self.instance
previously fetched from the endpoint handler
:param data: Schema data payload
:param discard: List of fields to not set on the model
"""
discard = discard or []
if not self.instance:
self.instance = self.__class_model__() # pylint: disable=not-callable
for field in data:
if field not in discard:
setattr(self.instance, field, data.get(field))
return self.instance
class BaseOwnedSchema(BaseSupersetSchema):
"""
Implements owners validation,pre load and post_load
(to populate the owners field) on Marshmallow schemas
"""
owners_field_name = "owners"
@post_load
def make_object(self, data: Dict, discard: List[str] = None) -> Model:
discard = discard or []
discard.append(self.owners_field_name)
instance = super().make_object(data, discard)
if "owners" not in data and g.user not in instance.owners:
instance.owners.append(g.user)
if self.owners_field_name in data:
self.set_owners(instance, data[self.owners_field_name])
return instance
@pre_load
def pre_load(self, data: Dict):
# if PUT request don't set owners to empty list
if not self.instance:
data[self.owners_field_name] = data.get(self.owners_field_name, [])
@staticmethod
def set_owners(instance: Model, owners: List[int]):
owner_objs = list()
if g.user.id not in owners:
owners.append(g.user.id)
for owner_id in owners:
user = current_app.appbuilder.get_session.query(
current_app.appbuilder.sm.user_model
).get(owner_id)
owner_objs.append(user)
instance.owners = owner_objs

View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

183
superset/views/chart/api.py Normal file
View File

@ -0,0 +1,183 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Dict, List
from flask import current_app
from flask_appbuilder.models.sqla.interface import SQLAInterface
from marshmallow import fields, post_load, validates_schema, ValidationError
from marshmallow.validate import Length
from sqlalchemy.orm.exc import NoResultFound
from superset import appbuilder
from superset.connectors.connector_registry import ConnectorRegistry
from superset.exceptions import SupersetException
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.utils import core as utils
from superset.views.base_api import BaseOwnedModelRestApi
from superset.views.base_schemas import BaseOwnedSchema, validate_owner
from superset.views.chart.mixin import SliceMixin
def validate_json(value):
try:
utils.validate_json(value)
except SupersetException:
raise ValidationError("JSON not valid")
def validate_dashboard(value):
try:
(current_app.appbuilder.get_session.query(Dashboard).filter_by(id=value).one())
except NoResultFound:
raise ValidationError(f"Dashboard {value} does not exist")
def validate_update_datasource(data: Dict):
if not ("datasource_type" in data and "datasource_id" in data):
return
datasource_type = data["datasource_type"]
datasource_id = data["datasource_id"]
try:
datasource = ConnectorRegistry.get_datasource(
datasource_type, datasource_id, current_app.appbuilder.get_session
)
except (NoResultFound, KeyError):
raise ValidationError(
f"Datasource [{datasource_type}].{datasource_id} does not exist"
)
data["datasource_name"] = datasource.name
def populate_dashboards(instance: Slice, dashboards: List[int]):
"""
Mutates a Slice with the dashboards SQLA Models
"""
dashboards_tmp = []
for dashboard_id in dashboards:
dashboards_tmp.append(
current_app.appbuilder.get_session.query(Dashboard)
.filter_by(id=dashboard_id)
.one()
)
instance.dashboards = dashboards_tmp
class ChartPostSchema(BaseOwnedSchema):
__class_model__ = Slice
slice_name = fields.String(required=True, validate=Length(1, 250))
description = fields.String(allow_none=True)
viz_type = fields.String(allow_none=True, validate=Length(0, 250))
owners = fields.List(fields.Integer(validate=validate_owner))
params = fields.String(allow_none=True, validate=validate_json)
cache_timeout = fields.Integer()
datasource_id = fields.Integer(required=True)
datasource_type = fields.String(required=True)
datasource_name = fields.String(allow_none=True)
dashboards = fields.List(fields.Integer(validate=validate_dashboard))
@validates_schema
def validate_schema(self, data: Dict): # pylint: disable=no-self-use
validate_update_datasource(data)
@post_load
def make_object(self, data: Dict, discard: List[str] = None) -> Slice:
instance = super().make_object(data, discard=["dashboards"])
populate_dashboards(instance, data.get("dashboards", []))
return instance
class ChartPutSchema(BaseOwnedSchema):
instance: Slice
slice_name = fields.String(allow_none=True, validate=Length(0, 250))
description = fields.String(allow_none=True)
viz_type = fields.String(allow_none=True, validate=Length(0, 250))
owners = fields.List(fields.Integer(validate=validate_owner))
params = fields.String(allow_none=True)
cache_timeout = fields.Integer()
datasource_id = fields.Integer(allow_none=True)
datasource_type = fields.String(allow_none=True)
dashboards = fields.List(fields.Integer(validate=validate_dashboard))
@validates_schema
def validate_schema(self, data: Dict): # pylint: disable=no-self-use
validate_update_datasource(data)
@post_load
def make_object(self, data: Dict, discard: List[str] = None) -> Slice:
self.instance = super().make_object(data, ["dashboards"])
if "dashboards" in data:
populate_dashboards(self.instance, data["dashboards"])
return self.instance
class ChartRestApi(SliceMixin, BaseOwnedModelRestApi):
datamodel = SQLAInterface(Slice)
resource_name = "chart"
allow_browser_login = True
class_permission_name = "SliceModelView"
method_permission_name = {
"get_list": "list",
"get": "show",
"post": "add",
"put": "edit",
"delete": "delete",
"info": "list",
"related": "list",
}
show_columns = [
"slice_name",
"description",
"owners.id",
"owners.username",
"dashboards.id",
"dashboards.dashboard_title",
"viz_type",
"params",
"cache_timeout",
]
list_columns = [
"slice_name",
"description",
"changed_by.username",
"changed_by_name",
"changed_on",
"viz_type",
"params",
"cache_timeout",
]
# Will just affect _info endpoint
edit_columns = ["slice_name"]
add_columns = edit_columns
# exclude_route_methods = ("info",)
add_model_schema = ChartPostSchema()
edit_model_schema = ChartPutSchema()
order_rel_fields = {
"slices": ("slice_name", "asc"),
"owners": ("first_name", "asc"),
}
filter_rel_fields_field = {"owners": "first_name", "dashboards": "dashboard_title"}
appbuilder.add_api(ChartRestApi)

View File

@ -0,0 +1,31 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from sqlalchemy import or_
from superset import security_manager
from superset.views.base import BaseFilter
class SliceFilter(BaseFilter): # pylint: disable=too-few-public-methods
def apply(self, query, value):
if security_manager.all_datasource_access():
return query
perms = security_manager.user_view_menu_names("datasource_access")
schema_perms = security_manager.user_view_menu_names("schema_access")
return query.filter(
or_(self.model.perm.in_(perms), self.model.schema_perm.in_(schema_perms))
)

View File

@ -0,0 +1,87 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from flask import Markup
from flask_babel import lazy_gettext as _
from superset.views.chart.filters import SliceFilter
from superset.views.dashboard.filters import DashboardFilter
class SliceMixin: # pylint: disable=too-few-public-methods
list_title = _("Charts")
show_title = _("Show Chart")
add_title = _("Add Chart")
edit_title = _("Edit Chart")
can_add = False
search_columns = (
"slice_name",
"description",
"viz_type",
"datasource_name",
"owners",
)
list_columns = ["slice_link", "viz_type", "datasource_link", "creator", "modified"]
order_columns = ["viz_type", "datasource_link", "modified"]
edit_columns = [
"slice_name",
"description",
"viz_type",
"owners",
"dashboards",
"params",
"cache_timeout",
]
base_order = ("changed_on", "desc")
description_columns = {
"description": Markup(
"The content here can be displayed as widget headers in the "
"dashboard view. Supports "
'<a href="https://daringfireball.net/projects/markdown/"">'
"markdown</a>"
),
"params": _(
"These parameters are generated dynamically when clicking "
"the save or overwrite button in the explore view. This JSON "
"object is exposed here for reference and for power users who may "
"want to alter specific parameters."
),
"cache_timeout": _(
"Duration (in seconds) of the caching timeout for this chart. "
"Note this defaults to the datasource/table timeout if undefined."
),
}
base_filters = [["id", SliceFilter, lambda: []]]
label_columns = {
"cache_timeout": _("Cache Timeout"),
"creator": _("Creator"),
"dashboards": _("Dashboards"),
"datasource_link": _("Datasource"),
"description": _("Description"),
"modified": _("Last Modified"),
"owners": _("Owners"),
"params": _("Parameters"),
"slice_link": _("Chart"),
"slice_name": _("Name"),
"table": _("Table"),
"viz_type": _("Visualization Type"),
}
add_form_query_rel_fields = {"dashboards": [["name", DashboardFilter, None]]}
edit_form_query_rel_fields = add_form_query_rel_fields

View File

@ -0,0 +1,96 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
from flask_appbuilder import expose, has_access
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_babel import lazy_gettext as _
from superset import db
from superset.connectors.connector_registry import ConnectorRegistry
from superset.models.slice import Slice
from superset.utils import core as utils
from superset.views.base import check_ownership, DeleteMixin, SupersetModelView
from superset.views.chart.mixin import SliceMixin
class SliceModelView(
SliceMixin, SupersetModelView, DeleteMixin
): # pylint: disable=too-many-ancestors
route_base = "/chart"
datamodel = SQLAInterface(Slice)
def pre_add(self, item):
utils.validate_json(item.params)
def pre_update(self, item):
utils.validate_json(item.params)
check_ownership(item)
def pre_delete(self, item):
check_ownership(item)
@expose("/add", methods=["GET", "POST"])
@has_access
def add(self):
datasources = ConnectorRegistry.get_all_datasources(db.session)
datasources = [
{"value": str(d.id) + "__" + d.type, "label": repr(d)} for d in datasources
]
return self.render_template(
"superset/add_slice.html",
bootstrap_data=json.dumps(
{"datasources": sorted(datasources, key=lambda d: d["label"])}
),
)
class SliceAsync(SliceModelView): # pylint: disable=too-many-ancestors
route_base = "/sliceasync"
list_columns = [
"id",
"slice_link",
"viz_type",
"slice_name",
"creator",
"modified",
"icons",
"changed_on_humanized",
]
label_columns = {"icons": " ", "slice_link": _("Chart")}
class SliceAddView(SliceModelView): # pylint: disable=too-many-ancestors
route_base = "/sliceaddview"
list_columns = [
"id",
"slice_name",
"slice_url",
"edit_url",
"viz_type",
"params",
"description",
"description_markeddown",
"datasource_id",
"datasource_type",
"datasource_name_text",
"datasource_link",
"owners",
"modified",
"changed_on",
"changed_on_humanized",
]

View File

@ -19,7 +19,6 @@ import logging
import re
from contextlib import closing
from datetime import datetime, timedelta
from enum import Enum
from typing import Any, cast, Dict, List, Optional, Union
from urllib import parse
@ -88,6 +87,7 @@ from superset.sql_validators import get_validator_by_name
from superset.utils import core as utils, dashboard_import_export
from superset.utils.dates import now_as_float
from superset.utils.decorators import etag_cache, stats_timing
from superset.views.chart import views as chart_views
from .base import (
api,
@ -106,7 +106,8 @@ from .base import (
json_success,
SupersetModelView,
)
from .dashboard.filters import DashboardFilter
from .dashboard import views as dash_views
from .database import views as in_views
from .utils import (
apply_display_max_row_limit,
bootstrap_user_data,
@ -247,17 +248,6 @@ def _deserialize_results_payload(
return json.loads(payload) # type: ignore
class SliceFilter(BaseFilter):
def apply(self, query, func): # noqa
if security_manager.all_datasource_access():
return query
perms = security_manager.user_view_menu_names("datasource_access")
schema_perms = security_manager.user_view_menu_names("schema_access")
return query.filter(
or_(self.model.perm.in_(perms), self.model.schema_perm.in_(schema_perms))
)
class AccessRequestsModelView(SupersetModelView, DeleteMixin):
datamodel = SQLAInterface(DAR)
list_columns = [
@ -279,135 +269,6 @@ class AccessRequestsModelView(SupersetModelView, DeleteMixin):
}
class SliceModelView(SupersetModelView, DeleteMixin):
route_base = "/chart"
datamodel = SQLAInterface(Slice)
list_title = _("Charts")
show_title = _("Show Chart")
add_title = _("Add Chart")
edit_title = _("Edit Chart")
can_add = False
search_columns = (
"slice_name",
"description",
"viz_type",
"datasource_name",
"owners",
)
list_columns = ["slice_link", "viz_type", "datasource_link", "creator", "modified"]
order_columns = ["viz_type", "datasource_link", "modified"]
edit_columns = [
"slice_name",
"description",
"viz_type",
"owners",
"dashboards",
"params",
"cache_timeout",
]
base_order = ("changed_on", "desc")
description_columns = {
"description": Markup(
"The content here can be displayed as widget headers in the "
"dashboard view. Supports "
'<a href="https://daringfireball.net/projects/markdown/"">'
"markdown</a>"
),
"params": _(
"These parameters are generated dynamically when clicking "
"the save or overwrite button in the explore view. This JSON "
"object is exposed here for reference and for power users who may "
"want to alter specific parameters."
),
"cache_timeout": _(
"Duration (in seconds) of the caching timeout for this chart. "
"Note this defaults to the datasource/table timeout if undefined."
),
}
base_filters = [["id", SliceFilter, lambda: []]]
label_columns = {
"cache_timeout": _("Cache Timeout"),
"creator": _("Creator"),
"dashboards": _("Dashboards"),
"datasource_link": _("Datasource"),
"description": _("Description"),
"modified": _("Last Modified"),
"owners": _("Owners"),
"params": _("Parameters"),
"slice_link": _("Chart"),
"slice_name": _("Name"),
"table": _("Table"),
"viz_type": _("Visualization Type"),
}
add_form_query_rel_fields = {"dashboards": [["name", DashboardFilter, None]]}
edit_form_query_rel_fields = add_form_query_rel_fields
def pre_add(self, obj):
utils.validate_json(obj.params)
def pre_update(self, obj):
utils.validate_json(obj.params)
check_ownership(obj)
def pre_delete(self, obj):
check_ownership(obj)
@expose("/add", methods=["GET", "POST"])
@has_access
def add(self):
datasources = ConnectorRegistry.get_all_datasources(db.session)
datasources = [
{"value": str(d.id) + "__" + d.type, "label": repr(d)} for d in datasources
]
return self.render_template(
"superset/add_slice.html",
bootstrap_data=json.dumps(
{"datasources": sorted(datasources, key=lambda d: d["label"])}
),
)
class SliceAsync(SliceModelView):
route_base = "/sliceasync"
list_columns = [
"id",
"slice_link",
"viz_type",
"slice_name",
"creator",
"modified",
"icons",
"changed_on_humanized",
]
label_columns = {"icons": " ", "slice_link": _("Chart")}
class SliceAddView(SliceModelView):
route_base = "/sliceaddview"
list_columns = [
"id",
"slice_name",
"slice_url",
"edit_url",
"viz_type",
"params",
"description",
"description_markeddown",
"datasource_id",
"datasource_type",
"datasource_name_text",
"datasource_link",
"owners",
"modified",
"changed_on",
"changed_on_humanized",
]
@talisman(force_https=False)
@app.route("/health")
def health():

View File

@ -16,23 +16,20 @@
# under the License.
import json
import re
from typing import Dict, List
from flask import current_app, g, make_response, request
from flask import current_app, make_response
from flask_appbuilder.api import expose, protect, rison, safe
from flask_appbuilder.models.sqla.interface import SQLAInterface
from marshmallow import fields, post_load, pre_load, Schema, ValidationError
from marshmallow.validate import Length
from sqlalchemy.exc import SQLAlchemyError
from superset.exceptions import SupersetException
from superset.models.dashboard import Dashboard
from superset.utils import core as utils
from superset.views.base import (
BaseSupersetModelRestApi,
BaseSupersetSchema,
check_ownership_and_item_exists,
generate_download_headers,
)
from superset.views.base import generate_download_headers
from superset.views.base_api import BaseOwnedModelRestApi
from superset.views.base_schemas import BaseOwnedSchema, validate_owner
from .mixin import DashboardMixin
@ -79,33 +76,10 @@ def validate_slug_uniqueness(value):
raise ValidationError("Must be unique")
def validate_owners(value):
owner = (
current_app.appbuilder.get_session.query(
current_app.appbuilder.sm.user_model.id
)
.filter_by(id=value)
.one_or_none()
)
if not owner:
raise ValidationError(f"User {value} does not exist")
class BaseDashboardSchema(BaseSupersetSchema):
@staticmethod
def set_owners(instance, owners):
owner_objs = list()
if g.user.id not in owners:
owners.append(g.user.id)
for owner_id in owners:
user = current_app.appbuilder.get_session.query(
current_app.appbuilder.sm.user_model
).get(owner_id)
owner_objs.append(user)
instance.owners = owner_objs
class BaseDashboardSchema(BaseOwnedSchema):
@pre_load
def pre_load(self, data): # pylint: disable=no-self-use
super().pre_load(data)
data["slug"] = data.get("slug")
data["owners"] = data.get("owners", [])
if data["slug"]:
@ -115,46 +89,31 @@ class BaseDashboardSchema(BaseSupersetSchema):
class DashboardPostSchema(BaseDashboardSchema):
__class_model__ = Dashboard
dashboard_title = fields.String(allow_none=True, validate=Length(0, 500))
slug = fields.String(
allow_none=True, validate=[Length(1, 255), validate_slug_uniqueness]
)
owners = fields.List(fields.Integer(validate=validate_owners))
owners = fields.List(fields.Integer(validate=validate_owner))
position_json = fields.String(validate=validate_json)
css = fields.String()
json_metadata = fields.String(validate=validate_json_metadata)
published = fields.Boolean()
@post_load
def make_object(self, data): # pylint: disable=no-self-use
instance = Dashboard()
self.set_owners(instance, data["owners"])
for field in data:
if field == "owners":
self.set_owners(instance, data["owners"])
else:
setattr(instance, field, data.get(field))
return instance
class DashboardPutSchema(BaseDashboardSchema):
dashboard_title = fields.String(allow_none=True, validate=Length(0, 500))
slug = fields.String(allow_none=True, validate=Length(0, 255))
owners = fields.List(fields.Integer(validate=validate_owners))
owners = fields.List(fields.Integer(validate=validate_owner))
position_json = fields.String(validate=validate_json)
css = fields.String()
json_metadata = fields.String(validate=validate_json_metadata)
published = fields.Boolean()
@post_load
def make_object(self, data): # pylint: disable=no-self-use
if "owners" not in data and g.user not in self.instance.owners:
self.instance.owners.append(g.user)
for field in data:
if field == "owners":
self.set_owners(self.instance, data["owners"])
else:
setattr(self.instance, field, data.get(field))
def make_object(self, data: Dict, discard: List[str] = None) -> Dashboard:
self.instance = super().make_object(data, [])
for slc in self.instance.slices:
slc.owners = list(set(self.instance.owners) | set(slc.owners))
return self.instance
@ -163,7 +122,7 @@ class DashboardPutSchema(BaseDashboardSchema):
get_export_ids_schema = {"type": "array", "items": {"type": "integer"}}
class DashboardRestApi(DashboardMixin, BaseSupersetModelRestApi):
class DashboardRestApi(DashboardMixin, BaseOwnedModelRestApi):
datamodel = SQLAInterface(Dashboard)
resource_name = "dashboard"
@ -213,153 +172,6 @@ class DashboardRestApi(DashboardMixin, BaseSupersetModelRestApi):
}
filter_rel_fields_field = {"owners": "first_name", "slices": "slice_name"}
@expose("/<pk>", methods=["PUT"])
@protect()
@check_ownership_and_item_exists
@safe
def put(self, item): # pylint: disable=arguments-differ
"""Changes a dashboard
---
put:
parameters:
- in: path
schema:
type: integer
name: pk
requestBody:
description: Model schema
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/{{self.__class__.__name__}}.put'
responses:
200:
description: Item changed
content:
application/json:
schema:
type: object
properties:
result:
$ref: '#/components/schemas/{{self.__class__.__name__}}.put'
400:
$ref: '#/components/responses/400'
401:
$ref: '#/components/responses/401'
403:
$ref: '#/components/responses/401'
404:
$ref: '#/components/responses/404'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
if not request.is_json:
self.response_400(message="Request is not JSON")
item = self.edit_model_schema.load(request.json, instance=item)
if item.errors:
return self.response_422(message=item.errors)
try:
self.datamodel.edit(item.data, raise_exception=True)
return self.response(
200, result=self.edit_model_schema.dump(item.data, many=False).data
)
except SQLAlchemyError as e:
return self.response_422(message=str(e))
@expose("/", methods=["POST"])
@protect()
@safe
def post(self):
"""Creates a new dashboard
---
post:
requestBody:
description: Model schema
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/{{self.__class__.__name__}}.post'
responses:
201:
description: Dashboard added
content:
application/json:
schema:
type: object
properties:
id:
type: string
result:
$ref: '#/components/schemas/{{self.__class__.__name__}}.post'
400:
$ref: '#/components/responses/400'
401:
$ref: '#/components/responses/401'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
if not request.is_json:
return self.response_400(message="Request is not JSON")
item = self.add_model_schema.load(request.json)
# This validates custom Schema with custom validations
if item.errors:
return self.response_422(message=item.errors)
try:
self.datamodel.add(item.data, raise_exception=True)
return self.response(
201,
result=self.add_model_schema.dump(item.data, many=False).data,
id=item.data.id,
)
except SQLAlchemyError as e:
return self.response_422(message=str(e))
@expose("/<pk>", methods=["DELETE"])
@protect()
@check_ownership_and_item_exists
@safe
def delete(self, item): # pylint: disable=arguments-differ
"""Delete Dashboard
---
delete:
parameters:
- in: path
schema:
type: integer
name: pk
responses:
200:
description: Dashboard delete
content:
application/json:
schema:
type: object
properties:
message:
type: string
401:
$ref: '#/components/responses/401'
403:
$ref: '#/components/responses/401'
404:
$ref: '#/components/responses/404'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
try:
self.datamodel.delete(item, raise_exception=True)
return self.response(200, message="OK")
except SQLAlchemyError as e:
return self.response_422(message=str(e))
@expose("/export/", methods=["GET"])
@protect()
@safe

78
tests/base_api_tests.py Normal file
View File

@ -0,0 +1,78 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
import json
import prison
from superset import db, security_manager
class ApiOwnersTestCaseMixin:
"""
Implements shared tests for owners related field
"""
resource_name: str = ""
def test_get_related_owners(self):
"""
API: Test get related owners
"""
self.login(username="admin")
uri = f"api/v1/{self.resource_name}/related/owners"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
response = json.loads(rv.data.decode("utf-8"))
users = db.session.query(security_manager.user_model).all()
expected_users = [str(user) for user in users]
self.assertEqual(response["count"], len(users))
# This needs to be implemented like this, because ordering varies between
# postgres and mysql
response_users = [result["text"] for result in response["result"]]
for expected_user in expected_users:
self.assertIn(expected_user, response_users)
def test_get_filter_related_owners(self):
"""
API: Test get filter related owners
"""
self.login(username="admin")
argument = {"filter": "a"}
uri = f"api/v1/{self.resource_name}/related/owners?q={prison.dumps(argument)}"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"count": 2,
"result": [
{"text": "admin user", "value": 1},
{"text": "alpha user", "value": 5},
],
}
self.assertEqual(response, expected_response)
def test_get_related_fail(self):
"""
API: Test get related fail
"""
self.login(username="admin")
uri = f"api/v1/{self.resource_name}/related/owner"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)

View File

@ -61,6 +61,15 @@ class SupersetTestCase(TestCase):
username, first_name, last_name, email, role_admin, password
)
@staticmethod
def get_user(username: str) -> ab_models.User:
user = (
db.session.query(security_manager.user_model)
.filter_by(username=username)
.one_or_none()
)
return user
@classmethod
def create_druid_test_objects(cls):
# create druid cluster and druid datasources

458
tests/chart_api_tests.py Normal file
View File

@ -0,0 +1,458 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for Superset"""
import json
from typing import List, Optional
import prison
from superset import db, security_manager
from superset.connectors.connector_registry import ConnectorRegistry
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from .base_api_tests import ApiOwnersTestCaseMixin
from .base_tests import SupersetTestCase
class ChartApiTests(SupersetTestCase, ApiOwnersTestCaseMixin):
resource_name = "chart"
def __init__(self, *args, **kwargs):
super(ChartApiTests, self).__init__(*args, **kwargs)
def insert_chart(
self,
slice_name: str,
owners: List[int],
datasource_id: int,
datasource_type: str = "table",
description: str = None,
viz_type: str = None,
params: str = None,
cache_timeout: Optional[int] = None,
) -> Slice:
obj_owners = list()
for owner in owners:
user = db.session.query(security_manager.user_model).get(owner)
obj_owners.append(user)
datasource = ConnectorRegistry.get_datasource(
datasource_type, datasource_id, db.session
)
slice = Slice(
slice_name=slice_name,
datasource_id=datasource.id,
datasource_name=datasource.name,
datasource_type=datasource.type,
owners=obj_owners,
description=description,
viz_type=viz_type,
params=params,
cache_timeout=cache_timeout,
)
db.session.add(slice)
db.session.commit()
return slice
def test_delete_chart(self):
"""
Chart API: Test delete
"""
admin_id = self.get_user("admin").id
chart_id = self.insert_chart("name", [admin_id], 1).id
self.login(username="admin")
uri = f"api/v1/chart/{chart_id}"
rv = self.client.delete(uri)
self.assertEqual(rv.status_code, 200)
model = db.session.query(Slice).get(chart_id)
self.assertEqual(model, None)
def test_delete_not_found_chart(self):
"""
Chart API: Test not found delete
"""
self.login(username="admin")
chart_id = 1000
uri = f"api/v1/chart/{chart_id}"
rv = self.client.delete(uri)
self.assertEqual(rv.status_code, 404)
def test_delete_chart_admin_not_owned(self):
"""
Chart API: Test admin delete not owned
"""
gamma_id = self.get_user("gamma").id
chart_id = self.insert_chart("title", [gamma_id], 1).id
self.login(username="admin")
uri = f"api/v1/chart/{chart_id}"
rv = self.client.delete(uri)
self.assertEqual(rv.status_code, 200)
model = db.session.query(Slice).get(chart_id)
self.assertEqual(model, None)
def test_delete_chart_not_owned(self):
"""
Chart API: Test delete try not owned
"""
user_alpha1 = self.create_user(
"alpha1", "password", "Alpha", email="alpha1@superset.org"
)
user_alpha2 = self.create_user(
"alpha2", "password", "Alpha", email="alpha2@superset.org"
)
chart = self.insert_chart("title", [user_alpha1.id], 1)
self.login(username="alpha2", password="password")
uri = f"api/v1/chart/{chart.id}"
rv = self.client.delete(uri)
self.assertEqual(rv.status_code, 403)
db.session.delete(chart)
db.session.delete(user_alpha1)
db.session.delete(user_alpha2)
db.session.commit()
def test_create_chart(self):
"""
Chart API: Test create chart
"""
admin_id = self.get_user("admin").id
chart_data = {
"slice_name": "name1",
"description": "description1",
"owners": [admin_id],
"viz_type": "viz_type1",
"params": "1234",
"cache_timeout": 1000,
"datasource_id": 1,
"datasource_type": "table",
"dashboards": [1, 2],
}
self.login(username="admin")
uri = f"api/v1/chart/"
rv = self.client.post(uri, json=chart_data)
self.assertEqual(rv.status_code, 201)
data = json.loads(rv.data.decode("utf-8"))
model = db.session.query(Slice).get(data.get("id"))
db.session.delete(model)
db.session.commit()
def test_create_simple_chart(self):
"""
Chart API: Test create simple chart
"""
chart_data = {
"slice_name": "title1",
"datasource_id": 1,
"datasource_type": "table",
}
self.login(username="admin")
uri = f"api/v1/chart/"
rv = self.client.post(uri, json=chart_data)
self.assertEqual(rv.status_code, 201)
data = json.loads(rv.data.decode("utf-8"))
model = db.session.query(Slice).get(data.get("id"))
db.session.delete(model)
db.session.commit()
def test_create_chart_validate_owners(self):
"""
Chart API: Test create validate owners
"""
chart_data = {
"slice_name": "title1",
"datasource_id": 1,
"datasource_type": "table",
"owners": [1000],
}
self.login(username="admin")
uri = f"api/v1/chart/"
rv = self.client.post(uri, json=chart_data)
self.assertEqual(rv.status_code, 422)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {"message": {"owners": {"0": ["User 1000 does not exist"]}}}
self.assertEqual(response, expected_response)
def test_create_chart_validate_params(self):
"""
Chart API: Test create validate params json
"""
chart_data = {
"slice_name": "title1",
"datasource_id": 1,
"datasource_type": "table",
"params": '{"A:"a"}',
}
self.login(username="admin")
uri = f"api/v1/chart/"
rv = self.client.post(uri, json=chart_data)
self.assertEqual(rv.status_code, 422)
def test_create_chart_validate_datasource(self):
"""
Chart API: Test create validate datasource
"""
self.login(username="admin")
chart_data = {
"slice_name": "title1",
"datasource_id": 1,
"datasource_type": "unknown",
}
uri = f"api/v1/chart/"
rv = self.client.post(uri, json=chart_data)
self.assertEqual(rv.status_code, 422)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
response,
{"message": {"_schema": ["Datasource [unknown].1 does not exist"]}},
)
chart_data = {
"slice_name": "title1",
"datasource_id": 0,
"datasource_type": "table",
}
uri = f"api/v1/chart/"
rv = self.client.post(uri, json=chart_data)
self.assertEqual(rv.status_code, 422)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
response, {"message": {"_schema": ["Datasource [table].0 does not exist"]}}
)
def test_update_chart(self):
"""
Chart API: Test update
"""
admin = self.get_user("admin")
gamma = self.get_user("gamma")
chart_id = self.insert_chart("title", [admin.id], 1).id
chart_data = {
"slice_name": "title1_changed",
"description": "description1",
"owners": [gamma.id],
"viz_type": "viz_type1",
"params": "{'a': 1}",
"cache_timeout": 1000,
"datasource_id": 1,
"datasource_type": "table",
"dashboards": [1],
}
self.login(username="admin")
uri = f"api/v1/chart/{chart_id}"
rv = self.client.put(uri, json=chart_data)
self.assertEqual(rv.status_code, 200)
model = db.session.query(Slice).get(chart_id)
related_dashboard = db.session.query(Dashboard).get(1)
self.assertEqual(model.slice_name, "title1_changed")
self.assertEqual(model.description, "description1")
self.assertIn(admin, model.owners)
self.assertIn(gamma, model.owners)
self.assertEqual(model.viz_type, "viz_type1")
self.assertEqual(model.params, "{'a': 1}")
self.assertEqual(model.cache_timeout, 1000)
self.assertEqual(model.datasource_id, 1)
self.assertEqual(model.datasource_type, "table")
self.assertEqual(model.datasource_name, "birth_names")
self.assertIn(related_dashboard, model.dashboards)
db.session.delete(model)
db.session.commit()
def test_update_chart_new_owner(self):
"""
Chart API: Test update set new owner to current user
"""
gamma = self.get_user("gamma")
admin = self.get_user("admin")
chart_id = self.insert_chart("title", [gamma.id], 1).id
chart_data = {"slice_name": "title1_changed"}
self.login(username="admin")
uri = f"api/v1/chart/{chart_id}"
rv = self.client.put(uri, json=chart_data)
self.assertEqual(rv.status_code, 200)
model = db.session.query(Slice).get(chart_id)
self.assertIn(admin, model.owners)
db.session.delete(model)
db.session.commit()
def test_update_chart_not_owned(self):
"""
Chart API: Test update not owned
"""
user_alpha1 = self.create_user(
"alpha1", "password", "Alpha", email="alpha1@superset.org"
)
user_alpha2 = self.create_user(
"alpha2", "password", "Alpha", email="alpha2@superset.org"
)
chart = self.insert_chart("title", [user_alpha1.id], 1)
self.login(username="alpha2", password="password")
chart_data = {"slice_name": "title1_changed"}
uri = f"api/v1/chart/{chart.id}"
rv = self.client.put(uri, json=chart_data)
self.assertEqual(rv.status_code, 403)
db.session.delete(chart)
db.session.delete(user_alpha1)
db.session.delete(user_alpha2)
db.session.commit()
def test_update_chart_validate_datasource(self):
"""
Chart API: Test update validate datasource
"""
admin = self.get_user("admin")
chart = self.insert_chart("title", [admin.id], 1)
self.login(username="admin")
chart_data = {"datasource_id": 1, "datasource_type": "unknown"}
uri = f"api/v1/chart/{chart.id}"
rv = self.client.put(uri, json=chart_data)
self.assertEqual(rv.status_code, 422)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
response,
{"message": {"_schema": ["Datasource [unknown].1 does not exist"]}},
)
chart_data = {"datasource_id": 0, "datasource_type": "table"}
uri = f"api/v1/chart/{chart.id}"
rv = self.client.put(uri, json=chart_data)
self.assertEqual(rv.status_code, 422)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
response, {"message": {"_schema": ["Datasource [table].0 does not exist"]}}
)
db.session.delete(chart)
db.session.commit()
def test_update_chart_validate_owners(self):
"""
Chart API: Test update validate owners
"""
chart_data = {
"slice_name": "title1",
"datasource_id": 1,
"datasource_type": "table",
"owners": [1000],
}
self.login(username="admin")
uri = f"api/v1/chart/"
rv = self.client.post(uri, json=chart_data)
self.assertEqual(rv.status_code, 422)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {"message": {"owners": {"0": ["User 1000 does not exist"]}}}
self.assertEqual(response, expected_response)
def test_get_chart(self):
"""
Chart API: Test get chart
"""
admin = self.get_user("admin")
chart = self.insert_chart("title", [admin.id], 1)
self.login(username="admin")
uri = f"api/v1/chart/{chart.id}"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
expected_result = {
"cache_timeout": None,
"dashboards": [],
"description": None,
"owners": [{"id": 1, "username": "admin"}],
"params": None,
"slice_name": "title",
"viz_type": None,
}
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(data["result"], expected_result)
db.session.delete(chart)
db.session.commit()
def test_get_chart_not_found(self):
"""
Chart API: Test get chart not found
"""
chart_id = 1000
self.login(username="admin")
uri = f"api/v1/chart/{chart_id}"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
def test_get_chart_no_data_access(self):
"""
Chart API: Test get chart without data access
"""
self.login(username="gamma")
chart_no_access = (
db.session.query(Slice)
.filter_by(slice_name="Girl Name Cloud")
.one_or_none()
)
uri = f"api/v1/chart/{chart_no_access.id}"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
def test_get_charts(self):
"""
Chart API: Test get charts
"""
self.login(username="admin")
uri = f"api/v1/chart/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(data["count"], 33)
def test_get_charts_filter(self):
"""
Chart API: Test get charts filter
"""
self.login(username="admin")
arguments = {"filters": [{"col": "slice_name", "opr": "sw", "value": "G"}]}
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(data["count"], 5)
def test_get_charts_page(self):
"""
Chart API: Test get charts filter
"""
# Assuming we have 33 sample charts
self.login(username="admin")
arguments = {"page_size": 10, "page": 0}
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(len(data["result"]), 10)
arguments = {"page_size": 10, "page": 3}
uri = f"api/v1/chart/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(len(data["result"]), 3)
def test_get_charts_no_data_access(self):
"""
Chart API: Test get charts no data access
"""
self.login(username="gamma")
uri = f"api/v1/chart/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(data["count"], 0)

View File

@ -19,17 +19,19 @@ import json
from typing import List
import prison
from flask_appbuilder.security.sqla import models as ab_models
from superset import db, security_manager
from superset.models import core as models
from superset.models.slice import Slice
from superset.views.base import generate_download_headers
from .base_api_tests import ApiOwnersTestCaseMixin
from .base_tests import SupersetTestCase
class DashboardApiTests(SupersetTestCase):
class DashboardApiTests(SupersetTestCase, ApiOwnersTestCaseMixin):
resource_name = "dashboard"
def __init__(self, *args, **kwargs):
super(DashboardApiTests, self).__init__(*args, **kwargs)
@ -63,14 +65,6 @@ class DashboardApiTests(SupersetTestCase):
db.session.commit()
return dashboard
def get_user(self, username: str) -> ab_models.User:
user = (
db.session.query(security_manager.user_model)
.filter_by(username=username)
.one_or_none()
)
return user
def test_delete_dashboard(self):
"""
Dashboard API: Test delete
@ -367,64 +361,6 @@ class DashboardApiTests(SupersetTestCase):
db.session.delete(user_alpha2)
db.session.commit()
def test_get_related_owners(self):
"""
Dashboard API: Test dashboard get related owners
"""
self.login(username="admin")
uri = f"api/v1/dashboard/related/owners"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"count": 6,
"result": [
{"text": "admin user", "value": 1},
{"text": "alpha user", "value": 5},
{"text": "explore_beta user", "value": 6},
{"text": "gamma user", "value": 2},
{"text": "gamma2 user", "value": 3},
{"text": "gamma_sqllab user", "value": 4},
],
}
self.assertEqual(response["count"], expected_response["count"])
# This is needed to be implemented like this because ordering varies between
# postgres and mysql
for result in expected_response["result"]:
self.assertIn(result, response["result"])
def test_get_filter_related_owners(self):
"""
Dashboard API: Test dashboard get filter related owners
"""
self.login(username="admin")
argument = {"filter": "a"}
uri = "api/v1/dashboard/related/owners?{}={}".format(
"q", prison.dumps(argument)
)
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"count": 2,
"result": [
{"text": "admin user", "value": 1},
{"text": "alpha user", "value": 5},
],
}
self.assertEqual(response, expected_response)
def test_get_related_fail(self):
"""
Dashboard API: Test dashboard get related fail
"""
self.login(username="admin")
uri = "api/v1/dashboard/related/owner"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
def test_export(self):
"""
Dashboard API: Test dashboard export