feat: introduce hashids permalink keys (#19324)

* feat: introduce hashids permalink keys

* implement dashboard permalinks

* remove shorturl notice from UPDATING.md

* lint

* fix test

* introduce KeyValueResource

* make filterState optional

* fix test

* fix resource names
This commit is contained in:
Ville Brofeldt 2022-03-24 21:53:09 +02:00 committed by GitHub
parent dc769a9a34
commit f4b71abb22
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
39 changed files with 344 additions and 367 deletions

View File

@ -62,7 +62,6 @@ flag for the legacy datasource editor (DISABLE_LEGACY_DATASOURCE_EDITOR) in conf
### Deprecations
- [19078](https://github.com/apache/superset/pull/19078): Creation of old shorturl links has been deprecated in favor of a new permalink feature that solves the long url problem (old shorturls will still work, though!). By default, new permalinks use UUID4 as the key. However, to use serial ids similar to the old shorturls, add the following to your `superset_config.py`: `PERMALINK_KEY_TYPE = "id"`.
- [18960](https://github.com/apache/superset/pull/18960): Persisting URL params in chart metadata is no longer supported. To set a default value for URL params in Jinja code, use the optional second argument: `url_param("my-param", "my-default-value")`.
### Other

View File

@ -118,6 +118,8 @@ graphlib-backport==1.0.3
# via apache-superset
gunicorn==20.1.0
# via apache-superset
hashids==1.3.1
# via apache-superset
holidays==0.10.3
# via apache-superset
humanize==3.11.0

View File

@ -88,6 +88,7 @@ setup(
"geopy",
"graphlib-backport",
"gunicorn>=20.1.0",
"hashids>=1.3.1, <2",
"holidays==0.10.3", # PINNED! https://github.com/dr-prodigy/python-holidays/issues/406
"humanize",
"isodate",

View File

@ -43,7 +43,6 @@ from typing_extensions import Literal
from superset.constants import CHANGE_ME_SECRET_KEY
from superset.jinja_context import BaseTemplateProcessor
from superset.key_value.types import KeyType
from superset.stats_logger import DummyStatsLogger
from superset.superset_typing import CacheConfig
from superset.utils.core import is_test, parse_boolean_string
@ -600,8 +599,6 @@ EXPLORE_FORM_DATA_CACHE_CONFIG: CacheConfig = {
# store cache keys by datasource UID (via CacheKey) for custom processing/invalidation
STORE_CACHE_KEYS_IN_METADATA_DB = False
PERMALINK_KEY_TYPE: KeyType = "uuid"
# CORS Options
ENABLE_CORS = False
CORS_OPTIONS: Dict[Any, Any] = {}

View File

@ -18,10 +18,11 @@ from flask import session
from superset.dashboards.dao import DashboardDAO
from superset.extensions import cache_manager
from superset.key_value.utils import random_key
from superset.temporary_cache.commands.create import CreateTemporaryCacheCommand
from superset.temporary_cache.commands.entry import Entry
from superset.temporary_cache.commands.parameters import CommandParameters
from superset.temporary_cache.utils import cache_key, random_key
from superset.temporary_cache.utils import cache_key
class CreateFilterStateCommand(CreateTemporaryCacheCommand):

View File

@ -20,11 +20,12 @@ from flask import session
from superset.dashboards.dao import DashboardDAO
from superset.extensions import cache_manager
from superset.key_value.utils import random_key
from superset.temporary_cache.commands.entry import Entry
from superset.temporary_cache.commands.exceptions import TemporaryCacheAccessDeniedError
from superset.temporary_cache.commands.parameters import CommandParameters
from superset.temporary_cache.commands.update import UpdateTemporaryCacheCommand
from superset.temporary_cache.utils import cache_key, random_key
from superset.temporary_cache.utils import cache_key
class UpdateFilterStateCommand(UpdateTemporaryCacheCommand):

View File

@ -16,7 +16,7 @@
# under the License.
import logging
from flask import current_app, g, request, Response
from flask import g, request, Response
from flask_appbuilder.api import BaseApi, expose, protect, safe
from marshmallow import ValidationError
@ -101,11 +101,10 @@ class DashboardPermalinkRestApi(BaseApi):
500:
$ref: '#/components/responses/500'
"""
key_type = current_app.config["PERMALINK_KEY_TYPE"]
try:
state = self.add_model_schema.load(request.json)
key = CreateDashboardPermalinkCommand(
actor=g.user, dashboard_id=pk, state=state, key_type=key_type,
actor=g.user, dashboard_id=pk, state=state,
).run()
http_origin = request.headers.environ.get("HTTP_ORIGIN")
url = f"{http_origin}/superset/dashboard/p/{key}/"
@ -158,10 +157,7 @@ class DashboardPermalinkRestApi(BaseApi):
$ref: '#/components/responses/500'
"""
try:
key_type = current_app.config["PERMALINK_KEY_TYPE"]
value = GetDashboardPermalinkCommand(
actor=g.user, key=key, key_type=key_type
).run()
value = GetDashboardPermalinkCommand(actor=g.user, key=key).run()
if not value:
return self.response_404()
return self.response(200, **value)

View File

@ -17,7 +17,13 @@
from abc import ABC
from superset.commands.base import BaseCommand
from superset.key_value.shared_entries import get_permalink_salt
from superset.key_value.types import KeyValueResource, SharedKey
class BaseDashboardPermalinkCommand(BaseCommand, ABC):
resource = "dashboard_permalink"
resource = KeyValueResource.DASHBOARD_PERMALINK
@property
def salt(self) -> str:
return get_permalink_salt(SharedKey.DASHBOARD_PERMALINK_SALT)

View File

@ -24,23 +24,18 @@ from superset.dashboards.permalink.commands.base import BaseDashboardPermalinkCo
from superset.dashboards.permalink.exceptions import DashboardPermalinkCreateFailedError
from superset.dashboards.permalink.types import DashboardPermalinkState
from superset.key_value.commands.create import CreateKeyValueCommand
from superset.key_value.types import KeyType
from superset.key_value.utils import encode_permalink_key
logger = logging.getLogger(__name__)
class CreateDashboardPermalinkCommand(BaseDashboardPermalinkCommand):
def __init__(
self,
actor: User,
dashboard_id: str,
state: DashboardPermalinkState,
key_type: KeyType,
self, actor: User, dashboard_id: str, state: DashboardPermalinkState,
):
self.actor = actor
self.dashboard_id = dashboard_id
self.state = state
self.key_type = key_type
def run(self) -> str:
self.validate()
@ -50,12 +45,10 @@ class CreateDashboardPermalinkCommand(BaseDashboardPermalinkCommand):
"dashboardId": self.dashboard_id,
"state": self.state,
}
return CreateKeyValueCommand(
actor=self.actor,
resource=self.resource,
value=value,
key_type=self.key_type,
key = CreateKeyValueCommand(
actor=self.actor, resource=self.resource, value=value,
).run()
return encode_permalink_key(key=key.id, salt=self.salt)
except SQLAlchemyError as ex:
logger.exception("Error running create command")
raise DashboardPermalinkCreateFailedError() from ex

View File

@ -27,25 +27,21 @@ from superset.dashboards.permalink.exceptions import DashboardPermalinkGetFailed
from superset.dashboards.permalink.types import DashboardPermalinkValue
from superset.key_value.commands.get import GetKeyValueCommand
from superset.key_value.exceptions import KeyValueGetFailedError, KeyValueParseKeyError
from superset.key_value.types import KeyType
from superset.key_value.utils import decode_permalink_id
logger = logging.getLogger(__name__)
class GetDashboardPermalinkCommand(BaseDashboardPermalinkCommand):
def __init__(
self, actor: User, key: str, key_type: KeyType,
):
def __init__(self, actor: User, key: str):
self.actor = actor
self.key = key
self.key_type = key_type
def run(self) -> Optional[DashboardPermalinkValue]:
self.validate()
try:
command = GetKeyValueCommand(
resource=self.resource, key=self.key, key_type=self.key_type
)
key = decode_permalink_id(self.key, salt=self.salt)
command = GetKeyValueCommand(resource=self.resource, key=key)
value: Optional[DashboardPermalinkValue] = command.run()
if value:
DashboardDAO.get_by_id_or_slug(value["dashboardId"])

View File

@ -19,7 +19,7 @@ from marshmallow import fields, Schema
class DashboardPermalinkPostSchema(Schema):
filterState = fields.Dict(
required=True, allow_none=False, description="Native filter state",
required=False, allow_none=True, description="Native filter state",
)
urlParams = fields.List(
fields.Tuple(

View File

@ -18,7 +18,7 @@ from typing import Any, Dict, List, Optional, Tuple, TypedDict
class DashboardPermalinkState(TypedDict):
filterState: Dict[str, Any]
filterState: Optional[Dict[str, Any]]
hash: Optional[str]
urlParams: Optional[List[Tuple[str, str]]]

View File

@ -24,8 +24,9 @@ from superset.explore.form_data.commands.parameters import CommandParameters
from superset.explore.form_data.commands.state import TemporaryExploreState
from superset.explore.utils import check_access
from superset.extensions import cache_manager
from superset.key_value.utils import random_key
from superset.temporary_cache.commands.exceptions import TemporaryCacheCreateFailedError
from superset.temporary_cache.utils import cache_key, random_key
from superset.temporary_cache.utils import cache_key
from superset.utils.schema import validate_json
logger = logging.getLogger(__name__)

View File

@ -26,11 +26,12 @@ from superset.explore.form_data.commands.parameters import CommandParameters
from superset.explore.form_data.commands.state import TemporaryExploreState
from superset.explore.utils import check_access
from superset.extensions import cache_manager
from superset.key_value.utils import random_key
from superset.temporary_cache.commands.exceptions import (
TemporaryCacheAccessDeniedError,
TemporaryCacheUpdateFailedError,
)
from superset.temporary_cache.utils import cache_key, random_key
from superset.temporary_cache.utils import cache_key
from superset.utils.schema import validate_json
logger = logging.getLogger(__name__)

View File

@ -16,7 +16,7 @@
# under the License.
import logging
from flask import current_app, g, request, Response
from flask import g, request, Response
from flask_appbuilder.api import BaseApi, expose, protect, safe
from marshmallow import ValidationError
@ -98,12 +98,9 @@ class ExplorePermalinkRestApi(BaseApi):
500:
$ref: '#/components/responses/500'
"""
key_type = current_app.config["PERMALINK_KEY_TYPE"]
try:
state = self.add_model_schema.load(request.json)
key = CreateExplorePermalinkCommand(
actor=g.user, state=state, key_type=key_type,
).run()
key = CreateExplorePermalinkCommand(actor=g.user, state=state).run()
http_origin = request.headers.environ.get("HTTP_ORIGIN")
url = f"{http_origin}/superset/explore/p/{key}/"
return self.response(201, key=key, url=url)
@ -159,10 +156,7 @@ class ExplorePermalinkRestApi(BaseApi):
$ref: '#/components/responses/500'
"""
try:
key_type = current_app.config["PERMALINK_KEY_TYPE"]
value = GetExplorePermalinkCommand(
actor=g.user, key=key, key_type=key_type
).run()
value = GetExplorePermalinkCommand(actor=g.user, key=key).run()
if not value:
return self.response_404()
return self.response(200, **value)

View File

@ -17,7 +17,13 @@
from abc import ABC
from superset.commands.base import BaseCommand
from superset.key_value.shared_entries import get_permalink_salt
from superset.key_value.types import KeyValueResource, SharedKey
class BaseExplorePermalinkCommand(BaseCommand, ABC):
resource = "explore_permalink"
resource: KeyValueResource = KeyValueResource.EXPLORE_PERMALINK
@property
def salt(self) -> str:
return get_permalink_salt(SharedKey.EXPLORE_PERMALINK_SALT)

View File

@ -24,18 +24,17 @@ from superset.explore.permalink.commands.base import BaseExplorePermalinkCommand
from superset.explore.permalink.exceptions import ExplorePermalinkCreateFailedError
from superset.explore.utils import check_access
from superset.key_value.commands.create import CreateKeyValueCommand
from superset.key_value.types import KeyType
from superset.key_value.utils import encode_permalink_key
logger = logging.getLogger(__name__)
class CreateExplorePermalinkCommand(BaseExplorePermalinkCommand):
def __init__(self, actor: User, state: Dict[str, Any], key_type: KeyType):
def __init__(self, actor: User, state: Dict[str, Any]):
self.actor = actor
self.chart_id: Optional[int] = state["formData"].get("slice_id")
self.datasource: str = state["formData"]["datasource"]
self.state = state
self.key_type = key_type
def run(self) -> str:
self.validate()
@ -49,12 +48,10 @@ class CreateExplorePermalinkCommand(BaseExplorePermalinkCommand):
"state": self.state,
}
command = CreateKeyValueCommand(
actor=self.actor,
resource=self.resource,
value=value,
key_type=self.key_type,
actor=self.actor, resource=self.resource, value=value,
)
return command.run()
key = command.run()
return encode_permalink_key(key=key.id, salt=self.salt)
except SQLAlchemyError as ex:
logger.exception("Error running create command")
raise ExplorePermalinkCreateFailedError() from ex

View File

@ -27,24 +27,22 @@ from superset.explore.permalink.types import ExplorePermalinkValue
from superset.explore.utils import check_access
from superset.key_value.commands.get import GetKeyValueCommand
from superset.key_value.exceptions import KeyValueGetFailedError, KeyValueParseKeyError
from superset.key_value.types import KeyType
from superset.key_value.utils import decode_permalink_id
logger = logging.getLogger(__name__)
class GetExplorePermalinkCommand(BaseExplorePermalinkCommand):
def __init__(
self, actor: User, key: str, key_type: KeyType,
):
def __init__(self, actor: User, key: str):
self.actor = actor
self.key = key
self.key_type = key_type
def run(self) -> Optional[ExplorePermalinkValue]:
self.validate()
try:
key = decode_permalink_id(self.key, salt=self.salt)
value: Optional[ExplorePermalinkValue] = GetKeyValueCommand(
resource=self.resource, key=self.key, key_type=self.key_type
resource=self.resource, key=key,
).run()
if value:
chart_id: Optional[int] = value.get("chartId")

View File

@ -16,7 +16,6 @@
# under the License.
from datetime import datetime, timedelta
from hashlib import md5
from typing import Any, Dict, List, Optional
from uuid import UUID, uuid3
@ -24,10 +23,10 @@ from flask import Flask
from flask_caching import BaseCache
from superset.key_value.exceptions import KeyValueCreateFailedError
from superset.key_value.types import KeyType
from superset.key_value.types import KeyValueResource
from superset.key_value.utils import get_uuid_namespace
RESOURCE = "superset_metastore_cache"
KEY_TYPE: KeyType = "uuid"
RESOURCE = KeyValueResource.METASTORE_CACHE
class SupersetMetastoreCache(BaseCache):
@ -39,15 +38,12 @@ class SupersetMetastoreCache(BaseCache):
def factory(
cls, app: Flask, config: Dict[str, Any], args: List[Any], kwargs: Dict[str, Any]
) -> BaseCache:
# base namespace for generating deterministic UUIDs
md5_obj = md5()
seed = config.get("CACHE_KEY_PREFIX", "")
md5_obj.update(seed.encode("utf-8"))
kwargs["namespace"] = UUID(md5_obj.hexdigest())
kwargs["namespace"] = get_uuid_namespace(seed)
return cls(*args, **kwargs)
def get_key(self, key: str) -> str:
return str(uuid3(self.namespace, key))
def get_key(self, key: str) -> UUID:
return uuid3(self.namespace, key)
@staticmethod
def _prune() -> None:
@ -70,7 +66,6 @@ class SupersetMetastoreCache(BaseCache):
UpsertKeyValueCommand(
resource=RESOURCE,
key_type=KEY_TYPE,
key=self.get_key(key),
value=value,
expires_on=self._get_expiry(timeout),
@ -85,7 +80,6 @@ class SupersetMetastoreCache(BaseCache):
CreateKeyValueCommand(
resource=RESOURCE,
value=value,
key_type=KEY_TYPE,
key=self.get_key(key),
expires_on=self._get_expiry(timeout),
).run()
@ -98,9 +92,7 @@ class SupersetMetastoreCache(BaseCache):
# pylint: disable=import-outside-toplevel
from superset.key_value.commands.get import GetKeyValueCommand
return GetKeyValueCommand(
resource=RESOURCE, key_type=KEY_TYPE, key=self.get_key(key),
).run()
return GetKeyValueCommand(resource=RESOURCE, key=self.get_key(key)).run()
def has(self, key: str) -> bool:
entry = self.get(key)
@ -112,6 +104,4 @@ class SupersetMetastoreCache(BaseCache):
# pylint: disable=import-outside-toplevel
from superset.key_value.commands.delete import DeleteKeyValueCommand
return DeleteKeyValueCommand(
resource=RESOURCE, key_type=KEY_TYPE, key=self.get_key(key),
).run()
return DeleteKeyValueCommand(resource=RESOURCE, key=self.get_key(key)).run()

View File

@ -17,7 +17,7 @@
import logging
import pickle
from datetime import datetime
from typing import Any, Optional
from typing import Any, Optional, Union
from uuid import UUID
from flask_appbuilder.security.sqla.models import User
@ -27,27 +27,24 @@ from superset import db
from superset.commands.base import BaseCommand
from superset.key_value.exceptions import KeyValueCreateFailedError
from superset.key_value.models import KeyValueEntry
from superset.key_value.types import KeyType
from superset.key_value.utils import extract_key
from superset.key_value.types import Key, KeyValueResource
logger = logging.getLogger(__name__)
class CreateKeyValueCommand(BaseCommand):
actor: Optional[User]
resource: str
resource: KeyValueResource
value: Any
key_type: KeyType
key: Optional[str]
key: Optional[Union[int, UUID]]
expires_on: Optional[datetime]
def __init__(
self,
resource: str,
resource: KeyValueResource,
value: Any,
key_type: KeyType = "uuid",
actor: Optional[User] = None,
key: Optional[str] = None,
key: Optional[Union[int, UUID]] = None,
expires_on: Optional[datetime] = None,
):
"""
@ -55,7 +52,6 @@ class CreateKeyValueCommand(BaseCommand):
:param resource: the resource (dashboard, chart etc)
:param value: the value to persist in the key-value store
:param key_type: the type of the key to return
:param actor: the user performing the command
:param key: id of entry (autogenerated if undefined)
:param expires_on: entry expiration time
@ -64,11 +60,10 @@ class CreateKeyValueCommand(BaseCommand):
self.resource = resource
self.actor = actor
self.value = value
self.key_type = key_type
self.key = key
self.expires_on = expires_on
def run(self) -> str:
def run(self) -> Key:
try:
return self.create()
except SQLAlchemyError as ex:
@ -79,9 +74,9 @@ class CreateKeyValueCommand(BaseCommand):
def validate(self) -> None:
pass
def create(self) -> str:
def create(self) -> Key:
entry = KeyValueEntry(
resource=self.resource,
resource=self.resource.value,
value=pickle.dumps(self.value),
created_on=datetime.now(),
created_by_fk=None
@ -91,12 +86,12 @@ class CreateKeyValueCommand(BaseCommand):
)
if self.key is not None:
try:
if self.key_type == "uuid":
entry.uuid = UUID(self.key)
if isinstance(self.key, UUID):
entry.uuid = self.key
else:
entry.id = int(self.key)
entry.id = self.key
except ValueError as ex:
raise KeyValueCreateFailedError() from ex
db.session.add(entry)
db.session.commit()
return extract_key(entry, self.key_type)
return Key(id=entry.id, uuid=entry.uuid)

View File

@ -15,40 +15,35 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Optional
from typing import Union
from uuid import UUID
from flask_appbuilder.security.sqla.models import User
from sqlalchemy.exc import SQLAlchemyError
from superset import db
from superset.commands.base import BaseCommand
from superset.key_value.exceptions import KeyValueDeleteFailedError
from superset.key_value.models import KeyValueEntry
from superset.key_value.types import KeyType
from superset.key_value.types import KeyValueResource
from superset.key_value.utils import get_filter
logger = logging.getLogger(__name__)
class DeleteKeyValueCommand(BaseCommand):
key: str
key_type: KeyType
resource: str
key: Union[int, UUID]
resource: KeyValueResource
def __init__(
self, resource: str, key: str, key_type: KeyType = "uuid",
):
def __init__(self, resource: KeyValueResource, key: Union[int, UUID]):
"""
Delete a key-value pair
:param resource: the resource (dashboard, chart etc)
:param key: the key to delete
:param key_type: the type of key
:return: was the entry deleted or not
"""
self.resource = resource
self.key = key
self.key_type = key_type
def run(self) -> bool:
try:
@ -62,7 +57,7 @@ class DeleteKeyValueCommand(BaseCommand):
pass
def delete(self) -> bool:
filter_ = get_filter(self.resource, self.key, self.key_type)
filter_ = get_filter(self.resource, self.key)
entry = (
db.session.query(KeyValueEntry)
.filter_by(**filter_)

View File

@ -17,20 +17,22 @@
import logging
from datetime import datetime
from sqlalchemy import and_
from sqlalchemy.exc import SQLAlchemyError
from superset import db
from superset.commands.base import BaseCommand
from superset.key_value.exceptions import KeyValueDeleteFailedError
from superset.key_value.models import KeyValueEntry
from superset.key_value.types import KeyValueResource
logger = logging.getLogger(__name__)
class DeleteExpiredKeyValueCommand(BaseCommand):
resource: str
resource: KeyValueResource
def __init__(self, resource: str):
def __init__(self, resource: KeyValueResource):
"""
Delete all expired key-value pairs
@ -50,11 +52,15 @@ class DeleteExpiredKeyValueCommand(BaseCommand):
def validate(self) -> None:
pass
@staticmethod
def delete_expired() -> None:
def delete_expired(self) -> None:
(
db.session.query(KeyValueEntry)
.filter(KeyValueEntry.expires_on <= datetime.now())
.filter(
and_(
KeyValueEntry.resource == self.resource.value,
KeyValueEntry.expires_on <= datetime.now(),
)
)
.delete()
)
db.session.commit()

View File

@ -18,7 +18,8 @@
import logging
import pickle
from datetime import datetime
from typing import Any, Optional
from typing import Any, Optional, Union
from uuid import UUID
from sqlalchemy.exc import SQLAlchemyError
@ -26,29 +27,26 @@ from superset import db
from superset.commands.base import BaseCommand
from superset.key_value.exceptions import KeyValueGetFailedError
from superset.key_value.models import KeyValueEntry
from superset.key_value.types import KeyType
from superset.key_value.types import KeyValueResource
from superset.key_value.utils import get_filter
logger = logging.getLogger(__name__)
class GetKeyValueCommand(BaseCommand):
key: str
key_type: KeyType
resource: str
resource: KeyValueResource
key: Union[int, UUID]
def __init__(self, resource: str, key: str, key_type: KeyType = "uuid"):
def __init__(self, resource: KeyValueResource, key: Union[int, UUID]):
"""
Retrieve a key value entry
:param resource: the resource (dashboard, chart etc)
:param key: the key to retrieve
:param key_type: the type of the key to retrieve
:return: the value associated with the key if present
"""
self.resource = resource
self.key = key
self.key_type = key_type
def run(self) -> Any:
try:
@ -61,7 +59,7 @@ class GetKeyValueCommand(BaseCommand):
pass
def get(self) -> Optional[Any]:
filter_ = get_filter(self.resource, self.key, self.key_type)
filter_ = get_filter(self.resource, self.key)
entry = (
db.session.query(KeyValueEntry)
.filter_by(**filter_)

View File

@ -18,7 +18,8 @@
import logging
import pickle
from datetime import datetime
from typing import Any, Optional
from typing import Any, Optional, Union
from uuid import UUID
from flask_appbuilder.security.sqla.models import User
from sqlalchemy.exc import SQLAlchemyError
@ -27,27 +28,25 @@ from superset import db
from superset.commands.base import BaseCommand
from superset.key_value.exceptions import KeyValueUpdateFailedError
from superset.key_value.models import KeyValueEntry
from superset.key_value.types import KeyType
from superset.key_value.utils import extract_key, get_filter
from superset.key_value.types import Key, KeyValueResource
from superset.key_value.utils import get_filter
logger = logging.getLogger(__name__)
class UpdateKeyValueCommand(BaseCommand):
actor: Optional[User]
resource: str
resource: KeyValueResource
value: Any
key: str
key_type: KeyType
key: Union[int, UUID]
expires_on: Optional[datetime]
def __init__(
self,
resource: str,
key: str,
resource: KeyValueResource,
key: Union[int, UUID],
value: Any,
actor: Optional[User] = None,
key_type: KeyType = "uuid",
expires_on: Optional[datetime] = None,
):
"""
@ -57,7 +56,6 @@ class UpdateKeyValueCommand(BaseCommand):
:param key: the key to update
:param value: the value to persist in the key-value store
:param actor: the user performing the command
:param key_type: the type of the key to update
:param expires_on: entry expiration time
:return: the key associated with the updated value
"""
@ -65,10 +63,9 @@ class UpdateKeyValueCommand(BaseCommand):
self.resource = resource
self.key = key
self.value = value
self.key_type = key_type
self.expires_on = expires_on
def run(self) -> Optional[str]:
def run(self) -> Optional[Key]:
try:
return self.update()
except SQLAlchemyError as ex:
@ -79,8 +76,8 @@ class UpdateKeyValueCommand(BaseCommand):
def validate(self) -> None:
pass
def update(self) -> Optional[str]:
filter_ = get_filter(self.resource, self.key, self.key_type)
def update(self) -> Optional[Key]:
filter_ = get_filter(self.resource, self.key)
entry: KeyValueEntry = (
db.session.query(KeyValueEntry)
.filter_by(**filter_)
@ -96,6 +93,6 @@ class UpdateKeyValueCommand(BaseCommand):
)
db.session.merge(entry)
db.session.commit()
return extract_key(entry, self.key_type)
return Key(id=entry.id, uuid=entry.uuid)
return None

View File

@ -18,7 +18,8 @@
import logging
import pickle
from datetime import datetime
from typing import Any, Optional
from typing import Any, Optional, Union
from uuid import UUID
from flask_appbuilder.security.sqla.models import User
from sqlalchemy.exc import SQLAlchemyError
@ -28,27 +29,25 @@ from superset.commands.base import BaseCommand
from superset.key_value.commands.create import CreateKeyValueCommand
from superset.key_value.exceptions import KeyValueUpdateFailedError
from superset.key_value.models import KeyValueEntry
from superset.key_value.types import KeyType
from superset.key_value.utils import extract_key, get_filter
from superset.key_value.types import Key, KeyValueResource
from superset.key_value.utils import get_filter
logger = logging.getLogger(__name__)
class UpsertKeyValueCommand(BaseCommand):
actor: Optional[User]
resource: str
resource: KeyValueResource
value: Any
key: str
key_type: KeyType
key: Union[int, UUID]
expires_on: Optional[datetime]
def __init__(
self,
resource: str,
key: str,
resource: KeyValueResource,
key: Union[int, UUID],
value: Any,
actor: Optional[User] = None,
key_type: KeyType = "uuid",
expires_on: Optional[datetime] = None,
):
"""
@ -66,10 +65,9 @@ class UpsertKeyValueCommand(BaseCommand):
self.resource = resource
self.key = key
self.value = value
self.key_type = key_type
self.expires_on = expires_on
def run(self) -> Optional[str]:
def run(self) -> Optional[Key]:
try:
return self.upsert()
except SQLAlchemyError as ex:
@ -80,8 +78,8 @@ class UpsertKeyValueCommand(BaseCommand):
def validate(self) -> None:
pass
def upsert(self) -> Optional[str]:
filter_ = get_filter(self.resource, self.key, self.key_type)
def upsert(self) -> Optional[Key]:
filter_ = get_filter(self.resource, self.key)
entry: KeyValueEntry = (
db.session.query(KeyValueEntry)
.filter_by(**filter_)
@ -97,12 +95,11 @@ class UpsertKeyValueCommand(BaseCommand):
)
db.session.merge(entry)
db.session.commit()
return extract_key(entry, self.key_type)
return Key(entry.id, entry.uuid)
else:
return CreateKeyValueCommand(
resource=self.resource,
value=self.value,
key_type=self.key_type,
actor=self.actor,
key=self.key,
expires_on=self.expires_on,

View File

@ -0,0 +1,51 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Optional
from uuid import uuid3
from superset.key_value.types import KeyValueResource, SharedKey
from superset.key_value.utils import get_uuid_namespace, random_key
from superset.utils.memoized import memoized
RESOURCE = KeyValueResource.APP
NAMESPACE = get_uuid_namespace("")
def get_shared_value(key: SharedKey) -> Optional[Any]:
# pylint: disable=import-outside-toplevel
from superset.key_value.commands.get import GetKeyValueCommand
uuid_key = uuid3(NAMESPACE, key)
return GetKeyValueCommand(RESOURCE, key=uuid_key).run()
def set_shared_value(key: SharedKey, value: Any) -> None:
# pylint: disable=import-outside-toplevel
from superset.key_value.commands.create import CreateKeyValueCommand
uuid_key = uuid3(NAMESPACE, key)
CreateKeyValueCommand(resource=RESOURCE, value=value, key=uuid_key).run()
@memoized
def get_permalink_salt(key: SharedKey) -> str:
salt = get_shared_value(key)
if salt is None:
salt = random_key()
set_shared_value(key, value=salt)
return salt

View File

@ -15,7 +15,8 @@
# specific language governing permissions and limitations
# under the License.
from dataclasses import dataclass
from typing import Literal, Optional, TypedDict
from enum import Enum
from typing import Optional, TypedDict
from uuid import UUID
@ -25,10 +26,19 @@ class Key:
uuid: Optional[UUID]
KeyType = Literal["id", "uuid"]
class KeyValueFilter(TypedDict, total=False):
resource: str
id: Optional[int]
uuid: Optional[UUID]
class KeyValueResource(str, Enum):
APP = "app"
DASHBOARD_PERMALINK = "dashboard_permalink"
EXPLORE_PERMALINK = "explore_permalink"
METASTORE_CACHE = "superset_metastore_cache"
class SharedKey(str, Enum):
DASHBOARD_PERMALINK_SALT = "dashboard_permalink_salt"
EXPLORE_PERMALINK_SALT = "explore_permalink_salt"

View File

@ -14,44 +14,52 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Literal
from __future__ import annotations
from hashlib import md5
from secrets import token_urlsafe
from typing import Union
from uuid import UUID
from flask import current_app
import hashids
from flask_babel import gettext as _
from superset.key_value.exceptions import KeyValueParseKeyError
from superset.key_value.models import KeyValueEntry
from superset.key_value.types import Key, KeyType, KeyValueFilter
from superset.key_value.types import KeyValueFilter, KeyValueResource
HASHIDS_MIN_LENGTH = 11
def parse_permalink_key(key: str) -> Key:
key_type: Literal["id", "uuid"] = current_app.config["PERMALINK_KEY_TYPE"]
if key_type == "id":
return Key(id=int(key), uuid=None)
return Key(id=None, uuid=UUID(key))
def random_key() -> str:
return token_urlsafe(48)
def format_permalink_key(key: Key) -> str:
"""
return the string representation of the key
:param key: a key object with either a numerical or uuid key
:return: a formatted string
"""
return str(key.id if key.id is not None else key.uuid)
def extract_key(entry: KeyValueEntry, key_type: KeyType) -> str:
return str(entry.id if key_type == "id" else entry.uuid)
def get_filter(resource: str, key: str, key_type: KeyType) -> KeyValueFilter:
def get_filter(resource: KeyValueResource, key: Union[int, UUID]) -> KeyValueFilter:
try:
filter_: KeyValueFilter = {"resource": resource}
if key_type == "uuid":
filter_["uuid"] = UUID(key)
filter_: KeyValueFilter = {"resource": resource.value}
if isinstance(key, UUID):
filter_["uuid"] = key
else:
filter_["id"] = int(key)
filter_["id"] = key
return filter_
except ValueError as ex:
raise KeyValueParseKeyError() from ex
def encode_permalink_key(key: int, salt: str) -> str:
obj = hashids.Hashids(salt, min_length=HASHIDS_MIN_LENGTH)
return obj.encode(key)
def decode_permalink_id(key: str, salt: str) -> int:
obj = hashids.Hashids(salt, min_length=HASHIDS_MIN_LENGTH)
ids = obj.decode(key)
if len(ids) == 1:
return ids[0]
raise KeyValueParseKeyError(_("Invalid permalink key"))
def get_uuid_namespace(seed: str) -> UUID:
md5_obj = md5()
md5_obj.update(seed.encode("utf-8"))
return UUID(md5_obj.hexdigest())

View File

@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from secrets import token_urlsafe
from typing import Any
SEPARATOR = ";"
@ -22,7 +21,3 @@ SEPARATOR = ";"
def cache_key(*args: Any) -> str:
return SEPARATOR.join(str(arg) for arg in args)
def random_key() -> str:
return token_urlsafe(48)

View File

@ -748,8 +748,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
form_data_key = request.args.get("form_data_key")
if key is not None:
key_type = config["PERMALINK_KEY_TYPE"]
command = GetExplorePermalinkCommand(g.user, key, key_type)
command = GetExplorePermalinkCommand(g.user, key)
try:
permalink_value = command.run()
if permalink_value:
@ -2008,9 +2007,8 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
def dashboard_permalink( # pylint: disable=no-self-use
self, key: str,
) -> FlaskResponse:
key_type = config["PERMALINK_KEY_TYPE"]
try:
value = GetDashboardPermalinkCommand(g.user, key, key_type).run()
value = GetDashboardPermalinkCommand(g.user, key).run()
except DashboardPermalinkGetFailedError as ex:
flash(__("Error: %(msg)s", msg=ex.message), "danger")
return redirect("/dashboard/list/")

View File

@ -15,7 +15,9 @@
# specific language governing permissions and limitations
# under the License.
import json
from typing import Iterator
from unittest.mock import patch
from uuid import uuid3
import pytest
from flask_appbuilder.security.sqla.models import User
@ -24,8 +26,9 @@ from sqlalchemy.orm import Session
from superset import db
from superset.dashboards.commands.exceptions import DashboardAccessDeniedError
from superset.key_value.models import KeyValueEntry
from superset.key_value.types import KeyValueResource
from superset.key_value.utils import decode_permalink_id
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from tests.integration_tests.base_tests import login
from tests.integration_tests.fixtures.client import client
from tests.integration_tests.fixtures.world_bank_dashboard import (
@ -35,7 +38,7 @@ from tests.integration_tests.fixtures.world_bank_dashboard import (
from tests.integration_tests.test_app import app
STATE = {
"filterState": {"FILTER_1": "foo",},
"filterState": {"FILTER_1": "foo"},
"hash": "my-anchor",
}
@ -48,7 +51,22 @@ def dashboard_id(load_world_bank_dashboard_with_slices) -> int:
return dashboard.id
def test_post(client, dashboard_id: int):
@pytest.fixture
def permalink_salt() -> Iterator[str]:
from superset.key_value.shared_entries import get_permalink_salt, get_uuid_namespace
from superset.key_value.types import SharedKey
key = SharedKey.DASHBOARD_PERMALINK_SALT
salt = get_permalink_salt(key)
yield salt
namespace = get_uuid_namespace(salt)
db.session.query(KeyValueEntry).filter_by(
resource=KeyValueResource.APP, uuid=uuid3(namespace, key),
)
db.session.commit()
def test_post(client, dashboard_id: int, permalink_salt: str) -> None:
login(client, "admin")
resp = client.post(f"api/v1/dashboard/{dashboard_id}/permalink", json=STATE)
assert resp.status_code == 201
@ -56,7 +74,8 @@ def test_post(client, dashboard_id: int):
key = data["key"]
url = data["url"]
assert key in url
db.session.query(KeyValueEntry).filter_by(uuid=key).delete()
id_ = decode_permalink_id(key, permalink_salt)
db.session.query(KeyValueEntry).filter_by(id=id_).delete()
db.session.commit()
@ -76,7 +95,7 @@ def test_post_invalid_schema(client, dashboard_id: int):
assert resp.status_code == 400
def test_get(client, dashboard_id: int):
def test_get(client, dashboard_id: int, permalink_salt: str):
login(client, "admin")
resp = client.post(f"api/v1/dashboard/{dashboard_id}/permalink", json=STATE)
data = json.loads(resp.data.decode("utf-8"))
@ -86,5 +105,6 @@ def test_get(client, dashboard_id: int):
result = json.loads(resp.data.decode("utf-8"))
assert result["dashboardId"] == str(dashboard_id)
assert result["state"] == STATE
db.session.query(KeyValueEntry).filter_by(uuid=key).delete()
id_ = decode_permalink_id(key, permalink_salt)
db.session.query(KeyValueEntry).filter_by(id=id_).delete()
db.session.commit()

View File

@ -16,14 +16,16 @@
# under the License.
import json
import pickle
from typing import Any, Dict
from uuid import UUID
from typing import Any, Dict, Iterator
from uuid import uuid3
import pytest
from sqlalchemy.orm import Session
from superset import db
from superset.key_value.models import KeyValueEntry
from superset.key_value.types import KeyValueResource
from superset.key_value.utils import decode_permalink_id, encode_permalink_key
from superset.models.slice import Slice
from tests.integration_tests.base_tests import login
from tests.integration_tests.fixtures.client import client
@ -51,7 +53,22 @@ def form_data(chart) -> Dict[str, Any]:
}
def test_post(client, form_data):
@pytest.fixture
def permalink_salt() -> Iterator[str]:
from superset.key_value.shared_entries import get_permalink_salt, get_uuid_namespace
from superset.key_value.types import SharedKey
key = SharedKey.EXPLORE_PERMALINK_SALT
salt = get_permalink_salt(key)
yield salt
namespace = get_uuid_namespace(salt)
db.session.query(KeyValueEntry).filter_by(
resource=KeyValueResource.APP, uuid=uuid3(namespace, key),
)
db.session.commit()
def test_post(client, form_data: Dict[str, Any], permalink_salt: str):
login(client, "admin")
resp = client.post(f"api/v1/explore/permalink", json={"formData": form_data})
assert resp.status_code == 201
@ -59,7 +76,8 @@ def test_post(client, form_data):
key = data["key"]
url = data["url"]
assert key in url
db.session.query(KeyValueEntry).filter_by(uuid=key).delete()
id_ = decode_permalink_id(key, permalink_salt)
db.session.query(KeyValueEntry).filter_by(id=id_).delete()
db.session.commit()
@ -69,21 +87,18 @@ def test_post_access_denied(client, form_data):
assert resp.status_code == 404
def test_get_missing_chart(client, chart):
def test_get_missing_chart(client, chart, permalink_salt: str) -> None:
from superset.key_value.models import KeyValueEntry
key = 1234
uuid_key = "e2ea9d19-7988-4862-aa69-c3a1a7628cb9"
chart_id = 1234
entry = KeyValueEntry(
id=int(key),
uuid=UUID("e2ea9d19-7988-4862-aa69-c3a1a7628cb9"),
resource="explore_permalink",
resource=KeyValueResource.EXPLORE_PERMALINK,
value=pickle.dumps(
{
"chartId": key,
"chartId": chart_id,
"datasetId": chart.datasource.id,
"formData": {
"slice_id": key,
"slice_id": chart_id,
"datasource": f"{chart.datasource.id}__{chart.datasource.type}",
},
}
@ -91,20 +106,21 @@ def test_get_missing_chart(client, chart):
)
db.session.add(entry)
db.session.commit()
key = encode_permalink_key(entry.id, permalink_salt)
login(client, "admin")
resp = client.get(f"api/v1/explore/permalink/{uuid_key}")
resp = client.get(f"api/v1/explore/permalink/{key}")
assert resp.status_code == 404
db.session.delete(entry)
db.session.commit()
def test_post_invalid_schema(client):
def test_post_invalid_schema(client) -> None:
login(client, "admin")
resp = client.post(f"api/v1/explore/permalink", json={"abc": 123})
assert resp.status_code == 400
def test_get(client, form_data):
def test_get(client, form_data: Dict[str, Any], permalink_salt: str) -> None:
login(client, "admin")
resp = client.post(f"api/v1/explore/permalink", json={"formData": form_data})
data = json.loads(resp.data.decode("utf-8"))
@ -113,5 +129,6 @@ def test_get(client, form_data):
assert resp.status_code == 200
result = json.loads(resp.data.decode("utf-8"))
assert result["state"]["formData"] == form_data
db.session.query(KeyValueEntry).filter_by(uuid=key).delete()
id_ = decode_permalink_id(key, permalink_salt)
db.session.query(KeyValueEntry).filter_by(id=id_).delete()
db.session.commit()

View File

@ -36,12 +36,8 @@ def test_create_id_entry(app_context: AppContext, admin: User) -> None:
from superset.key_value.commands.create import CreateKeyValueCommand
from superset.key_value.models import KeyValueEntry
key = CreateKeyValueCommand(
actor=admin, resource=RESOURCE, value=VALUE, key_type="id",
).run()
entry = (
db.session.query(KeyValueEntry).filter_by(id=int(key)).autoflush(False).one()
)
key = CreateKeyValueCommand(actor=admin, resource=RESOURCE, value=VALUE).run()
entry = db.session.query(KeyValueEntry).filter_by(id=key.id).autoflush(False).one()
assert pickle.loads(entry.value) == VALUE
assert entry.created_by_fk == admin.id
db.session.delete(entry)
@ -52,11 +48,9 @@ def test_create_uuid_entry(app_context: AppContext, admin: User) -> None:
from superset.key_value.commands.create import CreateKeyValueCommand
from superset.key_value.models import KeyValueEntry
key = CreateKeyValueCommand(
actor=admin, resource=RESOURCE, value=VALUE, key_type="uuid",
).run()
key = CreateKeyValueCommand(actor=admin, resource=RESOURCE, value=VALUE).run()
entry = (
db.session.query(KeyValueEntry).filter_by(uuid=UUID(key)).autoflush(False).one()
db.session.query(KeyValueEntry).filter_by(uuid=key.uuid).autoflush(False).one()
)
assert pickle.loads(entry.value) == VALUE
assert entry.created_by_fk == admin.id

View File

@ -30,8 +30,8 @@ from tests.integration_tests.key_value.commands.fixtures import admin, RESOURCE,
if TYPE_CHECKING:
from superset.key_value.models import KeyValueEntry
ID_KEY = "234"
UUID_KEY = "5aae143c-44f1-478e-9153-ae6154df333a"
ID_KEY = 234
UUID_KEY = UUID("5aae143c-44f1-478e-9153-ae6154df333a")
@pytest.fixture
@ -39,10 +39,7 @@ def key_value_entry() -> KeyValueEntry:
from superset.key_value.models import KeyValueEntry
entry = KeyValueEntry(
id=int(ID_KEY),
uuid=UUID(UUID_KEY),
resource=RESOURCE,
value=pickle.dumps(VALUE),
id=ID_KEY, uuid=UUID_KEY, resource=RESOURCE, value=pickle.dumps(VALUE),
)
db.session.add(entry)
db.session.commit()
@ -55,10 +52,7 @@ def test_delete_id_entry(
from superset.key_value.commands.delete import DeleteKeyValueCommand
from superset.key_value.models import KeyValueEntry
assert (
DeleteKeyValueCommand(resource=RESOURCE, key=ID_KEY, key_type="id",).run()
is True
)
assert DeleteKeyValueCommand(resource=RESOURCE, key=ID_KEY).run() is True
def test_delete_uuid_entry(
@ -67,10 +61,7 @@ def test_delete_uuid_entry(
from superset.key_value.commands.delete import DeleteKeyValueCommand
from superset.key_value.models import KeyValueEntry
assert (
DeleteKeyValueCommand(resource=RESOURCE, key=UUID_KEY, key_type="uuid").run()
is True
)
assert DeleteKeyValueCommand(resource=RESOURCE, key=UUID_KEY).run() is True
def test_delete_entry_missing(
@ -79,7 +70,4 @@ def test_delete_entry_missing(
from superset.key_value.commands.delete import DeleteKeyValueCommand
from superset.key_value.models import KeyValueEntry
assert (
DeleteKeyValueCommand(resource=RESOURCE, key="456", key_type="id").run()
is False
)
assert DeleteKeyValueCommand(resource=RESOURCE, key=456).run() is False

View File

@ -26,14 +26,15 @@ from flask_appbuilder.security.sqla.models import User
from sqlalchemy.orm import Session
from superset.extensions import db
from superset.key_value.types import KeyValueResource
from tests.integration_tests.test_app import app
if TYPE_CHECKING:
from superset.key_value.models import KeyValueEntry
ID_KEY = "123"
UUID_KEY = "3e7a2ab8-bcaf-49b0-a5df-dfb432f291cc"
RESOURCE = "my_resource"
ID_KEY = 123
UUID_KEY = UUID("3e7a2ab8-bcaf-49b0-a5df-dfb432f291cc")
RESOURCE = KeyValueResource.APP
VALUE = {"foo": "bar"}
@ -42,10 +43,7 @@ def key_value_entry() -> Generator[KeyValueEntry, None, None]:
from superset.key_value.models import KeyValueEntry
entry = KeyValueEntry(
id=int(ID_KEY),
uuid=UUID(UUID_KEY),
resource=RESOURCE,
value=pickle.dumps(VALUE),
id=ID_KEY, uuid=UUID_KEY, resource=RESOURCE, value=pickle.dumps(VALUE),
)
db.session.add(entry)
db.session.commit()

View File

@ -39,7 +39,7 @@ if TYPE_CHECKING:
def test_get_id_entry(app_context: AppContext, key_value_entry: KeyValueEntry) -> None:
from superset.key_value.commands.get import GetKeyValueCommand
value = GetKeyValueCommand(resource=RESOURCE, key=ID_KEY, key_type="id").run()
value = GetKeyValueCommand(resource=RESOURCE, key=ID_KEY).run()
assert value == VALUE
@ -48,7 +48,7 @@ def test_get_uuid_entry(
) -> None:
from superset.key_value.commands.get import GetKeyValueCommand
value = GetKeyValueCommand(resource=RESOURCE, key=UUID_KEY, key_type="uuid").run()
value = GetKeyValueCommand(resource=RESOURCE, key=UUID_KEY).run()
assert value == VALUE
@ -57,7 +57,7 @@ def test_get_id_entry_missing(
) -> None:
from superset.key_value.commands.get import GetKeyValueCommand
value = GetKeyValueCommand(resource=RESOURCE, key="456", key_type="id").run()
value = GetKeyValueCommand(resource=RESOURCE, key=456).run()
assert value is None
@ -74,7 +74,7 @@ def test_get_expired_entry(app_context: AppContext) -> None:
)
db.session.add(entry)
db.session.commit()
value = GetKeyValueCommand(resource=RESOURCE, key=ID_KEY, key_type="id").run()
value = GetKeyValueCommand(resource=RESOURCE, key=ID_KEY).run()
assert value is None
db.session.delete(entry)
db.session.commit()
@ -94,7 +94,7 @@ def test_get_future_expiring_entry(app_context: AppContext) -> None:
)
db.session.add(entry)
db.session.commit()
value = GetKeyValueCommand(resource=RESOURCE, key=str(id_), key_type="id").run()
value = GetKeyValueCommand(resource=RESOURCE, key=id_).run()
assert value == VALUE
db.session.delete(entry)
db.session.commit()

View File

@ -46,12 +46,10 @@ def test_update_id_entry(
from superset.key_value.models import KeyValueEntry
key = UpdateKeyValueCommand(
actor=admin, resource=RESOURCE, key=ID_KEY, value=NEW_VALUE, key_type="id",
actor=admin, resource=RESOURCE, key=ID_KEY, value=NEW_VALUE,
).run()
assert key == ID_KEY
entry = (
db.session.query(KeyValueEntry).filter_by(id=int(ID_KEY)).autoflush(False).one()
)
assert key.id == ID_KEY
entry = db.session.query(KeyValueEntry).filter_by(id=ID_KEY).autoflush(False).one()
assert pickle.loads(entry.value) == NEW_VALUE
assert entry.changed_by_fk == admin.id
@ -63,25 +61,20 @@ def test_update_uuid_entry(
from superset.key_value.models import KeyValueEntry
key = UpdateKeyValueCommand(
actor=admin, resource=RESOURCE, key=UUID_KEY, value=NEW_VALUE, key_type="uuid",
actor=admin, resource=RESOURCE, key=UUID_KEY, value=NEW_VALUE,
).run()
assert key == UUID_KEY
assert key.uuid == UUID_KEY
entry = (
db.session.query(KeyValueEntry)
.filter_by(uuid=UUID(UUID_KEY))
.autoflush(False)
.one()
db.session.query(KeyValueEntry).filter_by(uuid=UUID_KEY).autoflush(False).one()
)
assert pickle.loads(entry.value) == NEW_VALUE
assert entry.changed_by_fk == admin.id
def test_update_missing_entry(
app_context: AppContext, admin: User, key_value_entry: KeyValueEntry,
) -> None:
def test_update_missing_entry(app_context: AppContext, admin: User) -> None:
from superset.key_value.commands.update import UpdateKeyValueCommand
key = UpdateKeyValueCommand(
actor=admin, resource=RESOURCE, key="456", value=NEW_VALUE, key_type="id",
actor=admin, resource=RESOURCE, key=456, value=NEW_VALUE,
).run()
assert key is None

View File

@ -46,9 +46,9 @@ def test_upsert_id_entry(
from superset.key_value.models import KeyValueEntry
key = UpsertKeyValueCommand(
actor=admin, resource=RESOURCE, key=ID_KEY, value=NEW_VALUE, key_type="id",
actor=admin, resource=RESOURCE, key=ID_KEY, value=NEW_VALUE,
).run()
assert key == ID_KEY
assert key.id == ID_KEY
entry = (
db.session.query(KeyValueEntry).filter_by(id=int(ID_KEY)).autoflush(False).one()
)
@ -63,28 +63,23 @@ def test_upsert_uuid_entry(
from superset.key_value.models import KeyValueEntry
key = UpsertKeyValueCommand(
actor=admin, resource=RESOURCE, key=UUID_KEY, value=NEW_VALUE, key_type="uuid",
actor=admin, resource=RESOURCE, key=UUID_KEY, value=NEW_VALUE,
).run()
assert key == UUID_KEY
assert key.uuid == UUID_KEY
entry = (
db.session.query(KeyValueEntry)
.filter_by(uuid=UUID(UUID_KEY))
.autoflush(False)
.one()
db.session.query(KeyValueEntry).filter_by(uuid=UUID_KEY).autoflush(False).one()
)
assert pickle.loads(entry.value) == NEW_VALUE
assert entry.changed_by_fk == admin.id
def test_upsert_missing_entry(
app_context: AppContext, admin: User, key_value_entry: KeyValueEntry,
) -> None:
def test_upsert_missing_entry(app_context: AppContext, admin: User) -> None:
from superset.key_value.commands.upsert import UpsertKeyValueCommand
from superset.key_value.models import KeyValueEntry
key = UpsertKeyValueCommand(
actor=admin, resource=RESOURCE, key="456", value=NEW_VALUE, key_type="id",
actor=admin, resource=RESOURCE, key=456, value=NEW_VALUE,
).run()
assert key == "456"
assert key.id == 456
db.session.query(KeyValueEntry).filter_by(id=456).delete()
db.session.commit()

View File

@ -16,102 +16,45 @@
# under the License.
from __future__ import annotations
import json
from typing import TYPE_CHECKING
from unittest.mock import patch
from uuid import UUID
if TYPE_CHECKING:
from superset.key_value.models import KeyValueEntry
import pytest
from flask.ctx import AppContext
from superset.key_value.types import Key
from superset.key_value.exceptions import KeyValueParseKeyError
from superset.key_value.types import KeyValueResource
RESOURCE = "my-resource"
UUID_KEY = "3e7a2ab8-bcaf-49b0-a5df-dfb432f291cc"
ID_KEY = "123"
RESOURCE = KeyValueResource.APP
UUID_KEY = UUID("3e7a2ab8-bcaf-49b0-a5df-dfb432f291cc")
ID_KEY = 123
@pytest.fixture
def key_value_entry(app_context: AppContext):
from superset.key_value.models import KeyValueEntry
return KeyValueEntry(
id=int(ID_KEY), uuid=UUID(UUID_KEY), value=json.dumps({"foo": "bar"}),
)
def test_parse_permalink_key_uuid_valid(app_context: AppContext) -> None:
from superset.key_value.utils import parse_permalink_key
assert parse_permalink_key(UUID_KEY) == Key(id=None, uuid=UUID(UUID_KEY))
def test_parse_permalink_key_id_invalid(app_context: AppContext) -> None:
from superset.key_value.utils import parse_permalink_key
with pytest.raises(ValueError):
parse_permalink_key(ID_KEY)
@patch("superset.key_value.utils.current_app.config", {"PERMALINK_KEY_TYPE": "id"})
def test_parse_permalink_key_id_valid(app_context: AppContext) -> None:
from superset.key_value.utils import parse_permalink_key
assert parse_permalink_key(ID_KEY) == Key(id=int(ID_KEY), uuid=None)
@patch("superset.key_value.utils.current_app.config", {"PERMALINK_KEY_TYPE": "id"})
def test_parse_permalink_key_uuid_invalid(app_context: AppContext) -> None:
from superset.key_value.utils import parse_permalink_key
with pytest.raises(ValueError):
parse_permalink_key(UUID_KEY)
def test_format_permalink_key_uuid(app_context: AppContext) -> None:
from superset.key_value.utils import format_permalink_key
assert format_permalink_key(Key(id=None, uuid=UUID(UUID_KEY))) == UUID_KEY
def test_format_permalink_key_id(app_context: AppContext) -> None:
from superset.key_value.utils import format_permalink_key
assert format_permalink_key(Key(id=int(ID_KEY), uuid=None)) == ID_KEY
def test_extract_key_uuid(
app_context: AppContext, key_value_entry: KeyValueEntry,
) -> None:
from superset.key_value.utils import extract_key
assert extract_key(key_value_entry, "id") == ID_KEY
def test_extract_key_id(
app_context: AppContext, key_value_entry: KeyValueEntry,
) -> None:
from superset.key_value.utils import extract_key
assert extract_key(key_value_entry, "uuid") == UUID_KEY
def test_get_filter_uuid(app_context: AppContext,) -> None:
def test_get_filter_uuid() -> None:
from superset.key_value.utils import get_filter
assert get_filter(resource=RESOURCE, key=UUID_KEY, key_type="uuid",) == {
assert get_filter(resource=RESOURCE, key=UUID_KEY) == {
"resource": RESOURCE,
"uuid": UUID(UUID_KEY),
"uuid": UUID_KEY,
}
def test_get_filter_id(app_context: AppContext,) -> None:
def test_get_filter_id() -> None:
from superset.key_value.utils import get_filter
assert get_filter(resource=RESOURCE, key=ID_KEY, key_type="id",) == {
assert get_filter(resource=RESOURCE, key=ID_KEY) == {
"resource": RESOURCE,
"id": int(ID_KEY),
"id": ID_KEY,
}
def test_encode_permalink_id_valid() -> None:
from superset.key_value.utils import encode_permalink_key
salt = "abc"
assert encode_permalink_key(1, salt) == "AyBn4lm9qG8"
def test_decode_permalink_id_invalid() -> None:
from superset.key_value.utils import decode_permalink_id
with pytest.raises(KeyValueParseKeyError):
decode_permalink_id("foo", "bar")