fix: Leverage actual database for rendering Jinjarized SQL (#27646)

This commit is contained in:
John Bodley 2024-03-27 08:12:25 +13:00 committed by GitHub
parent ed9e542781
commit 28cbedb82f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 15 additions and 12 deletions

View File

@ -73,7 +73,7 @@ class SqlTablesMixin: # pylint: disable=too-few-public-methods
return list(
extract_tables_from_jinja_sql(
self.sql, # type: ignore
self.database.db_engine_spec.engine, # type: ignore
self.database, # type: ignore
)
)
except SupersetSecurityException:

View File

@ -1963,9 +1963,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
default_schema = database.get_default_schema_for_query(query)
tables = {
Table(table_.table, table_.schema or default_schema)
for table_ in extract_tables_from_jinja_sql(
query.sql, database.db_engine_spec.engine
)
for table_ in extract_tables_from_jinja_sql(query.sql, database)
}
elif table:
tables = {table}

View File

@ -25,8 +25,7 @@ import re
import urllib.parse
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from typing import Any, cast, Generic, TypeVar
from unittest.mock import Mock
from typing import Any, cast, Generic, TYPE_CHECKING, TypeVar
import sqlglot
import sqlparse
@ -75,6 +74,9 @@ try:
except (ImportError, ModuleNotFoundError):
sqloxide_parse = None
if TYPE_CHECKING:
from superset.models.core import Database
RESULT_OPERATIONS = {"UNION", "INTERSECT", "EXCEPT", "SELECT"}
ON_KEYWORD = "ON"
PRECEDES_TABLE_NAME = {"FROM", "JOIN", "DESCRIBE", "WITH", "LEFT JOIN", "RIGHT JOIN"}
@ -1509,7 +1511,7 @@ def extract_table_references(
}
def extract_tables_from_jinja_sql(sql: str, engine: str | None = None) -> set[Table]:
def extract_tables_from_jinja_sql(sql: str, database: Database) -> set[Table]:
"""
Extract all table references in the Jinjafied SQL statement.
@ -1522,7 +1524,7 @@ def extract_tables_from_jinja_sql(sql: str, engine: str | None = None) -> set[Ta
SQLGlot.
:param sql: The Jinjafied SQL statement
:param engine: The associated database engine
:param database: The database associated with the SQL statement
:returns: The set of tables referenced in the SQL statement
:raises SupersetSecurityException: If SQLGlot is unable to parse the SQL statement
"""
@ -1531,8 +1533,7 @@ def extract_tables_from_jinja_sql(sql: str, engine: str | None = None) -> set[Ta
get_template_processor,
)
# Mock the required database as the processor signature is exposed publically.
processor = get_template_processor(database=Mock(backend=engine))
processor = get_template_processor(database)
template = processor.env.parse(sql)
tables = set()
@ -1562,6 +1563,6 @@ def extract_tables_from_jinja_sql(sql: str, engine: str | None = None) -> set[Ta
tables
| ParsedQuery(
sql_statement=processor.process_template(template),
engine=engine,
engine=database.db_engine_spec.engine,
).tables
)

View File

@ -17,6 +17,7 @@
# pylint: disable=invalid-name, redefined-outer-name, too-many-lines
from typing import Optional
from unittest.mock import Mock
import pytest
import sqlparse
@ -1959,7 +1960,10 @@ def test_extract_tables_from_jinja_sql(
expected: set[Table],
) -> None:
assert (
extract_tables_from_jinja_sql(sql.format(engine=engine, macro=macro), engine)
extract_tables_from_jinja_sql(
sql=sql.format(engine=engine, macro=macro),
database=Mock(),
)
== expected
)