mirror of https://github.com/apache/superset.git
fix(sql_parse): Ensure table extraction handles Jinja templating (#27470)
This commit is contained in:
parent
a8c01f4cad
commit
b25dd0c055
|
@ -144,11 +144,13 @@ class ExecuteSqlCommand(BaseCommand):
|
|||
try:
|
||||
logger.info("Triggering query_id: %i", query.id)
|
||||
|
||||
# Necessary to check access before rendering the Jinjafied query as the
|
||||
# some Jinja macros execute statements upon rendering.
|
||||
self._validate_access(query)
|
||||
self._execution_context.set_query(query)
|
||||
rendered_query = self._sql_query_render.render(self._execution_context)
|
||||
validate_rendered_query = copy.copy(query)
|
||||
validate_rendered_query.sql = rendered_query
|
||||
self._validate_access(validate_rendered_query)
|
||||
self._set_query_limit_if_required(rendered_query)
|
||||
self._query_dao.update(
|
||||
query, {"limit": self._execution_context.query.limit}
|
||||
|
|
|
@ -24,7 +24,7 @@ from typing import Any, Callable, cast, Optional, TYPE_CHECKING, TypedDict, Unio
|
|||
import dateutil
|
||||
from flask import current_app, has_request_context, request
|
||||
from flask_babel import gettext as _
|
||||
from jinja2 import DebugUndefined
|
||||
from jinja2 import DebugUndefined, Environment
|
||||
from jinja2.sandbox import SandboxedEnvironment
|
||||
from sqlalchemy.engine.interfaces import Dialect
|
||||
from sqlalchemy.sql.expression import bindparam
|
||||
|
@ -479,11 +479,11 @@ class BaseTemplateProcessor:
|
|||
self._applied_filters = applied_filters
|
||||
self._removed_filters = removed_filters
|
||||
self._context: dict[str, Any] = {}
|
||||
self._env = SandboxedEnvironment(undefined=DebugUndefined)
|
||||
self.env: Environment = SandboxedEnvironment(undefined=DebugUndefined)
|
||||
self.set_context(**kwargs)
|
||||
|
||||
# custom filters
|
||||
self._env.filters["where_in"] = WhereInMacro(database.get_dialect())
|
||||
self.env.filters["where_in"] = WhereInMacro(database.get_dialect())
|
||||
|
||||
def set_context(self, **kwargs: Any) -> None:
|
||||
self._context.update(kwargs)
|
||||
|
@ -496,7 +496,7 @@ class BaseTemplateProcessor:
|
|||
>>> process_template(sql)
|
||||
"SELECT '2017-01-01T00:00:00'"
|
||||
"""
|
||||
template = self._env.from_string(sql)
|
||||
template = self.env.from_string(sql)
|
||||
kwargs.update(self._context)
|
||||
|
||||
context = validate_template_context(self.engine, kwargs)
|
||||
|
@ -643,7 +643,7 @@ class TrinoTemplateProcessor(PrestoTemplateProcessor):
|
|||
engine = "trino"
|
||||
|
||||
def process_template(self, sql: str, **kwargs: Any) -> str:
|
||||
template = self._env.from_string(sql)
|
||||
template = self.env.from_string(sql)
|
||||
kwargs.update(self._context)
|
||||
|
||||
# Backwards compatibility if migrating from Presto.
|
||||
|
|
|
@ -46,6 +46,7 @@ from sqlalchemy.orm import backref, relationship
|
|||
from sqlalchemy.sql.elements import ColumnElement, literal_column
|
||||
|
||||
from superset import security_manager
|
||||
from superset.exceptions import SupersetSecurityException
|
||||
from superset.jinja_context import BaseTemplateProcessor, get_template_processor
|
||||
from superset.models.helpers import (
|
||||
AuditMixinNullable,
|
||||
|
@ -53,7 +54,7 @@ from superset.models.helpers import (
|
|||
ExtraJSONMixin,
|
||||
ImportExportMixin,
|
||||
)
|
||||
from superset.sql_parse import CtasMethod, ParsedQuery, Table
|
||||
from superset.sql_parse import CtasMethod, extract_tables_from_jinja_sql, Table
|
||||
from superset.sqllab.limiting_factor import LimitingFactor
|
||||
from superset.utils.core import get_column_name, MediumText, QueryStatus, user_label
|
||||
|
||||
|
@ -65,8 +66,25 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SqlTablesMixin: # pylint: disable=too-few-public-methods
|
||||
@property
|
||||
def sql_tables(self) -> list[Table]:
|
||||
try:
|
||||
return list(
|
||||
extract_tables_from_jinja_sql(
|
||||
self.sql, # type: ignore
|
||||
self.database.db_engine_spec.engine, # type: ignore
|
||||
)
|
||||
)
|
||||
except SupersetSecurityException:
|
||||
return []
|
||||
|
||||
|
||||
class Query(
|
||||
ExtraJSONMixin, ExploreMixin, Model
|
||||
SqlTablesMixin,
|
||||
ExtraJSONMixin,
|
||||
ExploreMixin,
|
||||
Model,
|
||||
): # pylint: disable=abstract-method,too-many-public-methods
|
||||
"""ORM model for SQL query
|
||||
|
||||
|
@ -181,10 +199,6 @@ class Query(
|
|||
def username(self) -> str:
|
||||
return self.user.username
|
||||
|
||||
@property
|
||||
def sql_tables(self) -> list[Table]:
|
||||
return list(ParsedQuery(self.sql, engine=self.db_engine_spec.engine).tables)
|
||||
|
||||
@property
|
||||
def columns(self) -> list["TableColumn"]:
|
||||
from superset.connectors.sqla.models import ( # pylint: disable=import-outside-toplevel
|
||||
|
@ -355,7 +369,13 @@ class Query(
|
|||
return self.make_sqla_column_compatible(sqla_column, label)
|
||||
|
||||
|
||||
class SavedQuery(AuditMixinNullable, ExtraJSONMixin, ImportExportMixin, Model):
|
||||
class SavedQuery(
|
||||
SqlTablesMixin,
|
||||
AuditMixinNullable,
|
||||
ExtraJSONMixin,
|
||||
ImportExportMixin,
|
||||
Model,
|
||||
):
|
||||
"""ORM model for SQL query"""
|
||||
|
||||
__tablename__ = "saved_query"
|
||||
|
@ -425,12 +445,6 @@ class SavedQuery(AuditMixinNullable, ExtraJSONMixin, ImportExportMixin, Model):
|
|||
def url(self) -> str:
|
||||
return f"/sqllab?savedQueryId={self.id}"
|
||||
|
||||
@property
|
||||
def sql_tables(self) -> list[Table]:
|
||||
return list(
|
||||
ParsedQuery(self.sql, engine=self.database.db_engine_spec.engine).tables
|
||||
)
|
||||
|
||||
@property
|
||||
def last_run_humanized(self) -> str:
|
||||
return naturaltime(datetime.now() - self.changed_on)
|
||||
|
|
|
@ -52,14 +52,12 @@ from sqlalchemy.orm import eagerload
|
|||
from sqlalchemy.orm.mapper import Mapper
|
||||
from sqlalchemy.orm.query import Query as SqlaQuery
|
||||
|
||||
from superset import sql_parse
|
||||
from superset.constants import RouteMethod
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import (
|
||||
DatasetInvalidPermissionEvaluationException,
|
||||
SupersetSecurityException,
|
||||
)
|
||||
from superset.jinja_context import get_template_processor
|
||||
from superset.security.guest_token import (
|
||||
GuestToken,
|
||||
GuestTokenResources,
|
||||
|
@ -68,6 +66,7 @@ from superset.security.guest_token import (
|
|||
GuestTokenUser,
|
||||
GuestUser,
|
||||
)
|
||||
from superset.sql_parse import extract_tables_from_jinja_sql
|
||||
from superset.superset_typing import Metric
|
||||
from superset.utils.core import (
|
||||
DatasourceName,
|
||||
|
@ -1961,16 +1960,12 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
return
|
||||
|
||||
if query:
|
||||
# make sure the quuery is valid SQL by rendering any Jinja
|
||||
processor = get_template_processor(database=query.database)
|
||||
rendered_sql = processor.process_template(query.sql)
|
||||
default_schema = database.get_default_schema_for_query(query)
|
||||
tables = {
|
||||
Table(table_.table, table_.schema or default_schema)
|
||||
for table_ in sql_parse.ParsedQuery(
|
||||
rendered_sql,
|
||||
engine=database.db_engine_spec.engine,
|
||||
).tables
|
||||
for table_ in extract_tables_from_jinja_sql(
|
||||
query.sql, database.db_engine_spec.engine
|
||||
)
|
||||
}
|
||||
elif table:
|
||||
tables = {table}
|
||||
|
|
|
@ -25,10 +25,12 @@ import urllib.parse
|
|||
from collections.abc import Iterable, Iterator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast
|
||||
from unittest.mock import Mock
|
||||
|
||||
import sqlglot
|
||||
import sqlparse
|
||||
from flask_babel import gettext as __
|
||||
from jinja2 import nodes
|
||||
from sqlalchemy import and_
|
||||
from sqlglot import exp, parse, parse_one
|
||||
from sqlglot.dialects.dialect import Dialect, Dialects
|
||||
|
@ -1232,3 +1234,61 @@ def extract_table_references(
|
|||
Table(*[part["value"] for part in table["name"][::-1]])
|
||||
for table in find_nodes_by_key(tree, "Table")
|
||||
}
|
||||
|
||||
|
||||
def extract_tables_from_jinja_sql(sql: str, engine: str | None = None) -> set[Table]:
|
||||
"""
|
||||
Extract all table references in the Jinjafied SQL statement.
|
||||
|
||||
Due to Jinja templating, a multiphase approach is necessary as the Jinjafied SQL
|
||||
statement may represent invalid SQL which is non-parsable by SQLGlot.
|
||||
|
||||
Firstly, we extract any tables referenced within the confines of specific Jinja
|
||||
macros. Secondly, we replace these non-SQL Jinja calls with a pseudo-benign SQL
|
||||
expression to help ensure that the resulting SQL statements are parsable by
|
||||
SQLGlot.
|
||||
|
||||
:param sql: The Jinjafied SQL statement
|
||||
:param engine: The associated database engine
|
||||
:returns: The set of tables referenced in the SQL statement
|
||||
:raises SupersetSecurityException: If SQLGlot is unable to parse the SQL statement
|
||||
"""
|
||||
|
||||
from superset.jinja_context import ( # pylint: disable=import-outside-toplevel
|
||||
get_template_processor,
|
||||
)
|
||||
|
||||
# Mock the required database as the processor signature is exposed publically.
|
||||
processor = get_template_processor(database=Mock(backend=engine))
|
||||
template = processor.env.parse(sql)
|
||||
|
||||
tables = set()
|
||||
|
||||
for node in template.find_all(nodes.Call):
|
||||
if isinstance(node.node, nodes.Getattr) and node.node.attr in (
|
||||
"latest_partition",
|
||||
"latest_sub_partition",
|
||||
):
|
||||
# Extract the table referenced in the macro.
|
||||
tables.add(
|
||||
Table(
|
||||
*[
|
||||
remove_quotes(part)
|
||||
for part in node.args[0].value.split(".")[::-1]
|
||||
if len(node.args) == 1
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
# Replace the potentially problematic Jinja macro with some benign SQL.
|
||||
node.__class__ = nodes.TemplateData
|
||||
node.fields = nodes.TemplateData.fields
|
||||
node.data = "NULL"
|
||||
|
||||
return (
|
||||
tables
|
||||
| ParsedQuery(
|
||||
sql_statement=processor.process_template(template),
|
||||
engine=engine,
|
||||
).tables
|
||||
)
|
||||
|
|
|
@ -79,8 +79,7 @@ class SqlQueryRenderImpl(SqlQueryRender):
|
|||
sql_template_processor: BaseTemplateProcessor,
|
||||
) -> None:
|
||||
if is_feature_enabled("ENABLE_TEMPLATE_PROCESSING"):
|
||||
# pylint: disable=protected-access
|
||||
syntax_tree = sql_template_processor._env.parse(rendered_query)
|
||||
syntax_tree = sql_template_processor.env.parse(rendered_query)
|
||||
undefined_parameters = find_undeclared_variables(syntax_tree)
|
||||
if undefined_parameters:
|
||||
self._raise_undefined_parameter_exception(
|
||||
|
|
|
@ -32,6 +32,7 @@ from superset.exceptions import (
|
|||
from superset.sql_parse import (
|
||||
add_table_name,
|
||||
extract_table_references,
|
||||
extract_tables_from_jinja_sql,
|
||||
get_rls_for_table,
|
||||
has_table_query,
|
||||
insert_rls_as_subquery,
|
||||
|
@ -1909,3 +1910,43 @@ def test_sqlstatement() -> None:
|
|||
|
||||
statement = SQLStatement("SET a=1")
|
||||
assert statement.get_settings() == {"a": "1"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"engine",
|
||||
[
|
||||
"hive",
|
||||
"presto",
|
||||
"trino",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"macro",
|
||||
[
|
||||
"latest_partition('foo.bar')",
|
||||
"latest_sub_partition('foo.bar', baz='qux')",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"sql,expected",
|
||||
[
|
||||
(
|
||||
"SELECT '{{{{ {engine}.{macro} }}}}'",
|
||||
{Table(table="bar", schema="foo")},
|
||||
),
|
||||
(
|
||||
"SELECT * FROM foo.baz WHERE quux = '{{{{ {engine}.{macro} }}}}'",
|
||||
{Table(table="bar", schema="foo"), Table(table="baz", schema="foo")},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tables_from_jinja_sql(
|
||||
engine: str,
|
||||
macro: str,
|
||||
sql: str,
|
||||
expected: set[Table],
|
||||
) -> None:
|
||||
assert (
|
||||
extract_tables_from_jinja_sql(sql.format(engine=engine, macro=macro), engine)
|
||||
== expected
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue