mirror of https://github.com/apache/superset.git
fix: adds the ability to disallow SQL functions per engine (#28639)
This commit is contained in:
parent
6575cacc5d
commit
5dfbab5424
|
@ -1227,6 +1227,15 @@ DB_CONNECTION_MUTATOR = None
|
|||
#
|
||||
DB_SQLA_URI_VALIDATOR: Callable[[URL], None] | None = None
|
||||
|
||||
# A set of disallowed SQL functions per engine. This is used to restrict the use of
|
||||
# unsafe SQL functions in SQL Lab and Charts. The keys of the dictionary are the engine
|
||||
# names, and the values are sets of disallowed functions.
|
||||
DISALLOWED_SQL_FUNCTIONS: dict[str, set[str]] = {
|
||||
"postgresql": {"version", "query_to_xml", "inet_server_addr", "inet_client_addr"},
|
||||
"clickhouse": {"url"},
|
||||
"mysql": {"version"},
|
||||
}
|
||||
|
||||
|
||||
# A function that intercepts the SQL to be executed and can alter it.
|
||||
# A common use case for this is around adding some sort of comment header to the SQL
|
||||
|
|
|
@ -62,7 +62,7 @@ from superset import sql_parse
|
|||
from superset.constants import TimeGrain as TimeGrainConstants
|
||||
from superset.databases.utils import get_table_metadata, make_url_safe
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import OAuth2Error, OAuth2RedirectError
|
||||
from superset.exceptions import DisallowedSQLFunction, OAuth2Error, OAuth2RedirectError
|
||||
from superset.sql_parse import ParsedQuery, SQLScript, Table
|
||||
from superset.superset_typing import (
|
||||
OAuth2ClientConfig,
|
||||
|
@ -1818,6 +1818,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
"""
|
||||
if not cls.allows_sql_comments:
|
||||
query = sql_parse.strip_comments_from_sql(query, engine=cls.engine)
|
||||
disallowed_functions = current_app.config["DISALLOWED_SQL_FUNCTIONS"].get(
|
||||
cls.engine, set()
|
||||
)
|
||||
if sql_parse.check_sql_functions_exist(query, disallowed_functions, cls.engine):
|
||||
raise DisallowedSQLFunction(disallowed_functions)
|
||||
|
||||
if cls.arraysize:
|
||||
cursor.arraysize = cls.arraysize
|
||||
|
|
|
@ -22,7 +22,7 @@ import threading
|
|||
import time
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
from flask import current_app
|
||||
from flask import current_app, Flask
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.engine.url import URL
|
||||
from sqlalchemy.exc import NoSuchTableError
|
||||
|
@ -218,11 +218,14 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
|
|||
execute_result: dict[str, Any] = {}
|
||||
execute_event = threading.Event()
|
||||
|
||||
def _execute(results: dict[str, Any], event: threading.Event) -> None:
|
||||
def _execute(
|
||||
results: dict[str, Any], event: threading.Event, app: Flask
|
||||
) -> None:
|
||||
logger.debug("Query %d: Running query: %s", query_id, sql)
|
||||
|
||||
try:
|
||||
cls.execute(cursor, sql, query.database)
|
||||
with app.app_context():
|
||||
cls.execute(cursor, sql, query.database)
|
||||
except Exception as ex: # pylint: disable=broad-except
|
||||
results["error"] = ex
|
||||
finally:
|
||||
|
@ -230,7 +233,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
|
|||
|
||||
execute_thread = threading.Thread(
|
||||
target=_execute,
|
||||
args=(execute_result, execute_event),
|
||||
args=(execute_result, execute_event, current_app._get_current_object()), # pylint: disable=protected-access
|
||||
)
|
||||
execute_thread.start()
|
||||
|
||||
|
|
|
@ -358,6 +358,21 @@ class OAuth2Error(SupersetErrorException):
|
|||
)
|
||||
|
||||
|
||||
class DisallowedSQLFunction(SupersetErrorException):
|
||||
"""
|
||||
Disallowed function found on SQL statement
|
||||
"""
|
||||
|
||||
def __init__(self, functions: set[str]):
|
||||
super().__init__(
|
||||
SupersetError(
|
||||
message=f"SQL statement contains disallowed function(s): {functions}",
|
||||
error_type=SupersetErrorType.SYNTAX_ERROR,
|
||||
level=ErrorLevel.ERROR,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class CreateKeyValueDistributedLockFailedException(Exception):
|
||||
"""
|
||||
Exception to signalize failure to acquire lock.
|
||||
|
|
|
@ -39,6 +39,7 @@ from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
|
|||
from sqlparse import keywords
|
||||
from sqlparse.lexer import Lexer
|
||||
from sqlparse.sql import (
|
||||
Function,
|
||||
Identifier,
|
||||
IdentifierList,
|
||||
Parenthesis,
|
||||
|
@ -223,6 +224,19 @@ def get_cte_remainder_query(sql: str) -> tuple[str | None, str]:
|
|||
return cte, remainder
|
||||
|
||||
|
||||
def check_sql_functions_exist(
|
||||
sql: str, function_list: set[str], engine: str | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the SQL statement contains any of the specified functions.
|
||||
|
||||
:param sql: The SQL statement
|
||||
:param function_list: The list of functions to search for
|
||||
:param engine: The engine to use for parsing the SQL statement
|
||||
"""
|
||||
return ParsedQuery(sql, engine=engine).check_functions_exist(function_list)
|
||||
|
||||
|
||||
def strip_comments_from_sql(statement: str, engine: str | None = None) -> str:
|
||||
"""
|
||||
Strips comments from a SQL statement, does a simple test first
|
||||
|
@ -743,6 +757,34 @@ class ParsedQuery:
|
|||
self._tables = self._extract_tables_from_sql()
|
||||
return self._tables
|
||||
|
||||
def _check_functions_exist_in_token(
|
||||
self, token: Token, functions: set[str]
|
||||
) -> bool:
|
||||
if (
|
||||
isinstance(token, Function)
|
||||
and token.get_name() is not None
|
||||
and token.get_name().lower() in functions
|
||||
):
|
||||
return True
|
||||
if hasattr(token, "tokens"):
|
||||
for inner_token in token.tokens:
|
||||
if self._check_functions_exist_in_token(inner_token, functions):
|
||||
return True
|
||||
return False
|
||||
|
||||
def check_functions_exist(self, functions: set[str]) -> bool:
|
||||
"""
|
||||
Check if the SQL statement contains any of the specified functions.
|
||||
|
||||
:param functions: A set of functions to search for
|
||||
:return: True if the statement contains any of the specified functions
|
||||
"""
|
||||
for statement in self._parsed:
|
||||
for token in statement.tokens:
|
||||
if self._check_functions_exist_in_token(token, functions):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _extract_tables_from_sql(self) -> set[Table]:
|
||||
"""
|
||||
Extract all table references in a query.
|
||||
|
|
|
@ -401,7 +401,7 @@ def test_handle_cursor_early_cancel(
|
|||
assert cancel_query_mock.call_args is None
|
||||
|
||||
|
||||
def test_execute_with_cursor_in_parallel(mocker: MockerFixture):
|
||||
def test_execute_with_cursor_in_parallel(app, mocker: MockerFixture):
|
||||
"""Test that `execute_with_cursor` fetches query ID from the cursor"""
|
||||
from superset.db_engine_specs.trino import TrinoEngineSpec
|
||||
|
||||
|
@ -416,16 +416,20 @@ def test_execute_with_cursor_in_parallel(mocker: MockerFixture):
|
|||
mock_cursor.query_id = query_id
|
||||
|
||||
mock_cursor.execute.side_effect = _mock_execute
|
||||
with patch.dict(
|
||||
"superset.config.DISALLOWED_SQL_FUNCTIONS",
|
||||
{},
|
||||
clear=True,
|
||||
):
|
||||
TrinoEngineSpec.execute_with_cursor(
|
||||
cursor=mock_cursor,
|
||||
sql="SELECT 1 FROM foo",
|
||||
query=mock_query,
|
||||
)
|
||||
|
||||
TrinoEngineSpec.execute_with_cursor(
|
||||
cursor=mock_cursor,
|
||||
sql="SELECT 1 FROM foo",
|
||||
query=mock_query,
|
||||
)
|
||||
|
||||
mock_query.set_extra_json_key.assert_called_once_with(
|
||||
key=QUERY_CANCEL_KEY, value=query_id
|
||||
)
|
||||
mock_query.set_extra_json_key.assert_called_once_with(
|
||||
key=QUERY_CANCEL_KEY, value=query_id
|
||||
)
|
||||
|
||||
|
||||
def test_get_columns(mocker: MockerFixture):
|
||||
|
|
|
@ -32,6 +32,7 @@ from superset.exceptions import (
|
|||
)
|
||||
from superset.sql_parse import (
|
||||
add_table_name,
|
||||
check_sql_functions_exist,
|
||||
extract_table_references,
|
||||
extract_tables_from_jinja_sql,
|
||||
get_rls_for_table,
|
||||
|
@ -1215,6 +1216,31 @@ def test_strip_comments_from_sql() -> None:
|
|||
)
|
||||
|
||||
|
||||
def test_check_sql_functions_exist() -> None:
|
||||
"""
|
||||
Test that comments are stripped out correctly.
|
||||
"""
|
||||
assert not (
|
||||
check_sql_functions_exist("select a, b from version", {"version"}, "postgresql")
|
||||
)
|
||||
|
||||
assert check_sql_functions_exist("select version()", {"version"}, "postgresql")
|
||||
|
||||
assert check_sql_functions_exist(
|
||||
"select version from version()", {"version"}, "postgresql"
|
||||
)
|
||||
|
||||
assert check_sql_functions_exist(
|
||||
"select 1, a.version from (select version from version()) as a",
|
||||
{"version"},
|
||||
"postgresql",
|
||||
)
|
||||
|
||||
assert check_sql_functions_exist(
|
||||
"select 1, a.version from (select version()) as a", {"version"}, "postgresql"
|
||||
)
|
||||
|
||||
|
||||
def test_sanitize_clause_valid():
|
||||
# regular clauses
|
||||
assert sanitize_clause("col = 1") == "col = 1"
|
||||
|
|
Loading…
Reference in New Issue