feat(sqlparse): improve table parsing (#26476)

This commit is contained in:
Beto Dealmeida 2024-01-22 11:16:50 -05:00 committed by GitHub
parent d34874cf2b
commit c0b57bd1c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 265 additions and 120 deletions

View File

@ -141,7 +141,9 @@ geographiclib==1.52
geopy==2.2.0
# via apache-superset
greenlet==2.0.2
# via shillelagh
# via
# shillelagh
# sqlalchemy
gunicorn==21.2.0
# via apache-superset
hashids==1.3.1
@ -155,7 +157,10 @@ idna==3.2
# email-validator
# requests
importlib-metadata==6.6.0
# via apache-superset
# via
# apache-superset
# flask
# shillelagh
importlib-resources==5.12.0
# via limits
isodate==0.6.0
@ -327,6 +332,8 @@ sqlalchemy-utils==0.38.3
# via
# apache-superset
# flask-appbuilder
sqlglot==20.8.0
# via apache-superset
sqlparse==0.4.4
# via apache-superset
sshtunnel==0.4.0
@ -376,7 +383,9 @@ wtforms-json==0.3.5
xlsxwriter==3.0.7
# via apache-superset
zipp==3.15.0
# via importlib-metadata
# via
# importlib-metadata
# importlib-resources
# The following packages are considered to be unsafe in a requirements file:
# setuptools

View File

@ -24,10 +24,6 @@ db-dtypes==1.1.1
# via pandas-gbq
docker==6.1.1
# via -r requirements/testing.in
exceptiongroup==1.1.1
# via pytest
ephem==4.1.4
# via lunarcalendar
flask-testing==0.8.1
# via -r requirements/testing.in
fonttools==4.39.4
@ -121,6 +117,8 @@ pyee==9.0.4
# via playwright
pyfakefs==5.2.2
# via -r requirements/testing.in
pyhive[presto]==0.7.0
# via apache-superset
pytest==7.3.1
# via
# -r requirements/testing.in

View File

@ -125,6 +125,7 @@ setup(
"slack_sdk>=3.19.0, <4",
"sqlalchemy>=1.4, <2",
"sqlalchemy-utils>=0.38.3, <0.39",
"sqlglot>=20,<21",
"sqlparse>=0.4.4, <0.5",
"tabulate>=0.8.9, <0.9",
"typing-extensions>=4, <5",

View File

@ -70,7 +70,10 @@ class DuplicateDatasetCommand(CreateMixin, BaseCommand):
table.normalize_columns = self._base_model.normalize_columns
table.always_filter_main_dttm = self._base_model.always_filter_main_dttm
table.is_sqllab_view = True
table.sql = ParsedQuery(self._base_model.sql).stripped()
table.sql = ParsedQuery(
self._base_model.sql,
engine=database.db_engine_spec.engine,
).stripped()
db.session.add(table)
cols = []
for config_ in self._base_model.columns:

View File

@ -115,7 +115,10 @@ class SqlResultExportCommand(BaseCommand):
limit = None
else:
sql = self._query.executed_sql
limit = ParsedQuery(sql).limit
limit = ParsedQuery(
sql,
engine=self._query.database.db_engine_spec.engine,
).limit
if limit is not None and self._query.limiting_factor in {
LimitingFactor.QUERY,
LimitingFactor.DROPDOWN,

View File

@ -1457,7 +1457,7 @@ class SqlaTable(
return self.get_sqla_table(), None
from_sql = self.get_rendered_sql(template_processor)
parsed_query = ParsedQuery(from_sql)
parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine)
if not (
parsed_query.is_unknown()
or self.db_engine_spec.is_readonly_query(parsed_query)

View File

@ -111,7 +111,7 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> list[ResultSetColumnType]:
sql = dataset.get_template_processor().process_template(
dataset.sql, **dataset.template_params_dict
)
parsed_query = ParsedQuery(sql)
parsed_query = ParsedQuery(sql, engine=db_engine_spec.engine)
if not db_engine_spec.is_readonly_query(parsed_query):
raise SupersetSecurityException(
SupersetError(

View File

@ -899,7 +899,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return database.compile_sqla_query(qry)
if cls.limit_method == LimitMethod.FORCE_LIMIT:
parsed_query = sql_parse.ParsedQuery(sql)
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
sql = parsed_query.set_or_update_query_limit(limit, force=force)
return sql
@ -980,7 +980,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
:param sql: SQL query
:return: Value of limit clause in query
"""
parsed_query = sql_parse.ParsedQuery(sql)
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
return parsed_query.limit
@classmethod
@ -992,7 +992,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
:param limit: New limit to insert/replace into query
:return: Query with new limit
"""
parsed_query = sql_parse.ParsedQuery(sql)
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
return parsed_query.set_or_update_query_limit(limit)
@classmethod
@ -1487,7 +1487,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
:param database: Database instance
:return: Dictionary with different costs
"""
parsed_query = ParsedQuery(statement)
parsed_query = ParsedQuery(statement, engine=cls.engine)
sql = parsed_query.stripped()
sql_query_mutator = current_app.config["SQL_QUERY_MUTATOR"]
mutate_after_split = current_app.config["MUTATE_AFTER_SPLIT"]
@ -1522,7 +1522,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
"Database does not support cost estimation"
)
parsed_query = sql_parse.ParsedQuery(sql)
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
statements = parsed_query.get_statements()
costs = []
@ -1583,7 +1583,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
:return:
"""
if not cls.allows_sql_comments:
query = sql_parse.strip_comments_from_sql(query)
query = sql_parse.strip_comments_from_sql(query, engine=cls.engine)
if cls.arraysize:
cursor.arraysize = cls.arraysize

View File

@ -435,7 +435,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
if not cls.get_allow_cost_estimate(extra):
raise SupersetException("Database does not support cost estimation")
parsed_query = sql_parse.ParsedQuery(sql)
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
statements = parsed_query.get_statements()
costs = []
for statement in statements:

View File

@ -1093,7 +1093,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
"""
from_sql = self.get_rendered_sql(template_processor)
parsed_query = ParsedQuery(from_sql)
parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine)
if not (
parsed_query.is_unknown()
or self.db_engine_spec.is_readonly_query(parsed_query)

View File

@ -183,7 +183,7 @@ class Query(
@property
def sql_tables(self) -> list[Table]:
return list(ParsedQuery(self.sql).tables)
return list(ParsedQuery(self.sql, engine=self.db_engine_spec.engine).tables)
@property
def columns(self) -> list["TableColumn"]:
@ -427,7 +427,9 @@ class SavedQuery(AuditMixinNullable, ExtraJSONMixin, ImportExportMixin, Model):
@property
def sql_tables(self) -> list[Table]:
return list(ParsedQuery(self.sql).tables)
return list(
ParsedQuery(self.sql, engine=self.database.db_engine_spec.engine).tables
)
@property
def last_run_humanized(self) -> str:

View File

@ -1876,7 +1876,10 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
default_schema = database.get_default_schema_for_query(query)
tables = {
Table(table_.table, table_.schema or default_schema)
for table_ in sql_parse.ParsedQuery(query.sql).tables
for table_ in sql_parse.ParsedQuery(
query.sql,
engine=database.db_engine_spec.engine,
).tables
}
elif table:
tables = {table}

View File

@ -199,7 +199,7 @@ def execute_sql_statement(
database: Database = query.database
db_engine_spec = database.db_engine_spec
parsed_query = ParsedQuery(sql_statement)
parsed_query = ParsedQuery(sql_statement, engine=db_engine_spec.engine)
if is_feature_enabled("RLS_IN_SQLLAB"):
# There are two ways to insert RLS: either replacing the table with a subquery
# that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is
@ -219,7 +219,8 @@ def execute_sql_statement(
database.id,
query.schema,
)
)
),
engine=db_engine_spec.engine,
)
sql = parsed_query.stripped()
@ -409,7 +410,11 @@ def execute_sql_statements(
)
# Breaking down into multiple statements
parsed_query = ParsedQuery(rendered_query, strip_comments=True)
parsed_query = ParsedQuery(
rendered_query,
strip_comments=True,
engine=db_engine_spec.engine,
)
if not db_engine_spec.run_multiple_statements_as_one:
statements = parsed_query.get_statements()
logger.info(

View File

@ -14,15 +14,22 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=too-many-lines
import logging
import re
from collections.abc import Iterator
import urllib.parse
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from typing import Any, cast, Optional
from urllib import parse
import sqlparse
from sqlalchemy import and_
from sqlglot import exp, parse, parse_one
from sqlglot.dialects import Dialects
from sqlglot.errors import ParseError
from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
from sqlparse import keywords
from sqlparse.lexer import Lexer
from sqlparse.sql import (
@ -53,7 +60,7 @@ from superset.utils.backports import StrEnum
try:
from sqloxide import parse_sql as sqloxide_parse
except: # pylint: disable=bare-except
except (ImportError, ModuleNotFoundError):
sqloxide_parse = None
RESULT_OPERATIONS = {"UNION", "INTERSECT", "EXCEPT", "SELECT"}
@ -72,6 +79,59 @@ sqlparser_sql_regex.insert(25, (r"'(''|\\\\|\\|[^'])*'", sqlparse.tokens.String.
lex.set_SQL_REGEX(sqlparser_sql_regex)
# mapping between DB engine specs and sqlglot dialects
SQLGLOT_DIALECTS = {
"ascend": Dialects.HIVE,
"awsathena": Dialects.PRESTO,
"bigquery": Dialects.BIGQUERY,
"clickhouse": Dialects.CLICKHOUSE,
"clickhousedb": Dialects.CLICKHOUSE,
"cockroachdb": Dialects.POSTGRES,
# "crate": ???
# "databend": ???
"databricks": Dialects.DATABRICKS,
# "db2": ???
# "dremio": ???
"drill": Dialects.DRILL,
# "druid": ???
"duckdb": Dialects.DUCKDB,
# "dynamodb": ???
# "elasticsearch": ???
# "exa": ???
# "firebird": ???
# "firebolt": ???
"gsheets": Dialects.SQLITE,
"hana": Dialects.POSTGRES,
"hive": Dialects.HIVE,
# "ibmi": ???
# "impala": ???
# "kustokql": ???
# "kylin": ???
# "mssql": ???
"mysql": Dialects.MYSQL,
"netezza": Dialects.POSTGRES,
# "ocient": ???
# "odelasticsearch": ???
"oracle": Dialects.ORACLE,
# "pinot": ???
"postgresql": Dialects.POSTGRES,
"presto": Dialects.PRESTO,
"pydoris": Dialects.DORIS,
"redshift": Dialects.REDSHIFT,
# "risingwave": ???
# "rockset": ???
"shillelagh": Dialects.SQLITE,
"snowflake": Dialects.SNOWFLAKE,
# "solr": ???
"sqlite": Dialects.SQLITE,
"starrocks": Dialects.STARROCKS,
"superset": Dialects.SQLITE,
"teradatasql": Dialects.TERADATA,
"trino": Dialects.TRINO,
"vertica": Dialects.POSTGRES,
}
class CtasMethod(StrEnum):
TABLE = "TABLE"
VIEW = "VIEW"
@ -150,7 +210,7 @@ def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
return cte, remainder
def strip_comments_from_sql(statement: str) -> str:
def strip_comments_from_sql(statement: str, engine: Optional[str] = None) -> str:
"""
Strips comments from a SQL statement, does a simple test first
to avoid always instantiating the expensive ParsedQuery constructor
@ -160,7 +220,11 @@ def strip_comments_from_sql(statement: str) -> str:
:param statement: A string with the SQL statement
:return: SQL statement without comments
"""
return ParsedQuery(statement).strip_comments() if "--" in statement else statement
return (
ParsedQuery(statement, engine=engine).strip_comments()
if "--" in statement
else statement
)
@dataclass(eq=True, frozen=True)
@ -179,7 +243,7 @@ class Table:
"""
return ".".join(
parse.quote(part, safe="").replace(".", "%2E")
urllib.parse.quote(part, safe="").replace(".", "%2E")
for part in [self.catalog, self.schema, self.table]
if part
)
@ -189,11 +253,17 @@ class Table:
class ParsedQuery:
def __init__(self, sql_statement: str, strip_comments: bool = False):
def __init__(
self,
sql_statement: str,
strip_comments: bool = False,
engine: Optional[str] = None,
):
if strip_comments:
sql_statement = sqlparse.format(sql_statement, strip_comments=True)
self.sql: str = sql_statement
self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
self._tables: set[Table] = set()
self._alias_names: set[str] = set()
self._limit: Optional[int] = None
@ -206,14 +276,94 @@ class ParsedQuery:
@property
def tables(self) -> set[Table]:
if not self._tables:
for statement in self._parsed:
self._extract_from_token(statement)
self._tables = {
table for table in self._tables if str(table) not in self._alias_names
}
self._tables = self._extract_tables_from_sql()
return self._tables
def _extract_tables_from_sql(self) -> set[Table]:
"""
Extract all table references in a query.
Note: this uses sqlglot, since it's better at catching more edge cases.
"""
try:
statements = parse(self.sql, dialect=self._dialect)
except ParseError:
logger.warning("Unable to parse SQL (%s): %s", self._dialect, self.sql)
return set()
return {
table
for statement in statements
for table in self._extract_tables_from_statement(statement)
if statement
}
def _extract_tables_from_statement(self, statement: exp.Expression) -> set[Table]:
"""
Extract all table references in a single statement.
Please not that this is not trivial; consider the following queries:
DESCRIBE some_table;
SHOW PARTITIONS FROM some_table;
WITH masked_name AS (SELECT * FROM some_table) SELECT * FROM masked_name;
See the unit tests for other tricky cases.
"""
sources: Iterable[exp.Table]
if isinstance(statement, exp.Describe):
# A `DESCRIBE` query has no sources in sqlglot, so we need to explicitly
# query for all tables.
sources = statement.find_all(exp.Table)
elif isinstance(statement, exp.Command):
# Commands, like `SHOW COLUMNS FROM foo`, have to be converted into a
# `SELECT` statetement in order to extract tables.
literal = statement.find(exp.Literal)
if not literal:
return set()
pseudo_query = parse_one(f"SELECT {literal.this}", dialect=self._dialect)
sources = pseudo_query.find_all(exp.Table)
else:
sources = [
source
for scope in traverse_scope(statement)
for source in scope.sources.values()
if isinstance(source, exp.Table) and not self._is_cte(source, scope)
]
return {
Table(
source.name,
source.db if source.db != "" else None,
source.catalog if source.catalog != "" else None,
)
for source in sources
}
def _is_cte(self, source: exp.Table, scope: Scope) -> bool:
"""
Is the source a CTE?
CTEs in the parent scope look like tables (and are represented by
exp.Table objects), but should not be considered as such;
otherwise a user with access to table `foo` could access any table
with a query like this:
WITH foo AS (SELECT * FROM target_table) SELECT * FROM foo
"""
parent_sources = scope.parent.sources if scope.parent else {}
ctes_in_scope = {
name
for name, parent_scope in parent_sources.items()
if isinstance(parent_scope, Scope)
and parent_scope.scope_type == ScopeType.CTE
}
return source.name in ctes_in_scope
@property
def limit(self) -> Optional[int]:
return self._limit
@ -393,28 +543,6 @@ class ParsedQuery:
def _is_identifier(token: Token) -> bool:
return isinstance(token, (IdentifierList, Identifier))
def _process_tokenlist(self, token_list: TokenList) -> None:
"""
Add table names to table set
:param token_list: TokenList to be processed
"""
# exclude subselects
if "(" not in str(token_list):
table = self.get_table(token_list)
if table and not table.table.startswith(CTE_PREFIX):
self._tables.add(table)
return
# store aliases
if token_list.has_alias():
self._alias_names.add(token_list.get_alias())
# some aliases are not parsed properly
if token_list.tokens[0].ttype == Name:
self._alias_names.add(token_list.tokens[0].value)
self._extract_from_token(token_list)
def as_create_table(
self,
table_name: str,
@ -441,50 +569,6 @@ class ParsedQuery:
exec_sql += f"CREATE {method} {full_table_name} AS \n{sql}"
return exec_sql
def _extract_from_token(self, token: Token) -> None:
"""
<Identifier> store a list of subtokens and <IdentifierList> store lists of
subtoken list.
It extracts <IdentifierList> and <Identifier> from :param token: and loops
through all subtokens recursively. It finds table_name_preceding_token and
passes <IdentifierList> and <Identifier> to self._process_tokenlist to populate
self._tables.
:param token: instance of Token or child class, e.g. TokenList, to be processed
"""
if not hasattr(token, "tokens"):
return
table_name_preceding_token = False
for item in token.tokens:
if item.is_group and (
not self._is_identifier(item) or isinstance(item.tokens[0], Parenthesis)
):
self._extract_from_token(item)
if item.ttype in Keyword and (
item.normalized in PRECEDES_TABLE_NAME
or item.normalized.endswith(" JOIN")
):
table_name_preceding_token = True
continue
if item.ttype in Keyword:
table_name_preceding_token = False
continue
if table_name_preceding_token:
if isinstance(item, Identifier):
self._process_tokenlist(item)
elif isinstance(item, IdentifierList):
for token2 in item.get_identifiers():
if isinstance(token2, TokenList):
self._process_tokenlist(token2)
elif isinstance(item, IdentifierList):
if any(not self._is_identifier(token2) for token2 in item.tokens):
self._extract_from_token(item)
def set_or_update_query_limit(self, new_limit: int, force: bool = False) -> str:
"""Returns the query with the specified limit.
@ -881,7 +965,7 @@ def insert_rls_in_predicate(
# mapping between sqloxide and SQLAlchemy dialects
SQLOXITE_DIALECTS = {
SQLOXIDE_DIALECTS = {
"ansi": {"trino", "trinonative", "presto"},
"hive": {"hive", "databricks"},
"ms": {"mssql"},
@ -914,7 +998,7 @@ def extract_table_references(
tree = None
if sqloxide_parse:
for dialect, sqla_dialects in SQLOXITE_DIALECTS.items():
for dialect, sqla_dialects in SQLOXIDE_DIALECTS.items():
if sqla_dialect in sqla_dialects:
break
sql_text = RE_JINJA_BLOCK.sub(" ", sql_text)

View File

@ -50,7 +50,7 @@ class PrestoDBSQLValidator(BaseSQLValidator):
) -> Optional[SQLValidationAnnotation]:
# pylint: disable=too-many-locals
db_engine_spec = database.db_engine_spec
parsed_query = ParsedQuery(statement)
parsed_query = ParsedQuery(statement, engine=db_engine_spec.engine)
sql = parsed_query.stripped()
# Hook to allow environment-specific mutation (usually comments) to the SQL
@ -154,7 +154,7 @@ class PrestoDBSQLValidator(BaseSQLValidator):
For example, "SELECT 1 FROM default.mytable" becomes "EXPLAIN (TYPE
VALIDATE) SELECT 1 FROM default.mytable.
"""
parsed_query = ParsedQuery(sql)
parsed_query = ParsedQuery(sql, engine=database.db_engine_spec.engine)
statements = parsed_query.get_statements()
logger.info("Validating %i statement(s)", len(statements))

View File

@ -58,7 +58,11 @@ class SqlQueryRenderImpl(SqlQueryRender):
database=query_model.database, query=query_model
)
parsed_query = ParsedQuery(query_model.sql, strip_comments=True)
parsed_query = ParsedQuery(
query_model.sql,
strip_comments=True,
engine=query_model.database.db_engine_spec.engine,
)
rendered_query = sql_template_processor.process_template(
parsed_query.stripped(), **execution_context.template_params
)

View File

@ -40,11 +40,11 @@ from superset.sql_parse import (
)
def extract_tables(query: str) -> set[Table]:
def extract_tables(query: str, engine: Optional[str] = None) -> set[Table]:
"""
Helper function to extract tables referenced in a query.
"""
return ParsedQuery(query).tables
return ParsedQuery(query, engine=engine).tables
def test_table() -> None:
@ -96,8 +96,13 @@ def test_extract_tables() -> None:
Table("left_table")
}
# reverse select
assert extract_tables("FROM t1 SELECT field") == {Table("t1")}
assert extract_tables(
"SELECT FROM (SELECT FROM forbidden_table) AS forbidden_table;"
) == {Table("forbidden_table")}
assert extract_tables(
"select * from (select * from forbidden_table) forbidden_table"
) == {Table("forbidden_table")}
def test_extract_tables_subselect() -> None:
@ -263,14 +268,16 @@ def test_extract_tables_illdefined() -> None:
assert extract_tables("SELECT * FROM schemaname.") == set()
assert extract_tables("SELECT * FROM catalogname.schemaname.") == set()
assert extract_tables("SELECT * FROM catalogname..") == set()
assert extract_tables("SELECT * FROM catalogname..tbname") == set()
assert extract_tables("SELECT * FROM catalogname..tbname") == {
Table(table="tbname", schema=None, catalog="catalogname")
}
def test_extract_tables_show_tables_from() -> None:
"""
Test ``SHOW TABLES FROM``.
"""
assert extract_tables("SHOW TABLES FROM s1 like '%order%'") == set()
assert extract_tables("SHOW TABLES FROM s1 like '%order%'", "mysql") == set()
def test_extract_tables_show_columns_from() -> None:
@ -311,7 +318,7 @@ WHERE regionkey IN (SELECT regionkey FROM t2)
"""
SELECT name
FROM t1
WHERE regionkey EXISTS (SELECT regionkey FROM t2)
WHERE EXISTS (SELECT 1 FROM t2 WHERE t1.regionkey = t2.regionkey);
"""
)
== {Table("t1"), Table("t2")}
@ -526,6 +533,18 @@ select * from (select key from q1) a
== {Table("src")}
)
# weird query with circular dependency
assert (
extract_tables(
"""
with src as ( select key from q2 where key = '5'),
q2 as ( select key from src where key = '5')
select * from (select key from src) a
"""
)
== set()
)
def test_extract_tables_multistatement() -> None:
"""
@ -665,7 +684,8 @@ def test_extract_tables_nested_select() -> None:
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(TABLE_NAME)
from INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
"""
""",
"mysql",
)
== {Table("COLUMNS", "INFORMATION_SCHEMA")}
)
@ -676,7 +696,8 @@ WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME)
from INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_NAME="bi_achievement_daily"),0x7e)));
"""
""",
"mysql",
)
== {Table("COLUMNS", "INFORMATION_SCHEMA")}
)
@ -1306,6 +1327,14 @@ def test_sqlparse_issue_652():
"(SELECT table_name FROM /**/ information_schema.tables WHERE table_name LIKE '%user%' LIMIT 1)",
True,
),
(
"SELECT FROM (SELECT FROM forbidden_table) AS forbidden_table;",
True,
),
(
"SELECT * FROM (SELECT * FROM forbidden_table) forbidden_table",
True,
),
],
)
def test_has_table_query(sql: str, expected: bool) -> None:
@ -1790,13 +1819,17 @@ def test_extract_table_references(mocker: MockerFixture) -> None:
assert extract_table_references(
sql,
"trino",
) == {Table(table="other_table", schema=None, catalog=None)}
) == {
Table(table="table", schema=None, catalog=None),
Table(table="other_table", schema=None, catalog=None),
}
logger.warning.assert_called_once()
logger = mocker.patch("superset.migrations.shared.utils.logger")
sql = "SELECT * FROM table UNION ALL SELECT * FROM other_table"
assert extract_table_references(sql, "trino", show_warning=False) == {
Table(table="other_table", schema=None, catalog=None)
Table(table="table", schema=None, catalog=None),
Table(table="other_table", schema=None, catalog=None),
}
logger.warning.assert_not_called()