diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 58dc210e2b..eeaecb3ad6 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -17,12 +17,14 @@ # pylint: disable=too-many-lines +from __future__ import annotations + import logging import re import urllib.parse from collections.abc import Iterable, Iterator from dataclasses import dataclass -from typing import Any, cast, Optional, Union +from typing import Any, cast import sqlglot import sqlparse @@ -138,7 +140,7 @@ class CtasMethod(StrEnum): VIEW = "VIEW" -def _extract_limit_from_query(statement: TokenList) -> Optional[int]: +def _extract_limit_from_query(statement: TokenList) -> int | None: """ Extract limit clause from SQL statement. @@ -159,9 +161,7 @@ def _extract_limit_from_query(statement: TokenList) -> Optional[int]: return None -def extract_top_from_query( - statement: TokenList, top_keywords: set[str] -) -> Optional[int]: +def extract_top_from_query(statement: TokenList, top_keywords: set[str]) -> int | None: """ Extract top clause value from SQL statement. @@ -185,7 +185,7 @@ def extract_top_from_query( return top -def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]: +def get_cte_remainder_query(sql: str) -> tuple[str | None, str]: """ parse the SQL and return the CTE and rest of the block to the caller @@ -193,7 +193,7 @@ def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]: :return: CTE and remainder block to the caller """ - cte: Optional[str] = None + cte: str | None = None remainder = sql stmt = sqlparse.parse(sql)[0] @@ -211,7 +211,7 @@ def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]: return cte, remainder -def strip_comments_from_sql(statement: str, engine: Optional[str] = None) -> str: +def strip_comments_from_sql(statement: str, engine: str | None = None) -> str: """ Strips comments from a SQL statement, does a simple test first to avoid always instantiating the expensive ParsedQuery constructor @@ -235,8 +235,8 @@ class Table: """ table: str - schema: Optional[str] = None - catalog: Optional[str] = None + schema: str | None = None + catalog: str | None = None def __str__(self) -> str: """ @@ -255,7 +255,7 @@ class Table: def extract_tables_from_statement( statement: exp.Expression, - dialect: Optional[Dialects], + dialect: Dialects | None, ) -> set[Table]: """ Extract all table references in a single statement. @@ -334,7 +334,7 @@ class SQLScript: def __init__( self, query: str, - engine: Optional[str] = None, + engine: str | None = None, ): dialect = SQLGLOT_DIALECTS.get(engine) if engine else None @@ -375,8 +375,8 @@ class SQLStatement: def __init__( self, - statement: Union[str, exp.Expression], - engine: Optional[str] = None, + statement: str | exp.Expression, + engine: str | None = None, ): dialect = SQLGLOT_DIALECTS.get(engine) if engine else None @@ -394,7 +394,7 @@ class SQLStatement: @staticmethod def _parse_statement( sql_statement: str, - dialect: Optional[Dialects], + dialect: Dialects | None, ) -> exp.Expression: """ Parse a single SQL statement. @@ -437,7 +437,7 @@ class ParsedQuery: self, sql_statement: str, strip_comments: bool = False, - engine: Optional[str] = None, + engine: str | None = None, ): if strip_comments: sql_statement = sqlparse.format(sql_statement, strip_comments=True) @@ -446,7 +446,7 @@ class ParsedQuery: self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None self._tables: set[Table] = set() self._alias_names: set[str] = set() - self._limit: Optional[int] = None + self._limit: int | None = None logger.debug("Parsing with sqlparse statement: %s", self.sql) self._parsed = sqlparse.parse(self.stripped()) @@ -550,7 +550,7 @@ class ParsedQuery: return source.name in ctes_in_scope @property - def limit(self) -> Optional[int]: + def limit(self) -> int | None: return self._limit def _get_cte_tables(self, parsed: dict[str, Any]) -> list[dict[str, Any]]: @@ -631,7 +631,7 @@ class ParsedQuery: return True - def get_inner_cte_expression(self, tokens: TokenList) -> Optional[TokenList]: + def get_inner_cte_expression(self, tokens: TokenList) -> TokenList | None: for token in tokens: if self._is_identifier(token): for identifier_token in token.tokens: @@ -695,7 +695,7 @@ class ParsedQuery: return statements @staticmethod - def get_table(tlist: TokenList) -> Optional[Table]: + def get_table(tlist: TokenList) -> Table | None: """ Return the table if valid, i.e., conforms to the [[catalog.]schema.]table construct. @@ -731,7 +731,7 @@ class ParsedQuery: def as_create_table( self, table_name: str, - schema_name: Optional[str] = None, + schema_name: str | None = None, overwrite: bool = False, method: CtasMethod = CtasMethod.TABLE, ) -> str: @@ -891,8 +891,8 @@ def add_table_name(rls: TokenList, table: str) -> None: def get_rls_for_table( candidate: Token, database_id: int, - default_schema: Optional[str], -) -> Optional[TokenList]: + default_schema: str | None, +) -> TokenList | None: """ Given a table name, return any associated RLS predicates. """ @@ -938,7 +938,7 @@ def get_rls_for_table( def insert_rls_as_subquery( token_list: TokenList, database_id: int, - default_schema: Optional[str], + default_schema: str | None, ) -> TokenList: """ Update a statement inplace applying any associated RLS predicates. @@ -954,7 +954,7 @@ def insert_rls_as_subquery( This method is safer than ``insert_rls_in_predicate``, but doesn't work in all databases. """ - rls: Optional[TokenList] = None + rls: TokenList | None = None state = InsertRLSState.SCANNING for token in token_list.tokens: # Recurse into child token list @@ -1030,7 +1030,7 @@ def insert_rls_as_subquery( def insert_rls_in_predicate( token_list: TokenList, database_id: int, - default_schema: Optional[str], + default_schema: str | None, ) -> TokenList: """ Update a statement inplace applying any associated RLS predicates. @@ -1041,7 +1041,7 @@ def insert_rls_in_predicate( after: SELECT * FROM some_table WHERE ( 1=1) AND some_table.id=42 """ - rls: Optional[TokenList] = None + rls: TokenList | None = None state = InsertRLSState.SCANNING for token in token_list.tokens: # Recurse into child token list @@ -1175,7 +1175,7 @@ RE_JINJA_BLOCK = re.compile(r"\{[%#][^\{\}%#]+[%#]\}") def extract_table_references( sql_text: str, sqla_dialect: str, show_warning: bool = True -) -> set["Table"]: +) -> set[Table]: """ Return all the dependencies from a SQL sql_text. """