fix: adds the ability to disallow SQL functions per engine (#28639)

This commit is contained in:
Daniel Vaz Gaspar 2024-05-29 10:51:28 +01:00 committed by GitHub
parent 6575cacc5d
commit 5dfbab5424
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 119 additions and 15 deletions

View File

@ -1227,6 +1227,15 @@ DB_CONNECTION_MUTATOR = None
#
DB_SQLA_URI_VALIDATOR: Callable[[URL], None] | None = None
# A set of disallowed SQL functions per engine. This is used to restrict the use of
# unsafe SQL functions in SQL Lab and Charts. The keys of the dictionary are the engine
# names, and the values are sets of disallowed functions.
DISALLOWED_SQL_FUNCTIONS: dict[str, set[str]] = {
"postgresql": {"version", "query_to_xml", "inet_server_addr", "inet_client_addr"},
"clickhouse": {"url"},
"mysql": {"version"},
}
# A function that intercepts the SQL to be executed and can alter it.
# A common use case for this is around adding some sort of comment header to the SQL

View File

@ -62,7 +62,7 @@ from superset import sql_parse
from superset.constants import TimeGrain as TimeGrainConstants
from superset.databases.utils import get_table_metadata, make_url_safe
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import OAuth2Error, OAuth2RedirectError
from superset.exceptions import DisallowedSQLFunction, OAuth2Error, OAuth2RedirectError
from superset.sql_parse import ParsedQuery, SQLScript, Table
from superset.superset_typing import (
OAuth2ClientConfig,
@ -1818,6 +1818,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
"""
if not cls.allows_sql_comments:
query = sql_parse.strip_comments_from_sql(query, engine=cls.engine)
disallowed_functions = current_app.config["DISALLOWED_SQL_FUNCTIONS"].get(
cls.engine, set()
)
if sql_parse.check_sql_functions_exist(query, disallowed_functions, cls.engine):
raise DisallowedSQLFunction(disallowed_functions)
if cls.arraysize:
cursor.arraysize = cls.arraysize

View File

@ -22,7 +22,7 @@ import threading
import time
from typing import Any, TYPE_CHECKING
from flask import current_app
from flask import current_app, Flask
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
from sqlalchemy.exc import NoSuchTableError
@ -218,11 +218,14 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
execute_result: dict[str, Any] = {}
execute_event = threading.Event()
def _execute(results: dict[str, Any], event: threading.Event) -> None:
def _execute(
results: dict[str, Any], event: threading.Event, app: Flask
) -> None:
logger.debug("Query %d: Running query: %s", query_id, sql)
try:
cls.execute(cursor, sql, query.database)
with app.app_context():
cls.execute(cursor, sql, query.database)
except Exception as ex: # pylint: disable=broad-except
results["error"] = ex
finally:
@ -230,7 +233,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
execute_thread = threading.Thread(
target=_execute,
args=(execute_result, execute_event),
args=(execute_result, execute_event, current_app._get_current_object()), # pylint: disable=protected-access
)
execute_thread.start()

View File

@ -358,6 +358,21 @@ class OAuth2Error(SupersetErrorException):
)
class DisallowedSQLFunction(SupersetErrorException):
"""
Disallowed function found on SQL statement
"""
def __init__(self, functions: set[str]):
super().__init__(
SupersetError(
message=f"SQL statement contains disallowed function(s): {functions}",
error_type=SupersetErrorType.SYNTAX_ERROR,
level=ErrorLevel.ERROR,
)
)
class CreateKeyValueDistributedLockFailedException(Exception):
"""
Exception to signalize failure to acquire lock.

View File

@ -39,6 +39,7 @@ from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
from sqlparse import keywords
from sqlparse.lexer import Lexer
from sqlparse.sql import (
Function,
Identifier,
IdentifierList,
Parenthesis,
@ -223,6 +224,19 @@ def get_cte_remainder_query(sql: str) -> tuple[str | None, str]:
return cte, remainder
def check_sql_functions_exist(
sql: str, function_list: set[str], engine: str | None = None
) -> bool:
"""
Check if the SQL statement contains any of the specified functions.
:param sql: The SQL statement
:param function_list: The list of functions to search for
:param engine: The engine to use for parsing the SQL statement
"""
return ParsedQuery(sql, engine=engine).check_functions_exist(function_list)
def strip_comments_from_sql(statement: str, engine: str | None = None) -> str:
"""
Strips comments from a SQL statement, does a simple test first
@ -743,6 +757,34 @@ class ParsedQuery:
self._tables = self._extract_tables_from_sql()
return self._tables
def _check_functions_exist_in_token(
self, token: Token, functions: set[str]
) -> bool:
if (
isinstance(token, Function)
and token.get_name() is not None
and token.get_name().lower() in functions
):
return True
if hasattr(token, "tokens"):
for inner_token in token.tokens:
if self._check_functions_exist_in_token(inner_token, functions):
return True
return False
def check_functions_exist(self, functions: set[str]) -> bool:
"""
Check if the SQL statement contains any of the specified functions.
:param functions: A set of functions to search for
:return: True if the statement contains any of the specified functions
"""
for statement in self._parsed:
for token in statement.tokens:
if self._check_functions_exist_in_token(token, functions):
return True
return False
def _extract_tables_from_sql(self) -> set[Table]:
"""
Extract all table references in a query.

View File

@ -401,7 +401,7 @@ def test_handle_cursor_early_cancel(
assert cancel_query_mock.call_args is None
def test_execute_with_cursor_in_parallel(mocker: MockerFixture):
def test_execute_with_cursor_in_parallel(app, mocker: MockerFixture):
"""Test that `execute_with_cursor` fetches query ID from the cursor"""
from superset.db_engine_specs.trino import TrinoEngineSpec
@ -416,16 +416,20 @@ def test_execute_with_cursor_in_parallel(mocker: MockerFixture):
mock_cursor.query_id = query_id
mock_cursor.execute.side_effect = _mock_execute
with patch.dict(
"superset.config.DISALLOWED_SQL_FUNCTIONS",
{},
clear=True,
):
TrinoEngineSpec.execute_with_cursor(
cursor=mock_cursor,
sql="SELECT 1 FROM foo",
query=mock_query,
)
TrinoEngineSpec.execute_with_cursor(
cursor=mock_cursor,
sql="SELECT 1 FROM foo",
query=mock_query,
)
mock_query.set_extra_json_key.assert_called_once_with(
key=QUERY_CANCEL_KEY, value=query_id
)
mock_query.set_extra_json_key.assert_called_once_with(
key=QUERY_CANCEL_KEY, value=query_id
)
def test_get_columns(mocker: MockerFixture):

View File

@ -32,6 +32,7 @@ from superset.exceptions import (
)
from superset.sql_parse import (
add_table_name,
check_sql_functions_exist,
extract_table_references,
extract_tables_from_jinja_sql,
get_rls_for_table,
@ -1215,6 +1216,31 @@ def test_strip_comments_from_sql() -> None:
)
def test_check_sql_functions_exist() -> None:
"""
Test that comments are stripped out correctly.
"""
assert not (
check_sql_functions_exist("select a, b from version", {"version"}, "postgresql")
)
assert check_sql_functions_exist("select version()", {"version"}, "postgresql")
assert check_sql_functions_exist(
"select version from version()", {"version"}, "postgresql"
)
assert check_sql_functions_exist(
"select 1, a.version from (select version from version()) as a",
{"version"},
"postgresql",
)
assert check_sql_functions_exist(
"select 1, a.version from (select version()) as a", {"version"}, "postgresql"
)
def test_sanitize_clause_valid():
# regular clauses
assert sanitize_clause("col = 1") == "col = 1"