chore: add annotations to `sql_parse.py` (#27520)

This commit is contained in:
Beto Dealmeida 2024-03-14 18:16:06 -04:00 committed by GitHub
parent d2c90013fc
commit 024b88a40d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 28 additions and 28 deletions

View File

@ -17,12 +17,14 @@
# pylint: disable=too-many-lines
from __future__ import annotations
import logging
import re
import urllib.parse
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from typing import Any, cast, Optional, Union
from typing import Any, cast
import sqlglot
import sqlparse
@ -138,7 +140,7 @@ class CtasMethod(StrEnum):
VIEW = "VIEW"
def _extract_limit_from_query(statement: TokenList) -> Optional[int]:
def _extract_limit_from_query(statement: TokenList) -> int | None:
"""
Extract limit clause from SQL statement.
@ -159,9 +161,7 @@ def _extract_limit_from_query(statement: TokenList) -> Optional[int]:
return None
def extract_top_from_query(
statement: TokenList, top_keywords: set[str]
) -> Optional[int]:
def extract_top_from_query(statement: TokenList, top_keywords: set[str]) -> int | None:
"""
Extract top clause value from SQL statement.
@ -185,7 +185,7 @@ def extract_top_from_query(
return top
def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
def get_cte_remainder_query(sql: str) -> tuple[str | None, str]:
"""
parse the SQL and return the CTE and rest of the block to the caller
@ -193,7 +193,7 @@ def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
:return: CTE and remainder block to the caller
"""
cte: Optional[str] = None
cte: str | None = None
remainder = sql
stmt = sqlparse.parse(sql)[0]
@ -211,7 +211,7 @@ def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
return cte, remainder
def strip_comments_from_sql(statement: str, engine: Optional[str] = None) -> str:
def strip_comments_from_sql(statement: str, engine: str | None = None) -> str:
"""
Strips comments from a SQL statement, does a simple test first
to avoid always instantiating the expensive ParsedQuery constructor
@ -235,8 +235,8 @@ class Table:
"""
table: str
schema: Optional[str] = None
catalog: Optional[str] = None
schema: str | None = None
catalog: str | None = None
def __str__(self) -> str:
"""
@ -255,7 +255,7 @@ class Table:
def extract_tables_from_statement(
statement: exp.Expression,
dialect: Optional[Dialects],
dialect: Dialects | None,
) -> set[Table]:
"""
Extract all table references in a single statement.
@ -334,7 +334,7 @@ class SQLScript:
def __init__(
self,
query: str,
engine: Optional[str] = None,
engine: str | None = None,
):
dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
@ -375,8 +375,8 @@ class SQLStatement:
def __init__(
self,
statement: Union[str, exp.Expression],
engine: Optional[str] = None,
statement: str | exp.Expression,
engine: str | None = None,
):
dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
@ -394,7 +394,7 @@ class SQLStatement:
@staticmethod
def _parse_statement(
sql_statement: str,
dialect: Optional[Dialects],
dialect: Dialects | None,
) -> exp.Expression:
"""
Parse a single SQL statement.
@ -437,7 +437,7 @@ class ParsedQuery:
self,
sql_statement: str,
strip_comments: bool = False,
engine: Optional[str] = None,
engine: str | None = None,
):
if strip_comments:
sql_statement = sqlparse.format(sql_statement, strip_comments=True)
@ -446,7 +446,7 @@ class ParsedQuery:
self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
self._tables: set[Table] = set()
self._alias_names: set[str] = set()
self._limit: Optional[int] = None
self._limit: int | None = None
logger.debug("Parsing with sqlparse statement: %s", self.sql)
self._parsed = sqlparse.parse(self.stripped())
@ -550,7 +550,7 @@ class ParsedQuery:
return source.name in ctes_in_scope
@property
def limit(self) -> Optional[int]:
def limit(self) -> int | None:
return self._limit
def _get_cte_tables(self, parsed: dict[str, Any]) -> list[dict[str, Any]]:
@ -631,7 +631,7 @@ class ParsedQuery:
return True
def get_inner_cte_expression(self, tokens: TokenList) -> Optional[TokenList]:
def get_inner_cte_expression(self, tokens: TokenList) -> TokenList | None:
for token in tokens:
if self._is_identifier(token):
for identifier_token in token.tokens:
@ -695,7 +695,7 @@ class ParsedQuery:
return statements
@staticmethod
def get_table(tlist: TokenList) -> Optional[Table]:
def get_table(tlist: TokenList) -> Table | None:
"""
Return the table if valid, i.e., conforms to the [[catalog.]schema.]table
construct.
@ -731,7 +731,7 @@ class ParsedQuery:
def as_create_table(
self,
table_name: str,
schema_name: Optional[str] = None,
schema_name: str | None = None,
overwrite: bool = False,
method: CtasMethod = CtasMethod.TABLE,
) -> str:
@ -891,8 +891,8 @@ def add_table_name(rls: TokenList, table: str) -> None:
def get_rls_for_table(
candidate: Token,
database_id: int,
default_schema: Optional[str],
) -> Optional[TokenList]:
default_schema: str | None,
) -> TokenList | None:
"""
Given a table name, return any associated RLS predicates.
"""
@ -938,7 +938,7 @@ def get_rls_for_table(
def insert_rls_as_subquery(
token_list: TokenList,
database_id: int,
default_schema: Optional[str],
default_schema: str | None,
) -> TokenList:
"""
Update a statement inplace applying any associated RLS predicates.
@ -954,7 +954,7 @@ def insert_rls_as_subquery(
This method is safer than ``insert_rls_in_predicate``, but doesn't work in all
databases.
"""
rls: Optional[TokenList] = None
rls: TokenList | None = None
state = InsertRLSState.SCANNING
for token in token_list.tokens:
# Recurse into child token list
@ -1030,7 +1030,7 @@ def insert_rls_as_subquery(
def insert_rls_in_predicate(
token_list: TokenList,
database_id: int,
default_schema: Optional[str],
default_schema: str | None,
) -> TokenList:
"""
Update a statement inplace applying any associated RLS predicates.
@ -1041,7 +1041,7 @@ def insert_rls_in_predicate(
after: SELECT * FROM some_table WHERE ( 1=1) AND some_table.id=42
"""
rls: Optional[TokenList] = None
rls: TokenList | None = None
state = InsertRLSState.SCANNING
for token in token_list.tokens:
# Recurse into child token list
@ -1175,7 +1175,7 @@ RE_JINJA_BLOCK = re.compile(r"\{[%#][^\{\}%#]+[%#]\}")
def extract_table_references(
sql_text: str, sqla_dialect: str, show_warning: bool = True
) -> set["Table"]:
) -> set[Table]:
"""
Return all the dependencies from a SQL sql_text.
"""