fix(sql_parse): Ensure table extraction handles Jinja templating (#27470)

This commit is contained in:
John Bodley 2024-03-22 13:39:28 +13:00 committed by Michael S. Molina
parent 4ff331a66c
commit 7c14968e6d
7 changed files with 163 additions and 53 deletions

View File

@ -144,11 +144,13 @@ class ExecuteSqlCommand(BaseCommand):
try: try:
logger.info("Triggering query_id: %i", query.id) logger.info("Triggering query_id: %i", query.id)
# Necessary to check access before rendering the Jinjafied query as the
# some Jinja macros execute statements upon rendering.
self._validate_access(query)
self._execution_context.set_query(query) self._execution_context.set_query(query)
rendered_query = self._sql_query_render.render(self._execution_context) rendered_query = self._sql_query_render.render(self._execution_context)
validate_rendered_query = copy.copy(query) validate_rendered_query = copy.copy(query)
validate_rendered_query.sql = rendered_query validate_rendered_query.sql = rendered_query
self._validate_access(validate_rendered_query)
self._set_query_limit_if_required(rendered_query) self._set_query_limit_if_required(rendered_query)
self._query_dao.update( self._query_dao.update(
query, {"limit": self._execution_context.query.limit} query, {"limit": self._execution_context.query.limit}

View File

@ -24,7 +24,7 @@ from typing import Any, Callable, cast, Optional, TYPE_CHECKING, TypedDict, Unio
import dateutil import dateutil
from flask import current_app, g, has_request_context, request from flask import current_app, g, has_request_context, request
from flask_babel import gettext as _ from flask_babel import gettext as _
from jinja2 import DebugUndefined from jinja2 import DebugUndefined, Environment
from jinja2.sandbox import SandboxedEnvironment from jinja2.sandbox import SandboxedEnvironment
from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.sql.expression import bindparam from sqlalchemy.sql.expression import bindparam
@ -462,11 +462,11 @@ class BaseTemplateProcessor:
self._applied_filters = applied_filters self._applied_filters = applied_filters
self._removed_filters = removed_filters self._removed_filters = removed_filters
self._context: dict[str, Any] = {} self._context: dict[str, Any] = {}
self._env = SandboxedEnvironment(undefined=DebugUndefined) self.env: Environment = SandboxedEnvironment(undefined=DebugUndefined)
self.set_context(**kwargs) self.set_context(**kwargs)
# custom filters # custom filters
self._env.filters["where_in"] = WhereInMacro(database.get_dialect()) self.env.filters["where_in"] = WhereInMacro(database.get_dialect())
def set_context(self, **kwargs: Any) -> None: def set_context(self, **kwargs: Any) -> None:
self._context.update(kwargs) self._context.update(kwargs)
@ -479,7 +479,7 @@ class BaseTemplateProcessor:
>>> process_template(sql) >>> process_template(sql)
"SELECT '2017-01-01T00:00:00'" "SELECT '2017-01-01T00:00:00'"
""" """
template = self._env.from_string(sql) template = self.env.from_string(sql)
kwargs.update(self._context) kwargs.update(self._context)
context = validate_template_context(self.engine, kwargs) context = validate_template_context(self.engine, kwargs)
@ -623,7 +623,7 @@ class TrinoTemplateProcessor(PrestoTemplateProcessor):
engine = "trino" engine = "trino"
def process_template(self, sql: str, **kwargs: Any) -> str: def process_template(self, sql: str, **kwargs: Any) -> str:
template = self._env.from_string(sql) template = self.env.from_string(sql)
kwargs.update(self._context) kwargs.update(self._context)
# Backwards compatibility if migrating from Presto. # Backwards compatibility if migrating from Presto.

View File

@ -46,6 +46,7 @@ from sqlalchemy.orm import backref, relationship
from sqlalchemy.sql.elements import ColumnElement, literal_column from sqlalchemy.sql.elements import ColumnElement, literal_column
from superset import security_manager from superset import security_manager
from superset.exceptions import SupersetSecurityException
from superset.jinja_context import BaseTemplateProcessor, get_template_processor from superset.jinja_context import BaseTemplateProcessor, get_template_processor
from superset.models.helpers import ( from superset.models.helpers import (
AuditMixinNullable, AuditMixinNullable,
@ -53,7 +54,7 @@ from superset.models.helpers import (
ExtraJSONMixin, ExtraJSONMixin,
ImportExportMixin, ImportExportMixin,
) )
from superset.sql_parse import CtasMethod, ParsedQuery, Table from superset.sql_parse import CtasMethod, extract_tables_from_jinja_sql, Table
from superset.sqllab.limiting_factor import LimitingFactor from superset.sqllab.limiting_factor import LimitingFactor
from superset.utils.core import get_column_name, MediumText, QueryStatus, user_label from superset.utils.core import get_column_name, MediumText, QueryStatus, user_label
@ -65,8 +66,25 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class SqlTablesMixin: # pylint: disable=too-few-public-methods
@property
def sql_tables(self) -> list[Table]:
try:
return list(
extract_tables_from_jinja_sql(
self.sql, # type: ignore
self.database.db_engine_spec.engine, # type: ignore
)
)
except SupersetSecurityException:
return []
class Query( class Query(
ExtraJSONMixin, ExploreMixin, Model SqlTablesMixin,
ExtraJSONMixin,
ExploreMixin,
Model,
): # pylint: disable=abstract-method,too-many-public-methods ): # pylint: disable=abstract-method,too-many-public-methods
"""ORM model for SQL query """ORM model for SQL query
@ -181,10 +199,6 @@ class Query(
def username(self) -> str: def username(self) -> str:
return self.user.username return self.user.username
@property
def sql_tables(self) -> list[Table]:
return list(ParsedQuery(self.sql, engine=self.db_engine_spec.engine).tables)
@property @property
def columns(self) -> list["TableColumn"]: def columns(self) -> list["TableColumn"]:
from superset.connectors.sqla.models import ( # pylint: disable=import-outside-toplevel from superset.connectors.sqla.models import ( # pylint: disable=import-outside-toplevel
@ -355,7 +369,13 @@ class Query(
return self.make_sqla_column_compatible(sqla_column, label) return self.make_sqla_column_compatible(sqla_column, label)
class SavedQuery(AuditMixinNullable, ExtraJSONMixin, ImportExportMixin, Model): class SavedQuery(
SqlTablesMixin,
AuditMixinNullable,
ExtraJSONMixin,
ImportExportMixin,
Model,
):
"""ORM model for SQL query""" """ORM model for SQL query"""
__tablename__ = "saved_query" __tablename__ = "saved_query"
@ -425,12 +445,6 @@ class SavedQuery(AuditMixinNullable, ExtraJSONMixin, ImportExportMixin, Model):
def url(self) -> str: def url(self) -> str:
return f"/sqllab?savedQueryId={self.id}" return f"/sqllab?savedQueryId={self.id}"
@property
def sql_tables(self) -> list[Table]:
return list(
ParsedQuery(self.sql, engine=self.database.db_engine_spec.engine).tables
)
@property @property
def last_run_humanized(self) -> str: def last_run_humanized(self) -> str:
return naturaltime(datetime.now() - self.changed_on) return naturaltime(datetime.now() - self.changed_on)

View File

@ -52,14 +52,12 @@ from sqlalchemy.orm import eagerload
from sqlalchemy.orm.mapper import Mapper from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.orm.query import Query as SqlaQuery from sqlalchemy.orm.query import Query as SqlaQuery
from superset import sql_parse
from superset.constants import RouteMethod from superset.constants import RouteMethod
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import ( from superset.exceptions import (
DatasetInvalidPermissionEvaluationException, DatasetInvalidPermissionEvaluationException,
SupersetSecurityException, SupersetSecurityException,
) )
from superset.jinja_context import get_template_processor
from superset.security.guest_token import ( from superset.security.guest_token import (
GuestToken, GuestToken,
GuestTokenResources, GuestTokenResources,
@ -68,6 +66,7 @@ from superset.security.guest_token import (
GuestTokenUser, GuestTokenUser,
GuestUser, GuestUser,
) )
from superset.sql_parse import extract_tables_from_jinja_sql
from superset.superset_typing import Metric from superset.superset_typing import Metric
from superset.utils.core import ( from superset.utils.core import (
DatasourceName, DatasourceName,
@ -1961,16 +1960,12 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
return return
if query: if query:
# make sure the quuery is valid SQL by rendering any Jinja
processor = get_template_processor(database=query.database)
rendered_sql = processor.process_template(query.sql)
default_schema = database.get_default_schema_for_query(query) default_schema = database.get_default_schema_for_query(query)
tables = { tables = {
Table(table_.table, table_.schema or default_schema) Table(table_.table, table_.schema or default_schema)
for table_ in sql_parse.ParsedQuery( for table_ in extract_tables_from_jinja_sql(
rendered_sql, query.sql, database.db_engine_spec.engine
engine=database.db_engine_spec.engine, )
).tables
} }
elif table: elif table:
tables = {table} tables = {table}

View File

@ -16,16 +16,19 @@
# under the License. # under the License.
# pylint: disable=too-many-lines # pylint: disable=too-many-lines
from __future__ import annotations
import logging import logging
import re import re
import urllib.parse import urllib.parse
from collections.abc import Iterable, Iterator from collections.abc import Iterable, Iterator
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, cast, Optional from typing import Any, cast
from unittest.mock import Mock
import sqlparse import sqlparse
from flask_babel import gettext as __ from flask_babel import gettext as __
from jinja2 import nodes
from sqlalchemy import and_ from sqlalchemy import and_
from sqlglot import exp, parse, parse_one from sqlglot import exp, parse, parse_one
from sqlglot.dialects import Dialects from sqlglot.dialects import Dialects
@ -142,7 +145,7 @@ class CtasMethod(StrEnum):
VIEW = "VIEW" VIEW = "VIEW"
def _extract_limit_from_query(statement: TokenList) -> Optional[int]: def _extract_limit_from_query(statement: TokenList) -> int | None:
""" """
Extract limit clause from SQL statement. Extract limit clause from SQL statement.
@ -163,9 +166,7 @@ def _extract_limit_from_query(statement: TokenList) -> Optional[int]:
return None return None
def extract_top_from_query( def extract_top_from_query(statement: TokenList, top_keywords: set[str]) -> int | None:
statement: TokenList, top_keywords: set[str]
) -> Optional[int]:
""" """
Extract top clause value from SQL statement. Extract top clause value from SQL statement.
@ -189,7 +190,7 @@ def extract_top_from_query(
return top return top
def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]: def get_cte_remainder_query(sql: str) -> tuple[str | None, str]:
""" """
parse the SQL and return the CTE and rest of the block to the caller parse the SQL and return the CTE and rest of the block to the caller
@ -197,7 +198,7 @@ def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
:return: CTE and remainder block to the caller :return: CTE and remainder block to the caller
""" """
cte: Optional[str] = None cte: str | None = None
remainder = sql remainder = sql
stmt = sqlparse.parse(sql)[0] stmt = sqlparse.parse(sql)[0]
@ -215,7 +216,7 @@ def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
return cte, remainder return cte, remainder
def strip_comments_from_sql(statement: str, engine: Optional[str] = None) -> str: def strip_comments_from_sql(statement: str, engine: str | None = None) -> str:
""" """
Strips comments from a SQL statement, does a simple test first Strips comments from a SQL statement, does a simple test first
to avoid always instantiating the expensive ParsedQuery constructor to avoid always instantiating the expensive ParsedQuery constructor
@ -239,8 +240,8 @@ class Table:
""" """
table: str table: str
schema: Optional[str] = None schema: str | None = None
catalog: Optional[str] = None catalog: str | None = None
def __str__(self) -> str: def __str__(self) -> str:
""" """
@ -262,7 +263,7 @@ class ParsedQuery:
self, self,
sql_statement: str, sql_statement: str,
strip_comments: bool = False, strip_comments: bool = False,
engine: Optional[str] = None, engine: str | None = None,
): ):
if strip_comments: if strip_comments:
sql_statement = sqlparse.format(sql_statement, strip_comments=True) sql_statement = sqlparse.format(sql_statement, strip_comments=True)
@ -271,7 +272,7 @@ class ParsedQuery:
self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
self._tables: set[Table] = set() self._tables: set[Table] = set()
self._alias_names: set[str] = set() self._alias_names: set[str] = set()
self._limit: Optional[int] = None self._limit: int | None = None
logger.debug("Parsing with sqlparse statement: %s", self.sql) logger.debug("Parsing with sqlparse statement: %s", self.sql)
self._parsed = sqlparse.parse(self.stripped()) self._parsed = sqlparse.parse(self.stripped())
@ -382,7 +383,7 @@ class ParsedQuery:
return source.name in ctes_in_scope return source.name in ctes_in_scope
@property @property
def limit(self) -> Optional[int]: def limit(self) -> int | None:
return self._limit return self._limit
def _get_cte_tables(self, parsed: dict[str, Any]) -> list[dict[str, Any]]: def _get_cte_tables(self, parsed: dict[str, Any]) -> list[dict[str, Any]]:
@ -463,7 +464,7 @@ class ParsedQuery:
return True return True
def get_inner_cte_expression(self, tokens: TokenList) -> Optional[TokenList]: def get_inner_cte_expression(self, tokens: TokenList) -> TokenList | None:
for token in tokens: for token in tokens:
if self._is_identifier(token): if self._is_identifier(token):
for identifier_token in token.tokens: for identifier_token in token.tokens:
@ -527,7 +528,7 @@ class ParsedQuery:
return statements return statements
@staticmethod @staticmethod
def get_table(tlist: TokenList) -> Optional[Table]: def get_table(tlist: TokenList) -> Table | None:
""" """
Return the table if valid, i.e., conforms to the [[catalog.]schema.]table Return the table if valid, i.e., conforms to the [[catalog.]schema.]table
construct. construct.
@ -563,7 +564,7 @@ class ParsedQuery:
def as_create_table( def as_create_table(
self, self,
table_name: str, table_name: str,
schema_name: Optional[str] = None, schema_name: str | None = None,
overwrite: bool = False, overwrite: bool = False,
method: CtasMethod = CtasMethod.TABLE, method: CtasMethod = CtasMethod.TABLE,
) -> str: ) -> str:
@ -723,8 +724,8 @@ def add_table_name(rls: TokenList, table: str) -> None:
def get_rls_for_table( def get_rls_for_table(
candidate: Token, candidate: Token,
database_id: int, database_id: int,
default_schema: Optional[str], default_schema: str | None,
) -> Optional[TokenList]: ) -> TokenList | None:
""" """
Given a table name, return any associated RLS predicates. Given a table name, return any associated RLS predicates.
""" """
@ -770,7 +771,7 @@ def get_rls_for_table(
def insert_rls_as_subquery( def insert_rls_as_subquery(
token_list: TokenList, token_list: TokenList,
database_id: int, database_id: int,
default_schema: Optional[str], default_schema: str | None,
) -> TokenList: ) -> TokenList:
""" """
Update a statement inplace applying any associated RLS predicates. Update a statement inplace applying any associated RLS predicates.
@ -786,7 +787,7 @@ def insert_rls_as_subquery(
This method is safer than ``insert_rls_in_predicate``, but doesn't work in all This method is safer than ``insert_rls_in_predicate``, but doesn't work in all
databases. databases.
""" """
rls: Optional[TokenList] = None rls: TokenList | None = None
state = InsertRLSState.SCANNING state = InsertRLSState.SCANNING
for token in token_list.tokens: for token in token_list.tokens:
# Recurse into child token list # Recurse into child token list
@ -862,7 +863,7 @@ def insert_rls_as_subquery(
def insert_rls_in_predicate( def insert_rls_in_predicate(
token_list: TokenList, token_list: TokenList,
database_id: int, database_id: int,
default_schema: Optional[str], default_schema: str | None,
) -> TokenList: ) -> TokenList:
""" """
Update a statement inplace applying any associated RLS predicates. Update a statement inplace applying any associated RLS predicates.
@ -873,7 +874,7 @@ def insert_rls_in_predicate(
after: SELECT * FROM some_table WHERE ( 1=1) AND some_table.id=42 after: SELECT * FROM some_table WHERE ( 1=1) AND some_table.id=42
""" """
rls: Optional[TokenList] = None rls: TokenList | None = None
state = InsertRLSState.SCANNING state = InsertRLSState.SCANNING
for token in token_list.tokens: for token in token_list.tokens:
# Recurse into child token list # Recurse into child token list
@ -1007,7 +1008,7 @@ RE_JINJA_BLOCK = re.compile(r"\{[%#][^\{\}%#]+[%#]\}")
def extract_table_references( def extract_table_references(
sql_text: str, sqla_dialect: str, show_warning: bool = True sql_text: str, sqla_dialect: str, show_warning: bool = True
) -> set["Table"]: ) -> set[Table]:
""" """
Return all the dependencies from a SQL sql_text. Return all the dependencies from a SQL sql_text.
""" """
@ -1051,3 +1052,61 @@ def extract_table_references(
Table(*[part["value"] for part in table["name"][::-1]]) Table(*[part["value"] for part in table["name"][::-1]])
for table in find_nodes_by_key(tree, "Table") for table in find_nodes_by_key(tree, "Table")
} }
def extract_tables_from_jinja_sql(sql: str, engine: str | None = None) -> set[Table]:
"""
Extract all table references in the Jinjafied SQL statement.
Due to Jinja templating, a multiphase approach is necessary as the Jinjafied SQL
statement may represent invalid SQL which is non-parsable by SQLGlot.
Firstly, we extract any tables referenced within the confines of specific Jinja
macros. Secondly, we replace these non-SQL Jinja calls with a pseudo-benign SQL
expression to help ensure that the resulting SQL statements are parsable by
SQLGlot.
:param sql: The Jinjafied SQL statement
:param engine: The associated database engine
:returns: The set of tables referenced in the SQL statement
:raises SupersetSecurityException: If SQLGlot is unable to parse the SQL statement
"""
from superset.jinja_context import ( # pylint: disable=import-outside-toplevel
get_template_processor,
)
# Mock the required database as the processor signature is exposed publically.
processor = get_template_processor(database=Mock(backend=engine))
template = processor.env.parse(sql)
tables = set()
for node in template.find_all(nodes.Call):
if isinstance(node.node, nodes.Getattr) and node.node.attr in (
"latest_partition",
"latest_sub_partition",
):
# Extract the table referenced in the macro.
tables.add(
Table(
*[
remove_quotes(part)
for part in node.args[0].value.split(".")[::-1]
if len(node.args) == 1
]
)
)
# Replace the potentially problematic Jinja macro with some benign SQL.
node.__class__ = nodes.TemplateData
node.fields = nodes.TemplateData.fields
node.data = "NULL"
return (
tables
| ParsedQuery(
sql_statement=processor.process_template(template),
engine=engine,
).tables
)

View File

@ -79,8 +79,7 @@ class SqlQueryRenderImpl(SqlQueryRender):
sql_template_processor: BaseTemplateProcessor, sql_template_processor: BaseTemplateProcessor,
) -> None: ) -> None:
if is_feature_enabled("ENABLE_TEMPLATE_PROCESSING"): if is_feature_enabled("ENABLE_TEMPLATE_PROCESSING"):
# pylint: disable=protected-access syntax_tree = sql_template_processor.env.parse(rendered_query)
syntax_tree = sql_template_processor._env.parse(rendered_query)
undefined_parameters = find_undeclared_variables(syntax_tree) undefined_parameters = find_undeclared_variables(syntax_tree)
if undefined_parameters: if undefined_parameters:
self._raise_undefined_parameter_exception( self._raise_undefined_parameter_exception(

View File

@ -32,6 +32,7 @@ from superset.exceptions import (
from superset.sql_parse import ( from superset.sql_parse import (
add_table_name, add_table_name,
extract_table_references, extract_table_references,
extract_tables_from_jinja_sql,
get_rls_for_table, get_rls_for_table,
has_table_query, has_table_query,
insert_rls_as_subquery, insert_rls_as_subquery,
@ -1874,3 +1875,43 @@ WITH t AS (
) )
SELECT * FROM t""" SELECT * FROM t"""
).is_select() ).is_select()
@pytest.mark.parametrize(
"engine",
[
"hive",
"presto",
"trino",
],
)
@pytest.mark.parametrize(
"macro",
[
"latest_partition('foo.bar')",
"latest_sub_partition('foo.bar', baz='qux')",
],
)
@pytest.mark.parametrize(
"sql,expected",
[
(
"SELECT '{{{{ {engine}.{macro} }}}}'",
{Table(table="bar", schema="foo")},
),
(
"SELECT * FROM foo.baz WHERE quux = '{{{{ {engine}.{macro} }}}}'",
{Table(table="bar", schema="foo"), Table(table="baz", schema="foo")},
),
],
)
def test_extract_tables_from_jinja_sql(
engine: str,
macro: str,
sql: str,
expected: set[Table],
) -> None:
assert (
extract_tables_from_jinja_sql(sql.format(engine=engine, macro=macro), engine)
== expected
)