fix(permalink): migrate to marshmallow codec (#24166)

This commit is contained in:
Ville Brofeldt 2023-05-22 13:35:58 +03:00 committed by GitHub
parent 82d4249e17
commit 71d0543f28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 251 additions and 22 deletions

View File

@ -30,7 +30,7 @@ from superset.dashboards.permalink.commands.create import (
)
from superset.dashboards.permalink.commands.get import GetDashboardPermalinkCommand
from superset.dashboards.permalink.exceptions import DashboardPermalinkInvalidStateError
from superset.dashboards.permalink.schemas import DashboardPermalinkPostSchema
from superset.dashboards.permalink.schemas import DashboardPermalinkStateSchema
from superset.extensions import event_logger
from superset.key_value.exceptions import KeyValueAccessDeniedError
from superset.views.base_api import BaseSupersetApi, requires_json
@ -39,13 +39,13 @@ logger = logging.getLogger(__name__)
class DashboardPermalinkRestApi(BaseSupersetApi):
add_model_schema = DashboardPermalinkPostSchema()
add_model_schema = DashboardPermalinkStateSchema()
method_permission_name = MODEL_API_RW_METHOD_PERMISSION_MAP
allow_browser_login = True
class_permission_name = "DashboardPermalinkRestApi"
resource_name = "dashboard"
openapi_spec_tag = "Dashboard Permanent Link"
openapi_spec_component_schemas = (DashboardPermalinkPostSchema,)
openapi_spec_component_schemas = (DashboardPermalinkStateSchema,)
@expose("/<pk>/permalink", methods=("POST",))
@protect()

View File

@ -17,13 +17,18 @@
from abc import ABC
from superset.commands.base import BaseCommand
from superset.dashboards.permalink.schemas import DashboardPermalinkSchema
from superset.key_value.shared_entries import get_permalink_salt
from superset.key_value.types import JsonKeyValueCodec, KeyValueResource, SharedKey
from superset.key_value.types import (
KeyValueResource,
MarshmallowKeyValueCodec,
SharedKey,
)
class BaseDashboardPermalinkCommand(BaseCommand, ABC):
resource = KeyValueResource.DASHBOARD_PERMALINK
codec = JsonKeyValueCodec()
codec = MarshmallowKeyValueCodec(DashboardPermalinkSchema())
@property
def salt(self) -> str:

View File

@ -23,6 +23,7 @@ from superset.dashboards.permalink.commands.base import BaseDashboardPermalinkCo
from superset.dashboards.permalink.exceptions import DashboardPermalinkCreateFailedError
from superset.dashboards.permalink.types import DashboardPermalinkState
from superset.key_value.commands.upsert import UpsertKeyValueCommand
from superset.key_value.exceptions import KeyValueCodecEncodeException
from superset.key_value.utils import encode_permalink_key, get_deterministic_uuid
from superset.utils.core import get_user_id
@ -62,6 +63,8 @@ class CreateDashboardPermalinkCommand(BaseDashboardPermalinkCommand):
).run()
assert key.id # for type checks
return encode_permalink_key(key=key.id, salt=self.salt)
except KeyValueCodecEncodeException as ex:
raise DashboardPermalinkCreateFailedError(str(ex)) from ex
except SQLAlchemyError as ex:
logger.exception("Error running create command")
raise DashboardPermalinkCreateFailedError() from ex

View File

@ -25,7 +25,11 @@ from superset.dashboards.permalink.commands.base import BaseDashboardPermalinkCo
from superset.dashboards.permalink.exceptions import DashboardPermalinkGetFailedError
from superset.dashboards.permalink.types import DashboardPermalinkValue
from superset.key_value.commands.get import GetKeyValueCommand
from superset.key_value.exceptions import KeyValueGetFailedError, KeyValueParseKeyError
from superset.key_value.exceptions import (
KeyValueCodecDecodeException,
KeyValueGetFailedError,
KeyValueParseKeyError,
)
from superset.key_value.utils import decode_permalink_id
logger = logging.getLogger(__name__)
@ -51,6 +55,7 @@ class GetDashboardPermalinkCommand(BaseDashboardPermalinkCommand):
return None
except (
DashboardNotFoundError,
KeyValueCodecDecodeException,
KeyValueGetFailedError,
KeyValueParseKeyError,
) as ex:

View File

@ -17,7 +17,7 @@
from marshmallow import fields, Schema
class DashboardPermalinkPostSchema(Schema):
class DashboardPermalinkStateSchema(Schema):
dataMask = fields.Dict(
required=False,
allow_none=True,
@ -52,3 +52,12 @@ class DashboardPermalinkPostSchema(Schema):
allow_none=True,
metadata={"description": "Optional anchor link added to url hash"},
)
class DashboardPermalinkSchema(Schema):
dashboardId = fields.String(
required=True,
allow_none=False,
metadata={"description": "The id or slug of the dasbhoard"},
)
state = fields.Nested(DashboardPermalinkStateSchema())

View File

@ -32,7 +32,7 @@ from superset.datasets.commands.exceptions import (
from superset.explore.permalink.commands.create import CreateExplorePermalinkCommand
from superset.explore.permalink.commands.get import GetExplorePermalinkCommand
from superset.explore.permalink.exceptions import ExplorePermalinkInvalidStateError
from superset.explore.permalink.schemas import ExplorePermalinkPostSchema
from superset.explore.permalink.schemas import ExplorePermalinkStateSchema
from superset.extensions import event_logger
from superset.key_value.exceptions import KeyValueAccessDeniedError
from superset.views.base_api import BaseSupersetApi, requires_json, statsd_metrics
@ -41,13 +41,13 @@ logger = logging.getLogger(__name__)
class ExplorePermalinkRestApi(BaseSupersetApi):
add_model_schema = ExplorePermalinkPostSchema()
add_model_schema = ExplorePermalinkStateSchema()
method_permission_name = MODEL_API_RW_METHOD_PERMISSION_MAP
allow_browser_login = True
class_permission_name = "ExplorePermalinkRestApi"
resource_name = "explore"
openapi_spec_tag = "Explore Permanent Link"
openapi_spec_component_schemas = (ExplorePermalinkPostSchema,)
openapi_spec_component_schemas = (ExplorePermalinkStateSchema,)
@expose("/permalink", methods=("POST",))
@protect()

View File

@ -17,13 +17,18 @@
from abc import ABC
from superset.commands.base import BaseCommand
from superset.explore.permalink.schemas import ExplorePermalinkSchema
from superset.key_value.shared_entries import get_permalink_salt
from superset.key_value.types import JsonKeyValueCodec, KeyValueResource, SharedKey
from superset.key_value.types import (
KeyValueResource,
MarshmallowKeyValueCodec,
SharedKey,
)
class BaseExplorePermalinkCommand(BaseCommand, ABC):
resource: KeyValueResource = KeyValueResource.EXPLORE_PERMALINK
codec = JsonKeyValueCodec()
codec = MarshmallowKeyValueCodec(ExplorePermalinkSchema())
@property
def salt(self) -> str:

View File

@ -23,6 +23,7 @@ from superset.explore.permalink.commands.base import BaseExplorePermalinkCommand
from superset.explore.permalink.exceptions import ExplorePermalinkCreateFailedError
from superset.explore.utils import check_access as check_chart_access
from superset.key_value.commands.create import CreateKeyValueCommand
from superset.key_value.exceptions import KeyValueCodecEncodeException
from superset.key_value.utils import encode_permalink_key
from superset.utils.core import DatasourceType
@ -58,6 +59,8 @@ class CreateExplorePermalinkCommand(BaseExplorePermalinkCommand):
if key.id is None:
raise ExplorePermalinkCreateFailedError("Unexpected missing key id")
return encode_permalink_key(key=key.id, salt=self.salt)
except KeyValueCodecEncodeException as ex:
raise ExplorePermalinkCreateFailedError(str(ex)) from ex
except SQLAlchemyError as ex:
logger.exception("Error running create command")
raise ExplorePermalinkCreateFailedError() from ex

View File

@ -25,7 +25,11 @@ from superset.explore.permalink.exceptions import ExplorePermalinkGetFailedError
from superset.explore.permalink.types import ExplorePermalinkValue
from superset.explore.utils import check_access as check_chart_access
from superset.key_value.commands.get import GetKeyValueCommand
from superset.key_value.exceptions import KeyValueGetFailedError, KeyValueParseKeyError
from superset.key_value.exceptions import (
KeyValueCodecDecodeException,
KeyValueGetFailedError,
KeyValueParseKeyError,
)
from superset.key_value.utils import decode_permalink_id
from superset.utils.core import DatasourceType
@ -59,6 +63,7 @@ class GetExplorePermalinkCommand(BaseExplorePermalinkCommand):
return None
except (
DatasetNotFoundError,
KeyValueCodecDecodeException,
KeyValueGetFailedError,
KeyValueParseKeyError,
) as ex:

View File

@ -17,7 +17,7 @@
from marshmallow import fields, Schema
class ExplorePermalinkPostSchema(Schema):
class ExplorePermalinkStateSchema(Schema):
formData = fields.Dict(
required=True,
allow_none=False,
@ -41,3 +41,27 @@ class ExplorePermalinkPostSchema(Schema):
allow_none=True,
metadata={"description": "URL Parameters"},
)
class ExplorePermalinkSchema(Schema):
chartId = fields.Integer(
required=False,
allow_none=True,
metadata={"description": "The id of the chart"},
)
datasourceType = fields.String(
required=True,
allow_none=False,
metadata={"description": "The type of the datasource"},
)
datasourceId = fields.Integer(
required=False,
allow_none=True,
metadata={"description": "The id of the datasource"},
)
datasource = fields.String(
required=False,
allow_none=True,
metadata={"description": "The fully qualified datasource reference"},
)
state = fields.Nested(ExplorePermalinkStateSchema())

View File

@ -52,3 +52,15 @@ class KeyValueUpsertFailedError(UpdateFailedError):
class KeyValueAccessDeniedError(ForbiddenError):
message = _("You don't have permission to modify the value.")
class KeyValueCodecException(SupersetException):
pass
class KeyValueCodecEncodeException(KeyValueCodecException):
message = _("Unable to encode value")
class KeyValueCodecDecodeException(KeyValueCodecException):
message = _("Unable to decode value")

View File

@ -24,6 +24,13 @@ from enum import Enum
from typing import Any, Optional, TypedDict
from uuid import UUID
from marshmallow import Schema, ValidationError
from superset.key_value.exceptions import (
KeyValueCodecDecodeException,
KeyValueCodecEncodeException,
)
@dataclass
class Key:
@ -61,10 +68,16 @@ class KeyValueCodec(ABC):
class JsonKeyValueCodec(KeyValueCodec):
def encode(self, value: dict[Any, Any]) -> bytes:
return bytes(json.dumps(value), encoding="utf-8")
try:
return bytes(json.dumps(value), encoding="utf-8")
except TypeError as ex:
raise KeyValueCodecEncodeException(str(ex)) from ex
def decode(self, value: bytes) -> dict[Any, Any]:
return json.loads(value)
try:
return json.loads(value)
except TypeError as ex:
raise KeyValueCodecDecodeException(str(ex)) from ex
class PickleKeyValueCodec(KeyValueCodec):
@ -73,3 +86,22 @@ class PickleKeyValueCodec(KeyValueCodec):
def decode(self, value: bytes) -> dict[Any, Any]:
return pickle.loads(value)
class MarshmallowKeyValueCodec(JsonKeyValueCodec):
def __init__(self, schema: Schema):
self.schema = schema
def encode(self, value: dict[Any, Any]) -> bytes:
try:
obj = self.schema.dump(value)
return super().encode(obj)
except ValidationError as ex:
raise KeyValueCodecEncodeException(message=str(ex)) from ex
def decode(self, value: bytes) -> dict[Any, Any]:
try:
obj = super().decode(value)
return self.schema.load(obj)
except ValidationError as ex:
raise KeyValueCodecEncodeException(message=str(ex)) from ex

View File

@ -22,8 +22,9 @@ import pytest
from sqlalchemy.orm import Session
from superset import db
from superset.explore.permalink.schemas import ExplorePermalinkSchema
from superset.key_value.models import KeyValueEntry
from superset.key_value.types import JsonKeyValueCodec, KeyValueResource
from superset.key_value.types import KeyValueResource, MarshmallowKeyValueCodec
from superset.key_value.utils import decode_permalink_id, encode_permalink_key
from superset.models.slice import Slice
from superset.utils.core import DatasourceType
@ -94,14 +95,17 @@ def test_get_missing_chart(
chart_id = 1234
entry = KeyValueEntry(
resource=KeyValueResource.EXPLORE_PERMALINK,
value=JsonKeyValueCodec().encode(
value=MarshmallowKeyValueCodec(ExplorePermalinkSchema()).encode(
{
"chartId": chart_id,
"datasourceId": chart.datasource.id,
"datasourceType": DatasourceType.TABLE,
"formData": {
"slice_id": chart_id,
"datasource": f"{chart.datasource.id}__{chart.datasource.type}",
"datasourceType": DatasourceType.TABLE.value,
"state": {
"urlParams": [["foo", "bar"]],
"formData": {
"slice_id": chart_id,
"datasource": f"{chart.datasource.id}__{chart.datasource.type}",
},
},
}
),

View File

@ -0,0 +1,122 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from contextlib import nullcontext
from typing import Any
import pytest
from marshmallow import Schema
from superset.dashboards.permalink.schemas import DashboardPermalinkSchema
from superset.key_value.exceptions import KeyValueCodecEncodeException
from superset.key_value.types import (
JsonKeyValueCodec,
MarshmallowKeyValueCodec,
PickleKeyValueCodec,
)
@pytest.mark.parametrize(
"input_,expected_result",
[
(
{"foo": "bar"},
{"foo": "bar"},
),
(
{"foo": (1, 2, 3)},
{"foo": [1, 2, 3]},
),
(
{1, 2, 3},
KeyValueCodecEncodeException(),
),
(
object(),
KeyValueCodecEncodeException(),
),
],
)
def test_json_codec(input_: Any, expected_result: Any):
cm = (
pytest.raises(type(expected_result))
if isinstance(expected_result, Exception)
else nullcontext()
)
with cm:
codec = JsonKeyValueCodec()
encoded_value = codec.encode(input_)
assert expected_result == codec.decode(encoded_value)
@pytest.mark.parametrize(
"schema,input_,expected_result",
[
(
DashboardPermalinkSchema(),
{
"dashboardId": "1",
"state": {
"urlParams": [["foo", "bar"], ["foo", "baz"]],
},
},
{
"dashboardId": "1",
"state": {
"urlParams": [("foo", "bar"), ("foo", "baz")],
},
},
),
(
DashboardPermalinkSchema(),
{"foo": "bar"},
KeyValueCodecEncodeException(),
),
],
)
def test_marshmallow_codec(schema: Schema, input_: Any, expected_result: Any):
cm = (
pytest.raises(type(expected_result))
if isinstance(expected_result, Exception)
else nullcontext()
)
with cm:
codec = MarshmallowKeyValueCodec(schema)
encoded_value = codec.encode(input_)
assert expected_result == codec.decode(encoded_value)
@pytest.mark.parametrize(
"input_,expected_result",
[
(
{1, 2, 3},
{1, 2, 3},
),
(
{"foo": 1, "bar": {1: (1, 2, 3)}, "baz": {1, 2, 3}},
{
"foo": 1,
"bar": {1: (1, 2, 3)},
"baz": {1, 2, 3},
},
),
],
)
def test_pickle_codec(input_: Any, expected_result: Any):
codec = PickleKeyValueCodec()
encoded_value = codec.encode(input_)
assert expected_result == codec.decode(encoded_value)