From 5dfbab542422e6f68b020bc0bccf41caa3e1f248 Mon Sep 17 00:00:00 2001 From: Daniel Vaz Gaspar Date: Wed, 29 May 2024 10:51:28 +0100 Subject: [PATCH] fix: adds the ability to disallow SQL functions per engine (#28639) --- superset/config.py | 9 ++++ superset/db_engine_specs/base.py | 7 +++- superset/db_engine_specs/trino.py | 11 +++-- superset/exceptions.py | 15 +++++++ superset/sql_parse.py | 42 +++++++++++++++++++ .../unit_tests/db_engine_specs/test_trino.py | 24 ++++++----- tests/unit_tests/sql_parse_tests.py | 26 ++++++++++++ 7 files changed, 119 insertions(+), 15 deletions(-) diff --git a/superset/config.py b/superset/config.py index 3c92354322..aa8178d086 100644 --- a/superset/config.py +++ b/superset/config.py @@ -1227,6 +1227,15 @@ DB_CONNECTION_MUTATOR = None # DB_SQLA_URI_VALIDATOR: Callable[[URL], None] | None = None +# A set of disallowed SQL functions per engine. This is used to restrict the use of +# unsafe SQL functions in SQL Lab and Charts. The keys of the dictionary are the engine +# names, and the values are sets of disallowed functions. +DISALLOWED_SQL_FUNCTIONS: dict[str, set[str]] = { + "postgresql": {"version", "query_to_xml", "inet_server_addr", "inet_client_addr"}, + "clickhouse": {"url"}, + "mysql": {"version"}, +} + # A function that intercepts the SQL to be executed and can alter it. # A common use case for this is around adding some sort of comment header to the SQL diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 6df0dc61aa..548fb390d8 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -62,7 +62,7 @@ from superset import sql_parse from superset.constants import TimeGrain as TimeGrainConstants from superset.databases.utils import get_table_metadata, make_url_safe from superset.errors import ErrorLevel, SupersetError, SupersetErrorType -from superset.exceptions import OAuth2Error, OAuth2RedirectError +from superset.exceptions import DisallowedSQLFunction, OAuth2Error, OAuth2RedirectError from superset.sql_parse import ParsedQuery, SQLScript, Table from superset.superset_typing import ( OAuth2ClientConfig, @@ -1818,6 +1818,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods """ if not cls.allows_sql_comments: query = sql_parse.strip_comments_from_sql(query, engine=cls.engine) + disallowed_functions = current_app.config["DISALLOWED_SQL_FUNCTIONS"].get( + cls.engine, set() + ) + if sql_parse.check_sql_functions_exist(query, disallowed_functions, cls.engine): + raise DisallowedSQLFunction(disallowed_functions) if cls.arraysize: cursor.arraysize = cls.arraysize diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index eea00877d9..600f236b48 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -22,7 +22,7 @@ import threading import time from typing import Any, TYPE_CHECKING -from flask import current_app +from flask import current_app, Flask from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.url import URL from sqlalchemy.exc import NoSuchTableError @@ -218,11 +218,14 @@ class TrinoEngineSpec(PrestoBaseEngineSpec): execute_result: dict[str, Any] = {} execute_event = threading.Event() - def _execute(results: dict[str, Any], event: threading.Event) -> None: + def _execute( + results: dict[str, Any], event: threading.Event, app: Flask + ) -> None: logger.debug("Query %d: Running query: %s", query_id, sql) try: - cls.execute(cursor, sql, query.database) + with app.app_context(): + cls.execute(cursor, sql, query.database) except Exception as ex: # pylint: disable=broad-except results["error"] = ex finally: @@ -230,7 +233,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec): execute_thread = threading.Thread( target=_execute, - args=(execute_result, execute_event), + args=(execute_result, execute_event, current_app._get_current_object()), # pylint: disable=protected-access ) execute_thread.start() diff --git a/superset/exceptions.py b/superset/exceptions.py index 0315ee30f4..47cd511f8f 100644 --- a/superset/exceptions.py +++ b/superset/exceptions.py @@ -358,6 +358,21 @@ class OAuth2Error(SupersetErrorException): ) +class DisallowedSQLFunction(SupersetErrorException): + """ + Disallowed function found on SQL statement + """ + + def __init__(self, functions: set[str]): + super().__init__( + SupersetError( + message=f"SQL statement contains disallowed function(s): {functions}", + error_type=SupersetErrorType.SYNTAX_ERROR, + level=ErrorLevel.ERROR, + ) + ) + + class CreateKeyValueDistributedLockFailedException(Exception): """ Exception to signalize failure to acquire lock. diff --git a/superset/sql_parse.py b/superset/sql_parse.py index f32647042b..192a998c3f 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -39,6 +39,7 @@ from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope from sqlparse import keywords from sqlparse.lexer import Lexer from sqlparse.sql import ( + Function, Identifier, IdentifierList, Parenthesis, @@ -223,6 +224,19 @@ def get_cte_remainder_query(sql: str) -> tuple[str | None, str]: return cte, remainder +def check_sql_functions_exist( + sql: str, function_list: set[str], engine: str | None = None +) -> bool: + """ + Check if the SQL statement contains any of the specified functions. + + :param sql: The SQL statement + :param function_list: The list of functions to search for + :param engine: The engine to use for parsing the SQL statement + """ + return ParsedQuery(sql, engine=engine).check_functions_exist(function_list) + + def strip_comments_from_sql(statement: str, engine: str | None = None) -> str: """ Strips comments from a SQL statement, does a simple test first @@ -743,6 +757,34 @@ class ParsedQuery: self._tables = self._extract_tables_from_sql() return self._tables + def _check_functions_exist_in_token( + self, token: Token, functions: set[str] + ) -> bool: + if ( + isinstance(token, Function) + and token.get_name() is not None + and token.get_name().lower() in functions + ): + return True + if hasattr(token, "tokens"): + for inner_token in token.tokens: + if self._check_functions_exist_in_token(inner_token, functions): + return True + return False + + def check_functions_exist(self, functions: set[str]) -> bool: + """ + Check if the SQL statement contains any of the specified functions. + + :param functions: A set of functions to search for + :return: True if the statement contains any of the specified functions + """ + for statement in self._parsed: + for token in statement.tokens: + if self._check_functions_exist_in_token(token, functions): + return True + return False + def _extract_tables_from_sql(self) -> set[Table]: """ Extract all table references in a query. diff --git a/tests/unit_tests/db_engine_specs/test_trino.py b/tests/unit_tests/db_engine_specs/test_trino.py index 88608f1e38..3a2ac91ad6 100644 --- a/tests/unit_tests/db_engine_specs/test_trino.py +++ b/tests/unit_tests/db_engine_specs/test_trino.py @@ -401,7 +401,7 @@ def test_handle_cursor_early_cancel( assert cancel_query_mock.call_args is None -def test_execute_with_cursor_in_parallel(mocker: MockerFixture): +def test_execute_with_cursor_in_parallel(app, mocker: MockerFixture): """Test that `execute_with_cursor` fetches query ID from the cursor""" from superset.db_engine_specs.trino import TrinoEngineSpec @@ -416,16 +416,20 @@ def test_execute_with_cursor_in_parallel(mocker: MockerFixture): mock_cursor.query_id = query_id mock_cursor.execute.side_effect = _mock_execute + with patch.dict( + "superset.config.DISALLOWED_SQL_FUNCTIONS", + {}, + clear=True, + ): + TrinoEngineSpec.execute_with_cursor( + cursor=mock_cursor, + sql="SELECT 1 FROM foo", + query=mock_query, + ) - TrinoEngineSpec.execute_with_cursor( - cursor=mock_cursor, - sql="SELECT 1 FROM foo", - query=mock_query, - ) - - mock_query.set_extra_json_key.assert_called_once_with( - key=QUERY_CANCEL_KEY, value=query_id - ) + mock_query.set_extra_json_key.assert_called_once_with( + key=QUERY_CANCEL_KEY, value=query_id + ) def test_get_columns(mocker: MockerFixture): diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index 3b80b7e01d..6259d6272d 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -32,6 +32,7 @@ from superset.exceptions import ( ) from superset.sql_parse import ( add_table_name, + check_sql_functions_exist, extract_table_references, extract_tables_from_jinja_sql, get_rls_for_table, @@ -1215,6 +1216,31 @@ def test_strip_comments_from_sql() -> None: ) +def test_check_sql_functions_exist() -> None: + """ + Test that comments are stripped out correctly. + """ + assert not ( + check_sql_functions_exist("select a, b from version", {"version"}, "postgresql") + ) + + assert check_sql_functions_exist("select version()", {"version"}, "postgresql") + + assert check_sql_functions_exist( + "select version from version()", {"version"}, "postgresql" + ) + + assert check_sql_functions_exist( + "select 1, a.version from (select version from version()) as a", + {"version"}, + "postgresql", + ) + + assert check_sql_functions_exist( + "select 1, a.version from (select version()) as a", {"version"}, "postgresql" + ) + + def test_sanitize_clause_valid(): # regular clauses assert sanitize_clause("col = 1") == "col = 1"