fix(sql_parse): Provide more lenient logic when extracting latest[_sub]_partition (#28152)

This commit is contained in:
John Bodley 2024-04-25 22:02:25 -07:00 committed by GitHub
parent 1e47e65ac5
commit c5e7d870f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 34 additions and 27 deletions

View File

@ -1554,16 +1554,19 @@ def extract_tables_from_jinja_sql(sql: str, database: Database) -> set[Table]:
"latest_partition",
"latest_sub_partition",
):
# Extract the table referenced in the macro.
tables.add(
Table(
*[
remove_quotes(part.strip())
for part in node.args[0].as_const().split(".")[::-1]
if len(node.args) == 1
]
# Try to extract the table referenced in the macro.
try:
tables.add(
Table(
*[
remove_quotes(part.strip())
for part in node.args[0].as_const().split(".")[::-1]
if len(node.args) == 1
]
)
)
)
except nodes.Impossible:
pass
# Replace the potentially problematic Jinja macro with some benign SQL.
node.__class__ = nodes.TemplateData

View File

@ -1857,36 +1857,40 @@ def test_sqlstatement() -> None:
],
)
@pytest.mark.parametrize(
"macro",
[
"latest_partition('foo.bar')",
"latest_partition(' foo.bar ')", # Non-atypical user error which works
"latest_partition('foo.%s'|format('bar'))",
"latest_sub_partition('foo.bar', baz='qux')",
],
)
@pytest.mark.parametrize(
"sql,expected",
"macro,expected",
[
(
"SELECT '{{{{ {engine}.{macro} }}}}'",
"latest_partition('foo.bar')",
{Table(table="bar", schema="foo")},
),
(
"SELECT * FROM foo.baz WHERE quux = '{{{{ {engine}.{macro} }}}}'",
{Table(table="bar", schema="foo"), Table(table="baz", schema="foo")},
"latest_partition(' foo.bar ')", # Non-atypical user error which works
{Table(table="bar", schema="foo")},
),
(
"latest_partition('foo.%s'|format('bar'))",
{Table(table="bar", schema="foo")},
),
(
"latest_sub_partition('foo.bar', baz='qux')",
{Table(table="bar", schema="foo")},
),
(
"latest_partition('foo.%s'|format(str('bar')))",
set(),
),
(
"latest_partition('foo.{}'.format('bar'))",
set(),
),
],
)
def test_extract_tables_from_jinja_sql(
engine: str,
macro: str,
sql: str,
expected: set[Table],
engine: str, macro: str, expected: set[Table]
) -> None:
assert (
extract_tables_from_jinja_sql(
sql=sql.format(engine=engine, macro=macro),
sql=f"'{{{{ {engine}.{macro} }}}}'",
database=Mock(),
)
== expected