fix: Allows PUT and DELETE only for owners of dashboard filter state (#17644)

* fix: Allows PUT and DELETE only for owners of dashboard filter state

* Converts the values to TypedDict

* Fixes variable name
This commit is contained in:
Michael S. Molina 2021-12-05 22:13:09 -03:00 committed by GitHub
parent 8e02d11909
commit 2ae83fac86
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 239 additions and 134 deletions

View File

@ -16,17 +16,23 @@
# under the License. # under the License.
from typing import Optional from typing import Optional
from flask_appbuilder.security.sqla.models import User
from superset.dashboards.dao import DashboardDAO from superset.dashboards.dao import DashboardDAO
from superset.dashboards.filter_state.commands.entry import Entry
from superset.extensions import cache_manager from superset.extensions import cache_manager
from superset.key_value.commands.create import CreateKeyValueCommand from superset.key_value.commands.create import CreateKeyValueCommand
from superset.key_value.utils import cache_key from superset.key_value.utils import cache_key
class CreateFilterStateCommand(CreateKeyValueCommand): class CreateFilterStateCommand(CreateKeyValueCommand):
def create(self, resource_id: int, key: str, value: str) -> Optional[bool]: def create(
self, actor: User, resource_id: int, key: str, value: str
) -> Optional[bool]:
dashboard = DashboardDAO.get_by_id_or_slug(str(resource_id)) dashboard = DashboardDAO.get_by_id_or_slug(str(resource_id))
if dashboard: if dashboard:
entry: Entry = {"owner": actor.get_user_id(), "value": value}
return cache_manager.filter_state_cache.set( return cache_manager.filter_state_cache.set(
cache_key(resource_id, key), value cache_key(resource_id, key), entry
) )
return False return False

View File

@ -16,15 +16,27 @@
# under the License. # under the License.
from typing import Optional from typing import Optional
from flask_appbuilder.security.sqla.models import User
from superset.dashboards.dao import DashboardDAO from superset.dashboards.dao import DashboardDAO
from superset.dashboards.filter_state.commands.entry import Entry
from superset.extensions import cache_manager from superset.extensions import cache_manager
from superset.key_value.commands.delete import DeleteKeyValueCommand from superset.key_value.commands.delete import DeleteKeyValueCommand
from superset.key_value.commands.exceptions import KeyValueAccessDeniedError
from superset.key_value.utils import cache_key from superset.key_value.utils import cache_key
class DeleteFilterStateCommand(DeleteKeyValueCommand): class DeleteFilterStateCommand(DeleteKeyValueCommand):
def delete(self, resource_id: int, key: str) -> Optional[bool]: def delete(self, actor: User, resource_id: int, key: str) -> Optional[bool]:
dashboard = DashboardDAO.get_by_id_or_slug(str(resource_id)) dashboard = DashboardDAO.get_by_id_or_slug(str(resource_id))
if dashboard: if dashboard:
return cache_manager.filter_state_cache.delete(cache_key(resource_id, key)) entry: Entry = cache_manager.filter_state_cache.get(
cache_key(resource_id, key)
)
if entry:
if entry["owner"] != actor.get_user_id():
raise KeyValueAccessDeniedError()
return cache_manager.filter_state_cache.delete(
cache_key(resource_id, key)
)
return False return False

View File

@ -0,0 +1,22 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import TypedDict
class Entry(TypedDict):
owner: int
value: str

View File

@ -17,6 +17,7 @@
from typing import Optional from typing import Optional
from superset.dashboards.dao import DashboardDAO from superset.dashboards.dao import DashboardDAO
from superset.dashboards.filter_state.commands.entry import Entry
from superset.extensions import cache_manager from superset.extensions import cache_manager
from superset.key_value.commands.get import GetKeyValueCommand from superset.key_value.commands.get import GetKeyValueCommand
from superset.key_value.utils import cache_key from superset.key_value.utils import cache_key
@ -26,8 +27,10 @@ class GetFilterStateCommand(GetKeyValueCommand):
def get(self, resource_id: int, key: str, refreshTimeout: bool) -> Optional[str]: def get(self, resource_id: int, key: str, refreshTimeout: bool) -> Optional[str]:
dashboard = DashboardDAO.get_by_id_or_slug(str(resource_id)) dashboard = DashboardDAO.get_by_id_or_slug(str(resource_id))
if dashboard: if dashboard:
value = cache_manager.filter_state_cache.get(cache_key(resource_id, key)) entry: Entry = cache_manager.filter_state_cache.get(
cache_key(resource_id, key)
)
if refreshTimeout: if refreshTimeout:
cache_manager.filter_state_cache.set(key, value) cache_manager.filter_state_cache.set(key, entry)
return value return entry["value"]
return None return None

View File

@ -16,17 +16,31 @@
# under the License. # under the License.
from typing import Optional from typing import Optional
from flask_appbuilder.security.sqla.models import User
from superset.dashboards.dao import DashboardDAO from superset.dashboards.dao import DashboardDAO
from superset.dashboards.filter_state.commands.entry import Entry
from superset.extensions import cache_manager from superset.extensions import cache_manager
from superset.key_value.commands.exceptions import KeyValueAccessDeniedError
from superset.key_value.commands.update import UpdateKeyValueCommand from superset.key_value.commands.update import UpdateKeyValueCommand
from superset.key_value.utils import cache_key from superset.key_value.utils import cache_key
class UpdateFilterStateCommand(UpdateKeyValueCommand): class UpdateFilterStateCommand(UpdateKeyValueCommand):
def update(self, resource_id: int, key: str, value: str) -> Optional[bool]: def update(
self, actor: User, resource_id: int, key: str, value: str
) -> Optional[bool]:
dashboard = DashboardDAO.get_by_id_or_slug(str(resource_id)) dashboard = DashboardDAO.get_by_id_or_slug(str(resource_id))
if dashboard: if dashboard:
entry: Entry = cache_manager.filter_state_cache.get(
cache_key(resource_id, key)
)
if entry:
user_id = actor.get_user_id()
if entry["owner"] != user_id:
raise KeyValueAccessDeniedError()
new_entry: Entry = {"owner": actor.get_user_id(), "value": value}
return cache_manager.filter_state_cache.set( return cache_manager.filter_state_cache.set(
cache_key(resource_id, key), value cache_key(resource_id, key), new_entry
) )
return False return False

View File

@ -29,6 +29,7 @@ from superset.dashboards.commands.exceptions import (
DashboardNotFoundError, DashboardNotFoundError,
) )
from superset.exceptions import InvalidPayloadFormatError from superset.exceptions import InvalidPayloadFormatError
from superset.key_value.commands.exceptions import KeyValueAccessDeniedError
from superset.key_value.schemas import KeyValuePostSchema, KeyValuePutSchema from superset.key_value.schemas import KeyValuePostSchema, KeyValuePutSchema
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -64,7 +65,7 @@ class KeyValueRestApi(BaseApi, ABC):
return self.response(201, key=key) return self.response(201, key=key)
except ValidationError as error: except ValidationError as error:
return self.response_400(message=error.messages) return self.response_400(message=error.messages)
except DashboardAccessDeniedError: except (DashboardAccessDeniedError, KeyValueAccessDeniedError):
return self.response_403() return self.response_403()
except DashboardNotFoundError: except DashboardNotFoundError:
return self.response_404() return self.response_404()
@ -80,7 +81,7 @@ class KeyValueRestApi(BaseApi, ABC):
return self.response(200, message="Value updated successfully.") return self.response(200, message="Value updated successfully.")
except ValidationError as error: except ValidationError as error:
return self.response_400(message=error.messages) return self.response_400(message=error.messages)
except DashboardAccessDeniedError: except (DashboardAccessDeniedError, KeyValueAccessDeniedError):
return self.response_403() return self.response_403()
except DashboardNotFoundError: except DashboardNotFoundError:
return self.response_404() return self.response_404()
@ -91,7 +92,7 @@ class KeyValueRestApi(BaseApi, ABC):
if not value: if not value:
return self.response_404() return self.response_404()
return self.response(200, value=value) return self.response(200, value=value)
except DashboardAccessDeniedError: except (DashboardAccessDeniedError, KeyValueAccessDeniedError):
return self.response_403() return self.response_403()
except DashboardNotFoundError: except DashboardNotFoundError:
return self.response_404() return self.response_404()
@ -102,7 +103,7 @@ class KeyValueRestApi(BaseApi, ABC):
if not result: if not result:
return self.response_404() return self.response_404()
return self.response(200, message="Deleted successfully") return self.response(200, message="Deleted successfully")
except DashboardAccessDeniedError: except (DashboardAccessDeniedError, KeyValueAccessDeniedError):
return self.response_403() return self.response_403()
except DashboardNotFoundError: except DashboardNotFoundError:
return self.response_404() return self.response_404()

View File

@ -31,16 +31,16 @@ logger = logging.getLogger(__name__)
class CreateKeyValueCommand(BaseCommand, ABC): class CreateKeyValueCommand(BaseCommand, ABC):
def __init__( def __init__(
self, user: User, resource_id: int, value: str, self, actor: User, resource_id: int, value: str,
): ):
self._actor = user self._actor = actor
self._resource_id = resource_id self._resource_id = resource_id
self._value = value self._value = value
def run(self) -> Model: def run(self) -> Model:
try: try:
key = token_urlsafe(48) key = token_urlsafe(48)
self.create(self._resource_id, key, self._value) self.create(self._actor, self._resource_id, key, self._value)
return key return key
except SQLAlchemyError as ex: except SQLAlchemyError as ex:
logger.exception("Error running create command") logger.exception("Error running create command")
@ -50,5 +50,7 @@ class CreateKeyValueCommand(BaseCommand, ABC):
pass pass
@abstractmethod @abstractmethod
def create(self, resource_id: int, key: str, value: str) -> Optional[bool]: def create(
self, actor: User, resource_id: int, key: str, value: str
) -> Optional[bool]:
... ...

View File

@ -29,14 +29,14 @@ logger = logging.getLogger(__name__)
class DeleteKeyValueCommand(BaseCommand, ABC): class DeleteKeyValueCommand(BaseCommand, ABC):
def __init__(self, user: User, resource_id: int, key: str): def __init__(self, actor: User, resource_id: int, key: str):
self._actor = user self._actor = actor
self._resource_id = resource_id self._resource_id = resource_id
self._key = key self._key = key
def run(self) -> Model: def run(self) -> Model:
try: try:
return self.delete(self._resource_id, self._key) return self.delete(self._actor, self._resource_id, self._key)
except SQLAlchemyError as ex: except SQLAlchemyError as ex:
logger.exception("Error running delete command") logger.exception("Error running delete command")
raise KeyValueDeleteFailedError() from ex raise KeyValueDeleteFailedError() from ex
@ -45,5 +45,5 @@ class DeleteKeyValueCommand(BaseCommand, ABC):
pass pass
@abstractmethod @abstractmethod
def delete(self, resource_id: int, key: str) -> Optional[bool]: def delete(self, actor: User, resource_id: int, key: str) -> Optional[bool]:
... ...

View File

@ -20,6 +20,7 @@ from superset.commands.exceptions import (
CommandException, CommandException,
CreateFailedError, CreateFailedError,
DeleteFailedError, DeleteFailedError,
ForbiddenError,
UpdateFailedError, UpdateFailedError,
) )
@ -38,3 +39,7 @@ class KeyValueDeleteFailedError(DeleteFailedError):
class KeyValueUpdateFailedError(UpdateFailedError): class KeyValueUpdateFailedError(UpdateFailedError):
message = _("An error occurred while updating the value.") message = _("An error occurred while updating the value.")
class KeyValueAccessDeniedError(ForbiddenError):
message = _("You don't have permission to modify the value.")

View File

@ -30,8 +30,8 @@ logger = logging.getLogger(__name__)
class GetKeyValueCommand(BaseCommand, ABC): class GetKeyValueCommand(BaseCommand, ABC):
def __init__(self, user: User, resource_id: int, key: str): def __init__(self, actor: User, resource_id: int, key: str):
self._actor = user self._actor = actor
self._resource_id = resource_id self._resource_id = resource_id
self._key = key self._key = key

View File

@ -30,16 +30,16 @@ logger = logging.getLogger(__name__)
class UpdateKeyValueCommand(BaseCommand, ABC): class UpdateKeyValueCommand(BaseCommand, ABC):
def __init__( def __init__(
self, user: User, resource_id: int, key: str, value: str, self, actor: User, resource_id: int, key: str, value: str,
): ):
self._actor = user self._actor = actor
self._resource_id = resource_id self._resource_id = resource_id
self._key = key self._key = key
self._value = value self._value = value
def run(self) -> Model: def run(self) -> Model:
try: try:
return self.update(self._resource_id, self._key, self._value) return self.update(self._actor, self._resource_id, self._key, self._value)
except SQLAlchemyError as ex: except SQLAlchemyError as ex:
logger.exception("Error running update command") logger.exception("Error running update command")
raise KeyValueUpdateFailedError() from ex raise KeyValueUpdateFailedError() from ex
@ -48,5 +48,7 @@ class UpdateKeyValueCommand(BaseCommand, ABC):
pass pass
@abstractmethod @abstractmethod
def update(self, resource_id: int, key: str, value: str) -> Optional[bool]: def update(
self, actor: User, resource_id: int, key: str, value: str
) -> Optional[bool]:
... ...

View File

@ -15,61 +15,70 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import json import json
from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from flask.testing import FlaskClient from flask_appbuilder.security.sqla.models import User
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from superset import app from superset import app
from superset.dashboards.commands.exceptions import DashboardAccessDeniedError from superset.dashboards.commands.exceptions import DashboardAccessDeniedError
from superset.dashboards.filter_state.commands.entry import Entry
from superset.extensions import cache_manager from superset.extensions import cache_manager
from superset.key_value.utils import cache_key from superset.key_value.utils import cache_key
from superset.models.dashboard import Dashboard from superset.models.dashboard import Dashboard
from tests.integration_tests.base_tests import login
from tests.integration_tests.fixtures.world_bank_dashboard import ( from tests.integration_tests.fixtures.world_bank_dashboard import (
load_world_bank_dashboard_with_slices, load_world_bank_dashboard_with_slices,
) )
from tests.integration_tests.test_app import app from tests.integration_tests.test_app import app
dashboardId = 985374
key = "test-key" key = "test-key"
value = "test" value = "test"
class FilterStateTests:
@pytest.fixture @pytest.fixture
def client(self): def client():
with app.test_client() as client: with app.test_client() as client:
with app.app_context(): with app.app_context():
yield client yield client
@pytest.fixture @pytest.fixture
def dashboard_id(self, load_world_bank_dashboard_with_slices) -> int: def dashboard_id(load_world_bank_dashboard_with_slices) -> int:
with app.app_context() as ctx: with app.app_context() as ctx:
session: Session = ctx.app.appbuilder.get_session session: Session = ctx.app.appbuilder.get_session
dashboard = session.query(Dashboard).filter_by(slug="world_health").one() dashboard = session.query(Dashboard).filter_by(slug="world_health").one()
return dashboard.id return dashboard.id
@pytest.fixture @pytest.fixture
def cache(self, dashboard_id): def admin_id() -> int:
with app.app_context() as ctx:
session: Session = ctx.app.appbuilder.get_session
admin = session.query(User).filter_by(username="admin").one_or_none()
return admin.id
@pytest.fixture(autouse=True)
def cache(dashboard_id, admin_id):
app.config["FILTER_STATE_CACHE_CONFIG"] = {"CACHE_TYPE": "SimpleCache"} app.config["FILTER_STATE_CACHE_CONFIG"] = {"CACHE_TYPE": "SimpleCache"}
cache_manager.init_app(app) cache_manager.init_app(app)
cache_manager.filter_state_cache.set(cache_key(dashboard_id, key), value) entry: Entry = {"owner": admin_id, "value": value}
cache_manager.filter_state_cache.set(cache_key(dashboard_id, key), entry)
def setUp(self):
self.login(username="admin")
def test_post(self, client, dashboard_id: int): def test_post(client, dashboard_id: int):
login(client, "admin")
payload = { payload = {
"value": value, "value": value,
} }
resp = client.post( resp = client.post(f"api/v1/dashboard/{dashboard_id}/filter_state", json=payload)
f"api/v1/dashboard/{dashboard_id}/filter_state", json=payload
)
assert resp.status_code == 201 assert resp.status_code == 201
def test_post_bad_request(self, client, dashboard_id: int):
def test_post_bad_request(client, dashboard_id: int):
login(client, "admin")
payload = { payload = {
"value": 1234, "value": 1234,
} }
@ -78,20 +87,20 @@ class FilterStateTests:
) )
assert resp.status_code == 400 assert resp.status_code == 400
@patch("superset.security.SupersetSecurityManager.raise_for_dashboard_access") @patch("superset.security.SupersetSecurityManager.raise_for_dashboard_access")
def test_post_access_denied( def test_post_access_denied(mock_raise_for_dashboard_access, client, dashboard_id: int):
self, client, mock_raise_for_dashboard_access, dashboard_id: int login(client, "admin")
):
mock_raise_for_dashboard_access.side_effect = DashboardAccessDeniedError() mock_raise_for_dashboard_access.side_effect = DashboardAccessDeniedError()
payload = { payload = {
"value": value, "value": value,
} }
resp = client.post( resp = client.post(f"api/v1/dashboard/{dashboard_id}/filter_state", json=payload)
f"api/v1/dashboard/{dashboard_id}/filter_state", json=payload
)
assert resp.status_code == 403 assert resp.status_code == 403
def test_put(self, client, dashboard_id: int):
def test_put(client, dashboard_id: int):
login(client, "admin")
payload = { payload = {
"value": "new value", "value": "new value",
} }
@ -100,7 +109,9 @@ class FilterStateTests:
) )
assert resp.status_code == 200 assert resp.status_code == 200
def test_put_bad_request(self, client, dashboard_id: int):
def test_put_bad_request(client, dashboard_id: int):
login(client, "admin")
payload = { payload = {
"value": 1234, "value": 1234,
} }
@ -109,10 +120,10 @@ class FilterStateTests:
) )
assert resp.status_code == 400 assert resp.status_code == 400
@patch("superset.security.SupersetSecurityManager.raise_for_dashboard_access") @patch("superset.security.SupersetSecurityManager.raise_for_dashboard_access")
def test_put_access_denied( def test_put_access_denied(mock_raise_for_dashboard_access, client, dashboard_id: int):
self, client, mock_raise_for_dashboard_access, dashboard_id: int login(client, "admin")
):
mock_raise_for_dashboard_access.side_effect = DashboardAccessDeniedError() mock_raise_for_dashboard_access.side_effect = DashboardAccessDeniedError()
payload = { payload = {
"value": "new value", "value": "new value",
@ -122,36 +133,63 @@ class FilterStateTests:
) )
assert resp.status_code == 403 assert resp.status_code == 403
def test_get_key_not_found(self, client):
def test_put_not_owner(client, dashboard_id: int):
login(client, "gamma")
payload = {
"value": "new value",
}
resp = client.put(
f"api/v1/dashboard/{dashboard_id}/filter_state/{key}/", json=payload
)
assert resp.status_code == 403
def test_get_key_not_found(client):
login(client, "admin")
resp = client.get("unknown-key") resp = client.get("unknown-key")
assert resp.status_code == 404 assert resp.status_code == 404
def test_get_dashboard_not_found(self, client):
def test_get_dashboard_not_found(client):
login(client, "admin")
resp = client.get(f"api/v1/dashboard/{-1}/filter_state/{key}/") resp = client.get(f"api/v1/dashboard/{-1}/filter_state/{key}/")
assert resp.status_code == 404 assert resp.status_code == 404
def test_get(self, client, dashboard_id: int):
def test_get(client, dashboard_id: int):
login(client, "admin")
resp = client.get(f"api/v1/dashboard/{dashboard_id}/filter_state/{key}/") resp = client.get(f"api/v1/dashboard/{dashboard_id}/filter_state/{key}/")
assert resp.status_code == 200 assert resp.status_code == 200
data = json.loads(resp.data.decode("utf-8")) data = json.loads(resp.data.decode("utf-8"))
assert value == data.get("value") assert value == data.get("value")
@patch("superset.security.SupersetSecurityManager.raise_for_dashboard_access") @patch("superset.security.SupersetSecurityManager.raise_for_dashboard_access")
def test_get_access_denied( def test_get_access_denied(mock_raise_for_dashboard_access, client, dashboard_id):
self, client, mock_raise_for_dashboard_access, dashboard_id login(client, "admin")
):
mock_raise_for_dashboard_access.side_effect = DashboardAccessDeniedError() mock_raise_for_dashboard_access.side_effect = DashboardAccessDeniedError()
resp = client.get(f"api/v1/dashboard/{dashboard_id}/filter_state/{key}/") resp = client.get(f"api/v1/dashboard/{dashboard_id}/filter_state/{key}/")
assert resp.status_code == 403 assert resp.status_code == 403
def test_delete(self, client, dashboard_id: int):
def test_delete(client, dashboard_id: int):
login(client, "admin")
resp = client.delete(f"api/v1/dashboard/{dashboard_id}/filter_state/{key}/") resp = client.delete(f"api/v1/dashboard/{dashboard_id}/filter_state/{key}/")
assert resp.status_code == 200 assert resp.status_code == 200
@patch("superset.security.SupersetSecurityManager.raise_for_dashboard_access") @patch("superset.security.SupersetSecurityManager.raise_for_dashboard_access")
def test_delete_access_denied( def test_delete_access_denied(
self, client, mock_raise_for_dashboard_access, dashboard_id: int mock_raise_for_dashboard_access, client, dashboard_id: int
): ):
login(client, "admin")
mock_raise_for_dashboard_access.side_effect = DashboardAccessDeniedError() mock_raise_for_dashboard_access.side_effect = DashboardAccessDeniedError()
resp = client.delete(f"api/v1/dashboard/{dashboard_id}/filter_state/{key}/") resp = client.delete(f"api/v1/dashboard/{dashboard_id}/filter_state/{key}/")
assert resp.status_code == 403 assert resp.status_code == 403
def test_delete_not_owner(client, dashboard_id: int):
login(client, "gamma")
resp = client.delete(f"api/v1/dashboard/{dashboard_id}/filter_state/{key}/")
assert resp.status_code == 403