mirror of
https://github.com/apache/superset.git
synced 2024-09-17 19:19:38 -04:00
fix(sql_parse): Ensure table extraction handles Jinja templating (#27470)
This commit is contained in:
parent
4ff331a66c
commit
7c14968e6d
@ -144,11 +144,13 @@ class ExecuteSqlCommand(BaseCommand):
|
|||||||
try:
|
try:
|
||||||
logger.info("Triggering query_id: %i", query.id)
|
logger.info("Triggering query_id: %i", query.id)
|
||||||
|
|
||||||
|
# Necessary to check access before rendering the Jinjafied query as the
|
||||||
|
# some Jinja macros execute statements upon rendering.
|
||||||
|
self._validate_access(query)
|
||||||
self._execution_context.set_query(query)
|
self._execution_context.set_query(query)
|
||||||
rendered_query = self._sql_query_render.render(self._execution_context)
|
rendered_query = self._sql_query_render.render(self._execution_context)
|
||||||
validate_rendered_query = copy.copy(query)
|
validate_rendered_query = copy.copy(query)
|
||||||
validate_rendered_query.sql = rendered_query
|
validate_rendered_query.sql = rendered_query
|
||||||
self._validate_access(validate_rendered_query)
|
|
||||||
self._set_query_limit_if_required(rendered_query)
|
self._set_query_limit_if_required(rendered_query)
|
||||||
self._query_dao.update(
|
self._query_dao.update(
|
||||||
query, {"limit": self._execution_context.query.limit}
|
query, {"limit": self._execution_context.query.limit}
|
||||||
|
@ -24,7 +24,7 @@ from typing import Any, Callable, cast, Optional, TYPE_CHECKING, TypedDict, Unio
|
|||||||
import dateutil
|
import dateutil
|
||||||
from flask import current_app, g, has_request_context, request
|
from flask import current_app, g, has_request_context, request
|
||||||
from flask_babel import gettext as _
|
from flask_babel import gettext as _
|
||||||
from jinja2 import DebugUndefined
|
from jinja2 import DebugUndefined, Environment
|
||||||
from jinja2.sandbox import SandboxedEnvironment
|
from jinja2.sandbox import SandboxedEnvironment
|
||||||
from sqlalchemy.engine.interfaces import Dialect
|
from sqlalchemy.engine.interfaces import Dialect
|
||||||
from sqlalchemy.sql.expression import bindparam
|
from sqlalchemy.sql.expression import bindparam
|
||||||
@ -462,11 +462,11 @@ class BaseTemplateProcessor:
|
|||||||
self._applied_filters = applied_filters
|
self._applied_filters = applied_filters
|
||||||
self._removed_filters = removed_filters
|
self._removed_filters = removed_filters
|
||||||
self._context: dict[str, Any] = {}
|
self._context: dict[str, Any] = {}
|
||||||
self._env = SandboxedEnvironment(undefined=DebugUndefined)
|
self.env: Environment = SandboxedEnvironment(undefined=DebugUndefined)
|
||||||
self.set_context(**kwargs)
|
self.set_context(**kwargs)
|
||||||
|
|
||||||
# custom filters
|
# custom filters
|
||||||
self._env.filters["where_in"] = WhereInMacro(database.get_dialect())
|
self.env.filters["where_in"] = WhereInMacro(database.get_dialect())
|
||||||
|
|
||||||
def set_context(self, **kwargs: Any) -> None:
|
def set_context(self, **kwargs: Any) -> None:
|
||||||
self._context.update(kwargs)
|
self._context.update(kwargs)
|
||||||
@ -479,7 +479,7 @@ class BaseTemplateProcessor:
|
|||||||
>>> process_template(sql)
|
>>> process_template(sql)
|
||||||
"SELECT '2017-01-01T00:00:00'"
|
"SELECT '2017-01-01T00:00:00'"
|
||||||
"""
|
"""
|
||||||
template = self._env.from_string(sql)
|
template = self.env.from_string(sql)
|
||||||
kwargs.update(self._context)
|
kwargs.update(self._context)
|
||||||
|
|
||||||
context = validate_template_context(self.engine, kwargs)
|
context = validate_template_context(self.engine, kwargs)
|
||||||
@ -623,7 +623,7 @@ class TrinoTemplateProcessor(PrestoTemplateProcessor):
|
|||||||
engine = "trino"
|
engine = "trino"
|
||||||
|
|
||||||
def process_template(self, sql: str, **kwargs: Any) -> str:
|
def process_template(self, sql: str, **kwargs: Any) -> str:
|
||||||
template = self._env.from_string(sql)
|
template = self.env.from_string(sql)
|
||||||
kwargs.update(self._context)
|
kwargs.update(self._context)
|
||||||
|
|
||||||
# Backwards compatibility if migrating from Presto.
|
# Backwards compatibility if migrating from Presto.
|
||||||
|
@ -46,6 +46,7 @@ from sqlalchemy.orm import backref, relationship
|
|||||||
from sqlalchemy.sql.elements import ColumnElement, literal_column
|
from sqlalchemy.sql.elements import ColumnElement, literal_column
|
||||||
|
|
||||||
from superset import security_manager
|
from superset import security_manager
|
||||||
|
from superset.exceptions import SupersetSecurityException
|
||||||
from superset.jinja_context import BaseTemplateProcessor, get_template_processor
|
from superset.jinja_context import BaseTemplateProcessor, get_template_processor
|
||||||
from superset.models.helpers import (
|
from superset.models.helpers import (
|
||||||
AuditMixinNullable,
|
AuditMixinNullable,
|
||||||
@ -53,7 +54,7 @@ from superset.models.helpers import (
|
|||||||
ExtraJSONMixin,
|
ExtraJSONMixin,
|
||||||
ImportExportMixin,
|
ImportExportMixin,
|
||||||
)
|
)
|
||||||
from superset.sql_parse import CtasMethod, ParsedQuery, Table
|
from superset.sql_parse import CtasMethod, extract_tables_from_jinja_sql, Table
|
||||||
from superset.sqllab.limiting_factor import LimitingFactor
|
from superset.sqllab.limiting_factor import LimitingFactor
|
||||||
from superset.utils.core import get_column_name, MediumText, QueryStatus, user_label
|
from superset.utils.core import get_column_name, MediumText, QueryStatus, user_label
|
||||||
|
|
||||||
@ -65,8 +66,25 @@ if TYPE_CHECKING:
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SqlTablesMixin: # pylint: disable=too-few-public-methods
|
||||||
|
@property
|
||||||
|
def sql_tables(self) -> list[Table]:
|
||||||
|
try:
|
||||||
|
return list(
|
||||||
|
extract_tables_from_jinja_sql(
|
||||||
|
self.sql, # type: ignore
|
||||||
|
self.database.db_engine_spec.engine, # type: ignore
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except SupersetSecurityException:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
class Query(
|
class Query(
|
||||||
ExtraJSONMixin, ExploreMixin, Model
|
SqlTablesMixin,
|
||||||
|
ExtraJSONMixin,
|
||||||
|
ExploreMixin,
|
||||||
|
Model,
|
||||||
): # pylint: disable=abstract-method,too-many-public-methods
|
): # pylint: disable=abstract-method,too-many-public-methods
|
||||||
"""ORM model for SQL query
|
"""ORM model for SQL query
|
||||||
|
|
||||||
@ -181,10 +199,6 @@ class Query(
|
|||||||
def username(self) -> str:
|
def username(self) -> str:
|
||||||
return self.user.username
|
return self.user.username
|
||||||
|
|
||||||
@property
|
|
||||||
def sql_tables(self) -> list[Table]:
|
|
||||||
return list(ParsedQuery(self.sql, engine=self.db_engine_spec.engine).tables)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def columns(self) -> list["TableColumn"]:
|
def columns(self) -> list["TableColumn"]:
|
||||||
from superset.connectors.sqla.models import ( # pylint: disable=import-outside-toplevel
|
from superset.connectors.sqla.models import ( # pylint: disable=import-outside-toplevel
|
||||||
@ -355,7 +369,13 @@ class Query(
|
|||||||
return self.make_sqla_column_compatible(sqla_column, label)
|
return self.make_sqla_column_compatible(sqla_column, label)
|
||||||
|
|
||||||
|
|
||||||
class SavedQuery(AuditMixinNullable, ExtraJSONMixin, ImportExportMixin, Model):
|
class SavedQuery(
|
||||||
|
SqlTablesMixin,
|
||||||
|
AuditMixinNullable,
|
||||||
|
ExtraJSONMixin,
|
||||||
|
ImportExportMixin,
|
||||||
|
Model,
|
||||||
|
):
|
||||||
"""ORM model for SQL query"""
|
"""ORM model for SQL query"""
|
||||||
|
|
||||||
__tablename__ = "saved_query"
|
__tablename__ = "saved_query"
|
||||||
@ -425,12 +445,6 @@ class SavedQuery(AuditMixinNullable, ExtraJSONMixin, ImportExportMixin, Model):
|
|||||||
def url(self) -> str:
|
def url(self) -> str:
|
||||||
return f"/sqllab?savedQueryId={self.id}"
|
return f"/sqllab?savedQueryId={self.id}"
|
||||||
|
|
||||||
@property
|
|
||||||
def sql_tables(self) -> list[Table]:
|
|
||||||
return list(
|
|
||||||
ParsedQuery(self.sql, engine=self.database.db_engine_spec.engine).tables
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def last_run_humanized(self) -> str:
|
def last_run_humanized(self) -> str:
|
||||||
return naturaltime(datetime.now() - self.changed_on)
|
return naturaltime(datetime.now() - self.changed_on)
|
||||||
|
@ -52,14 +52,12 @@ from sqlalchemy.orm import eagerload
|
|||||||
from sqlalchemy.orm.mapper import Mapper
|
from sqlalchemy.orm.mapper import Mapper
|
||||||
from sqlalchemy.orm.query import Query as SqlaQuery
|
from sqlalchemy.orm.query import Query as SqlaQuery
|
||||||
|
|
||||||
from superset import sql_parse
|
|
||||||
from superset.constants import RouteMethod
|
from superset.constants import RouteMethod
|
||||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||||
from superset.exceptions import (
|
from superset.exceptions import (
|
||||||
DatasetInvalidPermissionEvaluationException,
|
DatasetInvalidPermissionEvaluationException,
|
||||||
SupersetSecurityException,
|
SupersetSecurityException,
|
||||||
)
|
)
|
||||||
from superset.jinja_context import get_template_processor
|
|
||||||
from superset.security.guest_token import (
|
from superset.security.guest_token import (
|
||||||
GuestToken,
|
GuestToken,
|
||||||
GuestTokenResources,
|
GuestTokenResources,
|
||||||
@ -68,6 +66,7 @@ from superset.security.guest_token import (
|
|||||||
GuestTokenUser,
|
GuestTokenUser,
|
||||||
GuestUser,
|
GuestUser,
|
||||||
)
|
)
|
||||||
|
from superset.sql_parse import extract_tables_from_jinja_sql
|
||||||
from superset.superset_typing import Metric
|
from superset.superset_typing import Metric
|
||||||
from superset.utils.core import (
|
from superset.utils.core import (
|
||||||
DatasourceName,
|
DatasourceName,
|
||||||
@ -1961,16 +1960,12 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||||||
return
|
return
|
||||||
|
|
||||||
if query:
|
if query:
|
||||||
# make sure the quuery is valid SQL by rendering any Jinja
|
|
||||||
processor = get_template_processor(database=query.database)
|
|
||||||
rendered_sql = processor.process_template(query.sql)
|
|
||||||
default_schema = database.get_default_schema_for_query(query)
|
default_schema = database.get_default_schema_for_query(query)
|
||||||
tables = {
|
tables = {
|
||||||
Table(table_.table, table_.schema or default_schema)
|
Table(table_.table, table_.schema or default_schema)
|
||||||
for table_ in sql_parse.ParsedQuery(
|
for table_ in extract_tables_from_jinja_sql(
|
||||||
rendered_sql,
|
query.sql, database.db_engine_spec.engine
|
||||||
engine=database.db_engine_spec.engine,
|
)
|
||||||
).tables
|
|
||||||
}
|
}
|
||||||
elif table:
|
elif table:
|
||||||
tables = {table}
|
tables = {table}
|
||||||
|
@ -16,16 +16,19 @@
|
|||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
# pylint: disable=too-many-lines
|
# pylint: disable=too-many-lines
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from collections.abc import Iterable, Iterator
|
from collections.abc import Iterable, Iterator
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, cast, Optional
|
from typing import Any, cast
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
import sqlparse
|
import sqlparse
|
||||||
from flask_babel import gettext as __
|
from flask_babel import gettext as __
|
||||||
|
from jinja2 import nodes
|
||||||
from sqlalchemy import and_
|
from sqlalchemy import and_
|
||||||
from sqlglot import exp, parse, parse_one
|
from sqlglot import exp, parse, parse_one
|
||||||
from sqlglot.dialects import Dialects
|
from sqlglot.dialects import Dialects
|
||||||
@ -142,7 +145,7 @@ class CtasMethod(StrEnum):
|
|||||||
VIEW = "VIEW"
|
VIEW = "VIEW"
|
||||||
|
|
||||||
|
|
||||||
def _extract_limit_from_query(statement: TokenList) -> Optional[int]:
|
def _extract_limit_from_query(statement: TokenList) -> int | None:
|
||||||
"""
|
"""
|
||||||
Extract limit clause from SQL statement.
|
Extract limit clause from SQL statement.
|
||||||
|
|
||||||
@ -163,9 +166,7 @@ def _extract_limit_from_query(statement: TokenList) -> Optional[int]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def extract_top_from_query(
|
def extract_top_from_query(statement: TokenList, top_keywords: set[str]) -> int | None:
|
||||||
statement: TokenList, top_keywords: set[str]
|
|
||||||
) -> Optional[int]:
|
|
||||||
"""
|
"""
|
||||||
Extract top clause value from SQL statement.
|
Extract top clause value from SQL statement.
|
||||||
|
|
||||||
@ -189,7 +190,7 @@ def extract_top_from_query(
|
|||||||
return top
|
return top
|
||||||
|
|
||||||
|
|
||||||
def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
|
def get_cte_remainder_query(sql: str) -> tuple[str | None, str]:
|
||||||
"""
|
"""
|
||||||
parse the SQL and return the CTE and rest of the block to the caller
|
parse the SQL and return the CTE and rest of the block to the caller
|
||||||
|
|
||||||
@ -197,7 +198,7 @@ def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
|
|||||||
:return: CTE and remainder block to the caller
|
:return: CTE and remainder block to the caller
|
||||||
|
|
||||||
"""
|
"""
|
||||||
cte: Optional[str] = None
|
cte: str | None = None
|
||||||
remainder = sql
|
remainder = sql
|
||||||
stmt = sqlparse.parse(sql)[0]
|
stmt = sqlparse.parse(sql)[0]
|
||||||
|
|
||||||
@ -215,7 +216,7 @@ def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
|
|||||||
return cte, remainder
|
return cte, remainder
|
||||||
|
|
||||||
|
|
||||||
def strip_comments_from_sql(statement: str, engine: Optional[str] = None) -> str:
|
def strip_comments_from_sql(statement: str, engine: str | None = None) -> str:
|
||||||
"""
|
"""
|
||||||
Strips comments from a SQL statement, does a simple test first
|
Strips comments from a SQL statement, does a simple test first
|
||||||
to avoid always instantiating the expensive ParsedQuery constructor
|
to avoid always instantiating the expensive ParsedQuery constructor
|
||||||
@ -239,8 +240,8 @@ class Table:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
table: str
|
table: str
|
||||||
schema: Optional[str] = None
|
schema: str | None = None
|
||||||
catalog: Optional[str] = None
|
catalog: str | None = None
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
"""
|
"""
|
||||||
@ -262,7 +263,7 @@ class ParsedQuery:
|
|||||||
self,
|
self,
|
||||||
sql_statement: str,
|
sql_statement: str,
|
||||||
strip_comments: bool = False,
|
strip_comments: bool = False,
|
||||||
engine: Optional[str] = None,
|
engine: str | None = None,
|
||||||
):
|
):
|
||||||
if strip_comments:
|
if strip_comments:
|
||||||
sql_statement = sqlparse.format(sql_statement, strip_comments=True)
|
sql_statement = sqlparse.format(sql_statement, strip_comments=True)
|
||||||
@ -271,7 +272,7 @@ class ParsedQuery:
|
|||||||
self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
|
self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
|
||||||
self._tables: set[Table] = set()
|
self._tables: set[Table] = set()
|
||||||
self._alias_names: set[str] = set()
|
self._alias_names: set[str] = set()
|
||||||
self._limit: Optional[int] = None
|
self._limit: int | None = None
|
||||||
|
|
||||||
logger.debug("Parsing with sqlparse statement: %s", self.sql)
|
logger.debug("Parsing with sqlparse statement: %s", self.sql)
|
||||||
self._parsed = sqlparse.parse(self.stripped())
|
self._parsed = sqlparse.parse(self.stripped())
|
||||||
@ -382,7 +383,7 @@ class ParsedQuery:
|
|||||||
return source.name in ctes_in_scope
|
return source.name in ctes_in_scope
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def limit(self) -> Optional[int]:
|
def limit(self) -> int | None:
|
||||||
return self._limit
|
return self._limit
|
||||||
|
|
||||||
def _get_cte_tables(self, parsed: dict[str, Any]) -> list[dict[str, Any]]:
|
def _get_cte_tables(self, parsed: dict[str, Any]) -> list[dict[str, Any]]:
|
||||||
@ -463,7 +464,7 @@ class ParsedQuery:
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def get_inner_cte_expression(self, tokens: TokenList) -> Optional[TokenList]:
|
def get_inner_cte_expression(self, tokens: TokenList) -> TokenList | None:
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
if self._is_identifier(token):
|
if self._is_identifier(token):
|
||||||
for identifier_token in token.tokens:
|
for identifier_token in token.tokens:
|
||||||
@ -527,7 +528,7 @@ class ParsedQuery:
|
|||||||
return statements
|
return statements
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_table(tlist: TokenList) -> Optional[Table]:
|
def get_table(tlist: TokenList) -> Table | None:
|
||||||
"""
|
"""
|
||||||
Return the table if valid, i.e., conforms to the [[catalog.]schema.]table
|
Return the table if valid, i.e., conforms to the [[catalog.]schema.]table
|
||||||
construct.
|
construct.
|
||||||
@ -563,7 +564,7 @@ class ParsedQuery:
|
|||||||
def as_create_table(
|
def as_create_table(
|
||||||
self,
|
self,
|
||||||
table_name: str,
|
table_name: str,
|
||||||
schema_name: Optional[str] = None,
|
schema_name: str | None = None,
|
||||||
overwrite: bool = False,
|
overwrite: bool = False,
|
||||||
method: CtasMethod = CtasMethod.TABLE,
|
method: CtasMethod = CtasMethod.TABLE,
|
||||||
) -> str:
|
) -> str:
|
||||||
@ -723,8 +724,8 @@ def add_table_name(rls: TokenList, table: str) -> None:
|
|||||||
def get_rls_for_table(
|
def get_rls_for_table(
|
||||||
candidate: Token,
|
candidate: Token,
|
||||||
database_id: int,
|
database_id: int,
|
||||||
default_schema: Optional[str],
|
default_schema: str | None,
|
||||||
) -> Optional[TokenList]:
|
) -> TokenList | None:
|
||||||
"""
|
"""
|
||||||
Given a table name, return any associated RLS predicates.
|
Given a table name, return any associated RLS predicates.
|
||||||
"""
|
"""
|
||||||
@ -770,7 +771,7 @@ def get_rls_for_table(
|
|||||||
def insert_rls_as_subquery(
|
def insert_rls_as_subquery(
|
||||||
token_list: TokenList,
|
token_list: TokenList,
|
||||||
database_id: int,
|
database_id: int,
|
||||||
default_schema: Optional[str],
|
default_schema: str | None,
|
||||||
) -> TokenList:
|
) -> TokenList:
|
||||||
"""
|
"""
|
||||||
Update a statement inplace applying any associated RLS predicates.
|
Update a statement inplace applying any associated RLS predicates.
|
||||||
@ -786,7 +787,7 @@ def insert_rls_as_subquery(
|
|||||||
This method is safer than ``insert_rls_in_predicate``, but doesn't work in all
|
This method is safer than ``insert_rls_in_predicate``, but doesn't work in all
|
||||||
databases.
|
databases.
|
||||||
"""
|
"""
|
||||||
rls: Optional[TokenList] = None
|
rls: TokenList | None = None
|
||||||
state = InsertRLSState.SCANNING
|
state = InsertRLSState.SCANNING
|
||||||
for token in token_list.tokens:
|
for token in token_list.tokens:
|
||||||
# Recurse into child token list
|
# Recurse into child token list
|
||||||
@ -862,7 +863,7 @@ def insert_rls_as_subquery(
|
|||||||
def insert_rls_in_predicate(
|
def insert_rls_in_predicate(
|
||||||
token_list: TokenList,
|
token_list: TokenList,
|
||||||
database_id: int,
|
database_id: int,
|
||||||
default_schema: Optional[str],
|
default_schema: str | None,
|
||||||
) -> TokenList:
|
) -> TokenList:
|
||||||
"""
|
"""
|
||||||
Update a statement inplace applying any associated RLS predicates.
|
Update a statement inplace applying any associated RLS predicates.
|
||||||
@ -873,7 +874,7 @@ def insert_rls_in_predicate(
|
|||||||
after: SELECT * FROM some_table WHERE ( 1=1) AND some_table.id=42
|
after: SELECT * FROM some_table WHERE ( 1=1) AND some_table.id=42
|
||||||
|
|
||||||
"""
|
"""
|
||||||
rls: Optional[TokenList] = None
|
rls: TokenList | None = None
|
||||||
state = InsertRLSState.SCANNING
|
state = InsertRLSState.SCANNING
|
||||||
for token in token_list.tokens:
|
for token in token_list.tokens:
|
||||||
# Recurse into child token list
|
# Recurse into child token list
|
||||||
@ -1007,7 +1008,7 @@ RE_JINJA_BLOCK = re.compile(r"\{[%#][^\{\}%#]+[%#]\}")
|
|||||||
|
|
||||||
def extract_table_references(
|
def extract_table_references(
|
||||||
sql_text: str, sqla_dialect: str, show_warning: bool = True
|
sql_text: str, sqla_dialect: str, show_warning: bool = True
|
||||||
) -> set["Table"]:
|
) -> set[Table]:
|
||||||
"""
|
"""
|
||||||
Return all the dependencies from a SQL sql_text.
|
Return all the dependencies from a SQL sql_text.
|
||||||
"""
|
"""
|
||||||
@ -1051,3 +1052,61 @@ def extract_table_references(
|
|||||||
Table(*[part["value"] for part in table["name"][::-1]])
|
Table(*[part["value"] for part in table["name"][::-1]])
|
||||||
for table in find_nodes_by_key(tree, "Table")
|
for table in find_nodes_by_key(tree, "Table")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def extract_tables_from_jinja_sql(sql: str, engine: str | None = None) -> set[Table]:
|
||||||
|
"""
|
||||||
|
Extract all table references in the Jinjafied SQL statement.
|
||||||
|
|
||||||
|
Due to Jinja templating, a multiphase approach is necessary as the Jinjafied SQL
|
||||||
|
statement may represent invalid SQL which is non-parsable by SQLGlot.
|
||||||
|
|
||||||
|
Firstly, we extract any tables referenced within the confines of specific Jinja
|
||||||
|
macros. Secondly, we replace these non-SQL Jinja calls with a pseudo-benign SQL
|
||||||
|
expression to help ensure that the resulting SQL statements are parsable by
|
||||||
|
SQLGlot.
|
||||||
|
|
||||||
|
:param sql: The Jinjafied SQL statement
|
||||||
|
:param engine: The associated database engine
|
||||||
|
:returns: The set of tables referenced in the SQL statement
|
||||||
|
:raises SupersetSecurityException: If SQLGlot is unable to parse the SQL statement
|
||||||
|
"""
|
||||||
|
|
||||||
|
from superset.jinja_context import ( # pylint: disable=import-outside-toplevel
|
||||||
|
get_template_processor,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the required database as the processor signature is exposed publically.
|
||||||
|
processor = get_template_processor(database=Mock(backend=engine))
|
||||||
|
template = processor.env.parse(sql)
|
||||||
|
|
||||||
|
tables = set()
|
||||||
|
|
||||||
|
for node in template.find_all(nodes.Call):
|
||||||
|
if isinstance(node.node, nodes.Getattr) and node.node.attr in (
|
||||||
|
"latest_partition",
|
||||||
|
"latest_sub_partition",
|
||||||
|
):
|
||||||
|
# Extract the table referenced in the macro.
|
||||||
|
tables.add(
|
||||||
|
Table(
|
||||||
|
*[
|
||||||
|
remove_quotes(part)
|
||||||
|
for part in node.args[0].value.split(".")[::-1]
|
||||||
|
if len(node.args) == 1
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Replace the potentially problematic Jinja macro with some benign SQL.
|
||||||
|
node.__class__ = nodes.TemplateData
|
||||||
|
node.fields = nodes.TemplateData.fields
|
||||||
|
node.data = "NULL"
|
||||||
|
|
||||||
|
return (
|
||||||
|
tables
|
||||||
|
| ParsedQuery(
|
||||||
|
sql_statement=processor.process_template(template),
|
||||||
|
engine=engine,
|
||||||
|
).tables
|
||||||
|
)
|
||||||
|
@ -79,8 +79,7 @@ class SqlQueryRenderImpl(SqlQueryRender):
|
|||||||
sql_template_processor: BaseTemplateProcessor,
|
sql_template_processor: BaseTemplateProcessor,
|
||||||
) -> None:
|
) -> None:
|
||||||
if is_feature_enabled("ENABLE_TEMPLATE_PROCESSING"):
|
if is_feature_enabled("ENABLE_TEMPLATE_PROCESSING"):
|
||||||
# pylint: disable=protected-access
|
syntax_tree = sql_template_processor.env.parse(rendered_query)
|
||||||
syntax_tree = sql_template_processor._env.parse(rendered_query)
|
|
||||||
undefined_parameters = find_undeclared_variables(syntax_tree)
|
undefined_parameters = find_undeclared_variables(syntax_tree)
|
||||||
if undefined_parameters:
|
if undefined_parameters:
|
||||||
self._raise_undefined_parameter_exception(
|
self._raise_undefined_parameter_exception(
|
||||||
|
@ -32,6 +32,7 @@ from superset.exceptions import (
|
|||||||
from superset.sql_parse import (
|
from superset.sql_parse import (
|
||||||
add_table_name,
|
add_table_name,
|
||||||
extract_table_references,
|
extract_table_references,
|
||||||
|
extract_tables_from_jinja_sql,
|
||||||
get_rls_for_table,
|
get_rls_for_table,
|
||||||
has_table_query,
|
has_table_query,
|
||||||
insert_rls_as_subquery,
|
insert_rls_as_subquery,
|
||||||
@ -1874,3 +1875,43 @@ WITH t AS (
|
|||||||
)
|
)
|
||||||
SELECT * FROM t"""
|
SELECT * FROM t"""
|
||||||
).is_select()
|
).is_select()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"engine",
|
||||||
|
[
|
||||||
|
"hive",
|
||||||
|
"presto",
|
||||||
|
"trino",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"macro",
|
||||||
|
[
|
||||||
|
"latest_partition('foo.bar')",
|
||||||
|
"latest_sub_partition('foo.bar', baz='qux')",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"sql,expected",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"SELECT '{{{{ {engine}.{macro} }}}}'",
|
||||||
|
{Table(table="bar", schema="foo")},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"SELECT * FROM foo.baz WHERE quux = '{{{{ {engine}.{macro} }}}}'",
|
||||||
|
{Table(table="bar", schema="foo"), Table(table="baz", schema="foo")},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_extract_tables_from_jinja_sql(
|
||||||
|
engine: str,
|
||||||
|
macro: str,
|
||||||
|
sql: str,
|
||||||
|
expected: set[Table],
|
||||||
|
) -> None:
|
||||||
|
assert (
|
||||||
|
extract_tables_from_jinja_sql(sql.format(engine=engine, macro=macro), engine)
|
||||||
|
== expected
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user