fix: improve explore REST api validations (#27395)

This commit is contained in:
Daniel Vaz Gaspar 2024-03-05 17:44:51 +00:00 committed by GitHub
parent 721977a474
commit a3d2e0bf44
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 28 additions and 3 deletions

View File

@ -37,6 +37,7 @@ from superset.daos.exceptions import DatasourceNotFound
from superset.exceptions import SupersetException from superset.exceptions import SupersetException
from superset.explore.exceptions import WrongEndpointError from superset.explore.exceptions import WrongEndpointError
from superset.explore.permalink.exceptions import ExplorePermalinkGetFailedError from superset.explore.permalink.exceptions import ExplorePermalinkGetFailedError
from superset.extensions import security_manager
from superset.utils import core as utils from superset.utils import core as utils
from superset.views.utils import ( from superset.views.utils import (
get_datasource_info, get_datasource_info,
@ -61,7 +62,6 @@ class GetExploreCommand(BaseCommand, ABC):
# pylint: disable=too-many-locals,too-many-branches,too-many-statements # pylint: disable=too-many-locals,too-many-branches,too-many-statements
def run(self) -> Optional[dict[str, Any]]: def run(self) -> Optional[dict[str, Any]]:
initial_form_data = {} initial_form_data = {}
if self._permalink_key is not None: if self._permalink_key is not None:
command = GetExplorePermalinkCommand(self._permalink_key) command = GetExplorePermalinkCommand(self._permalink_key)
permalink_value = command.run() permalink_value = command.run()
@ -110,12 +110,19 @@ class GetExploreCommand(BaseCommand, ABC):
self._datasource_type = SqlaTable.type self._datasource_type = SqlaTable.type
datasource: Optional[BaseDatasource] = None datasource: Optional[BaseDatasource] = None
if self._datasource_id is not None: if self._datasource_id is not None:
with contextlib.suppress(DatasourceNotFound): with contextlib.suppress(DatasourceNotFound):
datasource = DatasourceDAO.get_datasource( datasource = DatasourceDAO.get_datasource(
cast(str, self._datasource_type), self._datasource_id cast(str, self._datasource_type), self._datasource_id
) )
datasource_name = datasource.name if datasource else _("[Missing Dataset]")
datasource_name = _("[Missing Dataset]")
if datasource:
datasource_name = datasource.name
security_manager.can_access_datasource(datasource)
viz_type = form_data.get("viz_type") viz_type = form_data.get("viz_type")
if not viz_type and datasource and datasource.default_endpoint: if not viz_type and datasource and datasource.default_endpoint:
raise WrongEndpointError(redirect=datasource.default_endpoint) raise WrongEndpointError(redirect=datasource.default_endpoint)

View File

@ -197,7 +197,7 @@ def test_get_from_permalink_unknown_key(test_client, login_as_admin):
@patch("superset.security.SupersetSecurityManager.can_access_datasource") @patch("superset.security.SupersetSecurityManager.can_access_datasource")
def test_get_dataset_access_denied( def test_get_dataset_access_denied_with_form_data_key(
mock_can_access_datasource, test_client, login_as_admin, dataset mock_can_access_datasource, test_client, login_as_admin, dataset
): ):
message = "Dataset access denied" message = "Dataset access denied"
@ -214,6 +214,24 @@ def test_get_dataset_access_denied(
assert data["message"] == message assert data["message"] == message
@patch("superset.security.SupersetSecurityManager.can_access_datasource")
def test_get_dataset_access_denied(
mock_can_access_datasource, test_client, login_as_admin, dataset
):
message = "Dataset access denied"
mock_can_access_datasource.side_effect = DatasetAccessDeniedError(
message=message, datasource_id=dataset.id, datasource_type=dataset.type
)
resp = test_client.get(
f"api/v1/explore/?datasource_id={dataset.id}&datasource_type={dataset.type}"
)
data = json.loads(resp.data.decode("utf-8"))
assert resp.status_code == 403
assert data["datasource_id"] == dataset.id
assert data["datasource_type"] == dataset.type
assert data["message"] == message
@patch("superset.daos.datasource.DatasourceDAO.get_datasource") @patch("superset.daos.datasource.DatasourceDAO.get_datasource")
def test_wrong_endpoint(mock_get_datasource, test_client, login_as_admin, dataset): def test_wrong_endpoint(mock_get_datasource, test_client, login_as_admin, dataset):
dataset.default_endpoint = "another_endpoint" dataset.default_endpoint = "another_endpoint"