mirror of https://github.com/apache/superset.git
fix: adds the ability to disallow SQL functions per engine (#28639)
This commit is contained in:
parent
6575cacc5d
commit
5dfbab5424
|
@ -1227,6 +1227,15 @@ DB_CONNECTION_MUTATOR = None
|
||||||
#
|
#
|
||||||
DB_SQLA_URI_VALIDATOR: Callable[[URL], None] | None = None
|
DB_SQLA_URI_VALIDATOR: Callable[[URL], None] | None = None
|
||||||
|
|
||||||
|
# A set of disallowed SQL functions per engine. This is used to restrict the use of
|
||||||
|
# unsafe SQL functions in SQL Lab and Charts. The keys of the dictionary are the engine
|
||||||
|
# names, and the values are sets of disallowed functions.
|
||||||
|
DISALLOWED_SQL_FUNCTIONS: dict[str, set[str]] = {
|
||||||
|
"postgresql": {"version", "query_to_xml", "inet_server_addr", "inet_client_addr"},
|
||||||
|
"clickhouse": {"url"},
|
||||||
|
"mysql": {"version"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# A function that intercepts the SQL to be executed and can alter it.
|
# A function that intercepts the SQL to be executed and can alter it.
|
||||||
# A common use case for this is around adding some sort of comment header to the SQL
|
# A common use case for this is around adding some sort of comment header to the SQL
|
||||||
|
|
|
@ -62,7 +62,7 @@ from superset import sql_parse
|
||||||
from superset.constants import TimeGrain as TimeGrainConstants
|
from superset.constants import TimeGrain as TimeGrainConstants
|
||||||
from superset.databases.utils import get_table_metadata, make_url_safe
|
from superset.databases.utils import get_table_metadata, make_url_safe
|
||||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||||
from superset.exceptions import OAuth2Error, OAuth2RedirectError
|
from superset.exceptions import DisallowedSQLFunction, OAuth2Error, OAuth2RedirectError
|
||||||
from superset.sql_parse import ParsedQuery, SQLScript, Table
|
from superset.sql_parse import ParsedQuery, SQLScript, Table
|
||||||
from superset.superset_typing import (
|
from superset.superset_typing import (
|
||||||
OAuth2ClientConfig,
|
OAuth2ClientConfig,
|
||||||
|
@ -1818,6 +1818,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||||
"""
|
"""
|
||||||
if not cls.allows_sql_comments:
|
if not cls.allows_sql_comments:
|
||||||
query = sql_parse.strip_comments_from_sql(query, engine=cls.engine)
|
query = sql_parse.strip_comments_from_sql(query, engine=cls.engine)
|
||||||
|
disallowed_functions = current_app.config["DISALLOWED_SQL_FUNCTIONS"].get(
|
||||||
|
cls.engine, set()
|
||||||
|
)
|
||||||
|
if sql_parse.check_sql_functions_exist(query, disallowed_functions, cls.engine):
|
||||||
|
raise DisallowedSQLFunction(disallowed_functions)
|
||||||
|
|
||||||
if cls.arraysize:
|
if cls.arraysize:
|
||||||
cursor.arraysize = cls.arraysize
|
cursor.arraysize = cls.arraysize
|
||||||
|
|
|
@ -22,7 +22,7 @@ import threading
|
||||||
import time
|
import time
|
||||||
from typing import Any, TYPE_CHECKING
|
from typing import Any, TYPE_CHECKING
|
||||||
|
|
||||||
from flask import current_app
|
from flask import current_app, Flask
|
||||||
from sqlalchemy.engine.reflection import Inspector
|
from sqlalchemy.engine.reflection import Inspector
|
||||||
from sqlalchemy.engine.url import URL
|
from sqlalchemy.engine.url import URL
|
||||||
from sqlalchemy.exc import NoSuchTableError
|
from sqlalchemy.exc import NoSuchTableError
|
||||||
|
@ -218,11 +218,14 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
|
||||||
execute_result: dict[str, Any] = {}
|
execute_result: dict[str, Any] = {}
|
||||||
execute_event = threading.Event()
|
execute_event = threading.Event()
|
||||||
|
|
||||||
def _execute(results: dict[str, Any], event: threading.Event) -> None:
|
def _execute(
|
||||||
|
results: dict[str, Any], event: threading.Event, app: Flask
|
||||||
|
) -> None:
|
||||||
logger.debug("Query %d: Running query: %s", query_id, sql)
|
logger.debug("Query %d: Running query: %s", query_id, sql)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cls.execute(cursor, sql, query.database)
|
with app.app_context():
|
||||||
|
cls.execute(cursor, sql, query.database)
|
||||||
except Exception as ex: # pylint: disable=broad-except
|
except Exception as ex: # pylint: disable=broad-except
|
||||||
results["error"] = ex
|
results["error"] = ex
|
||||||
finally:
|
finally:
|
||||||
|
@ -230,7 +233,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
|
||||||
|
|
||||||
execute_thread = threading.Thread(
|
execute_thread = threading.Thread(
|
||||||
target=_execute,
|
target=_execute,
|
||||||
args=(execute_result, execute_event),
|
args=(execute_result, execute_event, current_app._get_current_object()), # pylint: disable=protected-access
|
||||||
)
|
)
|
||||||
execute_thread.start()
|
execute_thread.start()
|
||||||
|
|
||||||
|
|
|
@ -358,6 +358,21 @@ class OAuth2Error(SupersetErrorException):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DisallowedSQLFunction(SupersetErrorException):
|
||||||
|
"""
|
||||||
|
Disallowed function found on SQL statement
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, functions: set[str]):
|
||||||
|
super().__init__(
|
||||||
|
SupersetError(
|
||||||
|
message=f"SQL statement contains disallowed function(s): {functions}",
|
||||||
|
error_type=SupersetErrorType.SYNTAX_ERROR,
|
||||||
|
level=ErrorLevel.ERROR,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CreateKeyValueDistributedLockFailedException(Exception):
|
class CreateKeyValueDistributedLockFailedException(Exception):
|
||||||
"""
|
"""
|
||||||
Exception to signalize failure to acquire lock.
|
Exception to signalize failure to acquire lock.
|
||||||
|
|
|
@ -39,6 +39,7 @@ from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
|
||||||
from sqlparse import keywords
|
from sqlparse import keywords
|
||||||
from sqlparse.lexer import Lexer
|
from sqlparse.lexer import Lexer
|
||||||
from sqlparse.sql import (
|
from sqlparse.sql import (
|
||||||
|
Function,
|
||||||
Identifier,
|
Identifier,
|
||||||
IdentifierList,
|
IdentifierList,
|
||||||
Parenthesis,
|
Parenthesis,
|
||||||
|
@ -223,6 +224,19 @@ def get_cte_remainder_query(sql: str) -> tuple[str | None, str]:
|
||||||
return cte, remainder
|
return cte, remainder
|
||||||
|
|
||||||
|
|
||||||
|
def check_sql_functions_exist(
|
||||||
|
sql: str, function_list: set[str], engine: str | None = None
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the SQL statement contains any of the specified functions.
|
||||||
|
|
||||||
|
:param sql: The SQL statement
|
||||||
|
:param function_list: The list of functions to search for
|
||||||
|
:param engine: The engine to use for parsing the SQL statement
|
||||||
|
"""
|
||||||
|
return ParsedQuery(sql, engine=engine).check_functions_exist(function_list)
|
||||||
|
|
||||||
|
|
||||||
def strip_comments_from_sql(statement: str, engine: str | None = None) -> str:
|
def strip_comments_from_sql(statement: str, engine: str | None = None) -> str:
|
||||||
"""
|
"""
|
||||||
Strips comments from a SQL statement, does a simple test first
|
Strips comments from a SQL statement, does a simple test first
|
||||||
|
@ -743,6 +757,34 @@ class ParsedQuery:
|
||||||
self._tables = self._extract_tables_from_sql()
|
self._tables = self._extract_tables_from_sql()
|
||||||
return self._tables
|
return self._tables
|
||||||
|
|
||||||
|
def _check_functions_exist_in_token(
|
||||||
|
self, token: Token, functions: set[str]
|
||||||
|
) -> bool:
|
||||||
|
if (
|
||||||
|
isinstance(token, Function)
|
||||||
|
and token.get_name() is not None
|
||||||
|
and token.get_name().lower() in functions
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
if hasattr(token, "tokens"):
|
||||||
|
for inner_token in token.tokens:
|
||||||
|
if self._check_functions_exist_in_token(inner_token, functions):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def check_functions_exist(self, functions: set[str]) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the SQL statement contains any of the specified functions.
|
||||||
|
|
||||||
|
:param functions: A set of functions to search for
|
||||||
|
:return: True if the statement contains any of the specified functions
|
||||||
|
"""
|
||||||
|
for statement in self._parsed:
|
||||||
|
for token in statement.tokens:
|
||||||
|
if self._check_functions_exist_in_token(token, functions):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def _extract_tables_from_sql(self) -> set[Table]:
|
def _extract_tables_from_sql(self) -> set[Table]:
|
||||||
"""
|
"""
|
||||||
Extract all table references in a query.
|
Extract all table references in a query.
|
||||||
|
|
|
@ -401,7 +401,7 @@ def test_handle_cursor_early_cancel(
|
||||||
assert cancel_query_mock.call_args is None
|
assert cancel_query_mock.call_args is None
|
||||||
|
|
||||||
|
|
||||||
def test_execute_with_cursor_in_parallel(mocker: MockerFixture):
|
def test_execute_with_cursor_in_parallel(app, mocker: MockerFixture):
|
||||||
"""Test that `execute_with_cursor` fetches query ID from the cursor"""
|
"""Test that `execute_with_cursor` fetches query ID from the cursor"""
|
||||||
from superset.db_engine_specs.trino import TrinoEngineSpec
|
from superset.db_engine_specs.trino import TrinoEngineSpec
|
||||||
|
|
||||||
|
@ -416,16 +416,20 @@ def test_execute_with_cursor_in_parallel(mocker: MockerFixture):
|
||||||
mock_cursor.query_id = query_id
|
mock_cursor.query_id = query_id
|
||||||
|
|
||||||
mock_cursor.execute.side_effect = _mock_execute
|
mock_cursor.execute.side_effect = _mock_execute
|
||||||
|
with patch.dict(
|
||||||
|
"superset.config.DISALLOWED_SQL_FUNCTIONS",
|
||||||
|
{},
|
||||||
|
clear=True,
|
||||||
|
):
|
||||||
|
TrinoEngineSpec.execute_with_cursor(
|
||||||
|
cursor=mock_cursor,
|
||||||
|
sql="SELECT 1 FROM foo",
|
||||||
|
query=mock_query,
|
||||||
|
)
|
||||||
|
|
||||||
TrinoEngineSpec.execute_with_cursor(
|
mock_query.set_extra_json_key.assert_called_once_with(
|
||||||
cursor=mock_cursor,
|
key=QUERY_CANCEL_KEY, value=query_id
|
||||||
sql="SELECT 1 FROM foo",
|
)
|
||||||
query=mock_query,
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_query.set_extra_json_key.assert_called_once_with(
|
|
||||||
key=QUERY_CANCEL_KEY, value=query_id
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_columns(mocker: MockerFixture):
|
def test_get_columns(mocker: MockerFixture):
|
||||||
|
|
|
@ -32,6 +32,7 @@ from superset.exceptions import (
|
||||||
)
|
)
|
||||||
from superset.sql_parse import (
|
from superset.sql_parse import (
|
||||||
add_table_name,
|
add_table_name,
|
||||||
|
check_sql_functions_exist,
|
||||||
extract_table_references,
|
extract_table_references,
|
||||||
extract_tables_from_jinja_sql,
|
extract_tables_from_jinja_sql,
|
||||||
get_rls_for_table,
|
get_rls_for_table,
|
||||||
|
@ -1215,6 +1216,31 @@ def test_strip_comments_from_sql() -> None:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_sql_functions_exist() -> None:
|
||||||
|
"""
|
||||||
|
Test that comments are stripped out correctly.
|
||||||
|
"""
|
||||||
|
assert not (
|
||||||
|
check_sql_functions_exist("select a, b from version", {"version"}, "postgresql")
|
||||||
|
)
|
||||||
|
|
||||||
|
assert check_sql_functions_exist("select version()", {"version"}, "postgresql")
|
||||||
|
|
||||||
|
assert check_sql_functions_exist(
|
||||||
|
"select version from version()", {"version"}, "postgresql"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert check_sql_functions_exist(
|
||||||
|
"select 1, a.version from (select version from version()) as a",
|
||||||
|
{"version"},
|
||||||
|
"postgresql",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert check_sql_functions_exist(
|
||||||
|
"select 1, a.version from (select version()) as a", {"version"}, "postgresql"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_sanitize_clause_valid():
|
def test_sanitize_clause_valid():
|
||||||
# regular clauses
|
# regular clauses
|
||||||
assert sanitize_clause("col = 1") == "col = 1"
|
assert sanitize_clause("col = 1") == "col = 1"
|
||||||
|
|
Loading…
Reference in New Issue