mirror of https://github.com/apache/superset.git
chore: add annotations to `sql_parse.py` (#27520)
This commit is contained in:
parent
d2c90013fc
commit
024b88a40d
|
@ -17,12 +17,14 @@
|
|||
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
import urllib.parse
|
||||
from collections.abc import Iterable, Iterator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast, Optional, Union
|
||||
from typing import Any, cast
|
||||
|
||||
import sqlglot
|
||||
import sqlparse
|
||||
|
@ -138,7 +140,7 @@ class CtasMethod(StrEnum):
|
|||
VIEW = "VIEW"
|
||||
|
||||
|
||||
def _extract_limit_from_query(statement: TokenList) -> Optional[int]:
|
||||
def _extract_limit_from_query(statement: TokenList) -> int | None:
|
||||
"""
|
||||
Extract limit clause from SQL statement.
|
||||
|
||||
|
@ -159,9 +161,7 @@ def _extract_limit_from_query(statement: TokenList) -> Optional[int]:
|
|||
return None
|
||||
|
||||
|
||||
def extract_top_from_query(
|
||||
statement: TokenList, top_keywords: set[str]
|
||||
) -> Optional[int]:
|
||||
def extract_top_from_query(statement: TokenList, top_keywords: set[str]) -> int | None:
|
||||
"""
|
||||
Extract top clause value from SQL statement.
|
||||
|
||||
|
@ -185,7 +185,7 @@ def extract_top_from_query(
|
|||
return top
|
||||
|
||||
|
||||
def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
|
||||
def get_cte_remainder_query(sql: str) -> tuple[str | None, str]:
|
||||
"""
|
||||
parse the SQL and return the CTE and rest of the block to the caller
|
||||
|
||||
|
@ -193,7 +193,7 @@ def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
|
|||
:return: CTE and remainder block to the caller
|
||||
|
||||
"""
|
||||
cte: Optional[str] = None
|
||||
cte: str | None = None
|
||||
remainder = sql
|
||||
stmt = sqlparse.parse(sql)[0]
|
||||
|
||||
|
@ -211,7 +211,7 @@ def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
|
|||
return cte, remainder
|
||||
|
||||
|
||||
def strip_comments_from_sql(statement: str, engine: Optional[str] = None) -> str:
|
||||
def strip_comments_from_sql(statement: str, engine: str | None = None) -> str:
|
||||
"""
|
||||
Strips comments from a SQL statement, does a simple test first
|
||||
to avoid always instantiating the expensive ParsedQuery constructor
|
||||
|
@ -235,8 +235,8 @@ class Table:
|
|||
"""
|
||||
|
||||
table: str
|
||||
schema: Optional[str] = None
|
||||
catalog: Optional[str] = None
|
||||
schema: str | None = None
|
||||
catalog: str | None = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
|
@ -255,7 +255,7 @@ class Table:
|
|||
|
||||
def extract_tables_from_statement(
|
||||
statement: exp.Expression,
|
||||
dialect: Optional[Dialects],
|
||||
dialect: Dialects | None,
|
||||
) -> set[Table]:
|
||||
"""
|
||||
Extract all table references in a single statement.
|
||||
|
@ -334,7 +334,7 @@ class SQLScript:
|
|||
def __init__(
|
||||
self,
|
||||
query: str,
|
||||
engine: Optional[str] = None,
|
||||
engine: str | None = None,
|
||||
):
|
||||
dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
|
||||
|
||||
|
@ -375,8 +375,8 @@ class SQLStatement:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
statement: Union[str, exp.Expression],
|
||||
engine: Optional[str] = None,
|
||||
statement: str | exp.Expression,
|
||||
engine: str | None = None,
|
||||
):
|
||||
dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
|
||||
|
||||
|
@ -394,7 +394,7 @@ class SQLStatement:
|
|||
@staticmethod
|
||||
def _parse_statement(
|
||||
sql_statement: str,
|
||||
dialect: Optional[Dialects],
|
||||
dialect: Dialects | None,
|
||||
) -> exp.Expression:
|
||||
"""
|
||||
Parse a single SQL statement.
|
||||
|
@ -437,7 +437,7 @@ class ParsedQuery:
|
|||
self,
|
||||
sql_statement: str,
|
||||
strip_comments: bool = False,
|
||||
engine: Optional[str] = None,
|
||||
engine: str | None = None,
|
||||
):
|
||||
if strip_comments:
|
||||
sql_statement = sqlparse.format(sql_statement, strip_comments=True)
|
||||
|
@ -446,7 +446,7 @@ class ParsedQuery:
|
|||
self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
|
||||
self._tables: set[Table] = set()
|
||||
self._alias_names: set[str] = set()
|
||||
self._limit: Optional[int] = None
|
||||
self._limit: int | None = None
|
||||
|
||||
logger.debug("Parsing with sqlparse statement: %s", self.sql)
|
||||
self._parsed = sqlparse.parse(self.stripped())
|
||||
|
@ -550,7 +550,7 @@ class ParsedQuery:
|
|||
return source.name in ctes_in_scope
|
||||
|
||||
@property
|
||||
def limit(self) -> Optional[int]:
|
||||
def limit(self) -> int | None:
|
||||
return self._limit
|
||||
|
||||
def _get_cte_tables(self, parsed: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
|
@ -631,7 +631,7 @@ class ParsedQuery:
|
|||
|
||||
return True
|
||||
|
||||
def get_inner_cte_expression(self, tokens: TokenList) -> Optional[TokenList]:
|
||||
def get_inner_cte_expression(self, tokens: TokenList) -> TokenList | None:
|
||||
for token in tokens:
|
||||
if self._is_identifier(token):
|
||||
for identifier_token in token.tokens:
|
||||
|
@ -695,7 +695,7 @@ class ParsedQuery:
|
|||
return statements
|
||||
|
||||
@staticmethod
|
||||
def get_table(tlist: TokenList) -> Optional[Table]:
|
||||
def get_table(tlist: TokenList) -> Table | None:
|
||||
"""
|
||||
Return the table if valid, i.e., conforms to the [[catalog.]schema.]table
|
||||
construct.
|
||||
|
@ -731,7 +731,7 @@ class ParsedQuery:
|
|||
def as_create_table(
|
||||
self,
|
||||
table_name: str,
|
||||
schema_name: Optional[str] = None,
|
||||
schema_name: str | None = None,
|
||||
overwrite: bool = False,
|
||||
method: CtasMethod = CtasMethod.TABLE,
|
||||
) -> str:
|
||||
|
@ -891,8 +891,8 @@ def add_table_name(rls: TokenList, table: str) -> None:
|
|||
def get_rls_for_table(
|
||||
candidate: Token,
|
||||
database_id: int,
|
||||
default_schema: Optional[str],
|
||||
) -> Optional[TokenList]:
|
||||
default_schema: str | None,
|
||||
) -> TokenList | None:
|
||||
"""
|
||||
Given a table name, return any associated RLS predicates.
|
||||
"""
|
||||
|
@ -938,7 +938,7 @@ def get_rls_for_table(
|
|||
def insert_rls_as_subquery(
|
||||
token_list: TokenList,
|
||||
database_id: int,
|
||||
default_schema: Optional[str],
|
||||
default_schema: str | None,
|
||||
) -> TokenList:
|
||||
"""
|
||||
Update a statement inplace applying any associated RLS predicates.
|
||||
|
@ -954,7 +954,7 @@ def insert_rls_as_subquery(
|
|||
This method is safer than ``insert_rls_in_predicate``, but doesn't work in all
|
||||
databases.
|
||||
"""
|
||||
rls: Optional[TokenList] = None
|
||||
rls: TokenList | None = None
|
||||
state = InsertRLSState.SCANNING
|
||||
for token in token_list.tokens:
|
||||
# Recurse into child token list
|
||||
|
@ -1030,7 +1030,7 @@ def insert_rls_as_subquery(
|
|||
def insert_rls_in_predicate(
|
||||
token_list: TokenList,
|
||||
database_id: int,
|
||||
default_schema: Optional[str],
|
||||
default_schema: str | None,
|
||||
) -> TokenList:
|
||||
"""
|
||||
Update a statement inplace applying any associated RLS predicates.
|
||||
|
@ -1041,7 +1041,7 @@ def insert_rls_in_predicate(
|
|||
after: SELECT * FROM some_table WHERE ( 1=1) AND some_table.id=42
|
||||
|
||||
"""
|
||||
rls: Optional[TokenList] = None
|
||||
rls: TokenList | None = None
|
||||
state = InsertRLSState.SCANNING
|
||||
for token in token_list.tokens:
|
||||
# Recurse into child token list
|
||||
|
@ -1175,7 +1175,7 @@ RE_JINJA_BLOCK = re.compile(r"\{[%#][^\{\}%#]+[%#]\}")
|
|||
|
||||
def extract_table_references(
|
||||
sql_text: str, sqla_dialect: str, show_warning: bool = True
|
||||
) -> set["Table"]:
|
||||
) -> set[Table]:
|
||||
"""
|
||||
Return all the dependencies from a SQL sql_text.
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue