mirror of https://github.com/apache/superset.git
feat: add decorator to guard public APIs (#12635)
* feat: add decorator to guard public APIs * Add unit tests * Refactor to use unit tests * Paramterize tests * Remove decorator
This commit is contained in:
parent
b1c203b7d9
commit
4255c22d01
|
@ -856,7 +856,7 @@ DB_CONNECTION_MUTATOR = None
|
|||
# The use case is can be around adding some sort of comment header
|
||||
# with information such as the username and worker node information
|
||||
#
|
||||
# def SQL_QUERY_MUTATOR(sql, username, security_manager):
|
||||
# def SQL_QUERY_MUTATOR(sql, user_name, security_manager, database):
|
||||
# dttm = datetime.now().isoformat()
|
||||
# return f"-- [SQL LAB] {username} {dttm}\n{sql}"
|
||||
SQL_QUERY_MUTATOR = None
|
||||
|
|
|
@ -29,11 +29,13 @@ from celery.exceptions import SoftTimeLimitExceeded
|
|||
from celery.task.base import Task
|
||||
from flask_babel import lazy_gettext as _
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from superset import app, results_backend, results_backend_use_msgpack, security_manager
|
||||
from superset.dataframe import df_to_records
|
||||
from superset.db_engine_specs import BaseEngineSpec
|
||||
from superset.extensions import celery_app
|
||||
from superset.models.core import Database
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.result_set import SupersetResultSet
|
||||
from superset.sql_parse import CtasMethod, ParsedQuery
|
||||
|
@ -47,13 +49,25 @@ from superset.utils.core import (
|
|||
from superset.utils.dates import now_as_float
|
||||
from superset.utils.decorators import stats_timing
|
||||
|
||||
|
||||
# pylint: disable=unused-argument, redefined-outer-name
|
||||
def dummy_sql_query_mutator(
|
||||
sql: str,
|
||||
user_name: Optional[str],
|
||||
security_manager: LocalProxy,
|
||||
database: Database,
|
||||
) -> str:
|
||||
"""A no-op version of SQL_QUERY_MUTATOR"""
|
||||
return sql
|
||||
|
||||
|
||||
config = app.config
|
||||
stats_logger = config["STATS_LOGGER"]
|
||||
SQLLAB_TIMEOUT = config["SQLLAB_ASYNC_TIME_LIMIT_SEC"]
|
||||
SQLLAB_HARD_TIMEOUT = SQLLAB_TIMEOUT + 60
|
||||
SQL_MAX_ROW = config["SQL_MAX_ROW"]
|
||||
SQLLAB_CTAS_NO_LIMIT = config["SQLLAB_CTAS_NO_LIMIT"]
|
||||
SQL_QUERY_MUTATOR = config["SQL_QUERY_MUTATOR"]
|
||||
SQL_QUERY_MUTATOR = config.get("SQL_QUERY_MUTATOR") or dummy_sql_query_mutator
|
||||
log_query = config["QUERY_LOGGER"]
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -195,8 +209,7 @@ def execute_sql_statement(
|
|||
sql = database.apply_limit_to_sql(sql, query.limit)
|
||||
|
||||
# Hook to allow environment-specific mutation (usually comments) to the SQL
|
||||
if SQL_QUERY_MUTATOR:
|
||||
sql = SQL_QUERY_MUTATOR(sql, user_name, security_manager, database)
|
||||
sql = SQL_QUERY_MUTATOR(sql, user_name, security_manager, database)
|
||||
|
||||
try:
|
||||
if log_query:
|
||||
|
|
|
@ -0,0 +1,74 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from base64 import b85encode
|
||||
from hashlib import md5
|
||||
from inspect import (
|
||||
getmembers,
|
||||
getsourcefile,
|
||||
getsourcelines,
|
||||
isclass,
|
||||
isfunction,
|
||||
isroutine,
|
||||
signature,
|
||||
)
|
||||
from textwrap import indent
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
def compute_hash(obj: Callable[..., Any]) -> str:
|
||||
if isfunction(obj):
|
||||
return compute_func_hash(obj)
|
||||
|
||||
if isclass(obj):
|
||||
return compute_class_hash(obj)
|
||||
|
||||
raise Exception(f"Invalid object: {obj}")
|
||||
|
||||
|
||||
def compute_func_hash(function: Callable[..., Any]) -> str:
|
||||
hashed = md5()
|
||||
hashed.update(str(signature(function)).encode())
|
||||
return b85encode(hashed.digest()).decode("utf-8")
|
||||
|
||||
|
||||
def compute_class_hash(class_: Callable[..., Any]) -> str:
|
||||
hashed = md5()
|
||||
public_methods = sorted(
|
||||
[
|
||||
(name, method)
|
||||
for name, method in getmembers(class_, predicate=isroutine)
|
||||
if not name.startswith("_") or name == "__init__"
|
||||
]
|
||||
)
|
||||
for name, method in public_methods:
|
||||
hashed.update(name.encode())
|
||||
hashed.update(str(signature(method)).encode())
|
||||
return b85encode(hashed.digest()).decode("utf-8")
|
||||
|
||||
|
||||
def get_warning_message(obj: Callable[..., Any], expected_hash: str) -> str:
|
||||
sourcefile = getsourcefile(obj)
|
||||
sourcelines = getsourcelines(obj)
|
||||
code = indent("".join(sourcelines[0]), " ")
|
||||
lineno = sourcelines[1]
|
||||
return (
|
||||
f"The object `{obj.__name__}` (in {sourcefile} "
|
||||
f"line {lineno}) has a public interface which has currently been "
|
||||
"modified. This MUST only be released in a new major version of "
|
||||
"Superset according to SIP-57. To remove this warning message "
|
||||
f"update the associated hash to '{expected_hash}'.\n\n{code}"
|
||||
)
|
|
@ -0,0 +1,106 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# pylint: disable=no-self-use
|
||||
import pytest
|
||||
|
||||
from superset.sql_lab import dummy_sql_query_mutator
|
||||
from superset.utils.public_interfaces import compute_hash, get_warning_message
|
||||
from tests.base_tests import SupersetTestCase
|
||||
|
||||
# These are public interfaces exposed by Superset. Make sure
|
||||
# to only change the interfaces and update the hashes in new
|
||||
# major versions of Superset.
|
||||
hashes = {
|
||||
dummy_sql_query_mutator: "Kv%NM3b;7BcpoD2wbPkW",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("interface,expected_hash", list(hashes.items()))
|
||||
def test_public_interfaces(interface, expected_hash):
|
||||
"""Test that public interfaces have not been accidentally changed."""
|
||||
current_hash = compute_hash(interface)
|
||||
assert current_hash == expected_hash, get_warning_message(interface, current_hash)
|
||||
|
||||
|
||||
def test_func_hash():
|
||||
"""Test that changing a function signature changes its hash."""
|
||||
|
||||
def some_function(a, b):
|
||||
return a + b
|
||||
|
||||
original_hash = compute_hash(some_function)
|
||||
|
||||
# pylint: disable=function-redefined
|
||||
def some_function(a, b, c):
|
||||
return a + b + c
|
||||
|
||||
assert original_hash != compute_hash(some_function)
|
||||
|
||||
|
||||
def test_class_hash():
|
||||
"""Test that changing a class changes its hash."""
|
||||
|
||||
# pylint: disable=too-few-public-methods, invalid-name
|
||||
class SomeClass:
|
||||
def __init__(self, a, b):
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
def add(self):
|
||||
return self.a + self.b
|
||||
|
||||
original_hash = compute_hash(SomeClass)
|
||||
|
||||
# changing the __init__ should change the hash
|
||||
# pylint: disable=function-redefined, too-few-public-methods, invalid-name
|
||||
class SomeClass:
|
||||
def __init__(self, a, b, c):
|
||||
self.a = a
|
||||
self.b = b
|
||||
self.c = c
|
||||
|
||||
def add(self):
|
||||
return self.a + self.b
|
||||
|
||||
assert original_hash != compute_hash(SomeClass)
|
||||
|
||||
# renaming a public method should change the hash
|
||||
# pylint: disable=function-redefined, too-few-public-methods, invalid-name
|
||||
class SomeClass:
|
||||
def __init__(self, a, b):
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
def sum(self):
|
||||
return self.a + self.b
|
||||
|
||||
assert original_hash != compute_hash(SomeClass)
|
||||
|
||||
# adding a private method should not change the hash
|
||||
# pylint: disable=function-redefined, too-few-public-methods, invalid-name
|
||||
class SomeClass:
|
||||
def __init__(self, a, b):
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
def add(self):
|
||||
return self._sum()
|
||||
|
||||
def _sum(self):
|
||||
return self.a + self.b
|
||||
|
||||
assert original_hash == compute_hash(SomeClass)
|
Loading…
Reference in New Issue