diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index 2d7384a74e..f22d774e88 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -73,7 +73,7 @@ class SqlTablesMixin: # pylint: disable=too-few-public-methods return list( extract_tables_from_jinja_sql( self.sql, # type: ignore - self.database.db_engine_spec.engine, # type: ignore + self.database, # type: ignore ) ) except SupersetSecurityException: diff --git a/superset/security/manager.py b/superset/security/manager.py index 2833e88645..e5a32e97a7 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -1963,9 +1963,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods default_schema = database.get_default_schema_for_query(query) tables = { Table(table_.table, table_.schema or default_schema) - for table_ in extract_tables_from_jinja_sql( - query.sql, database.db_engine_spec.engine - ) + for table_ in extract_tables_from_jinja_sql(query.sql, database) } elif table: tables = {table} diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 11e4279aa2..e4e9b9a672 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -25,8 +25,7 @@ import re import urllib.parse from collections.abc import Iterable, Iterator from dataclasses import dataclass -from typing import Any, cast, Generic, TypeVar -from unittest.mock import Mock +from typing import Any, cast, Generic, TYPE_CHECKING, TypeVar import sqlglot import sqlparse @@ -75,6 +74,9 @@ try: except (ImportError, ModuleNotFoundError): sqloxide_parse = None +if TYPE_CHECKING: + from superset.models.core import Database + RESULT_OPERATIONS = {"UNION", "INTERSECT", "EXCEPT", "SELECT"} ON_KEYWORD = "ON" PRECEDES_TABLE_NAME = {"FROM", "JOIN", "DESCRIBE", "WITH", "LEFT JOIN", "RIGHT JOIN"} @@ -1509,7 +1511,7 @@ def extract_table_references( } -def extract_tables_from_jinja_sql(sql: str, engine: str | None = None) -> set[Table]: +def extract_tables_from_jinja_sql(sql: str, database: Database) -> set[Table]: """ Extract all table references in the Jinjafied SQL statement. @@ -1522,7 +1524,7 @@ def extract_tables_from_jinja_sql(sql: str, engine: str | None = None) -> set[Ta SQLGlot. :param sql: The Jinjafied SQL statement - :param engine: The associated database engine + :param database: The database associated with the SQL statement :returns: The set of tables referenced in the SQL statement :raises SupersetSecurityException: If SQLGlot is unable to parse the SQL statement """ @@ -1531,8 +1533,7 @@ def extract_tables_from_jinja_sql(sql: str, engine: str | None = None) -> set[Ta get_template_processor, ) - # Mock the required database as the processor signature is exposed publically. - processor = get_template_processor(database=Mock(backend=engine)) + processor = get_template_processor(database) template = processor.env.parse(sql) tables = set() @@ -1562,6 +1563,6 @@ def extract_tables_from_jinja_sql(sql: str, engine: str | None = None) -> set[Ta tables | ParsedQuery( sql_statement=processor.process_template(template), - engine=engine, + engine=database.db_engine_spec.engine, ).tables ) diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index 79958b0743..eae43dd4c3 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -17,6 +17,7 @@ # pylint: disable=invalid-name, redefined-outer-name, too-many-lines from typing import Optional +from unittest.mock import Mock import pytest import sqlparse @@ -1959,7 +1960,10 @@ def test_extract_tables_from_jinja_sql( expected: set[Table], ) -> None: assert ( - extract_tables_from_jinja_sql(sql.format(engine=engine, macro=macro), engine) + extract_tables_from_jinja_sql( + sql=sql.format(engine=engine, macro=macro), + database=Mock(), + ) == expected )