fix: Refactor ownership checks and ensure consistency (#20499)

Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
John Bodley 2022-07-07 11:04:27 -07:00 committed by GitHub
parent e7b965a3b2
commit f0ca158989
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
107 changed files with 614 additions and 807 deletions

View File

@ -303,7 +303,7 @@ ignored-modules=numpy,pandas,alembic.op,sqlalchemy,alembic.context,flask_appbuil
# List of class names for which member attributes should not be checked (useful # List of class names for which member attributes should not be checked (useful
# for classes with dynamically set attributes). This supports the use of # for classes with dynamically set attributes). This supports the use of
# qualified names. # qualified names.
ignored-classes=contextlib.closing,optparse.Values,thread._local,_thread._local,sqlalchemy.orm.scoping.scoped_session ignored-classes=contextlib.closing,optparse.Values,thread._local,_thread._local
# List of members which are set dynamically and missed by pylint inference # List of members which are set dynamically and missed by pylint inference
# system, and so shouldn't trigger E1101 when accessed. Python regular # system, and so shouldn't trigger E1101 when accessed. Python regular

View File

@ -17,7 +17,7 @@
import logging import logging
from typing import Any, Dict from typing import Any, Dict
from flask import g, request, Response from flask import request, Response
from flask_appbuilder.api import expose, permission_name, protect, rison, safe from flask_appbuilder.api import expose, permission_name, protect, rison, safe
from flask_appbuilder.api.schemas import get_item_schema, get_list_schema from flask_appbuilder.api.schemas import get_item_schema, get_list_schema
from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_appbuilder.models.sqla.interface import SQLAInterface
@ -306,7 +306,7 @@ class AnnotationRestApi(BaseSupersetModelRestApi):
except ValidationError as error: except ValidationError as error:
return self.response_400(message=error.messages) return self.response_400(message=error.messages)
try: try:
new_model = CreateAnnotationCommand(g.user, item).run() new_model = CreateAnnotationCommand(item).run()
return self.response(201, id=new_model.id, result=item) return self.response(201, id=new_model.id, result=item)
except AnnotationLayerNotFoundError as ex: except AnnotationLayerNotFoundError as ex:
return self.response_400(message=str(ex)) return self.response_400(message=str(ex))
@ -381,7 +381,7 @@ class AnnotationRestApi(BaseSupersetModelRestApi):
except ValidationError as error: except ValidationError as error:
return self.response_400(message=error.messages) return self.response_400(message=error.messages)
try: try:
new_model = UpdateAnnotationCommand(g.user, annotation_id, item).run() new_model = UpdateAnnotationCommand(annotation_id, item).run()
return self.response(200, id=new_model.id, result=item) return self.response(200, id=new_model.id, result=item)
except (AnnotationNotFoundError, AnnotationLayerNotFoundError): except (AnnotationNotFoundError, AnnotationLayerNotFoundError):
return self.response_404() return self.response_404()
@ -438,7 +438,7 @@ class AnnotationRestApi(BaseSupersetModelRestApi):
$ref: '#/components/responses/500' $ref: '#/components/responses/500'
""" """
try: try:
DeleteAnnotationCommand(g.user, annotation_id).run() DeleteAnnotationCommand(annotation_id).run()
return self.response(200, message="OK") return self.response(200, message="OK")
except AnnotationNotFoundError: except AnnotationNotFoundError:
return self.response_404() return self.response_404()
@ -495,7 +495,7 @@ class AnnotationRestApi(BaseSupersetModelRestApi):
""" """
item_ids = kwargs["rison"] item_ids = kwargs["rison"]
try: try:
BulkDeleteAnnotationCommand(g.user, item_ids).run() BulkDeleteAnnotationCommand(item_ids).run()
return self.response( return self.response(
200, 200,
message=ngettext( message=ngettext(

View File

@ -17,8 +17,6 @@
import logging import logging
from typing import List, Optional from typing import List, Optional
from flask_appbuilder.security.sqla.models import User
from superset.annotation_layers.annotations.commands.exceptions import ( from superset.annotation_layers.annotations.commands.exceptions import (
AnnotationBulkDeleteFailedError, AnnotationBulkDeleteFailedError,
AnnotationNotFoundError, AnnotationNotFoundError,
@ -32,8 +30,7 @@ logger = logging.getLogger(__name__)
class BulkDeleteAnnotationCommand(BaseCommand): class BulkDeleteAnnotationCommand(BaseCommand):
def __init__(self, user: User, model_ids: List[int]): def __init__(self, model_ids: List[int]):
self._actor = user
self._model_ids = model_ids self._model_ids = model_ids
self._models: Optional[List[Annotation]] = None self._models: Optional[List[Annotation]] = None

View File

@ -19,7 +19,6 @@ from datetime import datetime
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
from marshmallow import ValidationError from marshmallow import ValidationError
from superset.annotation_layers.annotations.commands.exceptions import ( from superset.annotation_layers.annotations.commands.exceptions import (
@ -38,8 +37,7 @@ logger = logging.getLogger(__name__)
class CreateAnnotationCommand(BaseCommand): class CreateAnnotationCommand(BaseCommand):
def __init__(self, user: User, data: Dict[str, Any]): def __init__(self, data: Dict[str, Any]):
self._actor = user
self._properties = data.copy() self._properties = data.copy()
def run(self) -> Model: def run(self) -> Model:

View File

@ -18,7 +18,6 @@ import logging
from typing import Optional from typing import Optional
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
from superset.annotation_layers.annotations.commands.exceptions import ( from superset.annotation_layers.annotations.commands.exceptions import (
AnnotationDeleteFailedError, AnnotationDeleteFailedError,
@ -33,8 +32,7 @@ logger = logging.getLogger(__name__)
class DeleteAnnotationCommand(BaseCommand): class DeleteAnnotationCommand(BaseCommand):
def __init__(self, user: User, model_id: int): def __init__(self, model_id: int):
self._actor = user
self._model_id = model_id self._model_id = model_id
self._model: Optional[Annotation] = None self._model: Optional[Annotation] = None

View File

@ -19,7 +19,6 @@ from datetime import datetime
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
from marshmallow import ValidationError from marshmallow import ValidationError
from superset.annotation_layers.annotations.commands.exceptions import ( from superset.annotation_layers.annotations.commands.exceptions import (
@ -40,8 +39,7 @@ logger = logging.getLogger(__name__)
class UpdateAnnotationCommand(BaseCommand): class UpdateAnnotationCommand(BaseCommand):
def __init__(self, user: User, model_id: int, data: Dict[str, Any]): def __init__(self, model_id: int, data: Dict[str, Any]):
self._actor = user
self._model_id = model_id self._model_id = model_id
self._properties = data.copy() self._properties = data.copy()
self._model: Optional[Annotation] = None self._model: Optional[Annotation] = None

View File

@ -17,7 +17,7 @@
import logging import logging
from typing import Any from typing import Any
from flask import g, request, Response from flask import request, Response
from flask_appbuilder.api import expose, permission_name, protect, rison, safe from flask_appbuilder.api import expose, permission_name, protect, rison, safe
from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_babel import ngettext from flask_babel import ngettext
@ -151,7 +151,7 @@ class AnnotationLayerRestApi(BaseSupersetModelRestApi):
$ref: '#/components/responses/500' $ref: '#/components/responses/500'
""" """
try: try:
DeleteAnnotationLayerCommand(g.user, pk).run() DeleteAnnotationLayerCommand(pk).run()
return self.response(200, message="OK") return self.response(200, message="OK")
except AnnotationLayerNotFoundError: except AnnotationLayerNotFoundError:
return self.response_404() return self.response_404()
@ -216,7 +216,7 @@ class AnnotationLayerRestApi(BaseSupersetModelRestApi):
except ValidationError as error: except ValidationError as error:
return self.response_400(message=error.messages) return self.response_400(message=error.messages)
try: try:
new_model = CreateAnnotationLayerCommand(g.user, item).run() new_model = CreateAnnotationLayerCommand(item).run()
return self.response(201, id=new_model.id, result=item) return self.response(201, id=new_model.id, result=item)
except AnnotationLayerNotFoundError as ex: except AnnotationLayerNotFoundError as ex:
return self.response_400(message=str(ex)) return self.response_400(message=str(ex))
@ -288,7 +288,7 @@ class AnnotationLayerRestApi(BaseSupersetModelRestApi):
except ValidationError as error: except ValidationError as error:
return self.response_400(message=error.messages) return self.response_400(message=error.messages)
try: try:
new_model = UpdateAnnotationLayerCommand(g.user, pk, item).run() new_model = UpdateAnnotationLayerCommand(pk, item).run()
return self.response(200, id=new_model.id, result=item) return self.response(200, id=new_model.id, result=item)
except AnnotationLayerNotFoundError: except AnnotationLayerNotFoundError:
return self.response_404() return self.response_404()
@ -346,7 +346,7 @@ class AnnotationLayerRestApi(BaseSupersetModelRestApi):
""" """
item_ids = kwargs["rison"] item_ids = kwargs["rison"]
try: try:
BulkDeleteAnnotationLayerCommand(g.user, item_ids).run() BulkDeleteAnnotationLayerCommand(item_ids).run()
return self.response( return self.response(
200, 200,
message=ngettext( message=ngettext(

View File

@ -17,8 +17,6 @@
import logging import logging
from typing import List, Optional from typing import List, Optional
from flask_appbuilder.security.sqla.models import User
from superset.annotation_layers.commands.exceptions import ( from superset.annotation_layers.commands.exceptions import (
AnnotationLayerBulkDeleteFailedError, AnnotationLayerBulkDeleteFailedError,
AnnotationLayerBulkDeleteIntegrityError, AnnotationLayerBulkDeleteIntegrityError,
@ -33,8 +31,7 @@ logger = logging.getLogger(__name__)
class BulkDeleteAnnotationLayerCommand(BaseCommand): class BulkDeleteAnnotationLayerCommand(BaseCommand):
def __init__(self, user: User, model_ids: List[int]): def __init__(self, model_ids: List[int]):
self._actor = user
self._model_ids = model_ids self._model_ids = model_ids
self._models: Optional[List[AnnotationLayer]] = None self._models: Optional[List[AnnotationLayer]] = None

View File

@ -18,7 +18,6 @@ import logging
from typing import Any, Dict, List from typing import Any, Dict, List
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
from marshmallow import ValidationError from marshmallow import ValidationError
from superset.annotation_layers.commands.exceptions import ( from superset.annotation_layers.commands.exceptions import (
@ -34,8 +33,7 @@ logger = logging.getLogger(__name__)
class CreateAnnotationLayerCommand(BaseCommand): class CreateAnnotationLayerCommand(BaseCommand):
def __init__(self, user: User, data: Dict[str, Any]): def __init__(self, data: Dict[str, Any]):
self._actor = user
self._properties = data.copy() self._properties = data.copy()
def run(self) -> Model: def run(self) -> Model:

View File

@ -18,7 +18,6 @@ import logging
from typing import Optional from typing import Optional
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
from superset.annotation_layers.commands.exceptions import ( from superset.annotation_layers.commands.exceptions import (
AnnotationLayerDeleteFailedError, AnnotationLayerDeleteFailedError,
@ -34,8 +33,7 @@ logger = logging.getLogger(__name__)
class DeleteAnnotationLayerCommand(BaseCommand): class DeleteAnnotationLayerCommand(BaseCommand):
def __init__(self, user: User, model_id: int): def __init__(self, model_id: int):
self._actor = user
self._model_id = model_id self._model_id = model_id
self._model: Optional[AnnotationLayer] = None self._model: Optional[AnnotationLayer] = None

View File

@ -18,7 +18,6 @@ import logging
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
from marshmallow import ValidationError from marshmallow import ValidationError
from superset.annotation_layers.commands.exceptions import ( from superset.annotation_layers.commands.exceptions import (
@ -36,8 +35,7 @@ logger = logging.getLogger(__name__)
class UpdateAnnotationLayerCommand(BaseCommand): class UpdateAnnotationLayerCommand(BaseCommand):
def __init__(self, user: User, model_id: int, data: Dict[str, Any]): def __init__(self, model_id: int, data: Dict[str, Any]):
self._actor = user
self._model_id = model_id self._model_id = model_id
self._properties = data.copy() self._properties = data.copy()
self._model: Optional[AnnotationLayer] = None self._model: Optional[AnnotationLayer] = None

View File

@ -21,7 +21,7 @@ from io import BytesIO
from typing import Any, Optional from typing import Any, Optional
from zipfile import ZipFile from zipfile import ZipFile
from flask import g, redirect, request, Response, send_file, url_for from flask import redirect, request, Response, send_file, url_for
from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.api import expose, protect, rison, safe
from flask_appbuilder.hooks import before_request from flask_appbuilder.hooks import before_request
from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_appbuilder.models.sqla.interface import SQLAInterface
@ -285,7 +285,7 @@ class ChartRestApi(BaseSupersetModelRestApi):
except ValidationError as error: except ValidationError as error:
return self.response_400(message=error.messages) return self.response_400(message=error.messages)
try: try:
new_model = CreateChartCommand(g.user, item).run() new_model = CreateChartCommand(item).run()
return self.response(201, id=new_model.id, result=item) return self.response(201, id=new_model.id, result=item)
except ChartInvalidError as ex: except ChartInvalidError as ex:
return self.response_422(message=ex.normalized_messages()) return self.response_422(message=ex.normalized_messages())
@ -356,7 +356,7 @@ class ChartRestApi(BaseSupersetModelRestApi):
except ValidationError as error: except ValidationError as error:
return self.response_400(message=error.messages) return self.response_400(message=error.messages)
try: try:
changed_model = UpdateChartCommand(g.user, pk, item).run() changed_model = UpdateChartCommand(pk, item).run()
response = self.response(200, id=changed_model.id, result=item) response = self.response(200, id=changed_model.id, result=item)
except ChartNotFoundError: except ChartNotFoundError:
response = self.response_404() response = self.response_404()
@ -416,7 +416,7 @@ class ChartRestApi(BaseSupersetModelRestApi):
$ref: '#/components/responses/500' $ref: '#/components/responses/500'
""" """
try: try:
DeleteChartCommand(g.user, pk).run() DeleteChartCommand(pk).run()
return self.response(200, message="OK") return self.response(200, message="OK")
except ChartNotFoundError: except ChartNotFoundError:
return self.response_404() return self.response_404()
@ -476,7 +476,7 @@ class ChartRestApi(BaseSupersetModelRestApi):
""" """
item_ids = kwargs["rison"] item_ids = kwargs["rison"]
try: try:
BulkDeleteChartCommand(g.user, item_ids).run() BulkDeleteChartCommand(item_ids).run()
return self.response( return self.response(
200, 200,
message=ngettext( message=ngettext(

View File

@ -17,9 +17,9 @@
import logging import logging
from typing import List, Optional from typing import List, Optional
from flask_appbuilder.security.sqla.models import User
from flask_babel import lazy_gettext as _ from flask_babel import lazy_gettext as _
from superset import security_manager
from superset.charts.commands.exceptions import ( from superset.charts.commands.exceptions import (
ChartBulkDeleteFailedError, ChartBulkDeleteFailedError,
ChartBulkDeleteFailedReportsExistError, ChartBulkDeleteFailedReportsExistError,
@ -32,14 +32,12 @@ from superset.commands.exceptions import DeleteFailedError
from superset.exceptions import SupersetSecurityException from superset.exceptions import SupersetSecurityException
from superset.models.slice import Slice from superset.models.slice import Slice
from superset.reports.dao import ReportScheduleDAO from superset.reports.dao import ReportScheduleDAO
from superset.views.base import check_ownership
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BulkDeleteChartCommand(BaseCommand): class BulkDeleteChartCommand(BaseCommand):
def __init__(self, user: User, model_ids: List[int]): def __init__(self, model_ids: List[int]):
self._actor = user
self._model_ids = model_ids self._model_ids = model_ids
self._models: Optional[List[Slice]] = None self._models: Optional[List[Slice]] = None
@ -66,6 +64,6 @@ class BulkDeleteChartCommand(BaseCommand):
# Check ownership # Check ownership
for model in self._models: for model in self._models:
try: try:
check_ownership(model) security_manager.raise_for_ownership(model)
except SupersetSecurityException as ex: except SupersetSecurityException as ex:
raise ChartForbiddenError() from ex raise ChartForbiddenError() from ex

View File

@ -18,8 +18,8 @@ import logging
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from flask import g
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
from marshmallow import ValidationError from marshmallow import ValidationError
from superset.charts.commands.exceptions import ( from superset.charts.commands.exceptions import (
@ -37,15 +37,14 @@ logger = logging.getLogger(__name__)
class CreateChartCommand(CreateMixin, BaseCommand): class CreateChartCommand(CreateMixin, BaseCommand):
def __init__(self, user: User, data: Dict[str, Any]): def __init__(self, data: Dict[str, Any]):
self._actor = user
self._properties = data.copy() self._properties = data.copy()
def run(self) -> Model: def run(self) -> Model:
self.validate() self.validate()
try: try:
self._properties["last_saved_at"] = datetime.now() self._properties["last_saved_at"] = datetime.now()
self._properties["last_saved_by"] = self._actor self._properties["last_saved_by"] = g.user
chart = ChartDAO.create(self._properties) chart = ChartDAO.create(self._properties)
except DAOCreateFailedError as ex: except DAOCreateFailedError as ex:
logger.exception(ex.exception) logger.exception(ex.exception)
@ -73,7 +72,7 @@ class CreateChartCommand(CreateMixin, BaseCommand):
self._properties["dashboards"] = dashboards self._properties["dashboards"] = dashboards
try: try:
owners = self.populate_owners(self._actor, owner_ids) owners = self.populate_owners(owner_ids)
self._properties["owners"] = owners self._properties["owners"] = owners
except ValidationError as ex: except ValidationError as ex:
exceptions.append(ex) exceptions.append(ex)

View File

@ -18,9 +18,9 @@ import logging
from typing import Optional from typing import Optional
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
from flask_babel import lazy_gettext as _ from flask_babel import lazy_gettext as _
from superset import security_manager
from superset.charts.commands.exceptions import ( from superset.charts.commands.exceptions import (
ChartDeleteFailedError, ChartDeleteFailedError,
ChartDeleteFailedReportsExistError, ChartDeleteFailedReportsExistError,
@ -34,14 +34,12 @@ from superset.exceptions import SupersetSecurityException
from superset.models.dashboard import Dashboard from superset.models.dashboard import Dashboard
from superset.models.slice import Slice from superset.models.slice import Slice
from superset.reports.dao import ReportScheduleDAO from superset.reports.dao import ReportScheduleDAO
from superset.views.base import check_ownership
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DeleteChartCommand(BaseCommand): class DeleteChartCommand(BaseCommand):
def __init__(self, user: User, model_id: int): def __init__(self, model_id: int):
self._actor = user
self._model_id = model_id self._model_id = model_id
self._model: Optional[Slice] = None self._model: Optional[Slice] = None
@ -69,6 +67,6 @@ class DeleteChartCommand(BaseCommand):
) )
# Check ownership # Check ownership
try: try:
check_ownership(self._model) security_manager.raise_for_ownership(self._model)
except SupersetSecurityException as ex: except SupersetSecurityException as ex:
raise ChartForbiddenError() from ex raise ChartForbiddenError() from ex

View File

@ -18,10 +18,11 @@ import logging
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from flask import g
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
from marshmallow import ValidationError from marshmallow import ValidationError
from superset import security_manager
from superset.charts.commands.exceptions import ( from superset.charts.commands.exceptions import (
ChartForbiddenError, ChartForbiddenError,
ChartInvalidError, ChartInvalidError,
@ -37,7 +38,6 @@ from superset.dao.exceptions import DAOUpdateFailedError
from superset.dashboards.dao import DashboardDAO from superset.dashboards.dao import DashboardDAO
from superset.exceptions import SupersetSecurityException from superset.exceptions import SupersetSecurityException
from superset.models.slice import Slice from superset.models.slice import Slice
from superset.views.base import check_ownership
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -49,8 +49,7 @@ def is_query_context_update(properties: Dict[str, Any]) -> bool:
class UpdateChartCommand(UpdateMixin, BaseCommand): class UpdateChartCommand(UpdateMixin, BaseCommand):
def __init__(self, user: User, model_id: int, data: Dict[str, Any]): def __init__(self, model_id: int, data: Dict[str, Any]):
self._actor = user
self._model_id = model_id self._model_id = model_id
self._properties = data.copy() self._properties = data.copy()
self._model: Optional[Slice] = None self._model: Optional[Slice] = None
@ -60,7 +59,7 @@ class UpdateChartCommand(UpdateMixin, BaseCommand):
try: try:
if self._properties.get("query_context_generation") is None: if self._properties.get("query_context_generation") is None:
self._properties["last_saved_at"] = datetime.now() self._properties["last_saved_at"] = datetime.now()
self._properties["last_saved_by"] = self._actor self._properties["last_saved_by"] = g.user
chart = ChartDAO.update(self._model, self._properties) chart = ChartDAO.update(self._model, self._properties)
except DAOUpdateFailedError as ex: except DAOUpdateFailedError as ex:
logger.exception(ex.exception) logger.exception(ex.exception)
@ -88,8 +87,8 @@ class UpdateChartCommand(UpdateMixin, BaseCommand):
# ownership so the update can be performed by report workers # ownership so the update can be performed by report workers
if not is_query_context_update(self._properties): if not is_query_context_update(self._properties):
try: try:
check_ownership(self._model) security_manager.raise_for_ownership(self._model)
owners = self.populate_owners(self._actor, owner_ids) owners = self.populate_owners(owner_ids)
self._properties["owners"] = owners self._properties["owners"] = owners
except SupersetSecurityException as ex: except SupersetSecurityException as ex:
raise ChartForbiddenError() from ex raise ChartForbiddenError() from ex

View File

@ -45,34 +45,28 @@ class BaseCommand(ABC):
class CreateMixin: # pylint: disable=too-few-public-methods class CreateMixin: # pylint: disable=too-few-public-methods
@staticmethod @staticmethod
def populate_owners( def populate_owners(owner_ids: Optional[List[int]] = None) -> List[User]:
user: User, owner_ids: Optional[List[int]] = None
) -> List[User]:
""" """
Populate list of owners, defaulting to the current user if `owner_ids` is Populate list of owners, defaulting to the current user if `owner_ids` is
undefined or empty. If current user is missing in `owner_ids`, current user undefined or empty. If current user is missing in `owner_ids`, current user
is added unless belonging to the Admin role. is added unless belonging to the Admin role.
:param user: current user
:param owner_ids: list of owners by id's :param owner_ids: list of owners by id's
:raises OwnersNotFoundValidationError: if at least one owner can't be resolved :raises OwnersNotFoundValidationError: if at least one owner can't be resolved
:returns: Final list of owners :returns: Final list of owners
""" """
return populate_owners(user, owner_ids, default_to_user=True) return populate_owners(owner_ids, default_to_user=True)
class UpdateMixin: # pylint: disable=too-few-public-methods class UpdateMixin: # pylint: disable=too-few-public-methods
@staticmethod @staticmethod
def populate_owners( def populate_owners(owner_ids: Optional[List[int]] = None) -> List[User]:
user: User, owner_ids: Optional[List[int]] = None
) -> List[User]:
""" """
Populate list of owners. If current user is missing in `owner_ids`, current user Populate list of owners. If current user is missing in `owner_ids`, current user
is added unless belonging to the Admin role. is added unless belonging to the Admin role.
:param user: current user
:param owner_ids: list of owners by id's :param owner_ids: list of owners by id's
:raises OwnersNotFoundValidationError: if at least one owner can't be resolved :raises OwnersNotFoundValidationError: if at least one owner can't be resolved
:returns: Final list of owners :returns: Final list of owners
""" """
return populate_owners(user, owner_ids, default_to_user=False) return populate_owners(owner_ids, default_to_user=False)

View File

@ -18,8 +18,10 @@ from __future__ import annotations
from typing import List, Optional, TYPE_CHECKING from typing import List, Optional, TYPE_CHECKING
from flask import g
from flask_appbuilder.security.sqla.models import Role, User from flask_appbuilder.security.sqla.models import Role, User
from superset import security_manager
from superset.commands.exceptions import ( from superset.commands.exceptions import (
DatasourceNotFoundValidationError, DatasourceNotFoundValidationError,
OwnersNotFoundValidationError, OwnersNotFoundValidationError,
@ -27,21 +29,20 @@ from superset.commands.exceptions import (
) )
from superset.dao.exceptions import DatasourceNotFound from superset.dao.exceptions import DatasourceNotFound
from superset.datasource.dao import DatasourceDAO from superset.datasource.dao import DatasourceDAO
from superset.extensions import db, security_manager from superset.extensions import db
from superset.utils.core import DatasourceType from superset.utils.core import DatasourceType, get_user_id
if TYPE_CHECKING: if TYPE_CHECKING:
from superset.connectors.base.models import BaseDatasource from superset.connectors.base.models import BaseDatasource
def populate_owners( def populate_owners(
user: User,
owner_ids: Optional[List[int]], owner_ids: Optional[List[int]],
default_to_user: bool, default_to_user: bool,
) -> List[User]: ) -> List[User]:
""" """
Helper function for commands, will fetch all users from owners id's Helper function for commands, will fetch all users from owners id's
:param user: current user
:param owner_ids: list of owners by id's :param owner_ids: list of owners by id's
:param default_to_user: make user the owner if `owner_ids` is None or empty :param default_to_user: make user the owner if `owner_ids` is None or empty
:raises OwnersNotFoundValidationError: if at least one owner id can't be resolved :raises OwnersNotFoundValidationError: if at least one owner id can't be resolved
@ -50,12 +51,10 @@ def populate_owners(
owner_ids = owner_ids or [] owner_ids = owner_ids or []
owners = [] owners = []
if not owner_ids and default_to_user: if not owner_ids and default_to_user:
return [user] return [g.user]
if user.id not in owner_ids and "admin" not in [ if not (security_manager.is_admin() or get_user_id() in owner_ids):
role.name.lower() for role in user.roles
]:
# make sure non-admins can't remove themselves as owner by mistake # make sure non-admins can't remove themselves as owner by mistake
owners.append(user) owners.append(g.user)
for owner_id in owner_ids: for owner_id in owner_ids:
owner = security_manager.get_user_by_id(owner_id) owner = security_manager.get_user_by_id(owner_id)
if not owner: if not owner:

View File

@ -1,25 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from superset import conf, security_manager
def is_user_admin() -> bool:
user_roles = [role.name.lower() for role in security_manager.get_user_roles()]
admin_role = conf.get("AUTH_ROLE_ADMIN").lower()
return admin_role in user_roles

View File

@ -60,9 +60,10 @@ class SelectDataRequired(DataRequired): # pylint: disable=too-few-public-method
field_flags = () field_flags = ()
class TableColumnInlineView( class TableColumnInlineView( # pylint: disable=too-many-ancestors
CompactCRUDMixin, SupersetModelView CompactCRUDMixin,
): # pylint: disable=too-many-ancestors SupersetModelView,
):
datamodel = SQLAInterface(models.TableColumn) datamodel = SQLAInterface(models.TableColumn)
# TODO TODO, review need for this on related_views # TODO TODO, review need for this on related_views
class_permission_name = "Dataset" class_permission_name = "Dataset"
@ -194,9 +195,10 @@ class TableColumnInlineView(
edit_form_extra_fields = add_form_extra_fields edit_form_extra_fields = add_form_extra_fields
class SqlMetricInlineView( class SqlMetricInlineView( # pylint: disable=too-many-ancestors
CompactCRUDMixin, SupersetModelView CompactCRUDMixin,
): # pylint: disable=too-many-ancestors SupersetModelView,
):
datamodel = SQLAInterface(models.SqlMetric) datamodel = SQLAInterface(models.SqlMetric)
class_permission_name = "Dataset" class_permission_name = "Dataset"
method_permission_name = MODEL_VIEW_RW_METHOD_PERMISSION_MAP method_permission_name = MODEL_VIEW_RW_METHOD_PERMISSION_MAP
@ -278,9 +280,9 @@ class RowLevelSecurityListWidget(
super().__init__(**kwargs) super().__init__(**kwargs)
class RowLevelSecurityFiltersModelView( class RowLevelSecurityFiltersModelView( # pylint: disable=too-many-ancestors
SupersetModelView, DeleteMixin SupersetModelView, DeleteMixin
): # pylint: disable=too-many-ancestors ):
datamodel = SQLAInterface(models.RowLevelSecurityFilter) datamodel = SQLAInterface(models.RowLevelSecurityFilter)
list_widget = cast(SupersetListWidget, RowLevelSecurityListWidget) list_widget = cast(SupersetListWidget, RowLevelSecurityListWidget)

View File

@ -17,7 +17,7 @@
import logging import logging
from typing import Any from typing import Any
from flask import g, Response from flask import Response
from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.api import expose, protect, rison, safe
from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_babel import ngettext from flask_babel import ngettext
@ -130,7 +130,7 @@ class CssTemplateRestApi(BaseSupersetModelRestApi):
""" """
item_ids = kwargs["rison"] item_ids = kwargs["rison"]
try: try:
BulkDeleteCssTemplateCommand(g.user, item_ids).run() BulkDeleteCssTemplateCommand(item_ids).run()
return self.response( return self.response(
200, 200,
message=ngettext( message=ngettext(

View File

@ -17,8 +17,6 @@
import logging import logging
from typing import List, Optional from typing import List, Optional
from flask_appbuilder.security.sqla.models import User
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
from superset.css_templates.commands.exceptions import ( from superset.css_templates.commands.exceptions import (
CssTemplateBulkDeleteFailedError, CssTemplateBulkDeleteFailedError,
@ -32,8 +30,7 @@ logger = logging.getLogger(__name__)
class BulkDeleteCssTemplateCommand(BaseCommand): class BulkDeleteCssTemplateCommand(BaseCommand):
def __init__(self, user: User, model_ids: List[int]): def __init__(self, model_ids: List[int]):
self._actor = user
self._model_ids = model_ids self._model_ids = model_ids
self._models: Optional[List[CssTemplate]] = None self._models: Optional[List[CssTemplate]] = None

View File

@ -23,7 +23,7 @@ from io import BytesIO
from typing import Any, Callable, Optional from typing import Any, Callable, Optional
from zipfile import is_zipfile, ZipFile from zipfile import is_zipfile, ZipFile
from flask import g, make_response, redirect, request, Response, send_file, url_for from flask import make_response, redirect, request, Response, send_file, url_for
from flask_appbuilder import permission_name from flask_appbuilder import permission_name
from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.api import expose, protect, rison, safe
from flask_appbuilder.hooks import before_request from flask_appbuilder.hooks import before_request
@ -504,7 +504,7 @@ class DashboardRestApi(BaseSupersetModelRestApi):
except ValidationError as error: except ValidationError as error:
return self.response_400(message=error.messages) return self.response_400(message=error.messages)
try: try:
new_model = CreateDashboardCommand(g.user, item).run() new_model = CreateDashboardCommand(item).run()
return self.response(201, id=new_model.id, result=item) return self.response(201, id=new_model.id, result=item)
except DashboardInvalidError as ex: except DashboardInvalidError as ex:
return self.response_422(message=ex.normalized_messages()) return self.response_422(message=ex.normalized_messages())
@ -577,7 +577,7 @@ class DashboardRestApi(BaseSupersetModelRestApi):
except ValidationError as error: except ValidationError as error:
return self.response_400(message=error.messages) return self.response_400(message=error.messages)
try: try:
changed_model = UpdateDashboardCommand(g.user, pk, item).run() changed_model = UpdateDashboardCommand(pk, item).run()
last_modified_time = changed_model.changed_on.replace( last_modified_time = changed_model.changed_on.replace(
microsecond=0 microsecond=0
).timestamp() ).timestamp()
@ -644,7 +644,7 @@ class DashboardRestApi(BaseSupersetModelRestApi):
$ref: '#/components/responses/500' $ref: '#/components/responses/500'
""" """
try: try:
DeleteDashboardCommand(g.user, pk).run() DeleteDashboardCommand(pk).run()
return self.response(200, message="OK") return self.response(200, message="OK")
except DashboardNotFoundError: except DashboardNotFoundError:
return self.response_404() return self.response_404()
@ -704,7 +704,7 @@ class DashboardRestApi(BaseSupersetModelRestApi):
""" """
item_ids = kwargs["rison"] item_ids = kwargs["rison"]
try: try:
BulkDeleteDashboardCommand(g.user, item_ids).run() BulkDeleteDashboardCommand(item_ids).run()
return self.response( return self.response(
200, 200,
message=ngettext( message=ngettext(
@ -942,6 +942,7 @@ class DashboardRestApi(BaseSupersetModelRestApi):
dashboards = DashboardDAO.find_by_ids(requested_ids) dashboards = DashboardDAO.find_by_ids(requested_ids)
if not dashboards: if not dashboards:
return self.response_404() return self.response_404()
favorited_dashboard_ids = DashboardDAO.favorited_ids(dashboards) favorited_dashboard_ids = DashboardDAO.favorited_ids(dashboards)
res = [ res = [
{"id": request_id, "value": request_id in favorited_dashboard_ids} {"id": request_id, "value": request_id in favorited_dashboard_ids}

View File

@ -17,9 +17,9 @@
import logging import logging
from typing import List, Optional from typing import List, Optional
from flask_appbuilder.security.sqla.models import User
from flask_babel import lazy_gettext as _ from flask_babel import lazy_gettext as _
from superset import security_manager
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
from superset.commands.exceptions import DeleteFailedError from superset.commands.exceptions import DeleteFailedError
from superset.dashboards.commands.exceptions import ( from superset.dashboards.commands.exceptions import (
@ -32,14 +32,12 @@ from superset.dashboards.dao import DashboardDAO
from superset.exceptions import SupersetSecurityException from superset.exceptions import SupersetSecurityException
from superset.models.dashboard import Dashboard from superset.models.dashboard import Dashboard
from superset.reports.dao import ReportScheduleDAO from superset.reports.dao import ReportScheduleDAO
from superset.views.base import check_ownership
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BulkDeleteDashboardCommand(BaseCommand): class BulkDeleteDashboardCommand(BaseCommand):
def __init__(self, user: User, model_ids: List[int]): def __init__(self, model_ids: List[int]):
self._actor = user
self._model_ids = model_ids self._model_ids = model_ids
self._models: Optional[List[Dashboard]] = None self._models: Optional[List[Dashboard]] = None
@ -67,6 +65,6 @@ class BulkDeleteDashboardCommand(BaseCommand):
# Check ownership # Check ownership
for model in self._models: for model in self._models:
try: try:
check_ownership(model) security_manager.raise_for_ownership(model)
except SupersetSecurityException as ex: except SupersetSecurityException as ex:
raise DashboardForbiddenError() from ex raise DashboardForbiddenError() from ex

View File

@ -18,7 +18,6 @@ import logging
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
from marshmallow import ValidationError from marshmallow import ValidationError
from superset.commands.base import BaseCommand, CreateMixin from superset.commands.base import BaseCommand, CreateMixin
@ -35,8 +34,7 @@ logger = logging.getLogger(__name__)
class CreateDashboardCommand(CreateMixin, BaseCommand): class CreateDashboardCommand(CreateMixin, BaseCommand):
def __init__(self, user: User, data: Dict[str, Any]): def __init__(self, data: Dict[str, Any]):
self._actor = user
self._properties = data.copy() self._properties = data.copy()
def run(self) -> Model: def run(self) -> Model:
@ -60,7 +58,7 @@ class CreateDashboardCommand(CreateMixin, BaseCommand):
exceptions.append(DashboardSlugExistsValidationError()) exceptions.append(DashboardSlugExistsValidationError())
try: try:
owners = self.populate_owners(self._actor, owner_ids) owners = self.populate_owners(owner_ids)
self._properties["owners"] = owners self._properties["owners"] = owners
except ValidationError as ex: except ValidationError as ex:
exceptions.append(ex) exceptions.append(ex)

View File

@ -18,9 +18,9 @@ import logging
from typing import Optional from typing import Optional
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
from flask_babel import lazy_gettext as _ from flask_babel import lazy_gettext as _
from superset import security_manager
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
from superset.dao.exceptions import DAODeleteFailedError from superset.dao.exceptions import DAODeleteFailedError
from superset.dashboards.commands.exceptions import ( from superset.dashboards.commands.exceptions import (
@ -33,14 +33,12 @@ from superset.dashboards.dao import DashboardDAO
from superset.exceptions import SupersetSecurityException from superset.exceptions import SupersetSecurityException
from superset.models.dashboard import Dashboard from superset.models.dashboard import Dashboard
from superset.reports.dao import ReportScheduleDAO from superset.reports.dao import ReportScheduleDAO
from superset.views.base import check_ownership
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DeleteDashboardCommand(BaseCommand): class DeleteDashboardCommand(BaseCommand):
def __init__(self, user: User, model_id: int): def __init__(self, model_id: int):
self._actor = user
self._model_id = model_id self._model_id = model_id
self._model: Optional[Dashboard] = None self._model: Optional[Dashboard] = None
@ -67,6 +65,6 @@ class DeleteDashboardCommand(BaseCommand):
) )
# Check ownership # Check ownership
try: try:
check_ownership(self._model) security_manager.raise_for_ownership(self._model)
except SupersetSecurityException as ex: except SupersetSecurityException as ex:
raise DashboardForbiddenError() from ex raise DashboardForbiddenError() from ex

View File

@ -19,9 +19,9 @@ import logging
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
from marshmallow import ValidationError from marshmallow import ValidationError
from superset import security_manager
from superset.commands.base import BaseCommand, UpdateMixin from superset.commands.base import BaseCommand, UpdateMixin
from superset.commands.utils import populate_roles from superset.commands.utils import populate_roles
from superset.dao.exceptions import DAOUpdateFailedError from superset.dao.exceptions import DAOUpdateFailedError
@ -36,14 +36,12 @@ from superset.dashboards.dao import DashboardDAO
from superset.exceptions import SupersetSecurityException from superset.exceptions import SupersetSecurityException
from superset.extensions import db from superset.extensions import db
from superset.models.dashboard import Dashboard from superset.models.dashboard import Dashboard
from superset.views.base import check_ownership
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UpdateDashboardCommand(UpdateMixin, BaseCommand): class UpdateDashboardCommand(UpdateMixin, BaseCommand):
def __init__(self, user: User, model_id: int, data: Dict[str, Any]): def __init__(self, model_id: int, data: Dict[str, Any]):
self._actor = user
self._model_id = model_id self._model_id = model_id
self._properties = data.copy() self._properties = data.copy()
self._model: Optional[Dashboard] = None self._model: Optional[Dashboard] = None
@ -77,7 +75,7 @@ class UpdateDashboardCommand(UpdateMixin, BaseCommand):
raise DashboardNotFoundError() raise DashboardNotFoundError()
# Check ownership # Check ownership
try: try:
check_ownership(self._model) security_manager.raise_for_ownership(self._model)
except SupersetSecurityException as ex: except SupersetSecurityException as ex:
raise DashboardForbiddenError() from ex raise DashboardForbiddenError() from ex
@ -89,7 +87,7 @@ class UpdateDashboardCommand(UpdateMixin, BaseCommand):
if owners_ids is None: if owners_ids is None:
owners_ids = [owner.id for owner in self._model.owners] owners_ids = [owner.id for owner in self._model.owners]
try: try:
owners = self.populate_owners(self._actor, owners_ids) owners = self.populate_owners(owners_ids)
self._properties["owners"] = owners self._properties["owners"] = owners
except ValidationError as ex: except ValidationError as ex:
exceptions.append(ex) exceptions.append(ex)

View File

@ -17,7 +17,7 @@
import logging import logging
from typing import Any, cast from typing import Any, cast
from flask import g, request, Response from flask import request, Response
from flask_appbuilder.api import ( from flask_appbuilder.api import (
expose, expose,
get_list_schema, get_list_schema,
@ -243,7 +243,7 @@ class FilterSetRestApi(BaseSupersetModelRestApi):
""" """
try: try:
item = self.add_model_schema.load(request.json) item = self.add_model_schema.load(request.json)
new_model = CreateFilterSetCommand(g.user, dashboard_id, item).run() new_model = CreateFilterSetCommand(dashboard_id, item).run()
return self.response( return self.response(
201, **self.show_model_schema.dump(new_model, many=False) 201, **self.show_model_schema.dump(new_model, many=False)
) )
@ -314,7 +314,7 @@ class FilterSetRestApi(BaseSupersetModelRestApi):
""" """
try: try:
item = self.edit_model_schema.load(request.json) item = self.edit_model_schema.load(request.json)
changed_model = UpdateFilterSetCommand(g.user, dashboard_id, pk, item).run() changed_model = UpdateFilterSetCommand(dashboard_id, pk, item).run()
return self.response( return self.response(
200, **self.show_model_schema.dump(changed_model, many=False) 200, **self.show_model_schema.dump(changed_model, many=False)
) )
@ -374,7 +374,7 @@ class FilterSetRestApi(BaseSupersetModelRestApi):
$ref: '#/components/responses/500' $ref: '#/components/responses/500'
""" """
try: try:
changed_model = DeleteFilterSetCommand(g.user, dashboard_id, pk).run() changed_model = DeleteFilterSetCommand(dashboard_id, pk).run()
return self.response(200, id=changed_model.id) return self.response(200, id=changed_model.id)
except ValidationError as error: except ValidationError as error:
return self.response_400(message=error.messages) return self.response_400(message=error.messages)

View File

@ -18,10 +18,9 @@ import logging
from typing import cast, Optional from typing import cast, Optional
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
from superset import security_manager
from superset.common.not_authrized_object import NotAuthorizedException from superset.common.not_authrized_object import NotAuthorizedException
from superset.common.request_contexed_based import is_user_admin
from superset.dashboards.commands.exceptions import DashboardNotFoundError from superset.dashboards.commands.exceptions import DashboardNotFoundError
from superset.dashboards.dao import DashboardDAO from superset.dashboards.dao import DashboardDAO
from superset.dashboards.filter_sets.commands.exceptions import ( from superset.dashboards.filter_sets.commands.exceptions import (
@ -31,6 +30,7 @@ from superset.dashboards.filter_sets.commands.exceptions import (
from superset.dashboards.filter_sets.consts import USER_OWNER_TYPE from superset.dashboards.filter_sets.consts import USER_OWNER_TYPE
from superset.models.dashboard import Dashboard from superset.models.dashboard import Dashboard
from superset.models.filter_set import FilterSet from superset.models.filter_set import FilterSet
from superset.utils.core import get_user_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -41,9 +41,7 @@ class BaseFilterSetCommand:
_filter_set_id: Optional[int] _filter_set_id: Optional[int]
_filter_set: Optional[FilterSet] _filter_set: Optional[FilterSet]
def __init__(self, user: User, dashboard_id: int): def __init__(self, dashboard_id: int):
self._actor = user
self._is_actor_admin = is_user_admin()
self._dashboard_id = dashboard_id self._dashboard_id = dashboard_id
def run(self) -> Model: def run(self) -> Model:
@ -54,9 +52,6 @@ class BaseFilterSetCommand:
if not self._dashboard: if not self._dashboard:
raise DashboardNotFoundError() raise DashboardNotFoundError()
def is_user_dashboard_owner(self) -> bool:
return self._is_actor_admin or self._dashboard.is_actor_owner()
def validate_exist_filter_use_cases_set(self) -> None: # pylint: disable=C0103 def validate_exist_filter_use_cases_set(self) -> None: # pylint: disable=C0103
self._validate_filter_set_exists_and_set_when_exists() self._validate_filter_set_exists_and_set_when_exists()
self.check_ownership() self.check_ownership()
@ -70,15 +65,15 @@ class BaseFilterSetCommand:
def check_ownership(self) -> None: def check_ownership(self) -> None:
try: try:
if not self._is_actor_admin: if not security_manager.is_admin():
filter_set: FilterSet = cast(FilterSet, self._filter_set) filter_set: FilterSet = cast(FilterSet, self._filter_set)
if filter_set.owner_type == USER_OWNER_TYPE: if filter_set.owner_type == USER_OWNER_TYPE:
if self._actor.id != filter_set.owner_id: if get_user_id() != filter_set.owner_id:
raise FilterSetForbiddenError( raise FilterSetForbiddenError(
str(self._filter_set_id), str(self._filter_set_id),
"The user is not the owner of the filter_set", "The user is not the owner of the filter_set",
) )
elif not self.is_user_dashboard_owner(): elif not security_manager.is_owner(self._dashboard):
raise FilterSetForbiddenError( raise FilterSetForbiddenError(
str(self._filter_set_id), str(self._filter_set_id),
"The user is not an owner of the filter_set's dashboard", "The user is not an owner of the filter_set's dashboard",

View File

@ -17,9 +17,7 @@
import logging import logging
from typing import Any, Dict from typing import Any, Dict
from flask import g
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
from superset import security_manager from superset import security_manager
from superset.dashboards.filter_sets.commands.base import BaseFilterSetCommand from superset.dashboards.filter_sets.commands.base import BaseFilterSetCommand
@ -35,14 +33,15 @@ from superset.dashboards.filter_sets.consts import (
OWNER_TYPE_FIELD, OWNER_TYPE_FIELD,
) )
from superset.dashboards.filter_sets.dao import FilterSetDAO from superset.dashboards.filter_sets.dao import FilterSetDAO
from superset.utils.core import get_user_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CreateFilterSetCommand(BaseFilterSetCommand): class CreateFilterSetCommand(BaseFilterSetCommand):
# pylint: disable=C0103 # pylint: disable=C0103
def __init__(self, user: User, dashboard_id: int, data: Dict[str, Any]): def __init__(self, dashboard_id: int, data: Dict[str, Any]):
super().__init__(user, dashboard_id) super().__init__(dashboard_id)
self._properties = data.copy() self._properties = data.copy()
def run(self) -> Model: def run(self) -> Model:
@ -61,13 +60,13 @@ class CreateFilterSetCommand(BaseFilterSetCommand):
def _validate_owner_id_exists(self) -> None: def _validate_owner_id_exists(self) -> None:
owner_id = self._properties[OWNER_ID_FIELD] owner_id = self._properties[OWNER_ID_FIELD]
if not (g.user.id == owner_id or security_manager.get_user_by_id(owner_id)): if not (get_user_id() == owner_id or security_manager.get_user_by_id(owner_id)):
raise FilterSetCreateFailedError( raise FilterSetCreateFailedError(
str(self._dashboard_id), "owner_id does not exists" str(self._dashboard_id), "owner_id does not exists"
) )
def _validate_user_is_the_dashboard_owner(self) -> None: def _validate_user_is_the_dashboard_owner(self) -> None:
if not self.is_user_dashboard_owner(): if not security_manager.is_owner(self._dashboard):
raise UserIsNotDashboardOwnerError(str(self._dashboard_id)) raise UserIsNotDashboardOwnerError(str(self._dashboard_id))
def _validate_owner_id_is_dashboard_id(self) -> None: def _validate_owner_id_is_dashboard_id(self) -> None:

View File

@ -17,7 +17,6 @@
import logging import logging
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
from superset.dao.exceptions import DAODeleteFailedError from superset.dao.exceptions import DAODeleteFailedError
from superset.dashboards.filter_sets.commands.base import BaseFilterSetCommand from superset.dashboards.filter_sets.commands.base import BaseFilterSetCommand
@ -32,8 +31,8 @@ logger = logging.getLogger(__name__)
class DeleteFilterSetCommand(BaseFilterSetCommand): class DeleteFilterSetCommand(BaseFilterSetCommand):
def __init__(self, user: User, dashboard_id: int, filter_set_id: int): def __init__(self, dashboard_id: int, filter_set_id: int):
super().__init__(user, dashboard_id) super().__init__(dashboard_id)
self._filter_set_id = filter_set_id self._filter_set_id = filter_set_id
def run(self) -> Model: def run(self) -> Model:

View File

@ -18,7 +18,6 @@ import logging
from typing import Any, Dict from typing import Any, Dict
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
from superset.dao.exceptions import DAOUpdateFailedError from superset.dao.exceptions import DAOUpdateFailedError
from superset.dashboards.filter_sets.commands.base import BaseFilterSetCommand from superset.dashboards.filter_sets.commands.base import BaseFilterSetCommand
@ -32,10 +31,8 @@ logger = logging.getLogger(__name__)
class UpdateFilterSetCommand(BaseFilterSetCommand): class UpdateFilterSetCommand(BaseFilterSetCommand):
def __init__( def __init__(self, dashboard_id: int, filter_set_id: int, data: Dict[str, Any]):
self, user: User, dashboard_id: int, filter_set_id: int, data: Dict[str, Any] super().__init__(dashboard_id)
):
super().__init__(user, dashboard_id)
self._filter_set_id = filter_set_id self._filter_set_id = filter_set_id
self._properties = data.copy() self._properties = data.copy()

View File

@ -18,13 +18,14 @@ from __future__ import annotations
from typing import Any, TYPE_CHECKING from typing import Any, TYPE_CHECKING
from flask import g
from sqlalchemy import and_, or_ from sqlalchemy import and_, or_
from superset import security_manager
from superset.dashboards.filter_sets.consts import DASHBOARD_OWNER_TYPE, USER_OWNER_TYPE from superset.dashboards.filter_sets.consts import DASHBOARD_OWNER_TYPE, USER_OWNER_TYPE
from superset.models.dashboard import dashboard_user from superset.models.dashboard import dashboard_user
from superset.models.filter_set import FilterSet from superset.models.filter_set import FilterSet
from superset.views.base import BaseFilter, is_user_admin from superset.utils.core import get_user_id
from superset.views.base import BaseFilter
if TYPE_CHECKING: if TYPE_CHECKING:
from sqlalchemy.orm.query import Query from sqlalchemy.orm.query import Query
@ -32,9 +33,8 @@ if TYPE_CHECKING:
class FilterSetFilter(BaseFilter): # pylint: disable=too-few-public-methods) class FilterSetFilter(BaseFilter): # pylint: disable=too-few-public-methods)
def apply(self, query: Query, value: Any) -> Query: def apply(self, query: Query, value: Any) -> Query:
if is_user_admin(): if security_manager.is_admin():
return query return query
current_user_id = g.user.id
filter_set_ids_by_dashboard_owners = ( # pylint: disable=C0103 filter_set_ids_by_dashboard_owners = ( # pylint: disable=C0103
query.from_self(FilterSet.id) query.from_self(FilterSet.id)
@ -42,7 +42,7 @@ class FilterSetFilter(BaseFilter): # pylint: disable=too-few-public-methods)
.filter( .filter(
and_( and_(
FilterSet.owner_type == DASHBOARD_OWNER_TYPE, FilterSet.owner_type == DASHBOARD_OWNER_TYPE,
dashboard_user.c.user_id == current_user_id, dashboard_user.c.user_id == get_user_id(),
) )
) )
) )
@ -51,7 +51,7 @@ class FilterSetFilter(BaseFilter): # pylint: disable=too-few-public-methods)
or_( or_(
and_( and_(
FilterSet.owner_type == USER_OWNER_TYPE, FilterSet.owner_type == USER_OWNER_TYPE,
FilterSet.owner_id == current_user_id, FilterSet.owner_id == get_user_id(),
), ),
FilterSet.id.in_(filter_set_ids_by_dashboard_owners), FilterSet.id.in_(filter_set_ids_by_dashboard_owners),
) )

View File

@ -20,17 +20,17 @@ from flask import session
from superset.dashboards.filter_state.commands.utils import check_access from superset.dashboards.filter_state.commands.utils import check_access
from superset.extensions import cache_manager from superset.extensions import cache_manager
from superset.key_value.utils import get_owner, random_key from superset.key_value.utils import random_key
from superset.temporary_cache.commands.create import CreateTemporaryCacheCommand from superset.temporary_cache.commands.create import CreateTemporaryCacheCommand
from superset.temporary_cache.commands.entry import Entry from superset.temporary_cache.commands.entry import Entry
from superset.temporary_cache.commands.parameters import CommandParameters from superset.temporary_cache.commands.parameters import CommandParameters
from superset.temporary_cache.utils import cache_key from superset.temporary_cache.utils import cache_key
from superset.utils.core import get_user_id
class CreateFilterStateCommand(CreateTemporaryCacheCommand): class CreateFilterStateCommand(CreateTemporaryCacheCommand):
def create(self, cmd_params: CommandParameters) -> str: def create(self, cmd_params: CommandParameters) -> str:
resource_id = cmd_params.resource_id resource_id = cmd_params.resource_id
actor = cmd_params.actor
tab_id = cmd_params.tab_id tab_id = cmd_params.tab_id
contextual_key = cache_key(session.get("_id"), tab_id, resource_id) contextual_key = cache_key(session.get("_id"), tab_id, resource_id)
key = cache_manager.filter_state_cache.get(contextual_key) key = cache_manager.filter_state_cache.get(contextual_key)
@ -38,7 +38,7 @@ class CreateFilterStateCommand(CreateTemporaryCacheCommand):
key = random_key() key = random_key()
value = cast(str, cmd_params.value) # schema ensures that value is not optional value = cast(str, cmd_params.value) # schema ensures that value is not optional
check_access(resource_id) check_access(resource_id)
entry: Entry = {"owner": get_owner(actor), "value": value} entry: Entry = {"owner": get_user_id(), "value": value}
cache_manager.filter_state_cache.set(cache_key(resource_id, key), entry) cache_manager.filter_state_cache.set(cache_key(resource_id, key), entry)
cache_manager.filter_state_cache.set(contextual_key, key) cache_manager.filter_state_cache.set(contextual_key, key)
return key return key

View File

@ -18,23 +18,22 @@ from flask import session
from superset.dashboards.filter_state.commands.utils import check_access from superset.dashboards.filter_state.commands.utils import check_access
from superset.extensions import cache_manager from superset.extensions import cache_manager
from superset.key_value.utils import get_owner
from superset.temporary_cache.commands.delete import DeleteTemporaryCacheCommand from superset.temporary_cache.commands.delete import DeleteTemporaryCacheCommand
from superset.temporary_cache.commands.entry import Entry from superset.temporary_cache.commands.entry import Entry
from superset.temporary_cache.commands.exceptions import TemporaryCacheAccessDeniedError from superset.temporary_cache.commands.exceptions import TemporaryCacheAccessDeniedError
from superset.temporary_cache.commands.parameters import CommandParameters from superset.temporary_cache.commands.parameters import CommandParameters
from superset.temporary_cache.utils import cache_key from superset.temporary_cache.utils import cache_key
from superset.utils.core import get_user_id
class DeleteFilterStateCommand(DeleteTemporaryCacheCommand): class DeleteFilterStateCommand(DeleteTemporaryCacheCommand):
def delete(self, cmd_params: CommandParameters) -> bool: def delete(self, cmd_params: CommandParameters) -> bool:
resource_id = cmd_params.resource_id resource_id = cmd_params.resource_id
actor = cmd_params.actor
key = cache_key(resource_id, cmd_params.key) key = cache_key(resource_id, cmd_params.key)
check_access(resource_id) check_access(resource_id)
entry: Entry = cache_manager.filter_state_cache.get(key) entry: Entry = cache_manager.filter_state_cache.get(key)
if entry: if entry:
if entry["owner"] != get_owner(actor): if entry["owner"] != get_user_id():
raise TemporaryCacheAccessDeniedError() raise TemporaryCacheAccessDeniedError()
tab_id = cmd_params.tab_id tab_id = cmd_params.tab_id
contextual_key = cache_key(session.get("_id"), tab_id, resource_id) contextual_key = cache_key(session.get("_id"), tab_id, resource_id)

View File

@ -20,23 +20,23 @@ from flask import session
from superset.dashboards.filter_state.commands.utils import check_access from superset.dashboards.filter_state.commands.utils import check_access
from superset.extensions import cache_manager from superset.extensions import cache_manager
from superset.key_value.utils import get_owner, random_key from superset.key_value.utils import random_key
from superset.temporary_cache.commands.entry import Entry from superset.temporary_cache.commands.entry import Entry
from superset.temporary_cache.commands.exceptions import TemporaryCacheAccessDeniedError from superset.temporary_cache.commands.exceptions import TemporaryCacheAccessDeniedError
from superset.temporary_cache.commands.parameters import CommandParameters from superset.temporary_cache.commands.parameters import CommandParameters
from superset.temporary_cache.commands.update import UpdateTemporaryCacheCommand from superset.temporary_cache.commands.update import UpdateTemporaryCacheCommand
from superset.temporary_cache.utils import cache_key from superset.temporary_cache.utils import cache_key
from superset.utils.core import get_user_id
class UpdateFilterStateCommand(UpdateTemporaryCacheCommand): class UpdateFilterStateCommand(UpdateTemporaryCacheCommand):
def update(self, cmd_params: CommandParameters) -> Optional[str]: def update(self, cmd_params: CommandParameters) -> Optional[str]:
resource_id = cmd_params.resource_id resource_id = cmd_params.resource_id
actor = cmd_params.actor
key = cmd_params.key key = cmd_params.key
value = cast(str, cmd_params.value) # schema ensures that value is not optional value = cast(str, cmd_params.value) # schema ensures that value is not optional
check_access(resource_id) check_access(resource_id)
entry: Entry = cache_manager.filter_state_cache.get(cache_key(resource_id, key)) entry: Entry = cache_manager.filter_state_cache.get(cache_key(resource_id, key))
owner = get_owner(actor) owner = get_user_id()
if entry: if entry:
if entry["owner"] != owner: if entry["owner"] != owner:
raise TemporaryCacheAccessDeniedError() raise TemporaryCacheAccessDeniedError()

View File

@ -30,7 +30,7 @@ from superset.models.embedded_dashboard import EmbeddedDashboard
from superset.models.slice import Slice from superset.models.slice import Slice
from superset.security.guest_token import GuestTokenResourceType, GuestUser from superset.security.guest_token import GuestTokenResourceType, GuestUser
from superset.utils.core import get_user_id from superset.utils.core import get_user_id
from superset.views.base import BaseFilter, is_user_admin from superset.views.base import BaseFilter
from superset.views.base_api import BaseFavoriteFilter from superset.views.base_api import BaseFavoriteFilter
@ -98,7 +98,7 @@ class DashboardAccessFilter(BaseFilter): # pylint: disable=too-few-public-metho
""" """
def apply(self, query: Query, value: Any) -> Query: def apply(self, query: Query, value: Any) -> Query:
if is_user_admin(): if security_manager.is_admin():
return query return query
datasource_perms = security_manager.user_view_menu_names("datasource_access") datasource_perms = security_manager.user_view_menu_names("datasource_access")

View File

@ -16,7 +16,7 @@
# under the License. # under the License.
import logging import logging
from flask import g, request, Response from flask import request, Response
from flask_appbuilder.api import BaseApi, expose, protect, safe from flask_appbuilder.api import BaseApi, expose, protect, safe
from marshmallow import ValidationError from marshmallow import ValidationError
@ -104,7 +104,6 @@ class DashboardPermalinkRestApi(BaseApi):
try: try:
state = self.add_model_schema.load(request.json) state = self.add_model_schema.load(request.json)
key = CreateDashboardPermalinkCommand( key = CreateDashboardPermalinkCommand(
actor=g.user,
dashboard_id=pk, dashboard_id=pk,
state=state, state=state,
).run() ).run()
@ -162,7 +161,7 @@ class DashboardPermalinkRestApi(BaseApi):
$ref: '#/components/responses/500' $ref: '#/components/responses/500'
""" """
try: try:
value = GetDashboardPermalinkCommand(actor=g.user, key=key).run() value = GetDashboardPermalinkCommand(key=key).run()
if not value: if not value:
return self.response_404() return self.response_404()
return self.response(200, **value) return self.response(200, **value)

View File

@ -16,7 +16,6 @@
# under the License. # under the License.
import logging import logging
from flask_appbuilder.security.sqla.models import User
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from superset.dashboards.dao import DashboardDAO from superset.dashboards.dao import DashboardDAO
@ -32,11 +31,9 @@ logger = logging.getLogger(__name__)
class CreateDashboardPermalinkCommand(BaseDashboardPermalinkCommand): class CreateDashboardPermalinkCommand(BaseDashboardPermalinkCommand):
def __init__( def __init__(
self, self,
actor: User,
dashboard_id: str, dashboard_id: str,
state: DashboardPermalinkState, state: DashboardPermalinkState,
): ):
self.actor = actor
self.dashboard_id = dashboard_id self.dashboard_id = dashboard_id
self.state = state self.state = state
@ -49,7 +46,6 @@ class CreateDashboardPermalinkCommand(BaseDashboardPermalinkCommand):
"state": self.state, "state": self.state,
} }
key = CreateKeyValueCommand( key = CreateKeyValueCommand(
actor=self.actor,
resource=self.resource, resource=self.resource,
value=value, value=value,
).run() ).run()

View File

@ -17,7 +17,6 @@
import logging import logging
from typing import Optional from typing import Optional
from flask_appbuilder.security.sqla.models import User
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from superset.dashboards.commands.exceptions import DashboardNotFoundError from superset.dashboards.commands.exceptions import DashboardNotFoundError
@ -33,8 +32,7 @@ logger = logging.getLogger(__name__)
class GetDashboardPermalinkCommand(BaseDashboardPermalinkCommand): class GetDashboardPermalinkCommand(BaseDashboardPermalinkCommand):
def __init__(self, actor: User, key: str): def __init__(self, key: str):
self.actor = actor
self.key = key self.key = key
def run(self) -> Optional[DashboardPermalinkValue]: def run(self) -> Optional[DashboardPermalinkValue]:

View File

@ -22,7 +22,7 @@ from io import BytesIO
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from zipfile import ZipFile from zipfile import ZipFile
from flask import g, request, Response, send_file from flask import request, Response, send_file
from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.api import expose, protect, rison, safe
from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_appbuilder.models.sqla.interface import SQLAInterface
from marshmallow import ValidationError from marshmallow import ValidationError
@ -261,7 +261,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
except ValidationError as error: except ValidationError as error:
return self.response_400(message=error.messages) return self.response_400(message=error.messages)
try: try:
new_model = CreateDatabaseCommand(g.user, item).run() new_model = CreateDatabaseCommand(item).run()
# Return censored version for sqlalchemy URI # Return censored version for sqlalchemy URI
item["sqlalchemy_uri"] = new_model.sqlalchemy_uri item["sqlalchemy_uri"] = new_model.sqlalchemy_uri
item["expose_in_sqllab"] = new_model.expose_in_sqllab item["expose_in_sqllab"] = new_model.expose_in_sqllab
@ -342,7 +342,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
except ValidationError as error: except ValidationError as error:
return self.response_400(message=error.messages) return self.response_400(message=error.messages)
try: try:
changed_model = UpdateDatabaseCommand(g.user, pk, item).run() changed_model = UpdateDatabaseCommand(pk, item).run()
# Return censored version for sqlalchemy URI # Return censored version for sqlalchemy URI
item["sqlalchemy_uri"] = changed_model.sqlalchemy_uri item["sqlalchemy_uri"] = changed_model.sqlalchemy_uri
if changed_model.parameters: if changed_model.parameters:
@ -404,7 +404,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
$ref: '#/components/responses/500' $ref: '#/components/responses/500'
""" """
try: try:
DeleteDatabaseCommand(g.user, pk).run() DeleteDatabaseCommand(pk).run()
return self.response(200, message="OK") return self.response(200, message="OK")
except DatabaseNotFoundError: except DatabaseNotFoundError:
return self.response_404() return self.response_404()
@ -706,7 +706,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
# This validates custom Schema with custom validations # This validates custom Schema with custom validations
except ValidationError as error: except ValidationError as error:
return self.response_400(message=error.messages) return self.response_400(message=error.messages)
TestConnectionDatabaseCommand(g.user, item).run() TestConnectionDatabaseCommand(item).run()
return self.response(200, message="OK") return self.response(200, message="OK")
@expose("/<int:pk>/related_objects/", methods=["GET"]) @expose("/<int:pk>/related_objects/", methods=["GET"])
@ -1174,6 +1174,6 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
] ]
raise InvalidParametersError(errors) from ex raise InvalidParametersError(errors) from ex
command = ValidateDatabaseParametersCommand(g.user, payload) command = ValidateDatabaseParametersCommand(payload)
command.run() command.run()
return self.response(200, message="OK") return self.response(200, message="OK")

View File

@ -18,7 +18,6 @@ import logging
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
from marshmallow import ValidationError from marshmallow import ValidationError
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
@ -38,8 +37,7 @@ logger = logging.getLogger(__name__)
class CreateDatabaseCommand(BaseCommand): class CreateDatabaseCommand(BaseCommand):
def __init__(self, user: User, data: Dict[str, Any]): def __init__(self, data: Dict[str, Any]):
self._actor = user
self._properties = data.copy() self._properties = data.copy()
def run(self) -> Model: def run(self) -> Model:
@ -47,7 +45,7 @@ class CreateDatabaseCommand(BaseCommand):
try: try:
# Test connection before starting create transaction # Test connection before starting create transaction
TestConnectionDatabaseCommand(self._actor, self._properties).run() TestConnectionDatabaseCommand(self._properties).run()
except Exception as ex: except Exception as ex:
event_logger.log_with_context( event_logger.log_with_context(
action=f"db_creation_failed.{ex.__class__.__name__}", action=f"db_creation_failed.{ex.__class__.__name__}",

View File

@ -18,7 +18,6 @@ import logging
from typing import Optional from typing import Optional
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
from flask_babel import lazy_gettext as _ from flask_babel import lazy_gettext as _
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
@ -37,8 +36,7 @@ logger = logging.getLogger(__name__)
class DeleteDatabaseCommand(BaseCommand): class DeleteDatabaseCommand(BaseCommand):
def __init__(self, user: User, model_id: int): def __init__(self, model_id: int):
self._actor = user
self._model_id = model_id self._model_id = model_id
self._model: Optional[Database] = None self._model: Optional[Database] = None

View File

@ -20,7 +20,6 @@ from contextlib import closing
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from flask import current_app as app from flask import current_app as app
from flask_appbuilder.security.sqla.models import User
from flask_babel import gettext as _ from flask_babel import gettext as _
from func_timeout import func_timeout, FunctionTimedOut from func_timeout import func_timeout, FunctionTimedOut
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
@ -39,14 +38,12 @@ from superset.errors import ErrorLevel, SupersetErrorType
from superset.exceptions import SupersetSecurityException, SupersetTimeoutException from superset.exceptions import SupersetSecurityException, SupersetTimeoutException
from superset.extensions import event_logger from superset.extensions import event_logger
from superset.models.core import Database from superset.models.core import Database
from superset.utils.core import override_user
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TestConnectionDatabaseCommand(BaseCommand): class TestConnectionDatabaseCommand(BaseCommand):
def __init__(self, user: User, data: Dict[str, Any]): def __init__(self, data: Dict[str, Any]):
self._actor = user
self._properties = data.copy() self._properties = data.copy()
self._model: Optional[Database] = None self._model: Optional[Database] = None
@ -77,47 +74,41 @@ class TestConnectionDatabaseCommand(BaseCommand):
database.set_sqlalchemy_uri(uri) database.set_sqlalchemy_uri(uri)
database.db_engine_spec.mutate_db_for_connection_test(database) database.db_engine_spec.mutate_db_for_connection_test(database)
with override_user(self._actor): engine = database.get_sqla_engine()
engine = database.get_sqla_engine() event_logger.log_with_context(
event_logger.log_with_context( action="test_connection_attempt",
action="test_connection_attempt", engine=database.db_engine_spec.__name__,
engine=database.db_engine_spec.__name__, )
def ping(engine: Engine) -> bool:
with closing(engine.raw_connection()) as conn:
return engine.dialect.do_ping(conn)
try:
alive = func_timeout(
int(app.config["TEST_DATABASE_CONNECTION_TIMEOUT"].total_seconds()),
ping,
args=(engine,),
) )
except (sqlite3.ProgrammingError, RuntimeError):
def ping(engine: Engine) -> bool: # SQLite can't run on a separate thread, so ``func_timeout`` fails
with closing(engine.raw_connection()) as conn: # RuntimeError catches the equivalent error from duckdb.
return engine.dialect.do_ping(conn) alive = engine.dialect.do_ping(engine)
except FunctionTimedOut as ex:
try: raise SupersetTimeoutException(
alive = func_timeout( error_type=SupersetErrorType.CONNECTION_DATABASE_TIMEOUT,
int( message=(
app.config[ "Please check your connection details and database settings, "
"TEST_DATABASE_CONNECTION_TIMEOUT" "and ensure that your database is accepting connections, "
].total_seconds() "then try connecting again."
), ),
ping, level=ErrorLevel.ERROR,
args=(engine,), extra={"sqlalchemy_uri": database.sqlalchemy_uri},
) ) from ex
except Exception: # pylint: disable=broad-except
except (sqlite3.ProgrammingError, RuntimeError): alive = False
# SQLite can't run on a separate thread, so ``func_timeout`` fails if not alive:
# RuntimeError catches the equivalent error from duckdb. raise DBAPIError(None, None, None)
alive = engine.dialect.do_ping(engine)
except FunctionTimedOut as ex:
raise SupersetTimeoutException(
error_type=SupersetErrorType.CONNECTION_DATABASE_TIMEOUT,
message=(
"Please check your connection details and database settings, "
"and ensure that your database is accepting connections, "
"then try connecting again."
),
level=ErrorLevel.ERROR,
extra={"sqlalchemy_uri": database.sqlalchemy_uri},
) from ex
except Exception: # pylint: disable=broad-except
alive = False
if not alive:
raise DBAPIError(None, None, None)
# Log succesful connection test with engine # Log succesful connection test with engine
event_logger.log_with_context( event_logger.log_with_context(

View File

@ -18,7 +18,6 @@ import logging
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
from marshmallow import ValidationError from marshmallow import ValidationError
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
@ -38,8 +37,7 @@ logger = logging.getLogger(__name__)
class UpdateDatabaseCommand(BaseCommand): class UpdateDatabaseCommand(BaseCommand):
def __init__(self, user: User, model_id: int, data: Dict[str, Any]): def __init__(self, model_id: int, data: Dict[str, Any]):
self._actor = user
self._properties = data.copy() self._properties = data.copy()
self._model_id = model_id self._model_id = model_id
self._model: Optional[Database] = None self._model: Optional[Database] = None

View File

@ -18,7 +18,6 @@ import json
from contextlib import closing from contextlib import closing
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from flask_appbuilder.security.sqla.models import User
from flask_babel import gettext as __ from flask_babel import gettext as __
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
@ -35,14 +34,12 @@ from superset.db_engine_specs.base import BasicParametersMixin
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.extensions import event_logger from superset.extensions import event_logger
from superset.models.core import Database from superset.models.core import Database
from superset.utils.core import override_user
BYPASS_VALIDATION_ENGINES = {"bigquery"} BYPASS_VALIDATION_ENGINES = {"bigquery"}
class ValidateDatabaseParametersCommand(BaseCommand): class ValidateDatabaseParametersCommand(BaseCommand):
def __init__(self, user: User, parameters: Dict[str, Any]): def __init__(self, parameters: Dict[str, Any]):
self._actor = user
self._properties = parameters.copy() self._properties = parameters.copy()
self._model: Optional[Database] = None self._model: Optional[Database] = None
@ -117,22 +114,21 @@ class ValidateDatabaseParametersCommand(BaseCommand):
database.set_sqlalchemy_uri(sqlalchemy_uri) database.set_sqlalchemy_uri(sqlalchemy_uri)
database.db_engine_spec.mutate_db_for_connection_test(database) database.db_engine_spec.mutate_db_for_connection_test(database)
with override_user(self._actor): engine = database.get_sqla_engine()
engine = database.get_sqla_engine() try:
try: with closing(engine.raw_connection()) as conn:
with closing(engine.raw_connection()) as conn: alive = engine.dialect.do_ping(conn)
alive = engine.dialect.do_ping(conn) except Exception as ex:
except Exception as ex: url = make_url_safe(sqlalchemy_uri)
url = make_url_safe(sqlalchemy_uri) context = {
context = { "hostname": url.host,
"hostname": url.host, "password": url.password,
"password": url.password, "port": url.port,
"port": url.port, "username": url.username,
"username": url.username, "database": url.database,
"database": url.database, }
} errors = database.db_engine_spec.extract_errors(ex, context)
errors = database.db_engine_spec.extract_errors(ex, context) raise DatabaseTestConnectionFailedError(errors) from ex
raise DatabaseTestConnectionFailedError(errors) from ex
if not alive: if not alive:
raise DatabaseOfflineError( raise DatabaseOfflineError(

View File

@ -23,7 +23,7 @@ from zipfile import is_zipfile, ZipFile
import simplejson import simplejson
import yaml import yaml
from flask import g, make_response, request, Response, send_file from flask import make_response, request, Response, send_file
from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.api import expose, protect, rison, safe
from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_babel import ngettext from flask_babel import ngettext
@ -264,7 +264,7 @@ class DatasetRestApi(BaseSupersetModelRestApi):
return self.response_400(message=error.messages) return self.response_400(message=error.messages)
try: try:
new_model = CreateDatasetCommand(g.user, item).run() new_model = CreateDatasetCommand(item).run()
return self.response(201, id=new_model.id, result=item) return self.response(201, id=new_model.id, result=item)
except DatasetInvalidError as ex: except DatasetInvalidError as ex:
return self.response_422(message=ex.normalized_messages()) return self.response_422(message=ex.normalized_messages())
@ -344,11 +344,9 @@ class DatasetRestApi(BaseSupersetModelRestApi):
except ValidationError as error: except ValidationError as error:
return self.response_400(message=error.messages) return self.response_400(message=error.messages)
try: try:
changed_model = UpdateDatasetCommand( changed_model = UpdateDatasetCommand(pk, item, override_columns).run()
g.user, pk, item, override_columns
).run()
if override_columns: if override_columns:
RefreshDatasetCommand(g.user, pk).run() RefreshDatasetCommand(pk).run()
response = self.response(200, id=changed_model.id, result=item) response = self.response(200, id=changed_model.id, result=item)
except DatasetNotFoundError: except DatasetNotFoundError:
response = self.response_404() response = self.response_404()
@ -407,7 +405,7 @@ class DatasetRestApi(BaseSupersetModelRestApi):
$ref: '#/components/responses/500' $ref: '#/components/responses/500'
""" """
try: try:
DeleteDatasetCommand(g.user, pk).run() DeleteDatasetCommand(pk).run()
return self.response(200, message="OK") return self.response(200, message="OK")
except DatasetNotFoundError: except DatasetNotFoundError:
return self.response_404() return self.response_404()
@ -547,7 +545,7 @@ class DatasetRestApi(BaseSupersetModelRestApi):
$ref: '#/components/responses/500' $ref: '#/components/responses/500'
""" """
try: try:
RefreshDatasetCommand(g.user, pk).run() RefreshDatasetCommand(pk).run()
return self.response(200, message="OK") return self.response(200, message="OK")
except DatasetNotFoundError: except DatasetNotFoundError:
return self.response_404() return self.response_404()
@ -671,7 +669,7 @@ class DatasetRestApi(BaseSupersetModelRestApi):
""" """
item_ids = kwargs["rison"] item_ids = kwargs["rison"]
try: try:
BulkDeleteDatasetCommand(g.user, item_ids).run() BulkDeleteDatasetCommand(item_ids).run()
return self.response( return self.response(
200, 200,
message=ngettext( message=ngettext(
@ -812,7 +810,7 @@ class DatasetRestApi(BaseSupersetModelRestApi):
""" """
try: try:
force = parse_boolean_string(request.args.get("force")) force = parse_boolean_string(request.args.get("force"))
rv = SamplesDatasetCommand(g.user, pk, force).run() rv = SamplesDatasetCommand(pk, force).run()
response_data = simplejson.dumps( response_data = simplejson.dumps(
{"result": rv}, {"result": rv},
default=json_int_dttm_ser, default=json_int_dttm_ser,

View File

@ -16,7 +16,7 @@
# under the License. # under the License.
import logging import logging
from flask import g, Response from flask import Response
from flask_appbuilder.api import expose, permission_name, protect, safe from flask_appbuilder.api import expose, permission_name, protect, safe
from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_appbuilder.models.sqla.interface import SQLAInterface
@ -91,7 +91,7 @@ class DatasetColumnsRestApi(BaseSupersetModelRestApi):
$ref: '#/components/responses/500' $ref: '#/components/responses/500'
""" """
try: try:
DeleteDatasetColumnCommand(g.user, pk, column_id).run() DeleteDatasetColumnCommand(pk, column_id).run()
return self.response(200, message="OK") return self.response(200, message="OK")
except DatasetColumnNotFoundError: except DatasetColumnNotFoundError:
return self.response_404() return self.response_404()

View File

@ -18,8 +18,8 @@ import logging
from typing import Optional from typing import Optional
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
from superset import security_manager
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
from superset.connectors.sqla.models import TableColumn from superset.connectors.sqla.models import TableColumn
from superset.dao.exceptions import DAODeleteFailedError from superset.dao.exceptions import DAODeleteFailedError
@ -30,14 +30,12 @@ from superset.datasets.columns.commands.exceptions import (
) )
from superset.datasets.dao import DatasetDAO from superset.datasets.dao import DatasetDAO
from superset.exceptions import SupersetSecurityException from superset.exceptions import SupersetSecurityException
from superset.views.base import check_ownership
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DeleteDatasetColumnCommand(BaseCommand): class DeleteDatasetColumnCommand(BaseCommand):
def __init__(self, user: User, dataset_id: int, model_id: int): def __init__(self, dataset_id: int, model_id: int):
self._actor = user
self._dataset_id = dataset_id self._dataset_id = dataset_id
self._model_id = model_id self._model_id = model_id
self._model: Optional[TableColumn] = None self._model: Optional[TableColumn] = None
@ -60,6 +58,6 @@ class DeleteDatasetColumnCommand(BaseCommand):
raise DatasetColumnNotFoundError() raise DatasetColumnNotFoundError()
# Check ownership # Check ownership
try: try:
check_ownership(self._model) security_manager.raise_for_ownership(self._model)
except SupersetSecurityException as ex: except SupersetSecurityException as ex:
raise DatasetColumnForbiddenError() from ex raise DatasetColumnForbiddenError() from ex

View File

@ -17,8 +17,7 @@
import logging import logging
from typing import List, Optional from typing import List, Optional
from flask_appbuilder.security.sqla.models import User from superset import security_manager
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
from superset.commands.exceptions import DeleteFailedError from superset.commands.exceptions import DeleteFailedError
from superset.connectors.sqla.models import SqlaTable from superset.connectors.sqla.models import SqlaTable
@ -29,15 +28,13 @@ from superset.datasets.commands.exceptions import (
) )
from superset.datasets.dao import DatasetDAO from superset.datasets.dao import DatasetDAO
from superset.exceptions import SupersetSecurityException from superset.exceptions import SupersetSecurityException
from superset.extensions import db, security_manager from superset.extensions import db
from superset.views.base import check_ownership
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BulkDeleteDatasetCommand(BaseCommand): class BulkDeleteDatasetCommand(BaseCommand):
def __init__(self, user: User, model_ids: List[int]): def __init__(self, model_ids: List[int]):
self._actor = user
self._model_ids = model_ids self._model_ids = model_ids
self._models: Optional[List[SqlaTable]] = None self._models: Optional[List[SqlaTable]] = None
@ -84,6 +81,6 @@ class BulkDeleteDatasetCommand(BaseCommand):
# Check ownership # Check ownership
for model in self._models: for model in self._models:
try: try:
check_ownership(model) security_manager.raise_for_ownership(model)
except SupersetSecurityException as ex: except SupersetSecurityException as ex:
raise DatasetForbiddenError() from ex raise DatasetForbiddenError() from ex

View File

@ -18,7 +18,6 @@ import logging
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
from marshmallow import ValidationError from marshmallow import ValidationError
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
@ -38,8 +37,7 @@ logger = logging.getLogger(__name__)
class CreateDatasetCommand(CreateMixin, BaseCommand): class CreateDatasetCommand(CreateMixin, BaseCommand):
def __init__(self, user: User, data: Dict[str, Any]): def __init__(self, data: Dict[str, Any]):
self._actor = user
self._properties = data.copy() self._properties = data.copy()
def run(self) -> Model: def run(self) -> Model:
@ -89,7 +87,7 @@ class CreateDatasetCommand(CreateMixin, BaseCommand):
exceptions.append(TableNotFoundValidationError(table_name)) exceptions.append(TableNotFoundValidationError(table_name))
try: try:
owners = self.populate_owners(self._actor, owner_ids) owners = self.populate_owners(owner_ids)
self._properties["owners"] = owners self._properties["owners"] = owners
except ValidationError as ex: except ValidationError as ex:
exceptions.append(ex) exceptions.append(ex)

View File

@ -18,9 +18,9 @@ import logging
from typing import Optional from typing import Optional
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from superset import security_manager
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
from superset.connectors.sqla.models import SqlaTable from superset.connectors.sqla.models import SqlaTable
from superset.dao.exceptions import DAODeleteFailedError from superset.dao.exceptions import DAODeleteFailedError
@ -31,15 +31,13 @@ from superset.datasets.commands.exceptions import (
) )
from superset.datasets.dao import DatasetDAO from superset.datasets.dao import DatasetDAO
from superset.exceptions import SupersetSecurityException from superset.exceptions import SupersetSecurityException
from superset.extensions import db, security_manager from superset.extensions import db
from superset.views.base import check_ownership
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DeleteDatasetCommand(BaseCommand): class DeleteDatasetCommand(BaseCommand):
def __init__(self, user: User, model_id: int): def __init__(self, model_id: int):
self._actor = user
self._model_id = model_id self._model_id = model_id
self._model: Optional[SqlaTable] = None self._model: Optional[SqlaTable] = None
@ -85,6 +83,6 @@ class DeleteDatasetCommand(BaseCommand):
raise DatasetNotFoundError() raise DatasetNotFoundError()
# Check ownership # Check ownership
try: try:
check_ownership(self._model) security_manager.raise_for_ownership(self._model)
except SupersetSecurityException as ex: except SupersetSecurityException as ex:
raise DatasetForbiddenError() from ex raise DatasetForbiddenError() from ex

View File

@ -18,8 +18,8 @@ import logging
from typing import Optional from typing import Optional
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
from superset import security_manager
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
from superset.connectors.sqla.models import SqlaTable from superset.connectors.sqla.models import SqlaTable
from superset.datasets.commands.exceptions import ( from superset.datasets.commands.exceptions import (
@ -29,14 +29,12 @@ from superset.datasets.commands.exceptions import (
) )
from superset.datasets.dao import DatasetDAO from superset.datasets.dao import DatasetDAO
from superset.exceptions import SupersetSecurityException from superset.exceptions import SupersetSecurityException
from superset.views.base import check_ownership
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class RefreshDatasetCommand(BaseCommand): class RefreshDatasetCommand(BaseCommand):
def __init__(self, user: User, model_id: int): def __init__(self, model_id: int):
self._actor = user
self._model_id = model_id self._model_id = model_id
self._model: Optional[SqlaTable] = None self._model: Optional[SqlaTable] = None
@ -58,6 +56,6 @@ class RefreshDatasetCommand(BaseCommand):
raise DatasetNotFoundError() raise DatasetNotFoundError()
# Check ownership # Check ownership
try: try:
check_ownership(self._model) security_manager.raise_for_ownership(self._model)
except SupersetSecurityException as ex: except SupersetSecurityException as ex:
raise DatasetForbiddenError() from ex raise DatasetForbiddenError() from ex

View File

@ -17,8 +17,7 @@
import logging import logging
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from flask_appbuilder.security.sqla.models import User from superset import security_manager
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
from superset.common.chart_data import ChartDataResultType from superset.common.chart_data import ChartDataResultType
from superset.common.query_context_factory import QueryContextFactory from superset.common.query_context_factory import QueryContextFactory
@ -33,14 +32,12 @@ from superset.datasets.commands.exceptions import (
from superset.datasets.dao import DatasetDAO from superset.datasets.dao import DatasetDAO
from superset.exceptions import SupersetSecurityException from superset.exceptions import SupersetSecurityException
from superset.utils.core import QueryStatus from superset.utils.core import QueryStatus
from superset.views.base import check_ownership
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class SamplesDatasetCommand(BaseCommand): class SamplesDatasetCommand(BaseCommand):
def __init__(self, user: User, model_id: int, force: bool): def __init__(self, model_id: int, force: bool):
self._actor = user
self._model_id = model_id self._model_id = model_id
self._force = force self._force = force
self._model: Optional[SqlaTable] = None self._model: Optional[SqlaTable] = None
@ -78,6 +75,6 @@ class SamplesDatasetCommand(BaseCommand):
raise DatasetNotFoundError() raise DatasetNotFoundError()
# Check ownership # Check ownership
try: try:
check_ownership(self._model) security_manager.raise_for_ownership(self._model)
except SupersetSecurityException as ex: except SupersetSecurityException as ex:
raise DatasetForbiddenError() from ex raise DatasetForbiddenError() from ex

View File

@ -19,9 +19,9 @@ from collections import Counter
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
from marshmallow import ValidationError from marshmallow import ValidationError
from superset import security_manager
from superset.commands.base import BaseCommand, UpdateMixin from superset.commands.base import BaseCommand, UpdateMixin
from superset.connectors.sqla.models import SqlaTable from superset.connectors.sqla.models import SqlaTable
from superset.dao.exceptions import DAOUpdateFailedError from superset.dao.exceptions import DAOUpdateFailedError
@ -41,7 +41,6 @@ from superset.datasets.commands.exceptions import (
) )
from superset.datasets.dao import DatasetDAO from superset.datasets.dao import DatasetDAO
from superset.exceptions import SupersetSecurityException from superset.exceptions import SupersetSecurityException
from superset.views.base import check_ownership
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -49,12 +48,10 @@ logger = logging.getLogger(__name__)
class UpdateDatasetCommand(UpdateMixin, BaseCommand): class UpdateDatasetCommand(UpdateMixin, BaseCommand):
def __init__( def __init__(
self, self,
user: User,
model_id: int, model_id: int,
data: Dict[str, Any], data: Dict[str, Any],
override_columns: bool = False, override_columns: bool = False,
): ):
self._actor = user
self._model_id = model_id self._model_id = model_id
self._properties = data.copy() self._properties = data.copy()
self._model: Optional[SqlaTable] = None self._model: Optional[SqlaTable] = None
@ -83,7 +80,7 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand):
raise DatasetNotFoundError() raise DatasetNotFoundError()
# Check ownership # Check ownership
try: try:
check_ownership(self._model) security_manager.raise_for_ownership(self._model)
except SupersetSecurityException as ex: except SupersetSecurityException as ex:
raise DatasetForbiddenError() from ex raise DatasetForbiddenError() from ex
@ -99,7 +96,7 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand):
exceptions.append(DatabaseChangeValidationError()) exceptions.append(DatabaseChangeValidationError())
# Validate/Populate owner # Validate/Populate owner
try: try:
owners = self.populate_owners(self._actor, owner_ids) owners = self.populate_owners(owner_ids)
self._properties["owners"] = owners self._properties["owners"] = owners
except ValidationError as ex: except ValidationError as ex:
exceptions.append(ex) exceptions.append(ex)

View File

@ -17,7 +17,6 @@
import logging import logging
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from flask import current_app
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
@ -36,14 +35,6 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
model_cls = SqlaTable model_cls = SqlaTable
base_filter = DatasourceFilter base_filter = DatasourceFilter
@staticmethod
def get_owner_by_id(owner_id: int) -> Optional[object]:
return (
db.session.query(current_app.appbuilder.sm.user_model)
.filter_by(id=owner_id)
.one_or_none()
)
@staticmethod @staticmethod
def get_database_by_id(database_id: int) -> Optional[Database]: def get_database_by_id(database_id: int) -> Optional[Database]:
try: try:

View File

@ -16,7 +16,7 @@
# under the License. # under the License.
import logging import logging
from flask import g, Response from flask import Response
from flask_appbuilder.api import expose, permission_name, protect, safe from flask_appbuilder.api import expose, permission_name, protect, safe
from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_appbuilder.models.sqla.interface import SQLAInterface
@ -91,7 +91,7 @@ class DatasetMetricRestApi(BaseSupersetModelRestApi):
$ref: '#/components/responses/500' $ref: '#/components/responses/500'
""" """
try: try:
DeleteDatasetMetricCommand(g.user, pk, metric_id).run() DeleteDatasetMetricCommand(pk, metric_id).run()
return self.response(200, message="OK") return self.response(200, message="OK")
except DatasetMetricNotFoundError: except DatasetMetricNotFoundError:
return self.response_404() return self.response_404()

View File

@ -18,8 +18,8 @@ import logging
from typing import Optional from typing import Optional
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
from superset import security_manager
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
from superset.connectors.sqla.models import SqlMetric from superset.connectors.sqla.models import SqlMetric
from superset.dao.exceptions import DAODeleteFailedError from superset.dao.exceptions import DAODeleteFailedError
@ -30,14 +30,12 @@ from superset.datasets.metrics.commands.exceptions import (
DatasetMetricNotFoundError, DatasetMetricNotFoundError,
) )
from superset.exceptions import SupersetSecurityException from superset.exceptions import SupersetSecurityException
from superset.views.base import check_ownership
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DeleteDatasetMetricCommand(BaseCommand): class DeleteDatasetMetricCommand(BaseCommand):
def __init__(self, user: User, dataset_id: int, model_id: int): def __init__(self, dataset_id: int, model_id: int):
self._actor = user
self._dataset_id = dataset_id self._dataset_id = dataset_id
self._model_id = model_id self._model_id = model_id
self._model: Optional[SqlMetric] = None self._model: Optional[SqlMetric] = None
@ -60,6 +58,6 @@ class DeleteDatasetMetricCommand(BaseCommand):
raise DatasetMetricNotFoundError() raise DatasetMetricNotFoundError()
# Check ownership # Check ownership
try: try:
check_ownership(self._model) security_manager.raise_for_ownership(self._model)
except SupersetSecurityException as ex: except SupersetSecurityException as ex:
raise DatasetMetricForbiddenError() from ex raise DatasetMetricForbiddenError() from ex

View File

@ -54,7 +54,6 @@ class GetExploreCommand(BaseCommand, ABC):
self, self,
params: CommandParameters, params: CommandParameters,
) -> None: ) -> None:
self._actor = params.actor
self._permalink_key = params.permalink_key self._permalink_key = params.permalink_key
self._form_data_key = params.form_data_key self._form_data_key = params.form_data_key
self._dataset_id = params.dataset_id self._dataset_id = params.dataset_id
@ -66,7 +65,7 @@ class GetExploreCommand(BaseCommand, ABC):
initial_form_data = {} initial_form_data = {}
if self._permalink_key is not None: if self._permalink_key is not None:
command = GetExplorePermalinkCommand(self._actor, self._permalink_key) command = GetExplorePermalinkCommand(self._permalink_key)
permalink_value = command.run() permalink_value = command.run()
if not permalink_value: if not permalink_value:
raise ExplorePermalinkGetFailedError() raise ExplorePermalinkGetFailedError()
@ -76,9 +75,7 @@ class GetExploreCommand(BaseCommand, ABC):
if url_params: if url_params:
initial_form_data["url_params"] = dict(url_params) initial_form_data["url_params"] = dict(url_params)
elif self._form_data_key: elif self._form_data_key:
parameters = FormDataCommandParameters( parameters = FormDataCommandParameters(key=self._form_data_key)
actor=self._actor, key=self._form_data_key
)
value = GetFormDataCommand(parameters).run() value = GetFormDataCommand(parameters).run()
initial_form_data = json.loads(value) if value else {} initial_form_data = json.loads(value) if value else {}

View File

@ -16,7 +16,7 @@
# under the License. # under the License.
import logging import logging
from flask import g, request, Response from flask import request, Response
from flask_appbuilder.api import BaseApi, expose, protect, safe from flask_appbuilder.api import BaseApi, expose, protect, safe
from marshmallow import ValidationError from marshmallow import ValidationError
@ -102,7 +102,6 @@ class ExploreFormDataRestApi(BaseApi):
item = self.add_model_schema.load(request.json) item = self.add_model_schema.load(request.json)
tab_id = request.args.get("tab_id") tab_id = request.args.get("tab_id")
args = CommandParameters( args = CommandParameters(
actor=g.user,
datasource_id=item["datasource_id"], datasource_id=item["datasource_id"],
datasource_type=item["datasource_type"], datasource_type=item["datasource_type"],
chart_id=item.get("chart_id"), chart_id=item.get("chart_id"),
@ -173,7 +172,6 @@ class ExploreFormDataRestApi(BaseApi):
item = self.edit_model_schema.load(request.json) item = self.edit_model_schema.load(request.json)
tab_id = request.args.get("tab_id") tab_id = request.args.get("tab_id")
args = CommandParameters( args = CommandParameters(
actor=g.user,
datasource_id=item["datasource_id"], datasource_id=item["datasource_id"],
datasource_type=item["datasource_type"], datasource_type=item["datasource_type"],
chart_id=item.get("chart_id"), chart_id=item.get("chart_id"),
@ -233,7 +231,7 @@ class ExploreFormDataRestApi(BaseApi):
$ref: '#/components/responses/500' $ref: '#/components/responses/500'
""" """
try: try:
args = CommandParameters(actor=g.user, key=key) args = CommandParameters(key=key)
form_data = GetFormDataCommand(args).run() form_data = GetFormDataCommand(args).run()
if not form_data: if not form_data:
return self.response_404() return self.response_404()
@ -285,7 +283,7 @@ class ExploreFormDataRestApi(BaseApi):
$ref: '#/components/responses/500' $ref: '#/components/responses/500'
""" """
try: try:
args = CommandParameters(actor=g.user, key=key) args = CommandParameters(key=key)
result = DeleteFormDataCommand(args).run() result = DeleteFormDataCommand(args).run()
if not result: if not result:
return self.response_404() return self.response_404()

View File

@ -24,10 +24,10 @@ from superset.explore.form_data.commands.parameters import CommandParameters
from superset.explore.form_data.commands.state import TemporaryExploreState from superset.explore.form_data.commands.state import TemporaryExploreState
from superset.explore.form_data.commands.utils import check_access from superset.explore.form_data.commands.utils import check_access
from superset.extensions import cache_manager from superset.extensions import cache_manager
from superset.key_value.utils import get_owner, random_key from superset.key_value.utils import random_key
from superset.temporary_cache.commands.exceptions import TemporaryCacheCreateFailedError from superset.temporary_cache.commands.exceptions import TemporaryCacheCreateFailedError
from superset.temporary_cache.utils import cache_key from superset.temporary_cache.utils import cache_key
from superset.utils.core import DatasourceType from superset.utils.core import DatasourceType, get_user_id
from superset.utils.schema import validate_json from superset.utils.schema import validate_json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -44,9 +44,8 @@ class CreateFormDataCommand(BaseCommand):
datasource_type = self._cmd_params.datasource_type datasource_type = self._cmd_params.datasource_type
chart_id = self._cmd_params.chart_id chart_id = self._cmd_params.chart_id
tab_id = self._cmd_params.tab_id tab_id = self._cmd_params.tab_id
actor = self._cmd_params.actor
form_data = self._cmd_params.form_data form_data = self._cmd_params.form_data
check_access(datasource_id, chart_id, actor, datasource_type) check_access(datasource_id, chart_id, datasource_type)
contextual_key = cache_key( contextual_key = cache_key(
session.get("_id"), tab_id, datasource_id, chart_id, datasource_type session.get("_id"), tab_id, datasource_id, chart_id, datasource_type
) )
@ -55,7 +54,7 @@ class CreateFormDataCommand(BaseCommand):
key = random_key() key = random_key()
if form_data: if form_data:
state: TemporaryExploreState = { state: TemporaryExploreState = {
"owner": get_owner(actor), "owner": get_user_id(),
"datasource_id": datasource_id, "datasource_id": datasource_id,
"datasource_type": DatasourceType(datasource_type), "datasource_type": DatasourceType(datasource_type),
"chart_id": chart_id, "chart_id": chart_id,

View File

@ -26,13 +26,12 @@ from superset.explore.form_data.commands.parameters import CommandParameters
from superset.explore.form_data.commands.state import TemporaryExploreState from superset.explore.form_data.commands.state import TemporaryExploreState
from superset.explore.form_data.commands.utils import check_access from superset.explore.form_data.commands.utils import check_access
from superset.extensions import cache_manager from superset.extensions import cache_manager
from superset.key_value.utils import get_owner
from superset.temporary_cache.commands.exceptions import ( from superset.temporary_cache.commands.exceptions import (
TemporaryCacheAccessDeniedError, TemporaryCacheAccessDeniedError,
TemporaryCacheDeleteFailedError, TemporaryCacheDeleteFailedError,
) )
from superset.temporary_cache.utils import cache_key from superset.temporary_cache.utils import cache_key
from superset.utils.core import DatasourceType from superset.utils.core import DatasourceType, get_user_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -43,7 +42,6 @@ class DeleteFormDataCommand(BaseCommand, ABC):
def run(self) -> bool: def run(self) -> bool:
try: try:
actor = self._cmd_params.actor
key = self._cmd_params.key key = self._cmd_params.key
state: TemporaryExploreState = cache_manager.explore_form_data_cache.get( state: TemporaryExploreState = cache_manager.explore_form_data_cache.get(
key key
@ -52,8 +50,8 @@ class DeleteFormDataCommand(BaseCommand, ABC):
datasource_id: int = state["datasource_id"] datasource_id: int = state["datasource_id"]
chart_id: Optional[int] = state["chart_id"] chart_id: Optional[int] = state["chart_id"]
datasource_type = DatasourceType(state["datasource_type"]) datasource_type = DatasourceType(state["datasource_type"])
check_access(datasource_id, chart_id, actor, datasource_type) check_access(datasource_id, chart_id, datasource_type)
if state["owner"] != get_owner(actor): if state["owner"] != get_user_id():
raise TemporaryCacheAccessDeniedError() raise TemporaryCacheAccessDeniedError()
tab_id = self._cmd_params.tab_id tab_id = self._cmd_params.tab_id
contextual_key = cache_key( contextual_key = cache_key(

View File

@ -40,7 +40,6 @@ class GetFormDataCommand(BaseCommand, ABC):
def run(self) -> Optional[str]: def run(self) -> Optional[str]:
try: try:
actor = self._cmd_params.actor
key = self._cmd_params.key key = self._cmd_params.key
state: TemporaryExploreState = cache_manager.explore_form_data_cache.get( state: TemporaryExploreState = cache_manager.explore_form_data_cache.get(
key key
@ -49,7 +48,6 @@ class GetFormDataCommand(BaseCommand, ABC):
check_access( check_access(
state["datasource_id"], state["datasource_id"],
state["chart_id"], state["chart_id"],
actor,
DatasourceType(state["datasource_type"]), DatasourceType(state["datasource_type"]),
) )
if self._refresh_timeout: if self._refresh_timeout:

View File

@ -17,14 +17,11 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional
from flask_appbuilder.security.sqla.models import User
from superset.utils.core import DatasourceType from superset.utils.core import DatasourceType
@dataclass @dataclass
class CommandParameters: class CommandParameters:
actor: User
datasource_type: DatasourceType = DatasourceType.TABLE datasource_type: DatasourceType = DatasourceType.TABLE
datasource_id: int = 0 datasource_id: int = 0
chart_id: int = 0 chart_id: int = 0

View File

@ -26,13 +26,13 @@ from superset.explore.form_data.commands.parameters import CommandParameters
from superset.explore.form_data.commands.state import TemporaryExploreState from superset.explore.form_data.commands.state import TemporaryExploreState
from superset.explore.form_data.commands.utils import check_access from superset.explore.form_data.commands.utils import check_access
from superset.extensions import cache_manager from superset.extensions import cache_manager
from superset.key_value.utils import get_owner, random_key from superset.key_value.utils import random_key
from superset.temporary_cache.commands.exceptions import ( from superset.temporary_cache.commands.exceptions import (
TemporaryCacheAccessDeniedError, TemporaryCacheAccessDeniedError,
TemporaryCacheUpdateFailedError, TemporaryCacheUpdateFailedError,
) )
from superset.temporary_cache.utils import cache_key from superset.temporary_cache.utils import cache_key
from superset.utils.core import DatasourceType from superset.utils.core import DatasourceType, get_user_id
from superset.utils.schema import validate_json from superset.utils.schema import validate_json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -51,14 +51,13 @@ class UpdateFormDataCommand(BaseCommand, ABC):
datasource_id = self._cmd_params.datasource_id datasource_id = self._cmd_params.datasource_id
chart_id = self._cmd_params.chart_id chart_id = self._cmd_params.chart_id
datasource_type = self._cmd_params.datasource_type datasource_type = self._cmd_params.datasource_type
actor = self._cmd_params.actor
key = self._cmd_params.key key = self._cmd_params.key
form_data = self._cmd_params.form_data form_data = self._cmd_params.form_data
check_access(datasource_id, chart_id, actor, datasource_type) check_access(datasource_id, chart_id, datasource_type)
state: TemporaryExploreState = cache_manager.explore_form_data_cache.get( state: TemporaryExploreState = cache_manager.explore_form_data_cache.get(
key key
) )
owner = get_owner(actor) owner = get_user_id()
if state and form_data: if state and form_data:
if state["owner"] != owner: if state["owner"] != owner:
raise TemporaryCacheAccessDeniedError() raise TemporaryCacheAccessDeniedError()

View File

@ -16,8 +16,6 @@
# under the License. # under the License.
from typing import Optional from typing import Optional
from flask_appbuilder.security.sqla.models import User
from superset.charts.commands.exceptions import ( from superset.charts.commands.exceptions import (
ChartAccessDeniedError, ChartAccessDeniedError,
ChartNotFoundError, ChartNotFoundError,
@ -37,11 +35,10 @@ from superset.utils.core import DatasourceType
def check_access( def check_access(
datasource_id: int, datasource_id: int,
chart_id: Optional[int], chart_id: Optional[int],
actor: User,
datasource_type: DatasourceType, datasource_type: DatasourceType,
) -> None: ) -> None:
try: try:
explore_check_access(datasource_id, chart_id, actor, datasource_type) explore_check_access(datasource_id, chart_id, datasource_type)
except (ChartNotFoundError, DatasetNotFoundError) as ex: except (ChartNotFoundError, DatasetNotFoundError) as ex:
raise TemporaryCacheResourceNotFoundError from ex raise TemporaryCacheResourceNotFoundError from ex
except (ChartAccessDeniedError, DatasetAccessDeniedError) as ex: except (ChartAccessDeniedError, DatasetAccessDeniedError) as ex:

View File

@ -16,7 +16,7 @@
# under the License. # under the License.
import logging import logging
from flask import g, request, Response from flask import request, Response
from flask_appbuilder.api import BaseApi, expose, protect, safe from flask_appbuilder.api import BaseApi, expose, protect, safe
from marshmallow import ValidationError from marshmallow import ValidationError
@ -100,7 +100,7 @@ class ExplorePermalinkRestApi(BaseApi):
""" """
try: try:
state = self.add_model_schema.load(request.json) state = self.add_model_schema.load(request.json)
key = CreateExplorePermalinkCommand(actor=g.user, state=state).run() key = CreateExplorePermalinkCommand(state=state).run()
http_origin = request.headers.environ.get("HTTP_ORIGIN") http_origin = request.headers.environ.get("HTTP_ORIGIN")
url = f"{http_origin}/superset/explore/p/{key}/" url = f"{http_origin}/superset/explore/p/{key}/"
return self.response(201, key=key, url=url) return self.response(201, key=key, url=url)
@ -156,7 +156,7 @@ class ExplorePermalinkRestApi(BaseApi):
$ref: '#/components/responses/500' $ref: '#/components/responses/500'
""" """
try: try:
value = GetExplorePermalinkCommand(actor=g.user, key=key).run() value = GetExplorePermalinkCommand(key=key).run()
if not value: if not value:
return self.response_404() return self.response_404()
return self.response(200, **value) return self.response(200, **value)

View File

@ -17,7 +17,6 @@
import logging import logging
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from flask_appbuilder.security.sqla.models import User
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from superset.explore.permalink.commands.base import BaseExplorePermalinkCommand from superset.explore.permalink.commands.base import BaseExplorePermalinkCommand
@ -31,8 +30,7 @@ logger = logging.getLogger(__name__)
class CreateExplorePermalinkCommand(BaseExplorePermalinkCommand): class CreateExplorePermalinkCommand(BaseExplorePermalinkCommand):
def __init__(self, actor: User, state: Dict[str, Any]): def __init__(self, state: Dict[str, Any]):
self.actor = actor
self.chart_id: Optional[int] = state["formData"].get("slice_id") self.chart_id: Optional[int] = state["formData"].get("slice_id")
self.datasource: str = state["formData"]["datasource"] self.datasource: str = state["formData"]["datasource"]
self.state = state self.state = state
@ -43,9 +41,7 @@ class CreateExplorePermalinkCommand(BaseExplorePermalinkCommand):
d_id, d_type = self.datasource.split("__") d_id, d_type = self.datasource.split("__")
datasource_id = int(d_id) datasource_id = int(d_id)
datasource_type = DatasourceType(d_type) datasource_type = DatasourceType(d_type)
check_chart_access( check_chart_access(datasource_id, self.chart_id, datasource_type)
datasource_id, self.chart_id, self.actor, datasource_type
)
value = { value = {
"chartId": self.chart_id, "chartId": self.chart_id,
"datasourceId": datasource_id, "datasourceId": datasource_id,
@ -54,7 +50,6 @@ class CreateExplorePermalinkCommand(BaseExplorePermalinkCommand):
"state": self.state, "state": self.state,
} }
command = CreateKeyValueCommand( command = CreateKeyValueCommand(
actor=self.actor,
resource=self.resource, resource=self.resource,
value=value, value=value,
) )

View File

@ -17,7 +17,6 @@
import logging import logging
from typing import Optional from typing import Optional
from flask_appbuilder.security.sqla.models import User
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from superset.datasets.commands.exceptions import DatasetNotFoundError from superset.datasets.commands.exceptions import DatasetNotFoundError
@ -34,8 +33,7 @@ logger = logging.getLogger(__name__)
class GetExplorePermalinkCommand(BaseExplorePermalinkCommand): class GetExplorePermalinkCommand(BaseExplorePermalinkCommand):
def __init__(self, actor: User, key: str): def __init__(self, key: str):
self.actor = actor
self.key = key self.key = key
def run(self) -> Optional[ExplorePermalinkValue]: def run(self) -> Optional[ExplorePermalinkValue]:
@ -55,7 +53,7 @@ class GetExplorePermalinkCommand(BaseExplorePermalinkCommand):
datasource_type = DatasourceType( datasource_type = DatasourceType(
value.get("datasourceType", DatasourceType.TABLE) value.get("datasourceType", DatasourceType.TABLE)
) )
check_chart_access(datasource_id, chart_id, self.actor, datasource_type) check_chart_access(datasource_id, chart_id, datasource_type)
return value return value
return None return None
except ( except (

View File

@ -16,8 +16,6 @@
# under the License. # under the License.
from typing import Optional from typing import Optional
from flask_appbuilder.security.sqla.models import User
from superset import security_manager from superset import security_manager
from superset.charts.commands.exceptions import ( from superset.charts.commands.exceptions import (
ChartAccessDeniedError, ChartAccessDeniedError,
@ -36,8 +34,6 @@ from superset.datasets.commands.exceptions import (
from superset.datasets.dao import DatasetDAO from superset.datasets.dao import DatasetDAO
from superset.queries.dao import QueryDAO from superset.queries.dao import QueryDAO
from superset.utils.core import DatasourceType from superset.utils.core import DatasourceType
from superset.views.base import is_user_admin
from superset.views.utils import is_owner
def check_dataset_access(dataset_id: int) -> Optional[bool]: def check_dataset_access(dataset_id: int) -> Optional[bool]:
@ -80,7 +76,6 @@ def check_datasource_access(
def check_access( def check_access(
datasource_id: int, datasource_id: int,
chart_id: Optional[int], chart_id: Optional[int],
actor: User,
datasource_type: DatasourceType, datasource_type: DatasourceType,
) -> Optional[bool]: ) -> Optional[bool]:
check_datasource_access(datasource_id, datasource_type) check_datasource_access(datasource_id, datasource_type)
@ -88,11 +83,9 @@ def check_access(
return True return True
chart = ChartDAO.find_by_id(chart_id) chart = ChartDAO.find_by_id(chart_id)
if chart: if chart:
can_access_chart = ( can_access_chart = security_manager.is_owner(
is_user_admin() chart
or is_owner(chart, actor) ) or security_manager.can_access("can_read", "Chart")
or security_manager.can_access("can_read", "Chart")
)
if can_access_chart: if can_access_chart:
return True return True
raise ChartAccessDeniedError() raise ChartAccessDeniedError()

View File

@ -20,7 +20,6 @@ from datetime import datetime
from typing import Any, Optional, Union from typing import Any, Optional, Union
from uuid import UUID from uuid import UUID
from flask_appbuilder.security.sqla.models import User
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from superset import db from superset import db
@ -28,23 +27,21 @@ from superset.commands.base import BaseCommand
from superset.key_value.exceptions import KeyValueCreateFailedError from superset.key_value.exceptions import KeyValueCreateFailedError
from superset.key_value.models import KeyValueEntry from superset.key_value.models import KeyValueEntry
from superset.key_value.types import Key, KeyValueResource from superset.key_value.types import Key, KeyValueResource
from superset.utils.core import get_user_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CreateKeyValueCommand(BaseCommand): class CreateKeyValueCommand(BaseCommand):
actor: Optional[User]
resource: KeyValueResource resource: KeyValueResource
value: Any value: Any
key: Optional[Union[int, UUID]] key: Optional[Union[int, UUID]]
expires_on: Optional[datetime] expires_on: Optional[datetime]
# pylint: disable=too-many-arguments
def __init__( def __init__(
self, self,
resource: KeyValueResource, resource: KeyValueResource,
value: Any, value: Any,
actor: Optional[User] = None,
key: Optional[Union[int, UUID]] = None, key: Optional[Union[int, UUID]] = None,
expires_on: Optional[datetime] = None, expires_on: Optional[datetime] = None,
): ):
@ -53,13 +50,11 @@ class CreateKeyValueCommand(BaseCommand):
:param resource: the resource (dashboard, chart etc) :param resource: the resource (dashboard, chart etc)
:param value: the value to persist in the key-value store :param value: the value to persist in the key-value store
:param actor: the user performing the command
:param key: id of entry (autogenerated if undefined) :param key: id of entry (autogenerated if undefined)
:param expires_on: entry expiration time :param expires_on: entry expiration time
:return: the key associated with the persisted value :return: the key associated with the persisted value
""" """
self.resource = resource self.resource = resource
self.actor = actor
self.value = value self.value = value
self.key = key self.key = key
self.expires_on = expires_on self.expires_on = expires_on
@ -80,9 +75,7 @@ class CreateKeyValueCommand(BaseCommand):
resource=self.resource.value, resource=self.resource.value,
value=pickle.dumps(self.value), value=pickle.dumps(self.value),
created_on=datetime.now(), created_on=datetime.now(),
created_by_fk=None created_by_fk=get_user_id(),
if self.actor is None or self.actor.is_anonymous
else self.actor.id,
expires_on=self.expires_on, expires_on=self.expires_on,
) )
if self.key is not None: if self.key is not None:

View File

@ -21,7 +21,6 @@ from datetime import datetime
from typing import Any, Optional, Union from typing import Any, Optional, Union
from uuid import UUID from uuid import UUID
from flask_appbuilder.security.sqla.models import User
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from superset import db from superset import db
@ -30,24 +29,22 @@ from superset.key_value.exceptions import KeyValueUpdateFailedError
from superset.key_value.models import KeyValueEntry from superset.key_value.models import KeyValueEntry
from superset.key_value.types import Key, KeyValueResource from superset.key_value.types import Key, KeyValueResource
from superset.key_value.utils import get_filter from superset.key_value.utils import get_filter
from superset.utils.core import get_user_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UpdateKeyValueCommand(BaseCommand): class UpdateKeyValueCommand(BaseCommand):
actor: Optional[User]
resource: KeyValueResource resource: KeyValueResource
value: Any value: Any
key: Union[int, UUID] key: Union[int, UUID]
expires_on: Optional[datetime] expires_on: Optional[datetime]
# pylint: disable=too-many-argumentsåå
def __init__( def __init__(
self, self,
resource: KeyValueResource, resource: KeyValueResource,
key: Union[int, UUID], key: Union[int, UUID],
value: Any, value: Any,
actor: Optional[User] = None,
expires_on: Optional[datetime] = None, expires_on: Optional[datetime] = None,
): ):
""" """
@ -56,11 +53,9 @@ class UpdateKeyValueCommand(BaseCommand):
:param resource: the resource (dashboard, chart etc) :param resource: the resource (dashboard, chart etc)
:param key: the key to update :param key: the key to update
:param value: the value to persist in the key-value store :param value: the value to persist in the key-value store
:param actor: the user performing the command
:param expires_on: entry expiration time :param expires_on: entry expiration time
:return: the key associated with the updated value :return: the key associated with the updated value
""" """
self.actor = actor
self.resource = resource self.resource = resource
self.key = key self.key = key
self.value = value self.value = value
@ -89,9 +84,7 @@ class UpdateKeyValueCommand(BaseCommand):
entry.value = pickle.dumps(self.value) entry.value = pickle.dumps(self.value)
entry.expires_on = self.expires_on entry.expires_on = self.expires_on
entry.changed_on = datetime.now() entry.changed_on = datetime.now()
entry.changed_by_fk = ( entry.changed_by_fk = get_user_id()
None if self.actor is None or self.actor.is_anonymous else self.actor.id
)
db.session.merge(entry) db.session.merge(entry)
db.session.commit() db.session.commit()
return Key(id=entry.id, uuid=entry.uuid) return Key(id=entry.id, uuid=entry.uuid)

View File

@ -21,7 +21,6 @@ from datetime import datetime
from typing import Any, Optional, Union from typing import Any, Optional, Union
from uuid import UUID from uuid import UUID
from flask_appbuilder.security.sqla.models import User
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from superset import db from superset import db
@ -31,24 +30,22 @@ from superset.key_value.exceptions import KeyValueUpdateFailedError
from superset.key_value.models import KeyValueEntry from superset.key_value.models import KeyValueEntry
from superset.key_value.types import Key, KeyValueResource from superset.key_value.types import Key, KeyValueResource
from superset.key_value.utils import get_filter from superset.key_value.utils import get_filter
from superset.utils.core import get_user_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UpsertKeyValueCommand(BaseCommand): class UpsertKeyValueCommand(BaseCommand):
actor: Optional[User]
resource: KeyValueResource resource: KeyValueResource
value: Any value: Any
key: Union[int, UUID] key: Union[int, UUID]
expires_on: Optional[datetime] expires_on: Optional[datetime]
# pylint: disable=too-many-arguments
def __init__( def __init__(
self, self,
resource: KeyValueResource, resource: KeyValueResource,
key: Union[int, UUID], key: Union[int, UUID],
value: Any, value: Any,
actor: Optional[User] = None,
expires_on: Optional[datetime] = None, expires_on: Optional[datetime] = None,
): ):
""" """
@ -58,11 +55,9 @@ class UpsertKeyValueCommand(BaseCommand):
:param key: the key to update :param key: the key to update
:param value: the value to persist in the key-value store :param value: the value to persist in the key-value store
:param key_type: the type of the key to update :param key_type: the type of the key to update
:param actor: the user performing the command
:param expires_on: entry expiration time :param expires_on: entry expiration time
:return: the key associated with the updated value :return: the key associated with the updated value
""" """
self.actor = actor
self.resource = resource self.resource = resource
self.key = key self.key = key
self.value = value self.value = value
@ -91,16 +86,13 @@ class UpsertKeyValueCommand(BaseCommand):
entry.value = pickle.dumps(self.value) entry.value = pickle.dumps(self.value)
entry.expires_on = self.expires_on entry.expires_on = self.expires_on
entry.changed_on = datetime.now() entry.changed_on = datetime.now()
entry.changed_by_fk = ( entry.changed_by_fk = get_user_id()
None if self.actor is None or self.actor.is_anonymous else self.actor.id
)
db.session.merge(entry) db.session.merge(entry)
db.session.commit() db.session.commit()
return Key(entry.id, entry.uuid) return Key(entry.id, entry.uuid)
return CreateKeyValueCommand( return CreateKeyValueCommand(
resource=self.resource, resource=self.resource,
value=self.value, value=self.value,
actor=self.actor,
key=self.key, key=self.key,
expires_on=self.expires_on, expires_on=self.expires_on,
).run() ).run()

View File

@ -18,11 +18,10 @@ from __future__ import annotations
from hashlib import md5 from hashlib import md5
from secrets import token_urlsafe from secrets import token_urlsafe
from typing import Optional, Union from typing import Union
from uuid import UUID from uuid import UUID
import hashids import hashids
from flask_appbuilder.security.sqla.models import User
from flask_babel import gettext as _ from flask_babel import gettext as _
from superset.key_value.exceptions import KeyValueParseKeyError from superset.key_value.exceptions import KeyValueParseKeyError
@ -64,7 +63,3 @@ def get_uuid_namespace(seed: str) -> UUID:
md5_obj = md5() md5_obj = md5()
md5_obj.update(seed.encode("utf-8")) md5_obj.update(seed.encode("utf-8"))
return UUID(md5_obj.hexdigest()) return UUID(md5_obj.hexdigest())
def get_owner(user: User) -> Optional[int]:
return user.id if not user.is_anonymous else None

View File

@ -23,7 +23,6 @@ from functools import partial
from typing import Any, Callable, Dict, List, Set, Tuple, Type, Union from typing import Any, Callable, Dict, List, Set, Tuple, Type, Union
import sqlalchemy as sqla import sqlalchemy as sqla
from flask import g
from flask_appbuilder import Model from flask_appbuilder import Model
from flask_appbuilder.models.decorators import renders from flask_appbuilder.models.decorators import renders
from flask_appbuilder.security.sqla.models import User from flask_appbuilder.security.sqla.models import User
@ -47,7 +46,6 @@ from sqlalchemy.sql import join, select
from sqlalchemy.sql.elements import BinaryExpression from sqlalchemy.sql.elements import BinaryExpression
from superset import app, db, is_feature_enabled, security_manager from superset import app, db, is_feature_enabled, security_manager
from superset.common.request_contexed_based import is_user_admin
from superset.connectors.base.models import BaseDatasource from superset.connectors.base.models import BaseDatasource
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.datasource.dao import DatasourceDAO from superset.datasource.dao import DatasourceDAO
@ -59,6 +57,7 @@ from superset.models.tags import DashboardUpdater
from superset.models.user_attributes import UserAttribute from superset.models.user_attributes import UserAttribute
from superset.tasks.thumbnails import cache_dashboard_thumbnail from superset.tasks.thumbnails import cache_dashboard_thumbnail
from superset.utils import core as utils from superset.utils import core as utils
from superset.utils.core import get_user_id
from superset.utils.decorators import debounce from superset.utils.decorators import debounce
from superset.utils.hashing import md5_sha_from_str from superset.utils.hashing import md5_sha_from_str
from superset.utils.urls import get_url_path from superset.utils.urls import get_url_path
@ -203,15 +202,14 @@ class Dashboard(Model, AuditMixinNullable, ImportExportMixin):
@property @property
def filter_sets_lst(self) -> Dict[int, FilterSet]: def filter_sets_lst(self) -> Dict[int, FilterSet]:
if is_user_admin(): if security_manager.is_admin():
return self._filter_sets return self._filter_sets
current_user = g.user.id
filter_sets_by_owner_type: Dict[str, List[Any]] = {"Dashboard": [], "User": []} filter_sets_by_owner_type: Dict[str, List[Any]] = {"Dashboard": [], "User": []}
for fs in self._filter_sets: for fs in self._filter_sets:
filter_sets_by_owner_type[fs.owner_type].append(fs) filter_sets_by_owner_type[fs.owner_type].append(fs)
user_filter_sets = list( user_filter_sets = list(
filter( filter(
lambda filter_set: filter_set.owner_id == current_user, lambda filter_set: filter_set.owner_id == get_user_id(),
filter_sets_by_owner_type["User"], filter_sets_by_owner_type["User"],
) )
) )
@ -445,11 +443,6 @@ class Dashboard(Model, AuditMixinNullable, ImportExportMixin):
qry = session.query(Dashboard).filter(id_or_slug_filter(id_or_slug)) qry = session.query(Dashboard).filter(id_or_slug_filter(id_or_slug))
return qry.one_or_none() return qry.one_or_none()
def is_actor_owner(self) -> bool:
if g.user is None or g.user.is_anonymous or not g.user.is_authenticated:
return False
return g.user.id in set(map(lambda user: user.id, self.owners))
def id_or_slug_filter(id_or_slug: str) -> BinaryExpression: def id_or_slug_filter(id_or_slug: str) -> BinaryExpression:
if id_or_slug.isdigit(): if id_or_slug.isdigit():

View File

@ -192,7 +192,7 @@ class SavedQueryRestApi(BaseSupersetModelRestApi):
""" """
item_ids = kwargs["rison"] item_ids = kwargs["rison"]
try: try:
BulkDeleteSavedQueryCommand(g.user, item_ids).run() BulkDeleteSavedQueryCommand(item_ids).run()
return self.response( return self.response(
200, 200,
message=ngettext( message=ngettext(

View File

@ -17,8 +17,6 @@
import logging import logging
from typing import List, Optional from typing import List, Optional
from flask_appbuilder.security.sqla.models import User
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
from superset.dao.exceptions import DAODeleteFailedError from superset.dao.exceptions import DAODeleteFailedError
from superset.models.dashboard import Dashboard from superset.models.dashboard import Dashboard
@ -32,8 +30,7 @@ logger = logging.getLogger(__name__)
class BulkDeleteSavedQueryCommand(BaseCommand): class BulkDeleteSavedQueryCommand(BaseCommand):
def __init__(self, user: User, model_ids: List[int]): def __init__(self, model_ids: List[int]):
self._actor = user
self._model_ids = model_ids self._model_ids = model_ids
self._models: Optional[List[Dashboard]] = None self._models: Optional[List[Dashboard]] = None

View File

@ -17,7 +17,7 @@
import logging import logging
from typing import Any, Optional from typing import Any, Optional
from flask import g, request, Response from flask import request, Response
from flask_appbuilder.api import expose, permission_name, protect, rison, safe from flask_appbuilder.api import expose, permission_name, protect, rison, safe
from flask_appbuilder.hooks import before_request from flask_appbuilder.hooks import before_request
from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_appbuilder.models.sqla.interface import SQLAInterface
@ -266,7 +266,7 @@ class ReportScheduleRestApi(BaseSupersetModelRestApi):
$ref: '#/components/responses/500' $ref: '#/components/responses/500'
""" """
try: try:
DeleteReportScheduleCommand(g.user, pk).run() DeleteReportScheduleCommand(pk).run()
return self.response(200, message="OK") return self.response(200, message="OK")
except ReportScheduleNotFoundError: except ReportScheduleNotFoundError:
return self.response_404() return self.response_404()
@ -340,7 +340,7 @@ class ReportScheduleRestApi(BaseSupersetModelRestApi):
except ValidationError as error: except ValidationError as error:
return self.response_400(message=error.messages) return self.response_400(message=error.messages)
try: try:
new_model = CreateReportScheduleCommand(g.user, item).run() new_model = CreateReportScheduleCommand(item).run()
return self.response(201, id=new_model.id, result=item) return self.response(201, id=new_model.id, result=item)
except ReportScheduleNotFoundError as ex: except ReportScheduleNotFoundError as ex:
return self.response_400(message=str(ex)) return self.response_400(message=str(ex))
@ -421,7 +421,7 @@ class ReportScheduleRestApi(BaseSupersetModelRestApi):
except ValidationError as error: except ValidationError as error:
return self.response_400(message=error.messages) return self.response_400(message=error.messages)
try: try:
new_model = UpdateReportScheduleCommand(g.user, pk, item).run() new_model = UpdateReportScheduleCommand(pk, item).run()
return self.response(200, id=new_model.id, result=item) return self.response(200, id=new_model.id, result=item)
except ReportScheduleNotFoundError: except ReportScheduleNotFoundError:
return self.response_404() return self.response_404()
@ -483,7 +483,7 @@ class ReportScheduleRestApi(BaseSupersetModelRestApi):
""" """
item_ids = kwargs["rison"] item_ids = kwargs["rison"]
try: try:
BulkDeleteReportScheduleCommand(g.user, item_ids).run() BulkDeleteReportScheduleCommand(item_ids).run()
return self.response( return self.response(
200, 200,
message=ngettext( message=ngettext(

View File

@ -17,8 +17,7 @@
import logging import logging
from typing import List, Optional from typing import List, Optional
from flask_appbuilder.security.sqla.models import User from superset import security_manager
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
from superset.dao.exceptions import DAODeleteFailedError from superset.dao.exceptions import DAODeleteFailedError
from superset.exceptions import SupersetSecurityException from superset.exceptions import SupersetSecurityException
@ -29,14 +28,12 @@ from superset.reports.commands.exceptions import (
ReportScheduleNotFoundError, ReportScheduleNotFoundError,
) )
from superset.reports.dao import ReportScheduleDAO from superset.reports.dao import ReportScheduleDAO
from superset.views.base import check_ownership
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BulkDeleteReportScheduleCommand(BaseCommand): class BulkDeleteReportScheduleCommand(BaseCommand):
def __init__(self, user: User, model_ids: List[int]): def __init__(self, model_ids: List[int]):
self._actor = user
self._model_ids = model_ids self._model_ids = model_ids
self._models: Optional[List[ReportSchedule]] = None self._models: Optional[List[ReportSchedule]] = None
@ -58,6 +55,6 @@ class BulkDeleteReportScheduleCommand(BaseCommand):
# Check ownership # Check ownership
for model in self._models: for model in self._models:
try: try:
check_ownership(model) security_manager.raise_for_ownership(model)
except SupersetSecurityException as ex: except SupersetSecurityException as ex:
raise ReportScheduleForbiddenError() from ex raise ReportScheduleForbiddenError() from ex

View File

@ -19,7 +19,6 @@ import logging
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
from marshmallow import ValidationError from marshmallow import ValidationError
from superset.commands.base import CreateMixin from superset.commands.base import CreateMixin
@ -42,8 +41,7 @@ logger = logging.getLogger(__name__)
class CreateReportScheduleCommand(CreateMixin, BaseReportScheduleCommand): class CreateReportScheduleCommand(CreateMixin, BaseReportScheduleCommand):
def __init__(self, user: User, data: Dict[str, Any]): def __init__(self, data: Dict[str, Any]):
self._actor = user
self._properties = data.copy() self._properties = data.copy()
def run(self) -> Model: def run(self) -> Model:
@ -63,7 +61,6 @@ class CreateReportScheduleCommand(CreateMixin, BaseReportScheduleCommand):
creation_method = self._properties.get("creation_method") creation_method = self._properties.get("creation_method")
chart_id = self._properties.get("chart") chart_id = self._properties.get("chart")
dashboard_id = self._properties.get("dashboard") dashboard_id = self._properties.get("dashboard")
user_id = self._actor.id
# Validate type is required # Validate type is required
if not report_type: if not report_type:
@ -99,7 +96,7 @@ class CreateReportScheduleCommand(CreateMixin, BaseReportScheduleCommand):
if ( if (
creation_method != ReportCreationMethod.ALERTS_REPORTS creation_method != ReportCreationMethod.ALERTS_REPORTS
and not ReportScheduleDAO.validate_unique_creation_method( and not ReportScheduleDAO.validate_unique_creation_method(
user_id, dashboard_id, chart_id dashboard_id, chart_id
) )
): ):
raise ReportScheduleCreationMethodUniquenessValidationError() raise ReportScheduleCreationMethodUniquenessValidationError()
@ -110,7 +107,7 @@ class CreateReportScheduleCommand(CreateMixin, BaseReportScheduleCommand):
) )
try: try:
owners = self.populate_owners(self._actor, owner_ids) owners = self.populate_owners(owner_ids)
self._properties["owners"] = owners self._properties["owners"] = owners
except ValidationError as ex: except ValidationError as ex:
exceptions.append(ex) exceptions.append(ex)

View File

@ -18,8 +18,8 @@ import logging
from typing import Optional from typing import Optional
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
from superset import security_manager
from superset.commands.base import BaseCommand from superset.commands.base import BaseCommand
from superset.dao.exceptions import DAODeleteFailedError from superset.dao.exceptions import DAODeleteFailedError
from superset.exceptions import SupersetSecurityException from superset.exceptions import SupersetSecurityException
@ -30,14 +30,12 @@ from superset.reports.commands.exceptions import (
ReportScheduleNotFoundError, ReportScheduleNotFoundError,
) )
from superset.reports.dao import ReportScheduleDAO from superset.reports.dao import ReportScheduleDAO
from superset.views.base import check_ownership
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DeleteReportScheduleCommand(BaseCommand): class DeleteReportScheduleCommand(BaseCommand):
def __init__(self, user: User, model_id: int): def __init__(self, model_id: int):
self._actor = user
self._model_id = model_id self._model_id = model_id
self._model: Optional[ReportSchedule] = None self._model: Optional[ReportSchedule] = None
@ -58,6 +56,6 @@ class DeleteReportScheduleCommand(BaseCommand):
# Check ownership # Check ownership
try: try:
check_ownership(self._model) security_manager.raise_for_ownership(self._model)
except SupersetSecurityException as ex: except SupersetSecurityException as ex:
raise ReportScheduleForbiddenError() from ex raise ReportScheduleForbiddenError() from ex

View File

@ -19,9 +19,9 @@ import logging
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
from marshmallow import ValidationError from marshmallow import ValidationError
from superset import security_manager
from superset.commands.base import UpdateMixin from superset.commands.base import UpdateMixin
from superset.dao.exceptions import DAOUpdateFailedError from superset.dao.exceptions import DAOUpdateFailedError
from superset.databases.dao import DatabaseDAO from superset.databases.dao import DatabaseDAO
@ -37,14 +37,12 @@ from superset.reports.commands.exceptions import (
ReportScheduleUpdateFailedError, ReportScheduleUpdateFailedError,
) )
from superset.reports.dao import ReportScheduleDAO from superset.reports.dao import ReportScheduleDAO
from superset.views.base import check_ownership
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UpdateReportScheduleCommand(UpdateMixin, BaseReportScheduleCommand): class UpdateReportScheduleCommand(UpdateMixin, BaseReportScheduleCommand):
def __init__(self, user: User, model_id: int, data: Dict[str, Any]): def __init__(self, model_id: int, data: Dict[str, Any]):
self._actor = user
self._model_id = model_id self._model_id = model_id
self._properties = data.copy() self._properties = data.copy()
self._model: Optional[ReportSchedule] = None self._model: Optional[ReportSchedule] = None
@ -113,7 +111,7 @@ class UpdateReportScheduleCommand(UpdateMixin, BaseReportScheduleCommand):
# Check ownership # Check ownership
try: try:
check_ownership(self._model) security_manager.raise_for_ownership(self._model)
except SupersetSecurityException as ex: except SupersetSecurityException as ex:
raise ReportScheduleForbiddenError() from ex raise ReportScheduleForbiddenError() from ex
@ -121,7 +119,7 @@ class UpdateReportScheduleCommand(UpdateMixin, BaseReportScheduleCommand):
if owner_ids is None: if owner_ids is None:
owner_ids = [owner.id for owner in self._model.owners] owner_ids = [owner.id for owner in self._model.owners]
try: try:
owners = self.populate_owners(self._actor, owner_ids) owners = self.populate_owners(owner_ids)
self._properties["owners"] = owners self._properties["owners"] = owners
except ValidationError as ex: except ValidationError as ex:
exceptions.append(ex) exceptions.append(ex)

View File

@ -33,6 +33,7 @@ from superset.models.reports import (
ReportScheduleType, ReportScheduleType,
ReportState, ReportState,
) )
from superset.utils.core import get_user_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -116,14 +117,14 @@ class ReportScheduleDAO(BaseDAO):
@staticmethod @staticmethod
def validate_unique_creation_method( def validate_unique_creation_method(
user_id: int, dashboard_id: Optional[int] = None, chart_id: Optional[int] = None dashboard_id: Optional[int] = None, chart_id: Optional[int] = None
) -> bool: ) -> bool:
""" """
Validate if the user already has a chart or dashboard Validate if the user already has a chart or dashboard
with a report attached form the self subscribe reports with a report attached form the self subscribe reports
""" """
query = db.session.query(ReportSchedule).filter_by(created_by_fk=user_id) query = db.session.query(ReportSchedule).filter_by(created_by_fk=get_user_id())
if dashboard_id is not None: if dashboard_id is not None:
query = query.filter(ReportSchedule.dashboard_id == dashboard_id) query = query.filter(ReportSchedule.dashboard_id == dashboard_id)

View File

@ -1093,7 +1093,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
from superset.connectors.sqla.models import SqlaTable from superset.connectors.sqla.models import SqlaTable
from superset.extensions import feature_flag_manager from superset.extensions import feature_flag_manager
from superset.sql_parse import Table from superset.sql_parse import Table
from superset.views.utils import is_owner
if database and table or query: if database and table or query:
if query: if query:
@ -1126,7 +1125,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
for datasource_ in datasources: for datasource_ in datasources:
if self.can_access( if self.can_access(
"datasource_access", datasource_.perm "datasource_access", datasource_.perm
) or is_owner(datasource_, getattr(g, "user", None)): ) or self.is_owner(datasource_):
break break
else: else:
denied.add(table_) denied.add(table_)
@ -1152,7 +1151,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
if not ( if not (
self.can_access_schema(datasource) self.can_access_schema(datasource)
or self.can_access("datasource_access", datasource.perm or "") or self.can_access("datasource_access", datasource.perm or "")
or is_owner(datasource, getattr(g, "user", None)) or self.is_owner(datasource)
or ( or (
should_check_dashboard_access should_check_dashboard_access
and self.can_access_based_on_dashboard(datasource) and self.can_access_based_on_dashboard(datasource)
@ -1327,8 +1326,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
# pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel
from superset import is_feature_enabled from superset import is_feature_enabled
from superset.dashboards.commands.exceptions import DashboardAccessDeniedError from superset.dashboards.commands.exceptions import DashboardAccessDeniedError
from superset.views.base import is_user_admin
from superset.views.utils import is_owner
def has_rbac_access() -> bool: def has_rbac_access() -> bool:
return (not is_feature_enabled("DASHBOARD_RBAC")) or any( return (not is_feature_enabled("DASHBOARD_RBAC")) or any(
@ -1341,8 +1338,8 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
can_access = self.has_guest_access(dashboard) can_access = self.has_guest_access(dashboard)
else: else:
can_access = ( can_access = (
is_user_admin() self.is_admin()
or is_owner(dashboard, g.user) or self.is_owner(dashboard)
or (dashboard.published and has_rbac_access()) or (dashboard.published and has_rbac_access())
or (not dashboard.published and not dashboard.roles) or (not dashboard.published and not dashboard.roles)
) )
@ -1520,3 +1517,69 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
if str(resource["id"]) == str(dashboard.embedded[0].uuid): if str(resource["id"]) == str(dashboard.embedded[0].uuid):
return True return True
return False return False
def raise_for_ownership(self, resource: Model) -> None:
"""
Raise an exception if the user does not own the resource.
Note admins are deemed owners of all resources.
:param resource: The dashboard, dataste, chart, etc. resource
:raises SupersetSecurityException: If the current user is not an owner
"""
# pylint: disable=import-outside-toplevel
from superset import db
if self.is_admin():
return
# Set of wners that works across ORM models.
owners: List[User] = []
orig_resource = db.session.query(resource.__class__).get(resource.id)
if orig_resource:
if hasattr(resource, "owners"):
owners += orig_resource.owners
if hasattr(resource, "owner"):
owners.append(orig_resource.owner)
if hasattr(resource, "created_by"):
owners.append(orig_resource.created_by)
if g.user.is_anonymous or g.user not in owners:
raise SupersetSecurityException(
SupersetError(
error_type=SupersetErrorType.MISSING_OWNERSHIP_ERROR,
message=f"You don't have the rights to alter [{resource}]",
level=ErrorLevel.ERROR,
)
)
def is_owner(self, resource: Model) -> bool:
"""
Returns True if the current user is an owner of the resource, False otherwise.
:param resource: The dashboard, dataste, chart, etc. resource
:returns: Whethe the current user is an owner of the resource
"""
try:
self.raise_for_ownership(resource)
except SupersetSecurityException:
return False
return True
def is_admin(self) -> bool:
"""
Returns True if the current user is an admin user, False otherwise.
:returns: Whehther the current user is an admin user
"""
return current_app.config["AUTH_ROLE_ADMIN"] in [
role.name for role in self.get_user_roles()
]

View File

@ -20,7 +20,7 @@ from typing import Any
from apispec import APISpec from apispec import APISpec
from apispec.exceptions import DuplicateComponentNameError from apispec.exceptions import DuplicateComponentNameError
from flask import g, request, Response from flask import request, Response
from flask_appbuilder.api import BaseApi from flask_appbuilder.api import BaseApi
from marshmallow import ValidationError from marshmallow import ValidationError
@ -70,9 +70,7 @@ class TemporaryCacheRestApi(BaseApi, ABC):
try: try:
item = self.add_model_schema.load(request.json) item = self.add_model_schema.load(request.json)
tab_id = request.args.get("tab_id") tab_id = request.args.get("tab_id")
args = CommandParameters( args = CommandParameters(resource_id=pk, value=item["value"], tab_id=tab_id)
actor=g.user, resource_id=pk, value=item["value"], tab_id=tab_id
)
key = self.get_create_command()(args).run() key = self.get_create_command()(args).run()
return self.response(201, key=key) return self.response(201, key=key)
except ValidationError as ex: except ValidationError as ex:
@ -88,7 +86,6 @@ class TemporaryCacheRestApi(BaseApi, ABC):
item = self.edit_model_schema.load(request.json) item = self.edit_model_schema.load(request.json)
tab_id = request.args.get("tab_id") tab_id = request.args.get("tab_id")
args = CommandParameters( args = CommandParameters(
actor=g.user,
resource_id=pk, resource_id=pk,
key=key, key=key,
value=item["value"], value=item["value"],
@ -105,7 +102,7 @@ class TemporaryCacheRestApi(BaseApi, ABC):
def get(self, pk: int, key: str) -> Response: def get(self, pk: int, key: str) -> Response:
try: try:
args = CommandParameters(actor=g.user, resource_id=pk, key=key) args = CommandParameters(resource_id=pk, key=key)
value = self.get_get_command()(args).run() value = self.get_get_command()(args).run()
if not value: if not value:
return self.response_404() return self.response_404()
@ -117,7 +114,7 @@ class TemporaryCacheRestApi(BaseApi, ABC):
def delete(self, pk: int, key: str) -> Response: def delete(self, pk: int, key: str) -> Response:
try: try:
args = CommandParameters(actor=g.user, resource_id=pk, key=key) args = CommandParameters(resource_id=pk, key=key)
result = self.get_delete_command()(args).run() result = self.get_delete_command()(args).run()
if not result: if not result:
return self.response_404() return self.response_404()

View File

@ -17,12 +17,9 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional
from flask_appbuilder.security.sqla.models import User
@dataclass @dataclass
class CommandParameters: class CommandParameters:
actor: User
resource_id: int resource_id: int
tab_id: Optional[int] = None tab_id: Optional[int] = None
key: Optional[str] = None key: Optional[str] = None

View File

@ -25,9 +25,10 @@ from superset.views.base import DeleteMixin, SupersetModelView
from superset.views.core import DAR from superset.views.core import DAR
class AccessRequestsModelView( class AccessRequestsModelView( # pylint: disable=too-many-ancestors
SupersetModelView, DeleteMixin SupersetModelView,
): # pylint: disable=too-many-ancestors DeleteMixin,
):
datamodel = SQLAInterface(DAR) datamodel = SQLAInterface(DAR)
include_route_methods = RouteMethod.CRUD_SET include_route_methods = RouteMethod.CRUD_SET
list_columns = [ list_columns = [

View File

@ -47,9 +47,10 @@ class StartEndDttmValidator: # pylint: disable=too-few-public-methods
) )
class AnnotationModelView( class AnnotationModelView( # pylint: disable=too-many-ancestors
SupersetModelView, CompactCRUDMixin SupersetModelView,
): # pylint: disable=too-many-ancestors CompactCRUDMixin,
):
datamodel = SQLAInterface(Annotation) datamodel = SQLAInterface(Annotation)
include_route_methods = RouteMethod.CRUD_SET | {"annotation"} include_route_methods = RouteMethod.CRUD_SET | {"annotation"}

View File

@ -38,7 +38,6 @@ from flask_appbuilder import BaseView, Model, ModelView
from flask_appbuilder.actions import action from flask_appbuilder.actions import action
from flask_appbuilder.forms import DynamicForm from flask_appbuilder.forms import DynamicForm
from flask_appbuilder.models.sqla.filters import BaseFilter from flask_appbuilder.models.sqla.filters import BaseFilter
from flask_appbuilder.security.sqla.models import User
from flask_appbuilder.widgets import ListWidget from flask_appbuilder.widgets import ListWidget
from flask_babel import get_locale, gettext as __, lazy_gettext as _ from flask_babel import get_locale, gettext as __, lazy_gettext as _
from flask_jwt_extended.exceptions import NoAuthorizationError from flask_jwt_extended.exceptions import NoAuthorizationError
@ -270,11 +269,6 @@ def create_table_permissions(table: models.SqlaTable) -> None:
security_manager.add_permission_view_menu("schema_access", table.schema_perm) security_manager.add_permission_view_menu("schema_access", table.schema_perm)
def is_user_admin() -> bool:
user_roles = [role.name.lower() for role in list(security_manager.get_user_roles())]
return "admin" in user_roles
class BaseSupersetView(BaseView): class BaseSupersetView(BaseView):
@staticmethod @staticmethod
def json_response(obj: Any, status: int = 200) -> FlaskResponse: def json_response(obj: Any, status: int = 200) -> FlaskResponse:
@ -644,53 +638,6 @@ class CsvResponse(Response):
default_mimetype = "text/csv" default_mimetype = "text/csv"
def check_ownership(obj: Any, raise_if_false: bool = True) -> bool:
"""Meant to be used in `pre_update` hooks on models to enforce ownership
Admin have all access, and other users need to be referenced on either
the created_by field that comes with the ``AuditMixin``, or in a field
named ``owners`` which is expected to be a one-to-many with the User
model. It is meant to be used in the ModelView's pre_update hook in
which raising will abort the update.
"""
if not obj:
return False
security_exception = SupersetSecurityException(
SupersetError(
error_type=SupersetErrorType.MISSING_OWNERSHIP_ERROR,
message="You don't have the rights to alter [{}]".format(obj),
level=ErrorLevel.ERROR,
)
)
if g.user.is_anonymous:
if raise_if_false:
raise security_exception
return False
if is_user_admin():
return True
scoped_session = db.create_scoped_session()
orig_obj = scoped_session.query(obj.__class__).filter_by(id=obj.id).first()
# Making a list of owners that works across ORM models
owners: List[User] = []
if hasattr(orig_obj, "owners"):
owners += orig_obj.owners
if hasattr(orig_obj, "owner"):
owners += [orig_obj.owner]
if hasattr(orig_obj, "created_by"):
owners += [orig_obj.created_by]
owner_names = [o.username for o in owners if o]
if g.user and hasattr(g.user, "username") and g.user.username in owner_names:
return True
if raise_if_false:
raise security_exception
return False
def bind_field( def bind_field(
_: Any, form: DynamicForm, unbound_field: UnboundField, options: Dict[Any, Any] _: Any, form: DynamicForm, unbound_field: UnboundField, options: Dict[Any, Any]
) -> Field: ) -> Field:

View File

@ -21,16 +21,12 @@ from flask_appbuilder import expose, has_access
from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_babel import lazy_gettext as _ from flask_babel import lazy_gettext as _
from superset import security_manager
from superset.constants import MODEL_VIEW_RW_METHOD_PERMISSION_MAP, RouteMethod from superset.constants import MODEL_VIEW_RW_METHOD_PERMISSION_MAP, RouteMethod
from superset.models.slice import Slice from superset.models.slice import Slice
from superset.superset_typing import FlaskResponse from superset.superset_typing import FlaskResponse
from superset.utils import core as utils from superset.utils import core as utils
from superset.views.base import ( from superset.views.base import common_bootstrap_payload, DeleteMixin, SupersetModelView
check_ownership,
common_bootstrap_payload,
DeleteMixin,
SupersetModelView,
)
from superset.views.chart.mixin import SliceMixin from superset.views.chart.mixin import SliceMixin
from superset.views.utils import bootstrap_user_data from superset.views.utils import bootstrap_user_data
@ -53,10 +49,10 @@ class SliceModelView(
def pre_update(self, item: "SliceModelView") -> None: def pre_update(self, item: "SliceModelView") -> None:
utils.validate_json(item.params) utils.validate_json(item.params)
check_ownership(item) security_manager.raise_for_ownership(item)
def pre_delete(self, item: "SliceModelView") -> None: def pre_delete(self, item: "SliceModelView") -> None:
check_ownership(item) security_manager.raise_for_ownership(item)
@expose("/add", methods=["GET", "POST"]) @expose("/add", methods=["GET", "POST"])
@has_access @has_access

View File

@ -140,7 +140,6 @@ from superset.utils.decorators import check_dashboard_access
from superset.views.base import ( from superset.views.base import (
api, api,
BaseSupersetView, BaseSupersetView,
check_ownership,
common_bootstrap_payload, common_bootstrap_payload,
create_table_permissions, create_table_permissions,
CsvResponse, CsvResponse,
@ -164,7 +163,6 @@ from superset.views.utils import (
get_datasource_info, get_datasource_info,
get_form_data, get_form_data,
get_viz, get_viz,
is_owner,
sanitize_datasource_data, sanitize_datasource_data,
) )
from superset.viz import BaseViz from superset.viz import BaseViz
@ -368,8 +366,8 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
return json_error_response(err) return json_error_response(err)
# check if you can approve # check if you can approve
if security_manager.can_access_all_datasources() or check_ownership( if security_manager.can_access_all_datasources() or security_manager.is_owner(
datasource, raise_if_false=False datasource
): ):
# can by done by admin only # can by done by admin only
if role_to_grant: if role_to_grant:
@ -758,7 +756,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
form_data_key = request.args.get("form_data_key") form_data_key = request.args.get("form_data_key")
if key is not None: if key is not None:
command = GetExplorePermalinkCommand(g.user, key) command = GetExplorePermalinkCommand(key)
try: try:
permalink_value = command.run() permalink_value = command.run()
if permalink_value: if permalink_value:
@ -775,7 +773,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
flash(__("Error: %(msg)s", msg=ex.message), "danger") flash(__("Error: %(msg)s", msg=ex.message), "danger")
return redirect("/chart/list/") return redirect("/chart/list/")
elif form_data_key: elif form_data_key:
parameters = CommandParameters(actor=g.user, key=form_data_key) parameters = CommandParameters(key=form_data_key)
value = GetFormDataCommand(parameters).run() value = GetFormDataCommand(parameters).run()
initial_form_data = json.loads(value) if value else {} initial_form_data = json.loads(value) if value else {}
@ -857,7 +855,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
# slc perms # slc perms
slice_add_perm = security_manager.can_access("can_write", "Chart") slice_add_perm = security_manager.can_access("can_write", "Chart")
slice_overwrite_perm = is_owner(slc, g.user) if slc else False slice_overwrite_perm = security_manager.is_owner(slc) if slc else False
slice_download_perm = security_manager.can_access("can_csv", "Superset") slice_download_perm = security_manager.can_access("can_csv", "Superset")
form_data["datasource"] = str(datasource_id) + "__" + cast(str, datasource_type) form_data["datasource"] = str(datasource_id) + "__" + cast(str, datasource_type)
@ -1050,7 +1048,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
.one(), .one(),
) )
# check edit dashboard permissions # check edit dashboard permissions
dash_overwrite_perm = check_ownership(dash, raise_if_false=False) dash_overwrite_perm = security_manager.is_owner(dash)
if not dash_overwrite_perm: if not dash_overwrite_perm:
return json_error_response( return json_error_response(
_("You don't have the rights to ") _("You don't have the rights to ")
@ -1297,7 +1295,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
"""Save a dashboard's metadata""" """Save a dashboard's metadata"""
session = db.session() session = db.session()
dash = session.query(Dashboard).get(dashboard_id) dash = session.query(Dashboard).get(dashboard_id)
check_ownership(dash, raise_if_false=True) security_manager.raise_for_ownership(dash)
data = json.loads(request.form["data"]) data = json.loads(request.form["data"])
# client-side send back last_modified_time which was set when # client-side send back last_modified_time which was set when
# the dashboard was open. it was use to avoid mid-air collision. # the dashboard was open. it was use to avoid mid-air collision.
@ -1340,7 +1338,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
data = json.loads(request.form["data"]) data = json.loads(request.form["data"])
session = db.session() session = db.session()
dash = session.query(Dashboard).get(dashboard_id) dash = session.query(Dashboard).get(dashboard_id)
check_ownership(dash, raise_if_false=True) security_manager.raise_for_ownership(dash)
new_slices = session.query(Slice).filter(Slice.id.in_(data["slice_ids"])) new_slices = session.query(Slice).filter(Slice.id.in_(data["slice_ids"]))
dash.slices += new_slices dash.slices += new_slices
session.merge(dash) session.merge(dash)
@ -1664,7 +1662,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
def user_slices(self, user_id: Optional[int] = None) -> FlaskResponse: def user_slices(self, user_id: Optional[int] = None) -> FlaskResponse:
"""List of slices a user owns, created, modified or faved""" """List of slices a user owns, created, modified or faved"""
if not user_id: if not user_id:
user_id = cast(int, g.user.id) user_id = cast(int, get_user_id())
error_obj = self.get_user_activity_access_error(user_id) error_obj = self.get_user_activity_access_error(user_id)
if error_obj: if error_obj:
return error_obj return error_obj
@ -1717,7 +1715,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
def created_slices(self, user_id: Optional[int] = None) -> FlaskResponse: def created_slices(self, user_id: Optional[int] = None) -> FlaskResponse:
"""List of slices created by this user""" """List of slices created by this user"""
if not user_id: if not user_id:
user_id = cast(int, g.user.id) user_id = cast(int, get_user_id())
error_obj = self.get_user_activity_access_error(user_id) error_obj = self.get_user_activity_access_error(user_id)
if error_obj: if error_obj:
return error_obj return error_obj
@ -1748,7 +1746,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
def fave_slices(self, user_id: Optional[int] = None) -> FlaskResponse: def fave_slices(self, user_id: Optional[int] = None) -> FlaskResponse:
"""Favorite slices for a user""" """Favorite slices for a user"""
if user_id is None: if user_id is None:
user_id = g.user.id user_id = cast(int, get_user_id())
error_obj = self.get_user_activity_access_error(user_id) error_obj = self.get_user_activity_access_error(user_id)
if error_obj: if error_obj:
return error_obj return error_obj
@ -1957,8 +1955,8 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
f"/superset/request_access/?dashboard_id={dashboard.id}" f"/superset/request_access/?dashboard_id={dashboard.id}"
) )
dash_edit_perm = check_ownership( dash_edit_perm = security_manager.is_owner(
dashboard, raise_if_false=False dashboard
) and security_manager.can_access("can_save_dash", "Superset") ) and security_manager.can_access("can_save_dash", "Superset")
edit_mode = ( edit_mode = (
request.args.get(utils.ReservedUrlParameters.EDIT_MODE.value) == "true" request.args.get(utils.ReservedUrlParameters.EDIT_MODE.value) == "true"
@ -1994,7 +1992,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
key: str, key: str,
) -> FlaskResponse: ) -> FlaskResponse:
try: try:
value = GetDashboardPermalinkCommand(g.user, key).run() value = GetDashboardPermalinkCommand(key).run()
except DashboardPermalinkGetFailedError as ex: except DashboardPermalinkGetFailedError as ex:
flash(__("Error: %(msg)s", msg=ex.message), "danger") flash(__("Error: %(msg)s", msg=ex.message), "danger")
return redirect("/dashboard/list/") return redirect("/dashboard/list/")

View File

@ -25,9 +25,10 @@ from superset.superset_typing import FlaskResponse
from superset.views.base import DeleteMixin, SupersetModelView from superset.views.base import DeleteMixin, SupersetModelView
class CssTemplateModelView( class CssTemplateModelView( # pylint: disable=too-many-ancestors
SupersetModelView, DeleteMixin SupersetModelView,
): # pylint: disable=too-many-ancestors DeleteMixin,
):
datamodel = SQLAInterface(models.CssTemplate) datamodel = SQLAInterface(models.CssTemplate)
include_route_methods = RouteMethod.CRUD_SET include_route_methods = RouteMethod.CRUD_SET

View File

@ -16,8 +16,8 @@
# under the License. # under the License.
from flask_babel import lazy_gettext as _ from flask_babel import lazy_gettext as _
from ...dashboards.filters import DashboardAccessFilter from superset import security_manager
from ..base import check_ownership from superset.dashboards.filters import DashboardAccessFilter
class DashboardMixin: # pylint: disable=too-few-public-methods class DashboardMixin: # pylint: disable=too-few-public-methods
@ -90,4 +90,4 @@ class DashboardMixin: # pylint: disable=too-few-public-methods
} }
def pre_delete(self, item: "DashboardMixin") -> None: # pylint: disable=no-self-use def pre_delete(self, item: "DashboardMixin") -> None: # pylint: disable=no-self-use
check_ownership(item) security_manager.raise_for_ownership(item)

View File

@ -33,7 +33,6 @@ from superset.superset_typing import FlaskResponse
from superset.utils import core as utils from superset.utils import core as utils
from superset.views.base import ( from superset.views.base import (
BaseSupersetView, BaseSupersetView,
check_ownership,
common_bootstrap_payload, common_bootstrap_payload,
DeleteMixin, DeleteMixin,
generate_download_headers, generate_download_headers,
@ -97,12 +96,11 @@ class DashboardModelView(
item.owners.append(g.user) item.owners.append(g.user)
utils.validate_json(item.json_metadata) utils.validate_json(item.json_metadata)
utils.validate_json(item.position_json) utils.validate_json(item.position_json)
owners = list(item.owners)
for slc in item.slices: for slc in item.slices:
slc.owners = list(set(owners) | set(slc.owners)) slc.owners = list(set(item.owners) | set(slc.owners))
def pre_update(self, item: "DashboardModelView") -> None: def pre_update(self, item: "DashboardModelView") -> None:
check_ownership(item) security_manager.raise_for_ownership(item)
self.pre_add(item) self.pre_add(item)

View File

@ -18,7 +18,7 @@ import json
from collections import Counter from collections import Counter
from typing import Any from typing import Any
from flask import g, request from flask import request
from flask_appbuilder import expose from flask_appbuilder import expose
from flask_appbuilder.api import rison from flask_appbuilder.api import rison
from flask_appbuilder.security.decorators import has_access_api from flask_appbuilder.security.decorators import has_access_api
@ -27,7 +27,7 @@ from marshmallow import ValidationError
from sqlalchemy.exc import NoSuchTableError from sqlalchemy.exc import NoSuchTableError
from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.orm.exc import NoResultFound
from superset import db, event_logger from superset import db, event_logger, security_manager
from superset.commands.utils import populate_owners from superset.commands.utils import populate_owners
from superset.connectors.sqla.models import SqlaTable from superset.connectors.sqla.models import SqlaTable
from superset.connectors.sqla.utils import get_physical_table_metadata from superset.connectors.sqla.utils import get_physical_table_metadata
@ -37,14 +37,12 @@ from superset.datasets.commands.exceptions import (
) )
from superset.datasource.dao import DatasourceDAO from superset.datasource.dao import DatasourceDAO
from superset.exceptions import SupersetException, SupersetSecurityException from superset.exceptions import SupersetException, SupersetSecurityException
from superset.extensions import security_manager
from superset.models.core import Database from superset.models.core import Database
from superset.superset_typing import FlaskResponse from superset.superset_typing import FlaskResponse
from superset.utils.core import DatasourceType from superset.utils.core import DatasourceType
from superset.views.base import ( from superset.views.base import (
api, api,
BaseSupersetView, BaseSupersetView,
check_ownership,
handle_api_exception, handle_api_exception,
json_error_response, json_error_response,
) )
@ -84,13 +82,12 @@ class Datasource(BaseSupersetView):
if "owners" in datasource_dict and orm_datasource.owner_class is not None: if "owners" in datasource_dict and orm_datasource.owner_class is not None:
# Check ownership # Check ownership
try: try:
check_ownership(orm_datasource) security_manager.raise_for_ownership(orm_datasource)
except SupersetSecurityException as ex: except SupersetSecurityException as ex:
raise DatasetForbiddenError() from ex raise DatasetForbiddenError() from ex
user = security_manager.get_user_by_id(g.user.id)
datasource_dict["owners"] = populate_owners( datasource_dict["owners"] = populate_owners(
user, datasource_dict["owners"], default_to_user=False datasource_dict["owners"], default_to_user=False
) )
duplicates = [ duplicates = [

View File

@ -26,7 +26,10 @@ from superset.views.base import SupersetModelView
from . import LogMixin from . import LogMixin
class LogModelView(LogMixin, SupersetModelView): # pylint: disable=too-many-ancestors class LogModelView( # pylint: disable=too-many-ancestors
LogMixin,
SupersetModelView,
):
datamodel = SQLAInterface(models.Log) datamodel = SQLAInterface(models.Log)
include_route_methods = {RouteMethod.LIST, RouteMethod.SHOW} include_route_methods = {RouteMethod.LIST, RouteMethod.SHOW}
class_permission_name = "Log" class_permission_name = "Log"

View File

@ -36,9 +36,10 @@ from .base import BaseSupersetView, DeleteMixin, json_success, SupersetModelView
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class SavedQueryView( class SavedQueryView( # pylint: disable=too-many-ancestors
SupersetModelView, DeleteMixin SupersetModelView,
): # pylint: disable=too-many-ancestors DeleteMixin,
):
datamodel = SQLAInterface(SavedQuery) datamodel = SQLAInterface(SavedQuery)
include_route_methods = RouteMethod.CRUD_SET include_route_methods = RouteMethod.CRUD_SET

View File

@ -32,7 +32,6 @@ from sqlalchemy.orm.exc import NoResultFound
import superset.models.core as models import superset.models.core as models
from superset import app, dataframe, db, result_set, viz from superset import app, dataframe, db, result_set, viz
from superset.common.db_query_status import QueryStatus from superset.common.db_query_status import QueryStatus
from superset.connectors.sqla.models import SqlaTable
from superset.datasource.dao import DatasourceDAO from superset.datasource.dao import DatasourceDAO
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import ( from superset.exceptions import (
@ -427,11 +426,6 @@ def is_slice_in_container(
return False return False
def is_owner(obj: Union[Dashboard, Slice, SqlaTable], user: User) -> bool:
"""Check if user is owner of the slice"""
return obj and user in obj.owners
def check_resource_permissions( def check_resource_permissions(
check_perms: Callable[..., Any], check_perms: Callable[..., Any],
) -> Callable[..., Any]: ) -> Callable[..., Any]:

View File

@ -350,58 +350,60 @@ class TestImportChartsCommand(SupersetTestCase):
class TestChartsCreateCommand(SupersetTestCase): class TestChartsCreateCommand(SupersetTestCase):
@patch("superset.views.base.g") @patch("superset.utils.core.g")
@patch("superset.charts.commands.create.g")
@patch("superset.security.manager.g") @patch("superset.security.manager.g")
@pytest.mark.usefixtures("load_energy_table_with_slice") @pytest.mark.usefixtures("load_energy_table_with_slice")
def test_create_v1_response(self, mock_sm_g, mock_g): def test_create_v1_response(self, mock_sm_g, mock_c_g, mock_u_g):
"""Test that the create chart command creates a chart""" """Test that the create chart command creates a chart"""
actor = security_manager.find_user(username="admin") user = security_manager.find_user(username="admin")
mock_g.user = mock_sm_g.user = actor mock_u_g.user = mock_c_g.user = mock_sm_g.user = user
chart_data = { chart_data = {
"slice_name": "new chart", "slice_name": "new chart",
"description": "new description", "description": "new description",
"owners": [actor.id], "owners": [user.id],
"viz_type": "new_viz_type", "viz_type": "new_viz_type",
"params": json.dumps({"viz_type": "new_viz_type"}), "params": json.dumps({"viz_type": "new_viz_type"}),
"cache_timeout": 1000, "cache_timeout": 1000,
"datasource_id": 1, "datasource_id": 1,
"datasource_type": "table", "datasource_type": "table",
} }
command = CreateChartCommand(actor, chart_data) command = CreateChartCommand(chart_data)
chart = command.run() chart = command.run()
chart = db.session.query(Slice).get(chart.id) chart = db.session.query(Slice).get(chart.id)
assert chart.viz_type == "new_viz_type" assert chart.viz_type == "new_viz_type"
json_params = json.loads(chart.params) json_params = json.loads(chart.params)
assert json_params == {"viz_type": "new_viz_type"} assert json_params == {"viz_type": "new_viz_type"}
assert chart.slice_name == "new chart" assert chart.slice_name == "new chart"
assert chart.owners == [actor] assert chart.owners == [user]
db.session.delete(chart) db.session.delete(chart)
db.session.commit() db.session.commit()
class TestChartsUpdateCommand(SupersetTestCase): class TestChartsUpdateCommand(SupersetTestCase):
@patch("superset.views.base.g") @patch("superset.charts.commands.update.g")
@patch("superset.utils.core.g")
@patch("superset.security.manager.g") @patch("superset.security.manager.g")
@pytest.mark.usefixtures("load_energy_table_with_slice") @pytest.mark.usefixtures("load_energy_table_with_slice")
def test_update_v1_response(self, mock_sm_g, mock_g): def test_update_v1_response(self, mock_sm_g, mock_c_g, mock_u_g):
"""Test that a chart command updates properties""" """Test that a chart command updates properties"""
pk = db.session.query(Slice).all()[0].id pk = db.session.query(Slice).all()[0].id
actor = security_manager.find_user(username="admin") user = security_manager.find_user(username="admin")
mock_g.user = mock_sm_g.user = actor mock_u_g.user = mock_c_g.user = mock_sm_g.user = user
model_id = pk model_id = pk
json_obj = { json_obj = {
"description": "test for update", "description": "test for update",
"cache_timeout": None, "cache_timeout": None,
"owners": [actor.id], "owners": [user.id],
} }
command = UpdateChartCommand(actor, model_id, json_obj) command = UpdateChartCommand(model_id, json_obj)
last_saved_before = db.session.query(Slice).get(pk).last_saved_at last_saved_before = db.session.query(Slice).get(pk).last_saved_at
command.run() command.run()
chart = db.session.query(Slice).get(pk) chart = db.session.query(Slice).get(pk)
assert chart.last_saved_at != last_saved_before assert chart.last_saved_at != last_saved_before
assert chart.last_saved_by == actor assert chart.last_saved_by == user
@patch("superset.views.base.g") @patch("superset.utils.core.g")
@patch("superset.security.manager.g") @patch("superset.security.manager.g")
@pytest.mark.usefixtures("load_energy_table_with_slice") @pytest.mark.usefixtures("load_energy_table_with_slice")
def test_query_context_update_command(self, mock_sm_g, mock_g): def test_query_context_update_command(self, mock_sm_g, mock_g):
@ -415,14 +417,14 @@ class TestChartsUpdateCommand(SupersetTestCase):
chart.owners = [admin] chart.owners = [admin]
db.session.commit() db.session.commit()
actor = security_manager.find_user(username="alpha") user = security_manager.find_user(username="alpha")
mock_g.user = mock_sm_g.user = actor mock_g.user = mock_sm_g.user = user
query_context = json.dumps({"foo": "bar"}) query_context = json.dumps({"foo": "bar"})
json_obj = { json_obj = {
"query_context_generation": True, "query_context_generation": True,
"query_context": query_context, "query_context": query_context,
} }
command = UpdateChartCommand(actor, pk, json_obj) command = UpdateChartCommand(pk, json_obj)
command.run() command.run()
chart = db.session.query(Slice).get(pk) chart = db.session.query(Slice).get(pk)
assert chart.query_context == query_context assert chart.query_context == query_context

View File

@ -70,10 +70,11 @@ class TestCreateDatabaseCommand(SupersetTestCase):
@mock.patch( @mock.patch(
"superset.databases.commands.test_connection.event_logger.log_with_context" "superset.databases.commands.test_connection.event_logger.log_with_context"
) )
def test_create_duplicate_error(self, mock_logger): @mock.patch("superset.utils.core.g")
def test_create_duplicate_error(self, mock_g, mock_logger):
example_db = get_example_database() example_db = get_example_database()
mock_g.user = security_manager.find_user("admin")
command = CreateDatabaseCommand( command = CreateDatabaseCommand(
security_manager.find_user("admin"),
{"database_name": example_db.database_name}, {"database_name": example_db.database_name},
) )
with pytest.raises(DatabaseInvalidError) as excinfo: with pytest.raises(DatabaseInvalidError) as excinfo:
@ -90,8 +91,10 @@ class TestCreateDatabaseCommand(SupersetTestCase):
@mock.patch( @mock.patch(
"superset.databases.commands.test_connection.event_logger.log_with_context" "superset.databases.commands.test_connection.event_logger.log_with_context"
) )
def test_multiple_error_logging(self, mock_logger): @mock.patch("superset.utils.core.g")
command = CreateDatabaseCommand(security_manager.find_user("admin"), {}) def test_multiple_error_logging(self, mock_g, mock_logger):
mock_g.user = security_manager.find_user("admin")
command = CreateDatabaseCommand({})
with pytest.raises(DatabaseInvalidError) as excinfo: with pytest.raises(DatabaseInvalidError) as excinfo:
command.run() command.run()
assert str(excinfo.value) == ("Database parameters are invalid.") assert str(excinfo.value) == ("Database parameters are invalid.")
@ -643,15 +646,17 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
@mock.patch( @mock.patch(
"superset.databases.commands.test_connection.event_logger.log_with_context" "superset.databases.commands.test_connection.event_logger.log_with_context"
) )
def test_connection_db_exception(self, mock_event_logger, mock_get_sqla_engine): @mock.patch("superset.utils.core.g")
def test_connection_db_exception(
self, mock_g, mock_event_logger, mock_get_sqla_engine
):
"""Test to make sure event_logger is called when an exception is raised""" """Test to make sure event_logger is called when an exception is raised"""
database = get_example_database() database = get_example_database()
mock_g.user = security_manager.find_user("admin")
mock_get_sqla_engine.side_effect = Exception("An error has occurred!") mock_get_sqla_engine.side_effect = Exception("An error has occurred!")
db_uri = database.sqlalchemy_uri_decrypted db_uri = database.sqlalchemy_uri_decrypted
json_payload = {"sqlalchemy_uri": db_uri} json_payload = {"sqlalchemy_uri": db_uri}
command_without_db_name = TestConnectionDatabaseCommand( command_without_db_name = TestConnectionDatabaseCommand(json_payload)
security_manager.find_user("admin"), json_payload
)
with pytest.raises(DatabaseTestConnectionUnexpectedError) as excinfo: with pytest.raises(DatabaseTestConnectionUnexpectedError) as excinfo:
command_without_db_name.run() command_without_db_name.run()
@ -664,19 +669,19 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
@mock.patch( @mock.patch(
"superset.databases.commands.test_connection.event_logger.log_with_context" "superset.databases.commands.test_connection.event_logger.log_with_context"
) )
@mock.patch("superset.utils.core.g")
def test_connection_do_ping_exception( def test_connection_do_ping_exception(
self, mock_event_logger, mock_get_sqla_engine self, mock_g, mock_event_logger, mock_get_sqla_engine
): ):
"""Test to make sure do_ping exceptions gets captured""" """Test to make sure do_ping exceptions gets captured"""
database = get_example_database() database = get_example_database()
mock_g.user = security_manager.find_user("admin")
mock_get_sqla_engine.return_value.dialect.do_ping.side_effect = Exception( mock_get_sqla_engine.return_value.dialect.do_ping.side_effect = Exception(
"An error has occurred!" "An error has occurred!"
) )
db_uri = database.sqlalchemy_uri_decrypted db_uri = database.sqlalchemy_uri_decrypted
json_payload = {"sqlalchemy_uri": db_uri} json_payload = {"sqlalchemy_uri": db_uri}
command_without_db_name = TestConnectionDatabaseCommand( command_without_db_name = TestConnectionDatabaseCommand(json_payload)
security_manager.find_user("admin"), json_payload
)
with pytest.raises(DatabaseTestConnectionFailedError) as excinfo: with pytest.raises(DatabaseTestConnectionFailedError) as excinfo:
command_without_db_name.run() command_without_db_name.run()
@ -689,15 +694,17 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
@mock.patch( @mock.patch(
"superset.databases.commands.test_connection.event_logger.log_with_context" "superset.databases.commands.test_connection.event_logger.log_with_context"
) )
def test_connection_do_ping_timeout(self, mock_event_logger, mock_func_timeout): @mock.patch("superset.utils.core.g")
def test_connection_do_ping_timeout(
self, mock_g, mock_event_logger, mock_func_timeout
):
"""Test to make sure do_ping exceptions gets captured""" """Test to make sure do_ping exceptions gets captured"""
database = get_example_database() database = get_example_database()
mock_g.user = security_manager.find_user("admin")
mock_func_timeout.side_effect = FunctionTimedOut("Time out") mock_func_timeout.side_effect = FunctionTimedOut("Time out")
db_uri = database.sqlalchemy_uri_decrypted db_uri = database.sqlalchemy_uri_decrypted
json_payload = {"sqlalchemy_uri": db_uri} json_payload = {"sqlalchemy_uri": db_uri}
command_without_db_name = TestConnectionDatabaseCommand( command_without_db_name = TestConnectionDatabaseCommand(json_payload)
security_manager.find_user("admin"), json_payload
)
with pytest.raises(SupersetTimeoutException) as excinfo: with pytest.raises(SupersetTimeoutException) as excinfo:
command_without_db_name.run() command_without_db_name.run()
@ -711,20 +718,20 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
@mock.patch( @mock.patch(
"superset.databases.commands.test_connection.event_logger.log_with_context" "superset.databases.commands.test_connection.event_logger.log_with_context"
) )
@mock.patch("superset.utils.core.g")
def test_connection_superset_security_connection( def test_connection_superset_security_connection(
self, mock_event_logger, mock_get_sqla_engine self, mock_g, mock_event_logger, mock_get_sqla_engine
): ):
"""Test to make sure event_logger is called when security """Test to make sure event_logger is called when security
connection exc is raised""" connection exc is raised"""
database = get_example_database() database = get_example_database()
mock_g.user = security_manager.find_user("admin")
mock_get_sqla_engine.side_effect = SupersetSecurityException( mock_get_sqla_engine.side_effect = SupersetSecurityException(
SupersetError(error_type=500, message="test", level="info") SupersetError(error_type=500, message="test", level="info")
) )
db_uri = database.sqlalchemy_uri_decrypted db_uri = database.sqlalchemy_uri_decrypted
json_payload = {"sqlalchemy_uri": db_uri} json_payload = {"sqlalchemy_uri": db_uri}
command_without_db_name = TestConnectionDatabaseCommand( command_without_db_name = TestConnectionDatabaseCommand(json_payload)
security_manager.find_user("admin"), json_payload
)
with pytest.raises(DatabaseSecurityUnsafeError) as excinfo: with pytest.raises(DatabaseSecurityUnsafeError) as excinfo:
command_without_db_name.run() command_without_db_name.run()
@ -736,17 +743,19 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
@mock.patch( @mock.patch(
"superset.databases.commands.test_connection.event_logger.log_with_context" "superset.databases.commands.test_connection.event_logger.log_with_context"
) )
def test_connection_db_api_exc(self, mock_event_logger, mock_get_sqla_engine): @mock.patch("superset.utils.core.g")
def test_connection_db_api_exc(
self, mock_g, mock_event_logger, mock_get_sqla_engine
):
"""Test to make sure event_logger is called when DBAPIError is raised""" """Test to make sure event_logger is called when DBAPIError is raised"""
database = get_example_database() database = get_example_database()
mock_g.user = security_manager.find_user("admin")
mock_get_sqla_engine.side_effect = DBAPIError( mock_get_sqla_engine.side_effect = DBAPIError(
statement="error", params={}, orig={} statement="error", params={}, orig={}
) )
db_uri = database.sqlalchemy_uri_decrypted db_uri = database.sqlalchemy_uri_decrypted
json_payload = {"sqlalchemy_uri": db_uri} json_payload = {"sqlalchemy_uri": db_uri}
command_without_db_name = TestConnectionDatabaseCommand( command_without_db_name = TestConnectionDatabaseCommand(json_payload)
security_manager.find_user("admin"), json_payload
)
with pytest.raises(DatabaseTestConnectionFailedError) as excinfo: with pytest.raises(DatabaseTestConnectionFailedError) as excinfo:
command_without_db_name.run() command_without_db_name.run()
@ -778,7 +787,7 @@ def test_validate(DatabaseDAO, is_port_open, is_hostname_valid, app_context):
"query": {}, "query": {},
}, },
} }
command = ValidateDatabaseParametersCommand(None, payload) command = ValidateDatabaseParametersCommand(payload)
command.run() command.run()
@ -802,7 +811,7 @@ def test_validate_partial(is_port_open, is_hostname_valid, app_context):
"query": {}, "query": {},
}, },
} }
command = ValidateDatabaseParametersCommand(None, payload) command = ValidateDatabaseParametersCommand(payload)
with pytest.raises(SupersetErrorsException) as excinfo: with pytest.raises(SupersetErrorsException) as excinfo:
command.run() command.run()
assert excinfo.value.errors == [ assert excinfo.value.errors == [
@ -841,7 +850,7 @@ def test_validate_partial_invalid_hostname(is_hostname_valid, app_context):
"query": {}, "query": {},
}, },
} }
command = ValidateDatabaseParametersCommand(None, payload) command = ValidateDatabaseParametersCommand(payload)
with pytest.raises(SupersetErrorsException) as excinfo: with pytest.raises(SupersetErrorsException) as excinfo:
command.run() command.run()
assert excinfo.value.errors == [ assert excinfo.value.errors == [

Some files were not shown because too many files have changed in this diff Show More