From 7c14968e6d3b83498c20d6896f5239bf520d2c79 Mon Sep 17 00:00:00 2001 From: John Bodley <4567245+john-bodley@users.noreply.github.com> Date: Fri, 22 Mar 2024 13:39:28 +1300 Subject: [PATCH] fix(sql_parse): Ensure table extraction handles Jinja templating (#27470) --- superset/commands/sql_lab/execute.py | 4 +- superset/jinja_context.py | 10 +-- superset/models/sql_lab.py | 40 ++++++---- superset/security/manager.py | 13 +--- superset/sql_parse.py | 105 +++++++++++++++++++++------ superset/sqllab/query_render.py | 3 +- tests/unit_tests/sql_parse_tests.py | 41 +++++++++++ 7 files changed, 163 insertions(+), 53 deletions(-) diff --git a/superset/commands/sql_lab/execute.py b/superset/commands/sql_lab/execute.py index 5d955571d8..533264fb28 100644 --- a/superset/commands/sql_lab/execute.py +++ b/superset/commands/sql_lab/execute.py @@ -144,11 +144,13 @@ class ExecuteSqlCommand(BaseCommand): try: logger.info("Triggering query_id: %i", query.id) + # Necessary to check access before rendering the Jinjafied query as the + # some Jinja macros execute statements upon rendering. + self._validate_access(query) self._execution_context.set_query(query) rendered_query = self._sql_query_render.render(self._execution_context) validate_rendered_query = copy.copy(query) validate_rendered_query.sql = rendered_query - self._validate_access(validate_rendered_query) self._set_query_limit_if_required(rendered_query) self._query_dao.update( query, {"limit": self._execution_context.query.limit} diff --git a/superset/jinja_context.py b/superset/jinja_context.py index 54d1f54866..9edddf24d9 100644 --- a/superset/jinja_context.py +++ b/superset/jinja_context.py @@ -24,7 +24,7 @@ from typing import Any, Callable, cast, Optional, TYPE_CHECKING, TypedDict, Unio import dateutil from flask import current_app, g, has_request_context, request from flask_babel import gettext as _ -from jinja2 import DebugUndefined +from jinja2 import DebugUndefined, Environment from jinja2.sandbox import SandboxedEnvironment from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.sql.expression import bindparam @@ -462,11 +462,11 @@ class BaseTemplateProcessor: self._applied_filters = applied_filters self._removed_filters = removed_filters self._context: dict[str, Any] = {} - self._env = SandboxedEnvironment(undefined=DebugUndefined) + self.env: Environment = SandboxedEnvironment(undefined=DebugUndefined) self.set_context(**kwargs) # custom filters - self._env.filters["where_in"] = WhereInMacro(database.get_dialect()) + self.env.filters["where_in"] = WhereInMacro(database.get_dialect()) def set_context(self, **kwargs: Any) -> None: self._context.update(kwargs) @@ -479,7 +479,7 @@ class BaseTemplateProcessor: >>> process_template(sql) "SELECT '2017-01-01T00:00:00'" """ - template = self._env.from_string(sql) + template = self.env.from_string(sql) kwargs.update(self._context) context = validate_template_context(self.engine, kwargs) @@ -623,7 +623,7 @@ class TrinoTemplateProcessor(PrestoTemplateProcessor): engine = "trino" def process_template(self, sql: str, **kwargs: Any) -> str: - template = self._env.from_string(sql) + template = self.env.from_string(sql) kwargs.update(self._context) # Backwards compatibility if migrating from Presto. diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index f4724d6dab..2d7384a74e 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -46,6 +46,7 @@ from sqlalchemy.orm import backref, relationship from sqlalchemy.sql.elements import ColumnElement, literal_column from superset import security_manager +from superset.exceptions import SupersetSecurityException from superset.jinja_context import BaseTemplateProcessor, get_template_processor from superset.models.helpers import ( AuditMixinNullable, @@ -53,7 +54,7 @@ from superset.models.helpers import ( ExtraJSONMixin, ImportExportMixin, ) -from superset.sql_parse import CtasMethod, ParsedQuery, Table +from superset.sql_parse import CtasMethod, extract_tables_from_jinja_sql, Table from superset.sqllab.limiting_factor import LimitingFactor from superset.utils.core import get_column_name, MediumText, QueryStatus, user_label @@ -65,8 +66,25 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +class SqlTablesMixin: # pylint: disable=too-few-public-methods + @property + def sql_tables(self) -> list[Table]: + try: + return list( + extract_tables_from_jinja_sql( + self.sql, # type: ignore + self.database.db_engine_spec.engine, # type: ignore + ) + ) + except SupersetSecurityException: + return [] + + class Query( - ExtraJSONMixin, ExploreMixin, Model + SqlTablesMixin, + ExtraJSONMixin, + ExploreMixin, + Model, ): # pylint: disable=abstract-method,too-many-public-methods """ORM model for SQL query @@ -181,10 +199,6 @@ class Query( def username(self) -> str: return self.user.username - @property - def sql_tables(self) -> list[Table]: - return list(ParsedQuery(self.sql, engine=self.db_engine_spec.engine).tables) - @property def columns(self) -> list["TableColumn"]: from superset.connectors.sqla.models import ( # pylint: disable=import-outside-toplevel @@ -355,7 +369,13 @@ class Query( return self.make_sqla_column_compatible(sqla_column, label) -class SavedQuery(AuditMixinNullable, ExtraJSONMixin, ImportExportMixin, Model): +class SavedQuery( + SqlTablesMixin, + AuditMixinNullable, + ExtraJSONMixin, + ImportExportMixin, + Model, +): """ORM model for SQL query""" __tablename__ = "saved_query" @@ -425,12 +445,6 @@ class SavedQuery(AuditMixinNullable, ExtraJSONMixin, ImportExportMixin, Model): def url(self) -> str: return f"/sqllab?savedQueryId={self.id}" - @property - def sql_tables(self) -> list[Table]: - return list( - ParsedQuery(self.sql, engine=self.database.db_engine_spec.engine).tables - ) - @property def last_run_humanized(self) -> str: return naturaltime(datetime.now() - self.changed_on) diff --git a/superset/security/manager.py b/superset/security/manager.py index a532431433..2833e88645 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -52,14 +52,12 @@ from sqlalchemy.orm import eagerload from sqlalchemy.orm.mapper import Mapper from sqlalchemy.orm.query import Query as SqlaQuery -from superset import sql_parse from superset.constants import RouteMethod from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import ( DatasetInvalidPermissionEvaluationException, SupersetSecurityException, ) -from superset.jinja_context import get_template_processor from superset.security.guest_token import ( GuestToken, GuestTokenResources, @@ -68,6 +66,7 @@ from superset.security.guest_token import ( GuestTokenUser, GuestUser, ) +from superset.sql_parse import extract_tables_from_jinja_sql from superset.superset_typing import Metric from superset.utils.core import ( DatasourceName, @@ -1961,16 +1960,12 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods return if query: - # make sure the quuery is valid SQL by rendering any Jinja - processor = get_template_processor(database=query.database) - rendered_sql = processor.process_template(query.sql) default_schema = database.get_default_schema_for_query(query) tables = { Table(table_.table, table_.schema or default_schema) - for table_ in sql_parse.ParsedQuery( - rendered_sql, - engine=database.db_engine_spec.engine, - ).tables + for table_ in extract_tables_from_jinja_sql( + query.sql, database.db_engine_spec.engine + ) } elif table: tables = {table} diff --git a/superset/sql_parse.py b/superset/sql_parse.py index db51991e22..58bca48a6e 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -16,16 +16,19 @@ # under the License. # pylint: disable=too-many-lines +from __future__ import annotations import logging import re import urllib.parse from collections.abc import Iterable, Iterator from dataclasses import dataclass -from typing import Any, cast, Optional +from typing import Any, cast +from unittest.mock import Mock import sqlparse from flask_babel import gettext as __ +from jinja2 import nodes from sqlalchemy import and_ from sqlglot import exp, parse, parse_one from sqlglot.dialects import Dialects @@ -142,7 +145,7 @@ class CtasMethod(StrEnum): VIEW = "VIEW" -def _extract_limit_from_query(statement: TokenList) -> Optional[int]: +def _extract_limit_from_query(statement: TokenList) -> int | None: """ Extract limit clause from SQL statement. @@ -163,9 +166,7 @@ def _extract_limit_from_query(statement: TokenList) -> Optional[int]: return None -def extract_top_from_query( - statement: TokenList, top_keywords: set[str] -) -> Optional[int]: +def extract_top_from_query(statement: TokenList, top_keywords: set[str]) -> int | None: """ Extract top clause value from SQL statement. @@ -189,7 +190,7 @@ def extract_top_from_query( return top -def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]: +def get_cte_remainder_query(sql: str) -> tuple[str | None, str]: """ parse the SQL and return the CTE and rest of the block to the caller @@ -197,7 +198,7 @@ def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]: :return: CTE and remainder block to the caller """ - cte: Optional[str] = None + cte: str | None = None remainder = sql stmt = sqlparse.parse(sql)[0] @@ -215,7 +216,7 @@ def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]: return cte, remainder -def strip_comments_from_sql(statement: str, engine: Optional[str] = None) -> str: +def strip_comments_from_sql(statement: str, engine: str | None = None) -> str: """ Strips comments from a SQL statement, does a simple test first to avoid always instantiating the expensive ParsedQuery constructor @@ -239,8 +240,8 @@ class Table: """ table: str - schema: Optional[str] = None - catalog: Optional[str] = None + schema: str | None = None + catalog: str | None = None def __str__(self) -> str: """ @@ -262,7 +263,7 @@ class ParsedQuery: self, sql_statement: str, strip_comments: bool = False, - engine: Optional[str] = None, + engine: str | None = None, ): if strip_comments: sql_statement = sqlparse.format(sql_statement, strip_comments=True) @@ -271,7 +272,7 @@ class ParsedQuery: self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None self._tables: set[Table] = set() self._alias_names: set[str] = set() - self._limit: Optional[int] = None + self._limit: int | None = None logger.debug("Parsing with sqlparse statement: %s", self.sql) self._parsed = sqlparse.parse(self.stripped()) @@ -382,7 +383,7 @@ class ParsedQuery: return source.name in ctes_in_scope @property - def limit(self) -> Optional[int]: + def limit(self) -> int | None: return self._limit def _get_cte_tables(self, parsed: dict[str, Any]) -> list[dict[str, Any]]: @@ -463,7 +464,7 @@ class ParsedQuery: return True - def get_inner_cte_expression(self, tokens: TokenList) -> Optional[TokenList]: + def get_inner_cte_expression(self, tokens: TokenList) -> TokenList | None: for token in tokens: if self._is_identifier(token): for identifier_token in token.tokens: @@ -527,7 +528,7 @@ class ParsedQuery: return statements @staticmethod - def get_table(tlist: TokenList) -> Optional[Table]: + def get_table(tlist: TokenList) -> Table | None: """ Return the table if valid, i.e., conforms to the [[catalog.]schema.]table construct. @@ -563,7 +564,7 @@ class ParsedQuery: def as_create_table( self, table_name: str, - schema_name: Optional[str] = None, + schema_name: str | None = None, overwrite: bool = False, method: CtasMethod = CtasMethod.TABLE, ) -> str: @@ -723,8 +724,8 @@ def add_table_name(rls: TokenList, table: str) -> None: def get_rls_for_table( candidate: Token, database_id: int, - default_schema: Optional[str], -) -> Optional[TokenList]: + default_schema: str | None, +) -> TokenList | None: """ Given a table name, return any associated RLS predicates. """ @@ -770,7 +771,7 @@ def get_rls_for_table( def insert_rls_as_subquery( token_list: TokenList, database_id: int, - default_schema: Optional[str], + default_schema: str | None, ) -> TokenList: """ Update a statement inplace applying any associated RLS predicates. @@ -786,7 +787,7 @@ def insert_rls_as_subquery( This method is safer than ``insert_rls_in_predicate``, but doesn't work in all databases. """ - rls: Optional[TokenList] = None + rls: TokenList | None = None state = InsertRLSState.SCANNING for token in token_list.tokens: # Recurse into child token list @@ -862,7 +863,7 @@ def insert_rls_as_subquery( def insert_rls_in_predicate( token_list: TokenList, database_id: int, - default_schema: Optional[str], + default_schema: str | None, ) -> TokenList: """ Update a statement inplace applying any associated RLS predicates. @@ -873,7 +874,7 @@ def insert_rls_in_predicate( after: SELECT * FROM some_table WHERE ( 1=1) AND some_table.id=42 """ - rls: Optional[TokenList] = None + rls: TokenList | None = None state = InsertRLSState.SCANNING for token in token_list.tokens: # Recurse into child token list @@ -1007,7 +1008,7 @@ RE_JINJA_BLOCK = re.compile(r"\{[%#][^\{\}%#]+[%#]\}") def extract_table_references( sql_text: str, sqla_dialect: str, show_warning: bool = True -) -> set["Table"]: +) -> set[Table]: """ Return all the dependencies from a SQL sql_text. """ @@ -1051,3 +1052,61 @@ def extract_table_references( Table(*[part["value"] for part in table["name"][::-1]]) for table in find_nodes_by_key(tree, "Table") } + + +def extract_tables_from_jinja_sql(sql: str, engine: str | None = None) -> set[Table]: + """ + Extract all table references in the Jinjafied SQL statement. + + Due to Jinja templating, a multiphase approach is necessary as the Jinjafied SQL + statement may represent invalid SQL which is non-parsable by SQLGlot. + + Firstly, we extract any tables referenced within the confines of specific Jinja + macros. Secondly, we replace these non-SQL Jinja calls with a pseudo-benign SQL + expression to help ensure that the resulting SQL statements are parsable by + SQLGlot. + + :param sql: The Jinjafied SQL statement + :param engine: The associated database engine + :returns: The set of tables referenced in the SQL statement + :raises SupersetSecurityException: If SQLGlot is unable to parse the SQL statement + """ + + from superset.jinja_context import ( # pylint: disable=import-outside-toplevel + get_template_processor, + ) + + # Mock the required database as the processor signature is exposed publically. + processor = get_template_processor(database=Mock(backend=engine)) + template = processor.env.parse(sql) + + tables = set() + + for node in template.find_all(nodes.Call): + if isinstance(node.node, nodes.Getattr) and node.node.attr in ( + "latest_partition", + "latest_sub_partition", + ): + # Extract the table referenced in the macro. + tables.add( + Table( + *[ + remove_quotes(part) + for part in node.args[0].value.split(".")[::-1] + if len(node.args) == 1 + ] + ) + ) + + # Replace the potentially problematic Jinja macro with some benign SQL. + node.__class__ = nodes.TemplateData + node.fields = nodes.TemplateData.fields + node.data = "NULL" + + return ( + tables + | ParsedQuery( + sql_statement=processor.process_template(template), + engine=engine, + ).tables + ) diff --git a/superset/sqllab/query_render.py b/superset/sqllab/query_render.py index 5597bcb086..caf9a3cb2b 100644 --- a/superset/sqllab/query_render.py +++ b/superset/sqllab/query_render.py @@ -79,8 +79,7 @@ class SqlQueryRenderImpl(SqlQueryRender): sql_template_processor: BaseTemplateProcessor, ) -> None: if is_feature_enabled("ENABLE_TEMPLATE_PROCESSING"): - # pylint: disable=protected-access - syntax_tree = sql_template_processor._env.parse(rendered_query) + syntax_tree = sql_template_processor.env.parse(rendered_query) undefined_parameters = find_undeclared_variables(syntax_tree) if undefined_parameters: self._raise_undefined_parameter_exception( diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index 025108e9b5..81ea0e5a7a 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -32,6 +32,7 @@ from superset.exceptions import ( from superset.sql_parse import ( add_table_name, extract_table_references, + extract_tables_from_jinja_sql, get_rls_for_table, has_table_query, insert_rls_as_subquery, @@ -1874,3 +1875,43 @@ WITH t AS ( ) SELECT * FROM t""" ).is_select() + + +@pytest.mark.parametrize( + "engine", + [ + "hive", + "presto", + "trino", + ], +) +@pytest.mark.parametrize( + "macro", + [ + "latest_partition('foo.bar')", + "latest_sub_partition('foo.bar', baz='qux')", + ], +) +@pytest.mark.parametrize( + "sql,expected", + [ + ( + "SELECT '{{{{ {engine}.{macro} }}}}'", + {Table(table="bar", schema="foo")}, + ), + ( + "SELECT * FROM foo.baz WHERE quux = '{{{{ {engine}.{macro} }}}}'", + {Table(table="bar", schema="foo"), Table(table="baz", schema="foo")}, + ), + ], +) +def test_extract_tables_from_jinja_sql( + engine: str, + macro: str, + sql: str, + expected: set[Table], +) -> None: + assert ( + extract_tables_from_jinja_sql(sql.format(engine=engine, macro=macro), engine) + == expected + )