fix: Cleanup serialization and hashing code (#14317)

This commit is contained in:
Ben Reinhart 2021-04-26 14:04:40 -07:00 committed by GitHub
parent ebc938059b
commit 2a1235c0c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 124 additions and 38 deletions

View File

@ -15,12 +15,10 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=R
import hashlib
import logging
from datetime import datetime, timedelta
from typing import Any, Dict, List, NamedTuple, Optional, Union
import simplejson as json
from flask_babel import gettext as _
from pandas import DataFrame
@ -40,6 +38,7 @@ from superset.utils.core import (
json_int_dttm_ser,
)
from superset.utils.date_parser import get_since_until, parse_human_timedelta
from superset.utils.hashing import md5_sha_from_dict
from superset.views.utils import get_time_range_endpoints
config = app.config
@ -333,14 +332,7 @@ class QueryObject:
if annotation_layers:
cache_dict["annotation_layers"] = annotation_layers
json_data = self.json_dumps(cache_dict, sort_keys=True)
return hashlib.md5(json_data.encode("utf-8")).hexdigest()
@staticmethod
def json_dumps(obj: Any, sort_keys: bool = False) -> str:
return json.dumps(
obj, default=json_int_dttm_ser, ignore_nan=True, sort_keys=sort_keys
)
return md5_sha_from_dict(cache_dict, default=json_int_dttm_ser, ignore_nan=True)
def exec_post_processing(self, df: DataFrame) -> DataFrame:
"""

View File

@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument
import hashlib
import json
import logging
import re
@ -63,6 +62,7 @@ from superset.models.sql_types.base import literal_dttm_type_factory
from superset.sql_parse import ParsedQuery, Table
from superset.utils import core as utils
from superset.utils.core import ColumnSpec, GenericDataType
from superset.utils.hashing import md5_sha_from_str
if TYPE_CHECKING:
# prevent circular imports
@ -1145,7 +1145,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
:param label: Expected expression label
:return: Truncated label
"""
label = hashlib.md5(label.encode("utf-8")).hexdigest()
label = md5_sha_from_str(label)
# truncate hash if it exceeds max length
if cls.max_column_name_length and len(label) > cls.max_column_name_length:
label = label[: cls.max_column_name_length]

View File

@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import hashlib
import re
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
@ -28,6 +27,7 @@ from superset.db_engine_specs.base import BaseEngineSpec
from superset.errors import SupersetErrorType
from superset.sql_parse import Table
from superset.utils import core as utils
from superset.utils.hashing import md5_sha_from_str
if TYPE_CHECKING:
from superset.models.core import Database # pragma: no cover
@ -141,7 +141,7 @@ class BigQueryEngineSpec(BaseEngineSpec):
:param label: Expected expression label
:return: Conditionally mutated label
"""
label_hashed = "_" + hashlib.md5(label.encode("utf-8")).hexdigest()
label_hashed = "_" + md5_sha_from_str(label)
# if label starts with number, add underscore as first character
label_mutated = "_" + label if re.match(r"^\d", label) else label
@ -163,7 +163,7 @@ class BigQueryEngineSpec(BaseEngineSpec):
:param label: expected expression label
:return: truncated label
"""
return "_" + hashlib.md5(label.encode("utf-8")).hexdigest()
return "_" + md5_sha_from_str(label)
@classmethod
def normalize_indexes(cls, indexes: List[Dict[str, Any]]) -> List[Dict[str, Any]]:

View File

@ -56,6 +56,7 @@ from superset.models.user_attributes import UserAttribute
from superset.tasks.thumbnails import cache_dashboard_thumbnail
from superset.utils import core as utils
from superset.utils.decorators import debounce
from superset.utils.hashing import md5_sha_from_str
from superset.utils.urls import get_url_path
# pylint: disable=too-many-public-methods
@ -199,7 +200,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
Returns a MD5 HEX digest that makes this dashboard unique
"""
unique_string = f"{self.position_json}.{self.css}.{self.json_metadata}"
return utils.md5_hex(unique_string)
return md5_sha_from_str(unique_string)
@property
def thumbnail_url(self) -> str:

View File

@ -34,6 +34,7 @@ from superset.models.helpers import AuditMixinNullable, ImportExportMixin
from superset.models.tags import ChartUpdater
from superset.tasks.thumbnails import cache_chart_thumbnail
from superset.utils import core as utils
from superset.utils.hashing import md5_sha_from_str
from superset.utils.urls import get_url_path
from superset.viz import BaseViz, viz_types # type: ignore
@ -202,7 +203,7 @@ class Slice(
"""
Returns a MD5 HEX digest that makes this dashboard unique
"""
return utils.md5_hex(self.params or "")
return md5_sha_from_str(self.params or "")
@property
def thumbnail_url(self) -> str:

View File

@ -14,8 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import hashlib
import json
import logging
from datetime import datetime, timedelta
from functools import wraps
@ -30,20 +28,15 @@ from superset.extensions import cache_manager
from superset.models.cache import CacheKey
from superset.stats_logger import BaseStatsLogger
from superset.utils.core import json_int_dttm_ser
from superset.utils.hashing import md5_sha_from_dict
config = app.config # type: ignore
stats_logger: BaseStatsLogger = config["STATS_LOGGER"]
logger = logging.getLogger(__name__)
# TODO: DRY up cache key code
def json_dumps(obj: Any, sort_keys: bool = False) -> str:
return json.dumps(obj, default=json_int_dttm_ser, sort_keys=sort_keys)
def generate_cache_key(values_dict: Dict[str, Any], key_prefix: str = "") -> str:
json_data = json_dumps(values_dict, sort_keys=True)
hash_str = hashlib.md5(json_data.encode("utf-8")).hexdigest()
hash_str = md5_sha_from_dict(values_dict, default=json_int_dttm_ser)
return f"{key_prefix}{hash_str}"

View File

@ -19,7 +19,6 @@ import collections
import decimal
import errno
import functools
import hashlib
import json
import logging
import os
@ -99,6 +98,7 @@ from superset.exceptions import (
)
from superset.typing import FlaskResponse, FormData, Metric
from superset.utils.dates import datetime_to_epoch, EPOCH
from superset.utils.hashing import md5_sha_from_str
try:
from pydruid.utils.having import Having
@ -484,10 +484,6 @@ def list_minus(l: List[Any], minus: List[Any]) -> List[Any]:
return [o for o in l if o not in minus]
def md5_hex(data: str) -> str:
return hashlib.md5(data.encode()).hexdigest()
class DashboardEncoder(json.JSONEncoder):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
@ -1381,7 +1377,7 @@ def create_ssl_cert_file(certificate: str) -> str:
:return: The path to the certificate file
:raises CertificateException: If certificate is not valid/unparseable
"""
filename = f"{hashlib.md5(certificate.encode('utf-8')).hexdigest()}.crt"
filename = f"{md5_sha_from_str(certificate)}.crt"
cert_dir = current_app.config["SSL_CERT_PATH"]
path = cert_dir if cert_dir else tempfile.gettempdir()
path = os.path.join(path, filename)

View File

@ -15,14 +15,20 @@
# specific language governing permissions and limitations
# under the License.
import hashlib
import json
from typing import Any, Dict
from typing import Any, Callable, Dict, Optional
import simplejson as json
def md5_sha_from_str(val: str) -> str:
return hashlib.md5(val.encode("utf-8")).hexdigest()
def md5_sha_from_dict(opts: Dict[Any, Any]) -> str:
json_data = json.dumps(opts, sort_keys=True)
def md5_sha_from_dict(
obj: Dict[Any, Any],
ignore_nan: bool = False,
default: Optional[Callable[[Any], Any]] = None,
) -> str:
json_data = json.dumps(obj, sort_keys=True, ignore_nan=ignore_nan, default=default)
return md5_sha_from_str(json_data)

View File

@ -0,0 +1,97 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-self-use
import datetime
import math
from typing import Any
import pytest
from superset.utils.hashing import md5_sha_from_dict, md5_sha_from_str
def test_basic_md5_sha():
obj = {
"product": "Coffee",
"company": "Gobias Industries",
"price_in_cents": 4000,
}
serialized_obj = (
'{"company": "Gobias Industries", "price_in_cents": 4000, "product": "Coffee"}'
)
assert md5_sha_from_str(serialized_obj) == md5_sha_from_dict(obj)
assert md5_sha_from_str(serialized_obj) == "35f22273cd6a6798b04f8ddef51135e3"
def test_sort_order_md5_sha():
obj_1 = {
"product": "Coffee",
"price_in_cents": 4000,
"company": "Gobias Industries",
}
obj_2 = {
"product": "Coffee",
"company": "Gobias Industries",
"price_in_cents": 4000,
}
assert md5_sha_from_dict(obj_1) == md5_sha_from_dict(obj_2)
assert md5_sha_from_dict(obj_1) == "35f22273cd6a6798b04f8ddef51135e3"
def test_custom_default_md5_sha():
def custom_datetime_serializer(obj: Any):
if isinstance(obj, datetime.datetime):
return "<datetime>"
obj = {
"product": "Coffee",
"company": "Gobias Industries",
"datetime": datetime.datetime.now(),
}
serialized_obj = '{"company": "Gobias Industries", "datetime": "<datetime>", "product": "Coffee"}'
assert md5_sha_from_str(serialized_obj) == md5_sha_from_dict(
obj, default=custom_datetime_serializer
)
assert md5_sha_from_str(serialized_obj) == "dc280121213aabcaeb8087aef268fd0d"
def test_ignore_nan_md5_sha():
obj = {
"product": "Coffee",
"company": "Gobias Industries",
"price": math.nan,
}
serialized_obj = (
'{"company": "Gobias Industries", "price": NaN, "product": "Coffee"}'
)
assert md5_sha_from_str(serialized_obj) == md5_sha_from_dict(obj)
assert md5_sha_from_str(serialized_obj) == "5d129d1dffebc0bacc734366476d586d"
serialized_obj = (
'{"company": "Gobias Industries", "price": null, "product": "Coffee"}'
)
assert md5_sha_from_str(serialized_obj) == md5_sha_from_dict(obj, ignore_nan=True)
assert md5_sha_from_str(serialized_obj) == "40e87d61f6add03816bccdeac5713b9f"

View File

@ -19,7 +19,6 @@ import unittest
import uuid
from datetime import date, datetime, time, timedelta
from decimal import Decimal
import hashlib
import json
import os
import re
@ -71,6 +70,7 @@ from superset.utils.core import (
zlib_decompress,
)
from superset.utils import schema
from superset.utils.hashing import md5_sha_from_str
from superset.views.utils import (
build_extra_filters,
get_form_data,
@ -960,7 +960,7 @@ class TestUtils(SupersetTestCase):
def test_ssl_certificate_file_creation(self):
path = create_ssl_cert_file(ssl_certificate)
expected_filename = hashlib.md5(ssl_certificate.encode("utf-8")).hexdigest()
expected_filename = md5_sha_from_str(ssl_certificate)
self.assertIn(expected_filename, path)
self.assertTrue(os.path.exists(path))