mirror of https://github.com/apache/superset.git
feat(sqlparse): improve table parsing (#26476)
This commit is contained in:
parent
d34874cf2b
commit
c0b57bd1c3
|
@ -141,7 +141,9 @@ geographiclib==1.52
|
|||
geopy==2.2.0
|
||||
# via apache-superset
|
||||
greenlet==2.0.2
|
||||
# via shillelagh
|
||||
# via
|
||||
# shillelagh
|
||||
# sqlalchemy
|
||||
gunicorn==21.2.0
|
||||
# via apache-superset
|
||||
hashids==1.3.1
|
||||
|
@ -155,7 +157,10 @@ idna==3.2
|
|||
# email-validator
|
||||
# requests
|
||||
importlib-metadata==6.6.0
|
||||
# via apache-superset
|
||||
# via
|
||||
# apache-superset
|
||||
# flask
|
||||
# shillelagh
|
||||
importlib-resources==5.12.0
|
||||
# via limits
|
||||
isodate==0.6.0
|
||||
|
@ -327,6 +332,8 @@ sqlalchemy-utils==0.38.3
|
|||
# via
|
||||
# apache-superset
|
||||
# flask-appbuilder
|
||||
sqlglot==20.8.0
|
||||
# via apache-superset
|
||||
sqlparse==0.4.4
|
||||
# via apache-superset
|
||||
sshtunnel==0.4.0
|
||||
|
@ -376,7 +383,9 @@ wtforms-json==0.3.5
|
|||
xlsxwriter==3.0.7
|
||||
# via apache-superset
|
||||
zipp==3.15.0
|
||||
# via importlib-metadata
|
||||
# via
|
||||
# importlib-metadata
|
||||
# importlib-resources
|
||||
|
||||
# The following packages are considered to be unsafe in a requirements file:
|
||||
# setuptools
|
||||
|
|
|
@ -24,10 +24,6 @@ db-dtypes==1.1.1
|
|||
# via pandas-gbq
|
||||
docker==6.1.1
|
||||
# via -r requirements/testing.in
|
||||
exceptiongroup==1.1.1
|
||||
# via pytest
|
||||
ephem==4.1.4
|
||||
# via lunarcalendar
|
||||
flask-testing==0.8.1
|
||||
# via -r requirements/testing.in
|
||||
fonttools==4.39.4
|
||||
|
@ -121,6 +117,8 @@ pyee==9.0.4
|
|||
# via playwright
|
||||
pyfakefs==5.2.2
|
||||
# via -r requirements/testing.in
|
||||
pyhive[presto]==0.7.0
|
||||
# via apache-superset
|
||||
pytest==7.3.1
|
||||
# via
|
||||
# -r requirements/testing.in
|
||||
|
|
1
setup.py
1
setup.py
|
@ -125,6 +125,7 @@ setup(
|
|||
"slack_sdk>=3.19.0, <4",
|
||||
"sqlalchemy>=1.4, <2",
|
||||
"sqlalchemy-utils>=0.38.3, <0.39",
|
||||
"sqlglot>=20,<21",
|
||||
"sqlparse>=0.4.4, <0.5",
|
||||
"tabulate>=0.8.9, <0.9",
|
||||
"typing-extensions>=4, <5",
|
||||
|
|
|
@ -70,7 +70,10 @@ class DuplicateDatasetCommand(CreateMixin, BaseCommand):
|
|||
table.normalize_columns = self._base_model.normalize_columns
|
||||
table.always_filter_main_dttm = self._base_model.always_filter_main_dttm
|
||||
table.is_sqllab_view = True
|
||||
table.sql = ParsedQuery(self._base_model.sql).stripped()
|
||||
table.sql = ParsedQuery(
|
||||
self._base_model.sql,
|
||||
engine=database.db_engine_spec.engine,
|
||||
).stripped()
|
||||
db.session.add(table)
|
||||
cols = []
|
||||
for config_ in self._base_model.columns:
|
||||
|
|
|
@ -115,7 +115,10 @@ class SqlResultExportCommand(BaseCommand):
|
|||
limit = None
|
||||
else:
|
||||
sql = self._query.executed_sql
|
||||
limit = ParsedQuery(sql).limit
|
||||
limit = ParsedQuery(
|
||||
sql,
|
||||
engine=self._query.database.db_engine_spec.engine,
|
||||
).limit
|
||||
if limit is not None and self._query.limiting_factor in {
|
||||
LimitingFactor.QUERY,
|
||||
LimitingFactor.DROPDOWN,
|
||||
|
|
|
@ -1457,7 +1457,7 @@ class SqlaTable(
|
|||
return self.get_sqla_table(), None
|
||||
|
||||
from_sql = self.get_rendered_sql(template_processor)
|
||||
parsed_query = ParsedQuery(from_sql)
|
||||
parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine)
|
||||
if not (
|
||||
parsed_query.is_unknown()
|
||||
or self.db_engine_spec.is_readonly_query(parsed_query)
|
||||
|
|
|
@ -111,7 +111,7 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> list[ResultSetColumnType]:
|
|||
sql = dataset.get_template_processor().process_template(
|
||||
dataset.sql, **dataset.template_params_dict
|
||||
)
|
||||
parsed_query = ParsedQuery(sql)
|
||||
parsed_query = ParsedQuery(sql, engine=db_engine_spec.engine)
|
||||
if not db_engine_spec.is_readonly_query(parsed_query):
|
||||
raise SupersetSecurityException(
|
||||
SupersetError(
|
||||
|
|
|
@ -899,7 +899,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
return database.compile_sqla_query(qry)
|
||||
|
||||
if cls.limit_method == LimitMethod.FORCE_LIMIT:
|
||||
parsed_query = sql_parse.ParsedQuery(sql)
|
||||
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
|
||||
sql = parsed_query.set_or_update_query_limit(limit, force=force)
|
||||
|
||||
return sql
|
||||
|
@ -980,7 +980,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
:param sql: SQL query
|
||||
:return: Value of limit clause in query
|
||||
"""
|
||||
parsed_query = sql_parse.ParsedQuery(sql)
|
||||
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
|
||||
return parsed_query.limit
|
||||
|
||||
@classmethod
|
||||
|
@ -992,7 +992,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
:param limit: New limit to insert/replace into query
|
||||
:return: Query with new limit
|
||||
"""
|
||||
parsed_query = sql_parse.ParsedQuery(sql)
|
||||
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
|
||||
return parsed_query.set_or_update_query_limit(limit)
|
||||
|
||||
@classmethod
|
||||
|
@ -1487,7 +1487,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
:param database: Database instance
|
||||
:return: Dictionary with different costs
|
||||
"""
|
||||
parsed_query = ParsedQuery(statement)
|
||||
parsed_query = ParsedQuery(statement, engine=cls.engine)
|
||||
sql = parsed_query.stripped()
|
||||
sql_query_mutator = current_app.config["SQL_QUERY_MUTATOR"]
|
||||
mutate_after_split = current_app.config["MUTATE_AFTER_SPLIT"]
|
||||
|
@ -1522,7 +1522,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
"Database does not support cost estimation"
|
||||
)
|
||||
|
||||
parsed_query = sql_parse.ParsedQuery(sql)
|
||||
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
|
||||
statements = parsed_query.get_statements()
|
||||
|
||||
costs = []
|
||||
|
@ -1583,7 +1583,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
:return:
|
||||
"""
|
||||
if not cls.allows_sql_comments:
|
||||
query = sql_parse.strip_comments_from_sql(query)
|
||||
query = sql_parse.strip_comments_from_sql(query, engine=cls.engine)
|
||||
|
||||
if cls.arraysize:
|
||||
cursor.arraysize = cls.arraysize
|
||||
|
|
|
@ -435,7 +435,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
|
|||
if not cls.get_allow_cost_estimate(extra):
|
||||
raise SupersetException("Database does not support cost estimation")
|
||||
|
||||
parsed_query = sql_parse.ParsedQuery(sql)
|
||||
parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
|
||||
statements = parsed_query.get_statements()
|
||||
costs = []
|
||||
for statement in statements:
|
||||
|
|
|
@ -1093,7 +1093,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
|||
"""
|
||||
|
||||
from_sql = self.get_rendered_sql(template_processor)
|
||||
parsed_query = ParsedQuery(from_sql)
|
||||
parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine)
|
||||
if not (
|
||||
parsed_query.is_unknown()
|
||||
or self.db_engine_spec.is_readonly_query(parsed_query)
|
||||
|
|
|
@ -183,7 +183,7 @@ class Query(
|
|||
|
||||
@property
|
||||
def sql_tables(self) -> list[Table]:
|
||||
return list(ParsedQuery(self.sql).tables)
|
||||
return list(ParsedQuery(self.sql, engine=self.db_engine_spec.engine).tables)
|
||||
|
||||
@property
|
||||
def columns(self) -> list["TableColumn"]:
|
||||
|
@ -427,7 +427,9 @@ class SavedQuery(AuditMixinNullable, ExtraJSONMixin, ImportExportMixin, Model):
|
|||
|
||||
@property
|
||||
def sql_tables(self) -> list[Table]:
|
||||
return list(ParsedQuery(self.sql).tables)
|
||||
return list(
|
||||
ParsedQuery(self.sql, engine=self.database.db_engine_spec.engine).tables
|
||||
)
|
||||
|
||||
@property
|
||||
def last_run_humanized(self) -> str:
|
||||
|
|
|
@ -1876,7 +1876,10 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
default_schema = database.get_default_schema_for_query(query)
|
||||
tables = {
|
||||
Table(table_.table, table_.schema or default_schema)
|
||||
for table_ in sql_parse.ParsedQuery(query.sql).tables
|
||||
for table_ in sql_parse.ParsedQuery(
|
||||
query.sql,
|
||||
engine=database.db_engine_spec.engine,
|
||||
).tables
|
||||
}
|
||||
elif table:
|
||||
tables = {table}
|
||||
|
|
|
@ -199,7 +199,7 @@ def execute_sql_statement(
|
|||
database: Database = query.database
|
||||
db_engine_spec = database.db_engine_spec
|
||||
|
||||
parsed_query = ParsedQuery(sql_statement)
|
||||
parsed_query = ParsedQuery(sql_statement, engine=db_engine_spec.engine)
|
||||
if is_feature_enabled("RLS_IN_SQLLAB"):
|
||||
# There are two ways to insert RLS: either replacing the table with a subquery
|
||||
# that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is
|
||||
|
@ -219,7 +219,8 @@ def execute_sql_statement(
|
|||
database.id,
|
||||
query.schema,
|
||||
)
|
||||
)
|
||||
),
|
||||
engine=db_engine_spec.engine,
|
||||
)
|
||||
|
||||
sql = parsed_query.stripped()
|
||||
|
@ -409,7 +410,11 @@ def execute_sql_statements(
|
|||
)
|
||||
|
||||
# Breaking down into multiple statements
|
||||
parsed_query = ParsedQuery(rendered_query, strip_comments=True)
|
||||
parsed_query = ParsedQuery(
|
||||
rendered_query,
|
||||
strip_comments=True,
|
||||
engine=db_engine_spec.engine,
|
||||
)
|
||||
if not db_engine_spec.run_multiple_statements_as_one:
|
||||
statements = parsed_query.get_statements()
|
||||
logger.info(
|
||||
|
|
|
@ -14,15 +14,22 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Iterator
|
||||
import urllib.parse
|
||||
from collections.abc import Iterable, Iterator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast, Optional
|
||||
from urllib import parse
|
||||
|
||||
import sqlparse
|
||||
from sqlalchemy import and_
|
||||
from sqlglot import exp, parse, parse_one
|
||||
from sqlglot.dialects import Dialects
|
||||
from sqlglot.errors import ParseError
|
||||
from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
|
||||
from sqlparse import keywords
|
||||
from sqlparse.lexer import Lexer
|
||||
from sqlparse.sql import (
|
||||
|
@ -53,7 +60,7 @@ from superset.utils.backports import StrEnum
|
|||
|
||||
try:
|
||||
from sqloxide import parse_sql as sqloxide_parse
|
||||
except: # pylint: disable=bare-except
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
sqloxide_parse = None
|
||||
|
||||
RESULT_OPERATIONS = {"UNION", "INTERSECT", "EXCEPT", "SELECT"}
|
||||
|
@ -72,6 +79,59 @@ sqlparser_sql_regex.insert(25, (r"'(''|\\\\|\\|[^'])*'", sqlparse.tokens.String.
|
|||
lex.set_SQL_REGEX(sqlparser_sql_regex)
|
||||
|
||||
|
||||
# mapping between DB engine specs and sqlglot dialects
|
||||
SQLGLOT_DIALECTS = {
|
||||
"ascend": Dialects.HIVE,
|
||||
"awsathena": Dialects.PRESTO,
|
||||
"bigquery": Dialects.BIGQUERY,
|
||||
"clickhouse": Dialects.CLICKHOUSE,
|
||||
"clickhousedb": Dialects.CLICKHOUSE,
|
||||
"cockroachdb": Dialects.POSTGRES,
|
||||
# "crate": ???
|
||||
# "databend": ???
|
||||
"databricks": Dialects.DATABRICKS,
|
||||
# "db2": ???
|
||||
# "dremio": ???
|
||||
"drill": Dialects.DRILL,
|
||||
# "druid": ???
|
||||
"duckdb": Dialects.DUCKDB,
|
||||
# "dynamodb": ???
|
||||
# "elasticsearch": ???
|
||||
# "exa": ???
|
||||
# "firebird": ???
|
||||
# "firebolt": ???
|
||||
"gsheets": Dialects.SQLITE,
|
||||
"hana": Dialects.POSTGRES,
|
||||
"hive": Dialects.HIVE,
|
||||
# "ibmi": ???
|
||||
# "impala": ???
|
||||
# "kustokql": ???
|
||||
# "kylin": ???
|
||||
# "mssql": ???
|
||||
"mysql": Dialects.MYSQL,
|
||||
"netezza": Dialects.POSTGRES,
|
||||
# "ocient": ???
|
||||
# "odelasticsearch": ???
|
||||
"oracle": Dialects.ORACLE,
|
||||
# "pinot": ???
|
||||
"postgresql": Dialects.POSTGRES,
|
||||
"presto": Dialects.PRESTO,
|
||||
"pydoris": Dialects.DORIS,
|
||||
"redshift": Dialects.REDSHIFT,
|
||||
# "risingwave": ???
|
||||
# "rockset": ???
|
||||
"shillelagh": Dialects.SQLITE,
|
||||
"snowflake": Dialects.SNOWFLAKE,
|
||||
# "solr": ???
|
||||
"sqlite": Dialects.SQLITE,
|
||||
"starrocks": Dialects.STARROCKS,
|
||||
"superset": Dialects.SQLITE,
|
||||
"teradatasql": Dialects.TERADATA,
|
||||
"trino": Dialects.TRINO,
|
||||
"vertica": Dialects.POSTGRES,
|
||||
}
|
||||
|
||||
|
||||
class CtasMethod(StrEnum):
|
||||
TABLE = "TABLE"
|
||||
VIEW = "VIEW"
|
||||
|
@ -150,7 +210,7 @@ def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
|
|||
return cte, remainder
|
||||
|
||||
|
||||
def strip_comments_from_sql(statement: str) -> str:
|
||||
def strip_comments_from_sql(statement: str, engine: Optional[str] = None) -> str:
|
||||
"""
|
||||
Strips comments from a SQL statement, does a simple test first
|
||||
to avoid always instantiating the expensive ParsedQuery constructor
|
||||
|
@ -160,7 +220,11 @@ def strip_comments_from_sql(statement: str) -> str:
|
|||
:param statement: A string with the SQL statement
|
||||
:return: SQL statement without comments
|
||||
"""
|
||||
return ParsedQuery(statement).strip_comments() if "--" in statement else statement
|
||||
return (
|
||||
ParsedQuery(statement, engine=engine).strip_comments()
|
||||
if "--" in statement
|
||||
else statement
|
||||
)
|
||||
|
||||
|
||||
@dataclass(eq=True, frozen=True)
|
||||
|
@ -179,7 +243,7 @@ class Table:
|
|||
"""
|
||||
|
||||
return ".".join(
|
||||
parse.quote(part, safe="").replace(".", "%2E")
|
||||
urllib.parse.quote(part, safe="").replace(".", "%2E")
|
||||
for part in [self.catalog, self.schema, self.table]
|
||||
if part
|
||||
)
|
||||
|
@ -189,11 +253,17 @@ class Table:
|
|||
|
||||
|
||||
class ParsedQuery:
|
||||
def __init__(self, sql_statement: str, strip_comments: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
sql_statement: str,
|
||||
strip_comments: bool = False,
|
||||
engine: Optional[str] = None,
|
||||
):
|
||||
if strip_comments:
|
||||
sql_statement = sqlparse.format(sql_statement, strip_comments=True)
|
||||
|
||||
self.sql: str = sql_statement
|
||||
self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
|
||||
self._tables: set[Table] = set()
|
||||
self._alias_names: set[str] = set()
|
||||
self._limit: Optional[int] = None
|
||||
|
@ -206,14 +276,94 @@ class ParsedQuery:
|
|||
@property
|
||||
def tables(self) -> set[Table]:
|
||||
if not self._tables:
|
||||
for statement in self._parsed:
|
||||
self._extract_from_token(statement)
|
||||
|
||||
self._tables = {
|
||||
table for table in self._tables if str(table) not in self._alias_names
|
||||
}
|
||||
self._tables = self._extract_tables_from_sql()
|
||||
return self._tables
|
||||
|
||||
def _extract_tables_from_sql(self) -> set[Table]:
|
||||
"""
|
||||
Extract all table references in a query.
|
||||
|
||||
Note: this uses sqlglot, since it's better at catching more edge cases.
|
||||
"""
|
||||
try:
|
||||
statements = parse(self.sql, dialect=self._dialect)
|
||||
except ParseError:
|
||||
logger.warning("Unable to parse SQL (%s): %s", self._dialect, self.sql)
|
||||
return set()
|
||||
|
||||
return {
|
||||
table
|
||||
for statement in statements
|
||||
for table in self._extract_tables_from_statement(statement)
|
||||
if statement
|
||||
}
|
||||
|
||||
def _extract_tables_from_statement(self, statement: exp.Expression) -> set[Table]:
|
||||
"""
|
||||
Extract all table references in a single statement.
|
||||
|
||||
Please not that this is not trivial; consider the following queries:
|
||||
|
||||
DESCRIBE some_table;
|
||||
SHOW PARTITIONS FROM some_table;
|
||||
WITH masked_name AS (SELECT * FROM some_table) SELECT * FROM masked_name;
|
||||
|
||||
See the unit tests for other tricky cases.
|
||||
"""
|
||||
sources: Iterable[exp.Table]
|
||||
|
||||
if isinstance(statement, exp.Describe):
|
||||
# A `DESCRIBE` query has no sources in sqlglot, so we need to explicitly
|
||||
# query for all tables.
|
||||
sources = statement.find_all(exp.Table)
|
||||
elif isinstance(statement, exp.Command):
|
||||
# Commands, like `SHOW COLUMNS FROM foo`, have to be converted into a
|
||||
# `SELECT` statetement in order to extract tables.
|
||||
literal = statement.find(exp.Literal)
|
||||
if not literal:
|
||||
return set()
|
||||
|
||||
pseudo_query = parse_one(f"SELECT {literal.this}", dialect=self._dialect)
|
||||
sources = pseudo_query.find_all(exp.Table)
|
||||
else:
|
||||
sources = [
|
||||
source
|
||||
for scope in traverse_scope(statement)
|
||||
for source in scope.sources.values()
|
||||
if isinstance(source, exp.Table) and not self._is_cte(source, scope)
|
||||
]
|
||||
|
||||
return {
|
||||
Table(
|
||||
source.name,
|
||||
source.db if source.db != "" else None,
|
||||
source.catalog if source.catalog != "" else None,
|
||||
)
|
||||
for source in sources
|
||||
}
|
||||
|
||||
def _is_cte(self, source: exp.Table, scope: Scope) -> bool:
|
||||
"""
|
||||
Is the source a CTE?
|
||||
|
||||
CTEs in the parent scope look like tables (and are represented by
|
||||
exp.Table objects), but should not be considered as such;
|
||||
otherwise a user with access to table `foo` could access any table
|
||||
with a query like this:
|
||||
|
||||
WITH foo AS (SELECT * FROM target_table) SELECT * FROM foo
|
||||
|
||||
"""
|
||||
parent_sources = scope.parent.sources if scope.parent else {}
|
||||
ctes_in_scope = {
|
||||
name
|
||||
for name, parent_scope in parent_sources.items()
|
||||
if isinstance(parent_scope, Scope)
|
||||
and parent_scope.scope_type == ScopeType.CTE
|
||||
}
|
||||
|
||||
return source.name in ctes_in_scope
|
||||
|
||||
@property
|
||||
def limit(self) -> Optional[int]:
|
||||
return self._limit
|
||||
|
@ -393,28 +543,6 @@ class ParsedQuery:
|
|||
def _is_identifier(token: Token) -> bool:
|
||||
return isinstance(token, (IdentifierList, Identifier))
|
||||
|
||||
def _process_tokenlist(self, token_list: TokenList) -> None:
|
||||
"""
|
||||
Add table names to table set
|
||||
|
||||
:param token_list: TokenList to be processed
|
||||
"""
|
||||
# exclude subselects
|
||||
if "(" not in str(token_list):
|
||||
table = self.get_table(token_list)
|
||||
if table and not table.table.startswith(CTE_PREFIX):
|
||||
self._tables.add(table)
|
||||
return
|
||||
|
||||
# store aliases
|
||||
if token_list.has_alias():
|
||||
self._alias_names.add(token_list.get_alias())
|
||||
|
||||
# some aliases are not parsed properly
|
||||
if token_list.tokens[0].ttype == Name:
|
||||
self._alias_names.add(token_list.tokens[0].value)
|
||||
self._extract_from_token(token_list)
|
||||
|
||||
def as_create_table(
|
||||
self,
|
||||
table_name: str,
|
||||
|
@ -441,50 +569,6 @@ class ParsedQuery:
|
|||
exec_sql += f"CREATE {method} {full_table_name} AS \n{sql}"
|
||||
return exec_sql
|
||||
|
||||
def _extract_from_token(self, token: Token) -> None:
|
||||
"""
|
||||
<Identifier> store a list of subtokens and <IdentifierList> store lists of
|
||||
subtoken list.
|
||||
|
||||
It extracts <IdentifierList> and <Identifier> from :param token: and loops
|
||||
through all subtokens recursively. It finds table_name_preceding_token and
|
||||
passes <IdentifierList> and <Identifier> to self._process_tokenlist to populate
|
||||
self._tables.
|
||||
|
||||
:param token: instance of Token or child class, e.g. TokenList, to be processed
|
||||
"""
|
||||
if not hasattr(token, "tokens"):
|
||||
return
|
||||
|
||||
table_name_preceding_token = False
|
||||
|
||||
for item in token.tokens:
|
||||
if item.is_group and (
|
||||
not self._is_identifier(item) or isinstance(item.tokens[0], Parenthesis)
|
||||
):
|
||||
self._extract_from_token(item)
|
||||
|
||||
if item.ttype in Keyword and (
|
||||
item.normalized in PRECEDES_TABLE_NAME
|
||||
or item.normalized.endswith(" JOIN")
|
||||
):
|
||||
table_name_preceding_token = True
|
||||
continue
|
||||
|
||||
if item.ttype in Keyword:
|
||||
table_name_preceding_token = False
|
||||
continue
|
||||
if table_name_preceding_token:
|
||||
if isinstance(item, Identifier):
|
||||
self._process_tokenlist(item)
|
||||
elif isinstance(item, IdentifierList):
|
||||
for token2 in item.get_identifiers():
|
||||
if isinstance(token2, TokenList):
|
||||
self._process_tokenlist(token2)
|
||||
elif isinstance(item, IdentifierList):
|
||||
if any(not self._is_identifier(token2) for token2 in item.tokens):
|
||||
self._extract_from_token(item)
|
||||
|
||||
def set_or_update_query_limit(self, new_limit: int, force: bool = False) -> str:
|
||||
"""Returns the query with the specified limit.
|
||||
|
||||
|
@ -881,7 +965,7 @@ def insert_rls_in_predicate(
|
|||
|
||||
|
||||
# mapping between sqloxide and SQLAlchemy dialects
|
||||
SQLOXITE_DIALECTS = {
|
||||
SQLOXIDE_DIALECTS = {
|
||||
"ansi": {"trino", "trinonative", "presto"},
|
||||
"hive": {"hive", "databricks"},
|
||||
"ms": {"mssql"},
|
||||
|
@ -914,7 +998,7 @@ def extract_table_references(
|
|||
tree = None
|
||||
|
||||
if sqloxide_parse:
|
||||
for dialect, sqla_dialects in SQLOXITE_DIALECTS.items():
|
||||
for dialect, sqla_dialects in SQLOXIDE_DIALECTS.items():
|
||||
if sqla_dialect in sqla_dialects:
|
||||
break
|
||||
sql_text = RE_JINJA_BLOCK.sub(" ", sql_text)
|
||||
|
|
|
@ -50,7 +50,7 @@ class PrestoDBSQLValidator(BaseSQLValidator):
|
|||
) -> Optional[SQLValidationAnnotation]:
|
||||
# pylint: disable=too-many-locals
|
||||
db_engine_spec = database.db_engine_spec
|
||||
parsed_query = ParsedQuery(statement)
|
||||
parsed_query = ParsedQuery(statement, engine=db_engine_spec.engine)
|
||||
sql = parsed_query.stripped()
|
||||
|
||||
# Hook to allow environment-specific mutation (usually comments) to the SQL
|
||||
|
@ -154,7 +154,7 @@ class PrestoDBSQLValidator(BaseSQLValidator):
|
|||
For example, "SELECT 1 FROM default.mytable" becomes "EXPLAIN (TYPE
|
||||
VALIDATE) SELECT 1 FROM default.mytable.
|
||||
"""
|
||||
parsed_query = ParsedQuery(sql)
|
||||
parsed_query = ParsedQuery(sql, engine=database.db_engine_spec.engine)
|
||||
statements = parsed_query.get_statements()
|
||||
|
||||
logger.info("Validating %i statement(s)", len(statements))
|
||||
|
|
|
@ -58,7 +58,11 @@ class SqlQueryRenderImpl(SqlQueryRender):
|
|||
database=query_model.database, query=query_model
|
||||
)
|
||||
|
||||
parsed_query = ParsedQuery(query_model.sql, strip_comments=True)
|
||||
parsed_query = ParsedQuery(
|
||||
query_model.sql,
|
||||
strip_comments=True,
|
||||
engine=query_model.database.db_engine_spec.engine,
|
||||
)
|
||||
rendered_query = sql_template_processor.process_template(
|
||||
parsed_query.stripped(), **execution_context.template_params
|
||||
)
|
||||
|
|
|
@ -40,11 +40,11 @@ from superset.sql_parse import (
|
|||
)
|
||||
|
||||
|
||||
def extract_tables(query: str) -> set[Table]:
|
||||
def extract_tables(query: str, engine: Optional[str] = None) -> set[Table]:
|
||||
"""
|
||||
Helper function to extract tables referenced in a query.
|
||||
"""
|
||||
return ParsedQuery(query).tables
|
||||
return ParsedQuery(query, engine=engine).tables
|
||||
|
||||
|
||||
def test_table() -> None:
|
||||
|
@ -96,8 +96,13 @@ def test_extract_tables() -> None:
|
|||
Table("left_table")
|
||||
}
|
||||
|
||||
# reverse select
|
||||
assert extract_tables("FROM t1 SELECT field") == {Table("t1")}
|
||||
assert extract_tables(
|
||||
"SELECT FROM (SELECT FROM forbidden_table) AS forbidden_table;"
|
||||
) == {Table("forbidden_table")}
|
||||
|
||||
assert extract_tables(
|
||||
"select * from (select * from forbidden_table) forbidden_table"
|
||||
) == {Table("forbidden_table")}
|
||||
|
||||
|
||||
def test_extract_tables_subselect() -> None:
|
||||
|
@ -263,14 +268,16 @@ def test_extract_tables_illdefined() -> None:
|
|||
assert extract_tables("SELECT * FROM schemaname.") == set()
|
||||
assert extract_tables("SELECT * FROM catalogname.schemaname.") == set()
|
||||
assert extract_tables("SELECT * FROM catalogname..") == set()
|
||||
assert extract_tables("SELECT * FROM catalogname..tbname") == set()
|
||||
assert extract_tables("SELECT * FROM catalogname..tbname") == {
|
||||
Table(table="tbname", schema=None, catalog="catalogname")
|
||||
}
|
||||
|
||||
|
||||
def test_extract_tables_show_tables_from() -> None:
|
||||
"""
|
||||
Test ``SHOW TABLES FROM``.
|
||||
"""
|
||||
assert extract_tables("SHOW TABLES FROM s1 like '%order%'") == set()
|
||||
assert extract_tables("SHOW TABLES FROM s1 like '%order%'", "mysql") == set()
|
||||
|
||||
|
||||
def test_extract_tables_show_columns_from() -> None:
|
||||
|
@ -311,7 +318,7 @@ WHERE regionkey IN (SELECT regionkey FROM t2)
|
|||
"""
|
||||
SELECT name
|
||||
FROM t1
|
||||
WHERE regionkey EXISTS (SELECT regionkey FROM t2)
|
||||
WHERE EXISTS (SELECT 1 FROM t2 WHERE t1.regionkey = t2.regionkey);
|
||||
"""
|
||||
)
|
||||
== {Table("t1"), Table("t2")}
|
||||
|
@ -526,6 +533,18 @@ select * from (select key from q1) a
|
|||
== {Table("src")}
|
||||
)
|
||||
|
||||
# weird query with circular dependency
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
with src as ( select key from q2 where key = '5'),
|
||||
q2 as ( select key from src where key = '5')
|
||||
select * from (select key from src) a
|
||||
"""
|
||||
)
|
||||
== set()
|
||||
)
|
||||
|
||||
|
||||
def test_extract_tables_multistatement() -> None:
|
||||
"""
|
||||
|
@ -665,7 +684,8 @@ def test_extract_tables_nested_select() -> None:
|
|||
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(TABLE_NAME)
|
||||
from INFORMATION_SCHEMA.COLUMNS
|
||||
WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
|
||||
"""
|
||||
""",
|
||||
"mysql",
|
||||
)
|
||||
== {Table("COLUMNS", "INFORMATION_SCHEMA")}
|
||||
)
|
||||
|
@ -676,7 +696,8 @@ WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
|
|||
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME)
|
||||
from INFORMATION_SCHEMA.COLUMNS
|
||||
WHERE TABLE_NAME="bi_achievement_daily"),0x7e)));
|
||||
"""
|
||||
""",
|
||||
"mysql",
|
||||
)
|
||||
== {Table("COLUMNS", "INFORMATION_SCHEMA")}
|
||||
)
|
||||
|
@ -1306,6 +1327,14 @@ def test_sqlparse_issue_652():
|
|||
"(SELECT table_name FROM /**/ information_schema.tables WHERE table_name LIKE '%user%' LIMIT 1)",
|
||||
True,
|
||||
),
|
||||
(
|
||||
"SELECT FROM (SELECT FROM forbidden_table) AS forbidden_table;",
|
||||
True,
|
||||
),
|
||||
(
|
||||
"SELECT * FROM (SELECT * FROM forbidden_table) forbidden_table",
|
||||
True,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_has_table_query(sql: str, expected: bool) -> None:
|
||||
|
@ -1790,13 +1819,17 @@ def test_extract_table_references(mocker: MockerFixture) -> None:
|
|||
assert extract_table_references(
|
||||
sql,
|
||||
"trino",
|
||||
) == {Table(table="other_table", schema=None, catalog=None)}
|
||||
) == {
|
||||
Table(table="table", schema=None, catalog=None),
|
||||
Table(table="other_table", schema=None, catalog=None),
|
||||
}
|
||||
logger.warning.assert_called_once()
|
||||
|
||||
logger = mocker.patch("superset.migrations.shared.utils.logger")
|
||||
sql = "SELECT * FROM table UNION ALL SELECT * FROM other_table"
|
||||
assert extract_table_references(sql, "trino", show_warning=False) == {
|
||||
Table(table="other_table", schema=None, catalog=None)
|
||||
Table(table="table", schema=None, catalog=None),
|
||||
Table(table="other_table", schema=None, catalog=None),
|
||||
}
|
||||
logger.warning.assert_not_called()
|
||||
|
||||
|
|
Loading…
Reference in New Issue