From 0fdb4b7c23cee51b665a4bef187051abe9d05008 Mon Sep 17 00:00:00 2001 From: Vitor Avila <96086495+Vitor-Avila@users.noreply.github.com> Date: Tue, 28 May 2024 12:41:31 -0300 Subject: [PATCH] chore(tags): Handle tagging as part of asset update call (#28570) --- .../src/components/Tags/utils.test.tsx | 6 +- .../src/components/Tags/utils.tsx | 14 +- .../components/PropertiesModal/index.tsx | 84 +--- .../components/PropertiesModal/index.tsx | 77 +--- superset-frontend/src/types/TagType.ts | 2 +- superset/charts/api.py | 4 +- superset/charts/schemas.py | 10 +- superset/commands/chart/update.py | 19 +- superset/commands/dashboard/update.py | 23 +- superset/commands/exceptions.py | 10 + superset/commands/utils.py | 85 +++- superset/dashboards/api.py | 3 + superset/dashboards/schemas.py | 4 + superset/security/manager.py | 2 + tests/integration_tests/charts/api_tests.py | 259 ++++++++++++ .../integration_tests/dashboards/api_tests.py | 273 +++++++++++++ tests/unit_tests/commands/test_utils.py | 367 +++++++++++++++++- 17 files changed, 1075 insertions(+), 167 deletions(-) diff --git a/superset-frontend/src/components/Tags/utils.test.tsx b/superset-frontend/src/components/Tags/utils.test.tsx index b6d28d60c1..e4ba71e3e2 100644 --- a/superset-frontend/src/components/Tags/utils.test.tsx +++ b/superset-frontend/src/components/Tags/utils.test.tsx @@ -21,15 +21,15 @@ import { tagToSelectOption } from 'src/components/Tags/utils'; describe('tagToSelectOption', () => { test('converts a Tag object with table_name to a SelectTagsValue', () => { const tag = { - id: '1', + id: 1, name: 'TagName', table_name: 'Table1', }; const expectedSelectTagsValue = { - value: 'TagName', + value: 1, label: 'TagName', - key: '1', + key: 1, }; expect(tagToSelectOption(tag)).toEqual(expectedSelectTagsValue); diff --git a/superset-frontend/src/components/Tags/utils.tsx b/superset-frontend/src/components/Tags/utils.tsx index f0dd4c46f6..ea3f9b1982 100644 --- a/superset-frontend/src/components/Tags/utils.tsx +++ b/superset-frontend/src/components/Tags/utils.tsx @@ -37,17 +37,17 @@ const cachedSupersetGet = cacheWrapper( ); type SelectTagsValue = { - value: string | number | undefined; - label: string; - key: string | number | undefined; + value: number | undefined; + label: string | undefined; + key: number | undefined; }; export const tagToSelectOption = ( - item: Tag & { table_name: string }, + tag: Tag & { table_name: string }, ): SelectTagsValue => ({ - value: item.name, - label: item.name, - key: item.id, + value: tag.id, + label: tag.name, + key: tag.id, }); export const loadTags = async ( diff --git a/superset-frontend/src/dashboard/components/PropertiesModal/index.tsx b/superset-frontend/src/dashboard/components/PropertiesModal/index.tsx index 8613a8db6f..12dadaa2eb 100644 --- a/superset-frontend/src/dashboard/components/PropertiesModal/index.tsx +++ b/superset-frontend/src/dashboard/components/PropertiesModal/index.tsx @@ -44,12 +44,7 @@ import ColorSchemeControlWrapper from 'src/dashboard/components/ColorSchemeContr import FilterScopeModal from 'src/dashboard/components/filterscope/FilterScopeModal'; import withToasts from 'src/components/MessageToasts/withToasts'; import TagType from 'src/types/TagType'; -import { - addTag, - deleteTaggedObjects, - fetchTags, - OBJECT_TYPES, -} from 'src/features/tags/tags'; +import { fetchTags, OBJECT_TYPES } from 'src/features/tags/tags'; import { loadTags } from 'src/components/Tags/utils'; const StyledFormItem = styled(FormItem)` @@ -115,10 +110,9 @@ const PropertiesModal = ({ const categoricalSchemeRegistry = getCategoricalSchemeRegistry(); const tagsAsSelectValues = useMemo(() => { - const selectTags = tags.map(tag => ({ - value: tag.name, + const selectTags = tags.map((tag: { id: number; name: string }) => ({ + value: tag.id, label: tag.name, - key: tag.name, })); return selectTags; }, [tags.length]); @@ -309,41 +303,6 @@ const PropertiesModal = ({ setColorScheme(colorScheme); }; - const updateTags = (oldTags: TagType[], newTags: TagType[]) => { - // update the tags for this object - // add tags that are in new tags, but not in old tags - // eslint-disable-next-line array-callback-return - newTags.map((tag: TagType) => { - if (!oldTags.some(t => t.name === tag.name)) { - addTag( - { - objectType: OBJECT_TYPES.DASHBOARD, - objectId: dashboardId, - includeTypes: false, - }, - tag.name, - () => {}, - () => {}, - ); - } - }); - // delete tags that are in old tags, but not in new tags - // eslint-disable-next-line array-callback-return - oldTags.map((tag: TagType) => { - if (!newTags.some(t => t.name === tag.name)) { - deleteTaggedObjects( - { - objectType: OBJECT_TYPES.DASHBOARD, - objectId: dashboardId, - }, - tag, - () => {}, - () => {}, - ); - } - }); - }; - const onFinish = () => { const { title, slug, certifiedBy, certificationDetails } = form.getFieldsValue(); @@ -401,31 +360,16 @@ const PropertiesModal = ({ updateMetadata: false, }); - if (isFeatureEnabled(FeatureFlag.TaggingSystem)) { - // update tags - try { - fetchTags( - { - objectType: OBJECT_TYPES.DASHBOARD, - objectId: dashboardId, - includeTypes: false, - }, - (currentTags: TagType[]) => updateTags(currentTags, tags), - error => { - handleErrorResponse(error); - }, - ); - } catch (error) { - handleErrorResponse(error); - } - } - const moreOnSubmitProps: { roles?: Roles } = {}; - const morePutProps: { roles?: number[] } = {}; + const morePutProps: { roles?: number[]; tags?: (number | undefined)[] } = + {}; if (isFeatureEnabled(FeatureFlag.DashboardRbac)) { moreOnSubmitProps.roles = roles; morePutProps.roles = (roles || []).map(r => r.id); } + if (isFeatureEnabled(FeatureFlag.TaggingSystem)) { + morePutProps.tags = tags.map(tag => tag.id); + } const onSubmitProps = { id: dashboardId, title, @@ -621,12 +565,12 @@ const PropertiesModal = ({ } }, [dashboardId]); - const handleChangeTags = (values: { label: string; value: number }[]) => { - // triggered whenever a new tag is selected or a tag was deselected - // on new tag selected, add the tag - - const uniqueTags = [...new Set(values.map(v => v.label))]; - setTags([...uniqueTags.map(t => ({ name: t }))]); + const handleChangeTags = (tags: { label: string; value: number }[]) => { + const parsedTags: TagType[] = ensureIsArray(tags).map(r => ({ + id: r.value, + name: r.label, + })); + setTags(parsedTags); }; return ( diff --git a/superset-frontend/src/explore/components/PropertiesModal/index.tsx b/superset-frontend/src/explore/components/PropertiesModal/index.tsx index 45af79a54b..39b4224b93 100644 --- a/superset-frontend/src/explore/components/PropertiesModal/index.tsx +++ b/superset-frontend/src/explore/components/PropertiesModal/index.tsx @@ -30,16 +30,12 @@ import { isFeatureEnabled, FeatureFlag, getClientErrorObject, + ensureIsArray, } from '@superset-ui/core'; import Chart, { Slice } from 'src/types/Chart'; import withToasts from 'src/components/MessageToasts/withToasts'; import { loadTags } from 'src/components/Tags/utils'; -import { - addTag, - deleteTaggedObjects, - fetchTags, - OBJECT_TYPES, -} from 'src/features/tags/tags'; +import { fetchTags, OBJECT_TYPES } from 'src/features/tags/tags'; import TagType from 'src/types/TagType'; export type PropertiesModalProps = { @@ -80,10 +76,9 @@ function PropertiesModal({ const [tags, setTags] = useState([]); const tagsAsSelectValues = useMemo(() => { - const selectTags = tags.map(tag => ({ - value: tag.name, + const selectTags = tags.map((tag: { id: number; name: string }) => ({ + value: tag.id, label: tag.name, - key: tag.name, })); return selectTags; }, [tags.length]); @@ -144,41 +139,6 @@ function PropertiesModal({ [], ); - const updateTags = (oldTags: TagType[], newTags: TagType[]) => { - // update the tags for this object - // add tags that are in new tags, but not in old tags - // eslint-disable-next-line array-callback-return - newTags.map((tag: TagType) => { - if (!oldTags.some(t => t.name === tag.name)) { - addTag( - { - objectType: OBJECT_TYPES.CHART, - objectId: slice.slice_id, - includeTypes: false, - }, - tag.name, - () => {}, - () => {}, - ); - } - }); - // delete tags that are in old tags, but not in new tags - // eslint-disable-next-line array-callback-return - oldTags.map((tag: TagType) => { - if (!newTags.some(t => t.name === tag.name)) { - deleteTaggedObjects( - { - objectType: OBJECT_TYPES.CHART, - objectId: slice.slice_id, - }, - tag, - () => {}, - () => {}, - ); - } - }); - }; - const onSubmit = async (values: { certified_by?: string; certification_details?: string; @@ -209,22 +169,7 @@ function PropertiesModal({ ).map(o => o.value); } if (isFeatureEnabled(FeatureFlag.TaggingSystem)) { - // update tags - try { - fetchTags( - { - objectType: OBJECT_TYPES.CHART, - objectId: slice.slice_id, - includeTypes: false, - }, - (currentTags: TagType[]) => updateTags(currentTags, tags), - error => { - showError(error); - }, - ); - } catch (error) { - showError(error); - } + payload.tags = tags.map(tag => tag.id); } try { @@ -282,12 +227,12 @@ function PropertiesModal({ } }, [slice.slice_id]); - const handleChangeTags = (values: { label: string; value: number }[]) => { - // triggered whenever a new tag is selected or a tag was deselected - // on new tag selected, add the tag - - const uniqueTags = [...new Set(values.map(v => v.label))]; - setTags([...uniqueTags.map(t => ({ name: t }))]); + const handleChangeTags = (tags: { label: string; value: number }[]) => { + const parsedTags: TagType[] = ensureIsArray(tags).map(r => ({ + id: r.value, + name: r.label, + })); + setTags(parsedTags); }; const handleClearTags = () => { diff --git a/superset-frontend/src/types/TagType.ts b/superset-frontend/src/types/TagType.ts index 8f445e50da..0ea5f44d2d 100644 --- a/superset-frontend/src/types/TagType.ts +++ b/superset-frontend/src/types/TagType.ts @@ -20,7 +20,7 @@ import { MouseEventHandler } from 'react'; export interface TagType { - id?: string | number; + id?: number; type?: string | number; editable?: boolean; onDelete?: (index: number) => void; diff --git a/superset/charts/api.py b/superset/charts/api.py index 05eb0ab0c2..4034c8ef27 100644 --- a/superset/charts/api.py +++ b/superset/charts/api.py @@ -69,7 +69,7 @@ from superset.commands.chart.export import ExportChartsCommand from superset.commands.chart.importers.dispatcher import ImportChartsCommand from superset.commands.chart.update import UpdateChartCommand from superset.commands.chart.warm_up_cache import ChartWarmUpCacheCommand -from superset.commands.exceptions import CommandException +from superset.commands.exceptions import CommandException, TagForbiddenError from superset.commands.importers.exceptions import ( IncorrectFormatError, NoValidFilesFoundError, @@ -404,6 +404,8 @@ class ChartRestApi(BaseSupersetModelRestApi): response = self.response_404() except ChartForbiddenError: response = self.response_403() + except TagForbiddenError as ex: + response = self.response(403, message=str(ex)) except ChartInvalidError as ex: response = self.response_422(message=ex.normalized_messages()) except ChartUpdateFailedError as ex: diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index 611f7af597..89e47a9dcb 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -27,7 +27,6 @@ from marshmallow.validate import Length, Range from superset import app from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType from superset.db_engine_specs.base import builtin_time_grains -from superset.tags.models import TagType from superset.utils import pandas_postprocessing, schema as utils from superset.utils.core import ( AnnotationType, @@ -122,6 +121,7 @@ description_markeddown_description = "Sanitized HTML version of the chart descri owners_name_description = "Name of an owner of the chart." certified_by_description = "Person or group that has certified this chart" certification_details_description = "Details of the certification" +tags_description = "Tags to be associated with the chart" openapi_spec_methods_override = { "get": {"get": {"summary": "Get a chart detail information"}}, @@ -143,12 +143,6 @@ openapi_spec_methods_override = { } -class TagSchema(Schema): - id = fields.Int() - name = fields.String() - type = fields.Enum(TagType, by_value=True) - - class ChartEntityResponseSchema(Schema): """ Schema for a chart object @@ -284,7 +278,7 @@ class ChartPutSchema(Schema): ) is_managed_externally = fields.Boolean(allow_none=True, dump_default=False) external_url = fields.String(allow_none=True) - tags = fields.Nested(TagSchema, many=True) + tags = fields.List(fields.Integer(metadata={"description": tags_description})) class ChartGetDatasourceObjectDataResponseSchema(Schema): diff --git a/superset/commands/chart/update.py b/superset/commands/chart/update.py index 178344634e..74b1c30aa8 100644 --- a/superset/commands/chart/update.py +++ b/superset/commands/chart/update.py @@ -32,12 +32,13 @@ from superset.commands.chart.exceptions import ( DashboardsNotFoundValidationError, DatasourceTypeUpdateRequiredValidationError, ) -from superset.commands.utils import get_datasource_by_id +from superset.commands.utils import get_datasource_by_id, update_tags, validate_tags from superset.daos.chart import ChartDAO from superset.daos.dashboard import DashboardDAO -from superset.daos.exceptions import DAOUpdateFailedError +from superset.daos.exceptions import DAODeleteFailedError, DAOUpdateFailedError from superset.exceptions import SupersetSecurityException from superset.models.slice import Slice +from superset.tags.models import ObjectType logger = logging.getLogger(__name__) @@ -59,11 +60,16 @@ class UpdateChartCommand(UpdateMixin, BaseCommand): assert self._model try: + # Update tags + tags = self._properties.pop("tags", None) + if tags is not None: + update_tags(ObjectType.chart, self._model.id, self._model.tags, tags) + if self._properties.get("query_context_generation") is None: self._properties["last_saved_at"] = datetime.now() self._properties["last_saved_by"] = g.user chart = ChartDAO.update(self._model, self._properties) - except DAOUpdateFailedError as ex: + except (DAOUpdateFailedError, DAODeleteFailedError) as ex: logger.exception(ex.exception) raise ChartUpdateFailedError() from ex return chart @@ -72,6 +78,7 @@ class UpdateChartCommand(UpdateMixin, BaseCommand): exceptions: list[ValidationError] = [] dashboard_ids = self._properties.get("dashboards") owner_ids: Optional[list[int]] = self._properties.get("owners") + tag_ids: Optional[list[int]] = self._properties.get("tags") # Validate if datasource_id is provided datasource_type is required datasource_id = self._properties.get("datasource_id") @@ -100,6 +107,12 @@ class UpdateChartCommand(UpdateMixin, BaseCommand): except ValidationError as ex: exceptions.append(ex) + # validate tags + try: + validate_tags(ObjectType.chart, self._model.tags, tag_ids) + except ValidationError as ex: + exceptions.append(ex) + # Validate/Populate datasource if datasource_id is not None: try: diff --git a/superset/commands/dashboard/update.py b/superset/commands/dashboard/update.py index b2b11e5f6e..d35fb6b28e 100644 --- a/superset/commands/dashboard/update.py +++ b/superset/commands/dashboard/update.py @@ -30,12 +30,13 @@ from superset.commands.dashboard.exceptions import ( DashboardSlugExistsValidationError, DashboardUpdateFailedError, ) -from superset.commands.utils import populate_roles +from superset.commands.utils import populate_roles, update_tags, validate_tags from superset.daos.dashboard import DashboardDAO -from superset.daos.exceptions import DAOUpdateFailedError +from superset.daos.exceptions import DAODeleteFailedError, DAOUpdateFailedError from superset.exceptions import SupersetSecurityException from superset.extensions import db from superset.models.dashboard import Dashboard +from superset.tags.models import ObjectType logger = logging.getLogger(__name__) @@ -51,6 +52,13 @@ class UpdateDashboardCommand(UpdateMixin, BaseCommand): assert self._model try: + # Update tags + tags = self._properties.pop("tags", None) + if tags is not None: + update_tags( + ObjectType.dashboard, self._model.id, self._model.tags, tags + ) + dashboard = DashboardDAO.update(self._model, self._properties, commit=False) if self._properties.get("json_metadata"): dashboard = DashboardDAO.set_dash_metadata( @@ -59,7 +67,7 @@ class UpdateDashboardCommand(UpdateMixin, BaseCommand): commit=False, ) db.session.commit() - except DAOUpdateFailedError as ex: + except (DAOUpdateFailedError, DAODeleteFailedError) as ex: logger.exception(ex.exception) raise DashboardUpdateFailedError() from ex return dashboard @@ -69,6 +77,7 @@ class UpdateDashboardCommand(UpdateMixin, BaseCommand): owner_ids: Optional[list[int]] = self._properties.get("owners") roles_ids: Optional[list[int]] = self._properties.get("roles") slug: Optional[str] = self._properties.get("slug") + tag_ids: Optional[list[int]] = self._properties.get("tags") # Validate/populate model exists self._model = DashboardDAO.find_by_id(self._model_id) @@ -93,8 +102,12 @@ class UpdateDashboardCommand(UpdateMixin, BaseCommand): self._properties["owners"] = owners except ValidationError as ex: exceptions.append(ex) - if exceptions: - raise DashboardInvalidError(exceptions=exceptions) + + # validate tags + try: + validate_tags(ObjectType.dashboard, self._model.tags, tag_ids) + except ValidationError as ex: + exceptions.append(ex) # Validate/Populate role if roles_ids is None: diff --git a/superset/commands/exceptions.py b/superset/commands/exceptions.py index 4fb36c5b57..7fc89ac1d9 100644 --- a/superset/commands/exceptions.py +++ b/superset/commands/exceptions.py @@ -140,3 +140,13 @@ class QueryNotFoundValidationError(ValidationError): def __init__(self) -> None: super().__init__([_("Query does not exist")], field_name="datasource_id") + + +class TagNotFoundValidationError(ValidationError): + def __init__(self, message: str) -> None: + super().__init__(message, field_name="tags") + + +class TagForbiddenError(ForbiddenError): + def __init__(self, message: str) -> None: + super().__init__(message) diff --git a/superset/commands/utils.py b/superset/commands/utils.py index f01c96ba28..29a31aa2aa 100644 --- a/superset/commands/utils.py +++ b/superset/commands/utils.py @@ -16,7 +16,8 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING +from collections import Counter +from typing import Optional, TYPE_CHECKING from flask import g from flask_appbuilder.security.sqla.models import Role, User @@ -26,9 +27,13 @@ from superset.commands.exceptions import ( DatasourceNotFoundValidationError, OwnersNotFoundValidationError, RolesNotFoundValidationError, + TagForbiddenError, + TagNotFoundValidationError, ) from superset.daos.datasource import DatasourceDAO from superset.daos.exceptions import DatasourceNotFound +from superset.daos.tag import TagDAO +from superset.tags.models import ObjectType, Tag, TagType from superset.utils.core import DatasourceType, get_user_id if TYPE_CHECKING: @@ -102,3 +107,81 @@ def get_datasource_by_id(datasource_id: int, datasource_type: str) -> BaseDataso ) except DatasourceNotFound as ex: raise DatasourceNotFoundValidationError() from ex + + +def validate_tags( + object_type: ObjectType, + current_tags: list[Tag], + new_tag_ids: Optional[list[int]], +) -> None: + """ + Helper function for update commands, to validate the tags list. Users + with `can_write` on `Tag` are allowed to both create new tags and manage + tag association with objects. Users with `can_tag` on `object_type` are + only allowed to manage existing existing tags' associations with the object. + + :param current_tags: list of current tags + :param new_tag_ids: list of tags specified in the update payload + """ + + # `tags` not part of the update payload + if new_tag_ids is None: + return + + # No changes in the list + current_custom_tags = [tag.id for tag in current_tags if tag.type == TagType.custom] + if Counter(current_custom_tags) == Counter(new_tag_ids): + return + + # No perm to tags assets + if not ( + security_manager.can_access("can_write", "Tag") + or security_manager.can_access("can_tag", object_type.name.capitalize()) + ): + validation_error = ( + f"You do not have permission to manage tags on {object_type.name}s" + ) + raise TagForbiddenError(validation_error) + + # Validate if new tags already exist + additional_tags = [tag for tag in new_tag_ids if tag not in current_custom_tags] + for tag_id in additional_tags: + if not TagDAO.find_by_id(tag_id): + validation_error = f"Tag ID {tag_id} not found" + raise TagNotFoundValidationError(validation_error) + + return + + +def update_tags( + object_type: ObjectType, + object_id: int, + current_tags: list[Tag], + new_tag_ids: list[int], +) -> None: + """ + Helper function for update commands, to update the tag relationship. + + :param object_id: The object (dashboard, chart, etc) ID + :param object_type: The object type + :param current_tags: list of current tags + :param new_tag_ids: list of tags specified in the update payload + """ + + current_custom_tags = [tag for tag in current_tags if tag.type == TagType.custom] + current_custom_tag_ids = [ + tag.id for tag in current_tags if tag.type == TagType.custom + ] + + tags_to_delete = [tag for tag in current_custom_tags if tag.id not in new_tag_ids] + for tag in tags_to_delete: + TagDAO.delete_tagged_object(object_type, object_id, tag.name) + + tag_ids_to_add = [ + tag_id for tag_id in new_tag_ids if tag_id not in current_custom_tag_ids + ] + if tag_ids_to_add: + tags_to_add = TagDAO.find_by_ids(tag_ids_to_add) + TagDAO.create_custom_tagged_objects( + object_type, object_id, [tag.name for tag in tags_to_add] + ) diff --git a/superset/dashboards/api.py b/superset/dashboards/api.py index 375c384660..ff6f5f5f48 100644 --- a/superset/dashboards/api.py +++ b/superset/dashboards/api.py @@ -49,6 +49,7 @@ from superset.commands.dashboard.exceptions import ( from superset.commands.dashboard.export import ExportDashboardsCommand from superset.commands.dashboard.importers.dispatcher import ImportDashboardsCommand from superset.commands.dashboard.update import UpdateDashboardCommand +from superset.commands.exceptions import TagForbiddenError from superset.commands.importers.exceptions import NoValidFilesFoundError from superset.commands.importers.v1.utils import get_contents_from_bundle from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod @@ -577,6 +578,8 @@ class DashboardRestApi(BaseSupersetModelRestApi): response = self.response_404() except DashboardForbiddenError: response = self.response_403() + except TagForbiddenError as ex: + response = self.response(403, message=str(ex)) except DashboardInvalidError as ex: return self.response_422(message=ex.normalized_messages()) except DashboardUpdateFailedError as ex: diff --git a/superset/dashboards/schemas.py b/superset/dashboards/schemas.py index c90da33734..60eb50918e 100644 --- a/superset/dashboards/schemas.py +++ b/superset/dashboards/schemas.py @@ -68,6 +68,7 @@ charts_description = ( ) certified_by_description = "Person or group that has certified this dashboard" certification_details_description = "Details of the certification" +tags_description = "Tags to be associated with the dashboard" openapi_spec_methods_override = { "get": {"get": {"summary": "Get a dashboard detail information"}}, @@ -369,6 +370,9 @@ class DashboardPutSchema(BaseDashboardSchema): ) is_managed_externally = fields.Boolean(allow_none=True, dump_default=False) external_url = fields.String(allow_none=True) + tags = fields.List( + fields.Integer(metadata={"description": tags_description}, allow_none=True) + ) class ChartFavStarResponseResult(Schema): diff --git a/superset/security/manager.py b/superset/security/manager.py index b5ca455ce6..009fd662fd 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -968,6 +968,8 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods self.add_permission_view_menu("can_view_query", "Dashboard") self.add_permission_view_menu("can_view_chart_as_table", "Dashboard") self.add_permission_view_menu("can_drill", "Dashboard") + self.add_permission_view_menu("can_tag", "Chart") + self.add_permission_view_menu("can_tag", "Dashboard") def create_missing_perms(self) -> None: """ diff --git a/tests/integration_tests/charts/api_tests.py b/tests/integration_tests/charts/api_tests.py index 65ede9221c..3b28cfbcaa 100644 --- a/tests/integration_tests/charts/api_tests.py +++ b/tests/integration_tests/charts/api_tests.py @@ -37,6 +37,7 @@ from superset.models.core import Database, FavStar, FavStarClassName from superset.models.dashboard import Dashboard from superset.models.slice import Slice from superset.reports.models import ReportSchedule, ReportScheduleType +from superset.tags.models import ObjectType, Tag, TaggedObject, TagType from superset.utils.core import get_example_default_schema from superset.utils.database import get_example_database # noqa: F401 from superset.viz import viz_types # noqa: F401 @@ -199,6 +200,53 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase): db.session.delete(self.chart) db.session.commit() + @pytest.fixture() + def create_custom_tags(self): + with self.create_app().app_context(): + tags: list[Tag] = [] + for tag_name in {"one_tag", "new_tag"}: + tag = Tag( + name=tag_name, + type="custom", + ) + db.session.add(tag) + db.session.commit() + tags.append(tag) + + yield tags + + for tags in tags: + db.session.delete(tags) + db.session.commit() + + @pytest.fixture() + def create_chart_with_tag(self, create_custom_tags): + with self.create_app().app_context(): + alpha_user = self.get_user(ALPHA_USERNAME) + + chart = self.insert_chart( + "chart with tag", + [alpha_user.id], + 1, + ) + + tag = db.session.query(Tag).filter(Tag.name == "one_tag").first() + tag_association = TaggedObject( + object_id=chart.id, + object_type=ObjectType.chart, + tag=tag, + ) + + db.session.add(tag_association) + db.session.commit() + + yield chart + + # rollback changes + db.session.delete(tag_association) + db.session.delete(chart) + db.session.commit() + def test_info_security_chart(self): """ Chart API: Test info security @@ -2000,3 +2048,214 @@ class TestChartApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCase): }, ], } + + @pytest.mark.usefixtures("create_chart_with_tag") + def test_update_chart_add_tags_can_write_on_tag(self): + """ + Validates a user with can write on tag permission can + add tags while updating a chart + """ + self.login(ADMIN_USERNAME) + + chart = ( + db.session.query(Slice).filter(Slice.slice_name == "chart with tag").first() + ) + new_tag = db.session.query(Tag).filter(Tag.name == "new_tag").one() + + # get existing tag and add a new one + new_tags = [tag.id for tag in chart.tags if tag.type == TagType.custom] + new_tags.append(new_tag.id) + update_payload = {"tags": new_tags} + + uri = f"api/v1/chart/{chart.id}" + rv = self.put_assert_metric(uri, update_payload, "put") + self.assertEqual(rv.status_code, 200) + model = db.session.query(Slice).get(chart.id) + + # Clean up system tags + tag_list = [tag.id for tag in model.tags if tag.type == TagType.custom] + self.assertEqual(tag_list, new_tags) + + @pytest.mark.usefixtures("create_chart_with_tag") + def test_update_chart_remove_tags_can_write_on_tag(self): + """ + Validates a user with can write on tag permission can + remove tags while updating a chart + """ + self.login(ADMIN_USERNAME) + + chart = ( + db.session.query(Slice).filter(Slice.slice_name == "chart with tag").first() + ) + + # get existing tag and add a new one + new_tags = [tag.id for tag in chart.tags if tag.type == TagType.custom] + new_tags.pop() + + update_payload = {"tags": new_tags} + + uri = f"api/v1/chart/{chart.id}" + rv = self.put_assert_metric(uri, update_payload, "put") + self.assertEqual(rv.status_code, 200) + model = db.session.query(Slice).get(chart.id) + + # Clean up system tags + tag_list = [tag.id for tag in model.tags if tag.type == TagType.custom] + self.assertEqual(tag_list, new_tags) + + @pytest.mark.usefixtures("create_chart_with_tag") + def test_update_chart_add_tags_can_tag_on_chart(self): + """ + Validates an owner with can tag on chart permission can + add tags while updating a chart + """ + self.login(ALPHA_USERNAME) + + alpha_role = security_manager.find_role("Alpha") + write_tags_perm = security_manager.add_permission_view_menu("can_write", "Tag") + security_manager.del_permission_role(alpha_role, write_tags_perm) + assert "can tag on Chart" in str(alpha_role.permissions) + + chart = ( + db.session.query(Slice).filter(Slice.slice_name == "chart with tag").first() + ) + new_tag = db.session.query(Tag).filter(Tag.name == "new_tag").one() + + # get existing tag and add a new one + new_tags = [tag.id for tag in chart.tags if tag.type == TagType.custom] + new_tags.append(new_tag.id) + update_payload = {"tags": new_tags} + + uri = f"api/v1/chart/{chart.id}" + rv = self.put_assert_metric(uri, update_payload, "put") + self.assertEqual(rv.status_code, 200) + model = db.session.query(Slice).get(chart.id) + + # Clean up system tags + tag_list = [tag.id for tag in model.tags if tag.type == TagType.custom] + self.assertEqual(tag_list, new_tags) + + security_manager.add_permission_role(alpha_role, write_tags_perm) + + @pytest.mark.usefixtures("create_chart_with_tag") + def test_update_chart_remove_tags_can_tag_on_chart(self): + """ + Validates an owner with can tag on chart permission can + remove tags from a chart + """ + self.login(ALPHA_USERNAME) + + alpha_role = security_manager.find_role("Alpha") + write_tags_perm = security_manager.add_permission_view_menu("can_write", "Tag") + security_manager.del_permission_role(alpha_role, write_tags_perm) + assert "can tag on Chart" in str(alpha_role.permissions) + + chart = ( + db.session.query(Slice).filter(Slice.slice_name == "chart with tag").first() + ) + + update_payload = {"tags": []} + + uri = f"api/v1/chart/{chart.id}" + rv = self.put_assert_metric(uri, update_payload, "put") + self.assertEqual(rv.status_code, 200) + model = db.session.query(Slice).get(chart.id) + + # Clean up system tags + tag_list = [tag.id for tag in model.tags if tag.type == TagType.custom] + self.assertEqual(tag_list, []) + + security_manager.add_permission_role(alpha_role, write_tags_perm) + + @pytest.mark.usefixtures("create_chart_with_tag") + def test_update_chart_add_tags_missing_permission(self): + """ + Validates an owner can't add tags to a chart if they don't + have permission to it + """ + self.login(ALPHA_USERNAME) + + alpha_role = security_manager.find_role("Alpha") + write_tags_perm = security_manager.add_permission_view_menu("can_write", "Tag") + tag_charts_perm = security_manager.add_permission_view_menu("can_tag", "Chart") + security_manager.del_permission_role(alpha_role, write_tags_perm) + security_manager.del_permission_role(alpha_role, tag_charts_perm) + + chart = ( + db.session.query(Slice).filter(Slice.slice_name == "chart with tag").first() + ) + new_tag = db.session.query(Tag).filter(Tag.name == "new_tag").one() + + # get existing tag and add a new one + new_tags = [tag.id for tag in chart.tags if tag.type == TagType.custom] + new_tags.append(new_tag.id) + update_payload = {"tags": new_tags} + + uri = f"api/v1/chart/{chart.id}" + rv = self.put_assert_metric(uri, update_payload, "put") + self.assertEqual(rv.status_code, 403) + self.assertEqual( + rv.json["message"], + "You do not have permission to manage tags on charts", + ) + + security_manager.add_permission_role(alpha_role, write_tags_perm) + security_manager.add_permission_role(alpha_role, tag_charts_perm) + + @pytest.mark.usefixtures("create_chart_with_tag") + def test_update_chart_remove_tags_missing_permission(self): + """ + Validates an owner can't remove tags from a chart if they don't + have permission to it + """ + self.login(ALPHA_USERNAME) + + alpha_role = security_manager.find_role("Alpha") + write_tags_perm = security_manager.add_permission_view_menu("can_write", "Tag") + tag_charts_perm = security_manager.add_permission_view_menu("can_tag", "Chart") + security_manager.del_permission_role(alpha_role, write_tags_perm) + security_manager.del_permission_role(alpha_role, tag_charts_perm) + + chart = ( + db.session.query(Slice).filter(Slice.slice_name == "chart with tag").first() + ) + + update_payload = {"tags": []} + + uri = f"api/v1/chart/{chart.id}" + rv = self.put_assert_metric(uri, update_payload, "put") + self.assertEqual(rv.status_code, 403) + self.assertEqual( + rv.json["message"], + "You do not have permission to manage tags on charts", + ) + + security_manager.add_permission_role(alpha_role, write_tags_perm) + security_manager.add_permission_role(alpha_role, tag_charts_perm) + + @pytest.mark.usefixtures("create_chart_with_tag") + def test_update_chart_no_tag_changes(self): + """ + Validates an owner without permission to change tags is able to + update a chart when tags haven't changed + """ + self.login(ALPHA_USERNAME) + + alpha_role = security_manager.find_role("Alpha") + write_tags_perm = security_manager.add_permission_view_menu("can_write", "Tag") + tag_charts_perm = security_manager.add_permission_view_menu("can_tag", "Chart") + security_manager.del_permission_role(alpha_role, write_tags_perm) + security_manager.del_permission_role(alpha_role, tag_charts_perm) + + chart = ( + db.session.query(Slice).filter(Slice.slice_name == "chart with tag").first() + ) + existing_tags = [tag.id for tag in chart.tags if tag.type == TagType.custom] + update_payload = {"tags": existing_tags} + + uri = f"api/v1/chart/{chart.id}" + rv = self.put_assert_metric(uri, update_payload, "put") + self.assertEqual(rv.status_code, 200) + + security_manager.add_permission_role(alpha_role, write_tags_perm) + security_manager.add_permission_role(alpha_role, tag_charts_perm) diff --git a/tests/integration_tests/dashboards/api_tests.py b/tests/integration_tests/dashboards/api_tests.py index 328fb5774e..614524cb2d 100644 --- a/tests/integration_tests/dashboards/api_tests.py +++ b/tests/integration_tests/dashboards/api_tests.py @@ -36,6 +36,7 @@ from superset.models.dashboard import Dashboard from superset.models.core import FavStar, FavStarClassName from superset.reports.models import ReportSchedule, ReportScheduleType from superset.models.slice import Slice +from superset.tags.models import Tag, TaggedObject, TagType, ObjectType from superset.utils.core import backend, override_user from tests.integration_tests.base_api_tests import ApiOwnersTestCaseMixin @@ -168,6 +169,52 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas db.session.delete(dashboard) db.session.commit() + @pytest.fixture() + def create_custom_tags(self): + with self.create_app().app_context(): + tags: list[Tag] = [] + for tag_name in {"one_tag", "new_tag"}: + tag = Tag( + name=tag_name, + type="custom", + ) + db.session.add(tag) + db.session.commit() + tags.append(tag) + + yield tags + + for tags in tags: + db.session.delete(tags) + db.session.commit() + + @pytest.fixture() + def create_dashboard_with_tag(self, create_custom_tags): + with self.create_app().app_context(): + gamma = self.get_user("gamma") + + dashboard = self.insert_dashboard( + "dash with tag", + None, + [gamma.id], + ) + tag = db.session.query(Tag).filter(Tag.name == "one_tag").first() + tag_association = TaggedObject( + object_id=dashboard.id, + object_type=ObjectType.dashboard, + tag=tag, + ) + + db.session.add(tag_association) + db.session.commit() + + yield dashboard + + # rollback changes + db.session.delete(tag_association) + db.session.delete(dashboard) + db.session.commit() + @pytest.mark.usefixtures("load_world_bank_dashboard_with_slices") def test_get_dashboard_datasets(self): self.login(ADMIN_USERNAME) @@ -2263,3 +2310,229 @@ class TestDashboardApi(ApiOwnersTestCaseMixin, InsertChartMixin, SupersetTestCas db.session.delete(dash) db.session.commit() + + @pytest.mark.usefixtures("create_dashboard_with_tag") + def test_update_dashboard_add_tags_can_write_on_tag(self): + """ + Validates a user with can write on tag permission can + add tags while updating a dashboard + """ + self.login(ADMIN_USERNAME) + + dashboard = ( + db.session.query(Dashboard) + .filter(Dashboard.dashboard_title == "dash with tag") + .first() + ) + new_tag = db.session.query(Tag).filter(Tag.name == "new_tag").one() + + # get existing tag and add a new one + new_tags = [tag.id for tag in dashboard.tags if tag.type == TagType.custom] + new_tags.append(new_tag.id) + update_payload = {"tags": new_tags} + + uri = f"api/v1/dashboard/{dashboard.id}" + rv = self.put_assert_metric(uri, update_payload, "put") + self.assertEqual(rv.status_code, 200) + model = db.session.query(Dashboard).get(dashboard.id) + + # Clean up system tags + tag_list = [tag.id for tag in model.tags if tag.type == TagType.custom] + self.assertEqual(tag_list, new_tags) + + @pytest.mark.usefixtures("create_dashboard_with_tag") + def test_update_dashboard_remove_tags_can_write_on_tag(self): + """ + Validates a user with can write on tag permission can + remove tags while updating a dashboard + """ + self.login(ADMIN_USERNAME) + + dashboard = ( + db.session.query(Dashboard) + .filter(Dashboard.dashboard_title == "dash with tag") + .first() + ) + + # get existing tag and add a new one + new_tags = [tag.id for tag in dashboard.tags if tag.type == TagType.custom] + new_tags.pop() + + update_payload = {"tags": new_tags} + + uri = f"api/v1/dashboard/{dashboard.id}" + rv = self.put_assert_metric(uri, update_payload, "put") + self.assertEqual(rv.status_code, 200) + model = db.session.query(Dashboard).get(dashboard.id) + + # Clean up system tags + tag_list = [tag.id for tag in model.tags if tag.type == TagType.custom] + self.assertEqual(tag_list, new_tags) + + @pytest.mark.usefixtures("create_dashboard_with_tag") + def test_update_dashboard_add_tags_can_tag_on_dashboard(self): + """ + Validates an owner with can tag on dashboard permission can + add tags while updating a dashboard + """ + self.login(GAMMA_USERNAME) + write_tags_perm = security_manager.add_permission_view_menu("can_write", "Tag") + gamma_role = security_manager.find_role("Gamma") + security_manager.del_permission_role(gamma_role, write_tags_perm) + assert "can tag on Dashboard" in str(gamma_role.permissions) + + dashboard = ( + db.session.query(Dashboard) + .filter(Dashboard.dashboard_title == "dash with tag") + .first() + ) + new_tag = db.session.query(Tag).filter(Tag.name == "new_tag").one() + + # get existing tag and add a new one + new_tags = [tag.id for tag in dashboard.tags if tag.type == TagType.custom] + new_tags.append(new_tag.id) + update_payload = {"tags": new_tags} + + uri = f"api/v1/dashboard/{dashboard.id}" + rv = self.put_assert_metric(uri, update_payload, "put") + self.assertEqual(rv.status_code, 200) + model = db.session.query(Dashboard).get(dashboard.id) + + # Clean up system tags + tag_list = [tag.id for tag in model.tags if tag.type == TagType.custom] + self.assertEqual(tag_list, new_tags) + + security_manager.add_permission_role(gamma_role, write_tags_perm) + + @pytest.mark.usefixtures("create_dashboard_with_tag") + def test_update_dashboard_remove_tags_can_tag_on_dashboard(self): + """ + Validates an owner with can tag on dashboard permission can + remove tags from a dashboard + """ + self.login(GAMMA_USERNAME) + write_tags_perm = security_manager.add_permission_view_menu("can_write", "Tag") + gamma_role = security_manager.find_role("Gamma") + security_manager.del_permission_role(gamma_role, write_tags_perm) + assert "can tag on Dashboard" in str(gamma_role.permissions) + + dashboard = ( + db.session.query(Dashboard) + .filter(Dashboard.dashboard_title == "dash with tag") + .first() + ) + + update_payload = {"tags": []} + + uri = f"api/v1/dashboard/{dashboard.id}" + rv = self.put_assert_metric(uri, update_payload, "put") + self.assertEqual(rv.status_code, 200) + model = db.session.query(Dashboard).get(dashboard.id) + + # Clean up system tags + tag_list = [tag.id for tag in model.tags if tag.type == TagType.custom] + self.assertEqual(tag_list, []) + + security_manager.add_permission_role(gamma_role, write_tags_perm) + + @pytest.mark.usefixtures("create_dashboard_with_tag") + def test_update_dashboard_add_tags_missing_permission(self): + """ + Validates an owner can't add tags to a dashboard if they don't + have permission to it + """ + self.login(GAMMA_USERNAME) + write_tags_perm = security_manager.add_permission_view_menu("can_write", "Tag") + tag_dashboards_perm = security_manager.add_permission_view_menu( + "can_tag", "Dashboard" + ) + gamma_role = security_manager.find_role("Gamma") + security_manager.del_permission_role(gamma_role, write_tags_perm) + security_manager.del_permission_role(gamma_role, tag_dashboards_perm) + + dashboard = ( + db.session.query(Dashboard) + .filter(Dashboard.dashboard_title == "dash with tag") + .first() + ) + new_tag = db.session.query(Tag).filter(Tag.name == "new_tag").one() + + # get existing tag and add a new one + new_tags = [tag.id for tag in dashboard.tags if tag.type == TagType.custom] + new_tags.append(new_tag.id) + update_payload = {"tags": new_tags} + + uri = f"api/v1/dashboard/{dashboard.id}" + rv = self.put_assert_metric(uri, update_payload, "put") + self.assertEqual(rv.status_code, 403) + self.assertEqual( + rv.json["message"], + "You do not have permission to manage tags on dashboards", + ) + + security_manager.add_permission_role(gamma_role, write_tags_perm) + security_manager.add_permission_role(gamma_role, tag_dashboards_perm) + + @pytest.mark.usefixtures("create_dashboard_with_tag") + def test_update_dashboard_remove_tags_missing_permission(self): + """ + Validates an owner can't remove tags from a dashboard if they don't + have permission to it + """ + self.login(GAMMA_USERNAME) + write_tags_perm = security_manager.add_permission_view_menu("can_write", "Tag") + tag_dashboards_perm = security_manager.add_permission_view_menu( + "can_tag", "Dashboard" + ) + gamma_role = security_manager.find_role("Gamma") + security_manager.del_permission_role(gamma_role, write_tags_perm) + security_manager.del_permission_role(gamma_role, tag_dashboards_perm) + + dashboard = ( + db.session.query(Dashboard) + .filter(Dashboard.dashboard_title == "dash with tag") + .first() + ) + + update_payload = {"tags": []} + + uri = f"api/v1/dashboard/{dashboard.id}" + rv = self.put_assert_metric(uri, update_payload, "put") + self.assertEqual(rv.status_code, 403) + self.assertEqual( + rv.json["message"], + "You do not have permission to manage tags on dashboards", + ) + + security_manager.add_permission_role(gamma_role, write_tags_perm) + security_manager.add_permission_role(gamma_role, tag_dashboards_perm) + + @pytest.mark.usefixtures("create_dashboard_with_tag") + def test_update_dashboard_no_tag_changes(self): + """ + Validates an owner without permission to change tags is able to + update a dashboard when tags haven't changed + """ + self.login(GAMMA_USERNAME) + write_tags_perm = security_manager.add_permission_view_menu("can_write", "Tag") + tag_dashboards_perm = security_manager.add_permission_view_menu( + "can_tag", "Dashboard" + ) + gamma_role = security_manager.find_role("Gamma") + security_manager.del_permission_role(gamma_role, write_tags_perm) + security_manager.del_permission_role(gamma_role, tag_dashboards_perm) + + dashboard = ( + db.session.query(Dashboard) + .filter(Dashboard.dashboard_title == "dash with tag") + .first() + ) + existing_tags = [tag.id for tag in dashboard.tags if tag.type == TagType.custom] + update_payload = {"tags": existing_tags} + + uri = f"api/v1/dashboard/{dashboard.id}" + rv = self.put_assert_metric(uri, update_payload, "put") + self.assertEqual(rv.status_code, 200) + + security_manager.add_permission_role(gamma_role, write_tags_perm) + security_manager.add_permission_role(gamma_role, tag_dashboards_perm) diff --git a/tests/unit_tests/commands/test_utils.py b/tests/unit_tests/commands/test_utils.py index d3ba5bbe41..810142d3d9 100644 --- a/tests/unit_tests/commands/test_utils.py +++ b/tests/unit_tests/commands/test_utils.py @@ -15,19 +15,74 @@ # specific language governing permissions and limitations # under the License. -from unittest.mock import MagicMock, patch -from superset.commands.utils import compute_owner_list, populate_owner_list, User +from unittest.mock import call, MagicMock, patch + +import pytest + +from superset.commands.exceptions import TagForbiddenError, TagNotFoundValidationError +from superset.commands.utils import ( + compute_owner_list, + populate_owner_list, + Tag, + TagType, + update_tags, + User, + validate_tags, +) +from superset.tags.models import ObjectType + +OBJECT_TYPES = {ObjectType.chart, ObjectType.chart} +MOCK_TAGS = [ + Tag( + id=1, + name="first", + type=TagType.custom, + ), + Tag( + id=2, + name="second", + type=TagType.custom, + ), + Tag( + id=3, + name="third", + type=TagType.custom, + ), + Tag( + id=4, + name="type:dashboard", + type=TagType.type, + ), + Tag( + id=4, + name="owner:1", + type=TagType.owner, + ), + Tag( + id=4, + name="avorited_by:2", + type=TagType.favorited_by, + ), +] @patch("superset.commands.utils.g") def test_populate_owner_list_default_to_user(mock_user): + """ + Test the ``populate_owner_list`` method when no owners are provided + and default_to_user is True (non-admin). + """ owner_list = populate_owner_list([], True) assert owner_list == [mock_user.user] @patch("superset.commands.utils.g") def test_populate_owner_list_default_to_user_handle_none(mock_user): + """ + Test the ``populate_owner_list`` method when owners is None + and default_to_user is True (non-admin). + """ owner_list = populate_owner_list(None, True) assert owner_list == [mock_user.user] @@ -36,6 +91,10 @@ def test_populate_owner_list_default_to_user_handle_none(mock_user): @patch("superset.commands.utils.security_manager") @patch("superset.commands.utils.get_user_id") def test_populate_owner_list_admin_user(mock_user_id, mock_sm, mock_g): + """ + Test the ``populate_owner_list`` method when an admin is setting + another user as an owner and default_to_user is False. + """ test_user = User(id=1, first_name="First", last_name="Last") mock_g.user = User(id=4, first_name="Admin", last_name="User") mock_user_id.return_value = 4 @@ -50,6 +109,10 @@ def test_populate_owner_list_admin_user(mock_user_id, mock_sm, mock_g): @patch("superset.commands.utils.security_manager") @patch("superset.commands.utils.get_user_id") def test_populate_owner_list_admin_user_empty_list(mock_user_id, mock_sm, mock_g): + """ + Test the ``populate_owner_list`` method when an admin is setting an empty list + of owners. + """ mock_g.user = User(id=4, first_name="Admin", last_name="User") mock_user_id.return_value = 4 mock_sm.is_admin = MagicMock(return_value=True) @@ -61,6 +124,10 @@ def test_populate_owner_list_admin_user_empty_list(mock_user_id, mock_sm, mock_g @patch("superset.commands.utils.security_manager") @patch("superset.commands.utils.get_user_id") def test_populate_owner_list_non_admin(mock_user_id, mock_sm, mock_g): + """ + Test the ``populate_owner_list`` method when a non admin is adding + another user as an owner and default_to_user is False (both get added). + """ test_user = User(id=1, first_name="First", last_name="Last") mock_g.user = User(id=4, first_name="Non", last_name="admin") mock_user_id.return_value = 4 @@ -73,6 +140,9 @@ def test_populate_owner_list_non_admin(mock_user_id, mock_sm, mock_g): @patch("superset.commands.utils.populate_owner_list") def test_compute_owner_list_new_owners(mock_populate_owner_list): + """ + Test the ``compute_owner_list`` method when replacing the owner list. + """ current_owners = [User(id=1), User(id=2), User(id=3)] new_owners = [4, 5, 6] @@ -82,6 +152,9 @@ def test_compute_owner_list_new_owners(mock_populate_owner_list): @patch("superset.commands.utils.populate_owner_list") def test_compute_owner_list_no_new_owners(mock_populate_owner_list): + """ + Test the ``compute_owner_list`` method when replacing new_owners is None. + """ current_owners = [User(id=1), User(id=2), User(id=3)] new_owners = None @@ -91,6 +164,9 @@ def test_compute_owner_list_no_new_owners(mock_populate_owner_list): @patch("superset.commands.utils.populate_owner_list") def test_compute_owner_list_new_owner_empty_list(mock_populate_owner_list): + """ + Test the ``compute_owner_list`` method when new_owners is an empty list. + """ current_owners = [User(id=1), User(id=2), User(id=3)] new_owners = [] @@ -100,6 +176,9 @@ def test_compute_owner_list_new_owner_empty_list(mock_populate_owner_list): @patch("superset.commands.utils.populate_owner_list") def test_compute_owner_list_no_owners(mock_populate_owner_list): + """ + Test the ``compute_owner_list`` method when current ownership is an empty list. + """ current_owners = [] new_owners = [4, 5, 6] @@ -109,8 +188,292 @@ def test_compute_owner_list_no_owners(mock_populate_owner_list): @patch("superset.commands.utils.populate_owner_list") def test_compute_owner_list_no_owners_handle_none(mock_populate_owner_list): + """ + Test the ``compute_owner_list`` method when current ownership is None. + """ current_owners = None new_owners = [4, 5, 6] compute_owner_list(current_owners, new_owners) mock_populate_owner_list.assert_called_once_with(new_owners, default_to_user=False) + + +@pytest.mark.parametrize("object_type", OBJECT_TYPES) +@patch("superset.commands.utils.security_manager") +def test_validate_tags_new_tags_is_none(mock_sm, object_type): + """ + Test the ``validate_tags`` method when new_tags is None. + """ + validate_tags(object_type, MOCK_TAGS, None) + mock_sm.can_access.assert_not_called() + + +@pytest.mark.parametrize("object_type", OBJECT_TYPES) +@patch("superset.commands.utils.security_manager") +def test_validate_tags_empty_list_can_write_on_tag(mock_sm, object_type): + """ + Test the ``validate_tags`` method when new_tags is an empty list and + user has permission to write on tag. + """ + mock_sm.can_access = MagicMock(return_value=True) + validate_tags(object_type, MOCK_TAGS, []) + mock_sm.can_access.assert_called_once_with("can_write", "Tag") + + +@pytest.mark.parametrize("object_type", OBJECT_TYPES) +@patch("superset.commands.utils.security_manager") +def test_validate_tags_empty_list_can_tag_on_object(mock_sm, object_type): + """ + Test the ``validate_tags`` method when new_tags is an empty list and + user has permission to tag objects. + """ + mock_sm.can_access = MagicMock(side_effect=[False, True]) + validate_tags(object_type, MOCK_TAGS, []) + mock_sm.can_access.assert_has_calls( + [call("can_write", "Tag"), call("can_tag", object_type.name.capitalize())] + ) + + +@pytest.mark.parametrize("object_type", OBJECT_TYPES) +@patch("superset.commands.utils.security_manager") +def test_validate_tags_empty_list_missing_permission(mock_sm, object_type): + """ + Test the ``validate_tags`` method when new_tags is an empty list and + the user doesn't have the required permission. + """ + mock_sm.can_access = MagicMock(side_effect=[False, False]) + with pytest.raises(TagForbiddenError): + validate_tags(object_type, MOCK_TAGS, []) + mock_sm.can_access.assert_has_calls( + [call("can_write", "Tag"), call("can_tag", object_type.name.capitalize())] + ) + + +@pytest.mark.parametrize("object_type", OBJECT_TYPES) +@patch("superset.commands.utils.security_manager") +def test_validate_tags_no_changes_can_write_on_tag(mock_sm, object_type): + """ + Test the ``validate_tags`` method when new_tags is equal to existing tags + and user has permission to write on tag. + """ + new_tags = [tag.id for tag in MOCK_TAGS if tag.type == TagType.custom] + validate_tags(object_type, MOCK_TAGS, new_tags) + mock_sm.can_access.assert_not_called() + + +@pytest.mark.parametrize("object_type", OBJECT_TYPES) +@patch("superset.commands.utils.security_manager") +def test_validate_tags_no_changes_can_tag_on_object(mock_sm, object_type): + """ + Test the ``validate_tags`` method when new_tags is equal to existing tags + and user has permission to tag objects. + """ + new_tags = [tag.id for tag in MOCK_TAGS if tag.type == TagType.custom] + validate_tags(object_type, MOCK_TAGS, new_tags) + mock_sm.can_access.assert_not_called() + + +@pytest.mark.parametrize("object_type", OBJECT_TYPES) +@patch("superset.commands.utils.security_manager") +def test_validate_tags_no_changes_missing_permission(mock_sm, object_type): + """ + Test the ``validate_tags`` method when new_tags is equal to existing tags + the user doens't have the required perms. + """ + new_tags = [tag.id for tag in MOCK_TAGS if tag.type == TagType.custom] + validate_tags(object_type, MOCK_TAGS, new_tags) + mock_sm.can_access.assert_not_called() + + +@pytest.mark.parametrize("object_type", OBJECT_TYPES) +@patch("superset.commands.utils.security_manager") +@patch("superset.commands.utils.TagDAO.find_by_id") +def test_validate_tags_add_new_tags_can_write_on_tag( + mock_tag_find_by_id, mock_sm, object_type +): + """ + Test the ``validate_tags`` method when new_tags are added and user has + permission to write on tag. + """ + new_tag_ids = [tag.id for tag in MOCK_TAGS if tag.type == TagType.custom] + new_tag = { + "id": 10, + "name": "New test tag", + "type": TagType.custom, + } + new_tag_ids.append(new_tag["id"]) + + mock_tag_find_by_id.return_value = new_tag + mock_sm.can_access = MagicMock(return_value=True) + + validate_tags(object_type, MOCK_TAGS, new_tag_ids) + + mock_sm.can_access.assert_called_once_with("can_write", "Tag") + + +@pytest.mark.parametrize("object_type", OBJECT_TYPES) +@patch("superset.commands.utils.security_manager") +@patch("superset.commands.utils.TagDAO.find_by_id") +def test_validate_tags_add_new_tags_can_tag_on_object( + mock_tag_find_by_id, mock_sm, object_type +): + """ + Test the ``validate_tags`` method when new_tags are added and user has + permission to tag objects. + """ + current_tags = [tag for tag in MOCK_TAGS if tag.type == TagType.custom] + new_tag = current_tags.pop() + new_tag_ids = [tag.id for tag in current_tags] + new_tag_ids.append(new_tag.id) + + mock_sm.can_access = MagicMock(side_effect=[False, True]) + mock_tag_find_by_id.return_value = new_tag + + validate_tags(object_type, current_tags, new_tag_ids) + + mock_sm.can_access.assert_has_calls( + [call("can_write", "Tag"), call("can_tag", object_type.name.capitalize())] + ) + + +@pytest.mark.parametrize("object_type", OBJECT_TYPES) +@patch("superset.commands.utils.security_manager") +@patch("superset.commands.utils.TagDAO.find_by_name") +def test_validate_tags_can_write_on_tag_unable_to_find_tag( + mock_tag_find_by_id, mock_sm, object_type +): + """ + Test the ``validate_tags`` method when an un-existing tag is being + added and user has permission to write on tag. + """ + fake_id = 100 + mock_sm.can_access = MagicMock(return_value=True) + mock_tag_find_by_id.return_value = None + with pytest.raises(TagNotFoundValidationError): + validate_tags(object_type, MOCK_TAGS, [fake_id]) + mock_sm.can_access.assert_called_once_with("can_write", "Tag") + + +@pytest.mark.parametrize("object_type", OBJECT_TYPES) +@patch("superset.commands.utils.security_manager") +@patch("superset.commands.utils.TagDAO.find_by_name") +def test_validate_tags_can_tag_on_object_unable_to_find_tag( + mock_tag_find_by_id, mock_sm, object_type +): + """ + Test the ``validate_tags`` method when an un-existing tag is being + added and user has permission to tag on object. + """ + fake_id = 100 + mock_sm.can_access = MagicMock(side_effect=[False, True]) + mock_tag_find_by_id.return_value = None + with pytest.raises(TagNotFoundValidationError): + validate_tags(object_type, MOCK_TAGS, [fake_id]) + mock_sm.can_access.assert_has_calls( + [call("can_write", "Tag"), call("can_tag", object_type.name.capitalize())] + ) + + +@pytest.mark.parametrize("object_type", OBJECT_TYPES) +@patch("superset.commands.utils.TagDAO") +def test_update_tags_adding_tags(mock_tag_dao, object_type): + """ + Test the ``update_tags`` method when adding tags. + """ + current_tags = [tag for tag in MOCK_TAGS if tag.type == TagType.custom] + new_tag = current_tags.pop() + new_tags = [tag for tag in MOCK_TAGS if tag.type == TagType.custom] + new_tag_ids = [tag.id for tag in new_tags] + + mock_tag_dao.find_by_ids.return_value = [new_tag] + + update_tags(object_type, 1, current_tags, new_tag_ids) + + mock_tag_dao.find_by_ids.assert_called_once_with([new_tag.id]) + mock_tag_dao.delete_tagged_object.assert_not_called() + mock_tag_dao.create_custom_tagged_objects.assert_called_once_with( + object_type, 1, [new_tag.name] + ) + + +@pytest.mark.parametrize("object_type", OBJECT_TYPES) +@patch("superset.commands.utils.TagDAO") +def test_update_tags_removing_tags(mock_tag_dao, object_type): + """ + Test the ``update_tags`` method when removing existing tags. + """ + new_tags = [tag for tag in MOCK_TAGS if tag.type == TagType.custom] + tag_to_be_removed = new_tags.pop() + new_tag_ids = [tag.id for tag in new_tags] + + update_tags(object_type, 1, MOCK_TAGS, new_tag_ids) + + mock_tag_dao.delete_tagged_object.assert_called_once_with( + object_type, 1, tag_to_be_removed.name + ) + mock_tag_dao.create_custom_tagged_objects.assert_not_called() + + +@pytest.mark.parametrize("object_type", OBJECT_TYPES) +@patch("superset.commands.utils.TagDAO") +def test_update_tags_adding_and_removing_tags(mock_tag_dao, object_type): + """ + Test the ``update_tags`` method when adding and removing existing tags. + """ + new_tags = [tag for tag in MOCK_TAGS if tag.type == TagType.custom] + tag_to_be_removed = new_tags.pop() + new_tag = Tag(id=10, name="my new tag", type=TagType.custom) + new_tags.append(new_tag) + new_tag_ids = [tag.id for tag in new_tags] + + mock_tag_dao.find_by_ids.return_value = [new_tag] + + update_tags(object_type, 1, MOCK_TAGS, new_tag_ids) + + mock_tag_dao.delete_tagged_object.assert_called_once_with( + object_type, 1, tag_to_be_removed.name + ) + mock_tag_dao.find_by_ids.assert_called_once_with([new_tag.id]) + mock_tag_dao.create_custom_tagged_objects.assert_called_once_with( + object_type, 1, ["my new tag"] + ) + + +@pytest.mark.parametrize("object_type", OBJECT_TYPES) +@patch("superset.commands.utils.TagDAO") +def test_update_tags_removing_all_tags(mock_tag_dao, object_type): + """ + Test the ``update_tags`` method when removing all tags. + """ + update_tags(object_type, 1, MOCK_TAGS, []) + + mock_tag_dao.delete_tagged_object.assert_has_calls( + [ + call(object_type, 1, tag.name) + for tag in MOCK_TAGS + if tag.type == TagType.custom + ] + ) + mock_tag_dao.create_custom_tagged_objects.assert_not_called() + + +@pytest.mark.parametrize("object_type", OBJECT_TYPES) +@patch("superset.commands.utils.TagDAO") +def test_update_tags_no_tags(mock_tag_dao, object_type): + """ + Test the ``update_tags`` method when the asset only has system tags. + """ + system_tags = [tag for tag in MOCK_TAGS if tag.type != TagType.custom] + new_tags = [tag for tag in MOCK_TAGS if tag.type == TagType.custom] + new_tag_ids = [tag.id for tag in new_tags] + new_tag_names = [tag.name for tag in new_tags] + + mock_tag_dao.find_by_ids.return_value = new_tags + + update_tags(object_type, 1, system_tags, new_tag_ids) + + mock_tag_dao.delete_tagged_object.assert_not_called() + mock_tag_dao.find_by_ids.assert_called_once_with(new_tag_ids) + mock_tag_dao.create_custom_tagged_objects.assert_called_once_with( + object_type, 1, new_tag_names + )