From 7e51b200b42f598c3344548f5f64521814a1d3cd Mon Sep 17 00:00:00 2001 From: Sujith Kumar S <31705464+sujiplr@users.noreply.github.com> Date: Mon, 21 Feb 2022 13:28:39 +0530 Subject: [PATCH] fix(mssql): support top syntax for limiting queries (#18746) * SQL-TOP Fix For Database Engines MSSQL is not supporting LIMIT syntax in SQLs. For limiting the rows, MSSQL having a different keyword TOP. Added fixes for handling the TOP and LIMIT clauses based on the database engines. * Teradata code for top clause handling removed from teradata.py Teradata code for top clause handling removed from teradata.py file, since we added generic section in base engine for the same. * Changes to handle CTE along with TOP in complex SQLs Added changes to handle TOP command in CTEs, for DB Engines which are not supporting inline CTEs. * Test cases for TOP unit testing in MSSQL Added multiple unit test cases for MSSQL top command handling and also along with CTEs * Corrected the select_keywords name key in basengine Corrected the select_keywords name key in basengine * Changes based on as per review. made the required corrections based on code review to keep good code readability and code cleanliness. * Review changes to correct lint and typo issues Made the changes according to the review comments. * fix linting errors * fix teradata tests * add coverage * lint * Code cleanliness Moved the top/limit flag check from sql_lab to core. * Changed for code cleanliness Changes for keeping code cleanliness * Corrected lint issue Corrected lint issue. * Code cleanliness Code cleanliness Co-authored-by: Ville Brofeldt --- superset/db_engine_specs/base.py | 75 ++++++ superset/db_engine_specs/mssql.py | 1 + superset/db_engine_specs/teradata.py | 238 +----------------- superset/models/core.py | 4 +- superset/sql_lab.py | 3 +- superset/sql_parse.py | 65 ++++- .../unit_tests/db_engine_specs/test_mssql.py | 31 +++ .../db_engine_specs/test_teradata.py | 81 ++---- 8 files changed, 196 insertions(+), 302 deletions(-) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 764f3fde70..d7e457baa8 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -300,6 +300,16 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods # If True, then it will allow in subquery , # if False it will allow as regular CTE allows_cte_in_subquery = True + # Whether allow LIMIT clause in the SQL + # If True, then the database engine is allowed for LIMIT clause + # If False, then the database engine is allowed for TOP clause + allow_limit_clause = True + # This set will give keywords for select statements + # to consider for the engines with TOP SQL parsing + select_keywords: Set[str] = {"SELECT"} + # This set will give the keywords for data limit statements + # to consider for the engines with TOP SQL parsing + top_keywords: Set[str] = {"TOP"} force_column_alias_quotes = False arraysize = 0 @@ -649,6 +659,71 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return sql + @classmethod + def apply_top_to_sql(cls, sql: str, limit: int) -> str: + """ + Alters the SQL statement to apply a TOP clause + :param limit: Maximum number of rows to be returned by the query + :param sql: SQL query + :return: SQL query with top clause + """ + + cte = None + sql_remainder = None + sql = sql.strip(" \t\n;") + sql_statement = sqlparse.format(sql, strip_comments=True) + query_limit: Optional[int] = sql_parse.extract_top_from_query( + sql_statement, cls.top_keywords + ) + if not limit: + final_limit = query_limit + elif int(query_limit or 0) < limit and query_limit is not None: + final_limit = query_limit + else: + final_limit = limit + if not cls.allows_cte_in_subquery: + cte, sql_remainder = sql_parse.get_cte_remainder_query(sql_statement) + if cte: + str_statement = str(sql_remainder) + cte = cte + "\n" + else: + cte = "" + str_statement = str(sql) + str_statement = str_statement.replace("\n", " ").replace("\r", "") + + tokens = str_statement.rstrip().split(" ") + tokens = [token for token in tokens if token] + if cls.top_not_in_sql(str_statement): + selects = [ + i + for i, word in enumerate(tokens) + if word.upper() in cls.select_keywords + ] + first_select = selects[0] + tokens.insert(first_select + 1, "TOP") + tokens.insert(first_select + 2, str(final_limit)) + + next_is_limit_token = False + new_tokens = [] + + for token in tokens: + if token in cls.top_keywords: + next_is_limit_token = True + elif next_is_limit_token: + if token.isdigit(): + token = str(final_limit) + next_is_limit_token = False + new_tokens.append(token) + sql = " ".join(new_tokens) + return cte + sql + + @classmethod + def top_not_in_sql(cls, sql: str) -> bool: + for top_word in cls.top_keywords: + if top_word.upper() in sql.upper(): + return False + return True + @classmethod def get_limit_from_sql(cls, sql: str) -> Optional[int]: """ diff --git a/superset/db_engine_specs/mssql.py b/superset/db_engine_specs/mssql.py index e5c66e046a..158e73adea 100644 --- a/superset/db_engine_specs/mssql.py +++ b/superset/db_engine_specs/mssql.py @@ -48,6 +48,7 @@ class MssqlEngineSpec(BaseEngineSpec): limit_method = LimitMethod.WRAP_SQL max_column_name_length = 128 allows_cte_in_subquery = False + allow_limit_clause = False _time_grain_expressions = { None: "{col}", diff --git a/superset/db_engine_specs/teradata.py b/superset/db_engine_specs/teradata.py index 8e7589980b..bd2ee51605 100644 --- a/superset/db_engine_specs/teradata.py +++ b/superset/db_engine_specs/teradata.py @@ -15,216 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Optional, Set - -import sqlparse -from sqlparse.sql import ( - Identifier, - IdentifierList, - Parenthesis, - remove_quotes, - Token, - TokenList, -) -from sqlparse.tokens import Keyword, Name, Punctuation, String, Whitespace -from sqlparse.utils import imt - from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod -from superset.sql_parse import Table - -PRECEDES_TABLE_NAME = {"FROM", "JOIN", "DESCRIBE", "WITH", "LEFT JOIN", "RIGHT JOIN"} -CTE_PREFIX = "CTE__" -JOIN = " JOIN" - - -def _extract_limit_from_query_td(statement: TokenList) -> Optional[int]: - td_limit_keywork = {"TOP", "SAMPLE"} - str_statement = str(statement) - str_statement = str_statement.replace("\n", " ").replace("\r", "") - token = str_statement.rstrip().split(" ") - token = [part for part in token if part] - limit = None - - for i, _ in enumerate(token): - if token[i].upper() in td_limit_keywork and len(token) - 1 > i: - try: - limit = int(token[i + 1]) - except ValueError: - limit = None - break - return limit - - -class ParsedQueryTeradata: - def __init__( - self, sql_statement: str, strip_comments: bool = False, uri_type: str = "None" - ): - - if strip_comments: - sql_statement = sqlparse.format(sql_statement, strip_comments=True) - - self.sql: str = sql_statement - self._tables: Set[Table] = set() - self._alias_names: Set[str] = set() - self._limit: Optional[int] = None - self.uri_type: str = uri_type - - self._parsed = sqlparse.parse(self.stripped()) - for statement in self._parsed: - self._limit = _extract_limit_from_query_td(statement) - - @property - def tables(self) -> Set[Table]: - if not self._tables: - for statement in self._parsed: - self._extract_from_token(statement) - - self._tables = { - table for table in self._tables if str(table) not in self._alias_names - } - return self._tables - - def stripped(self) -> str: - return self.sql.strip(" \t\n;") - - def _extract_from_token(self, token: Token) -> None: - """ - store a list of subtokens and store lists of - subtoken list. - - It extracts and from :param token: and loops - through all subtokens recursively. It finds table_name_preceding_token and - passes and to self._process_tokenlist to populate - - self._tables. - - :param token: instance of Token or child class, e.g. TokenList, to be processed - """ - if not hasattr(token, "tokens"): - return - - table_name_preceding_token = False - - for item in token.tokens: - if item.is_group and ( - not self._is_identifier(item) or isinstance(item.tokens[0], Parenthesis) - ): - self._extract_from_token(item) - - if item.ttype in Keyword and ( - item.normalized in PRECEDES_TABLE_NAME or item.normalized.endswith(JOIN) - ): - table_name_preceding_token = True - continue - - if item.ttype in Keyword: - table_name_preceding_token = False - continue - if table_name_preceding_token: - if isinstance(item, Identifier): - self._process_tokenlist(item) - elif isinstance(item, IdentifierList): - for item_list in item.get_identifiers(): - if isinstance(item_list, TokenList): - self._process_tokenlist(item_list) - elif isinstance(item, IdentifierList): - if any(not self._is_identifier(ItemList) for ItemList in item.tokens): - self._extract_from_token(item) - - @staticmethod - def _get_table(tlist: TokenList) -> Optional[Table]: - """ - Return the table if valid, i.e., conforms to the [[catalog.]schema.]table - construct. - - :param tlist: The SQL tokens - :returns: The table if the name conforms - """ - - # Strip the alias if present. - idx = len(tlist.tokens) - - if tlist.has_alias(): - ws_idx, _ = tlist.token_next_by(t=Whitespace) - - if ws_idx != -1: - idx = ws_idx - - tokens = tlist.tokens[:idx] - - odd_token_number = len(tokens) in (1, 3, 5) - qualified_name_parts = all( - imt(token, t=[Name, String]) for token in tokens[::2] - ) - dot_separators = all(imt(token, m=(Punctuation, ".")) for token in tokens[1::2]) - if odd_token_number and qualified_name_parts and dot_separators: - return Table(*[remove_quotes(token.value) for token in tokens[::-2]]) - - return None - - @staticmethod - def _is_identifier(token: Token) -> bool: - return isinstance(token, (IdentifierList, Identifier)) - - def _process_tokenlist(self, token_list: TokenList) -> None: - """ - Add table names to table set - - :param token_list: TokenList to be processed - """ - # exclude subselects - if "(" not in str(token_list): - table = self._get_table(token_list) - if table and not table.table.startswith(CTE_PREFIX): - self._tables.add(table) - return - - # store aliases - if token_list.has_alias(): - self._alias_names.add(token_list.get_alias()) - - # some aliases are not parsed properly - if token_list.tokens[0].ttype == Name: - self._alias_names.add(token_list.tokens[0].value) - self._extract_from_token(token_list) - - def set_or_update_query_limit_td(self, new_limit: int) -> str: - td_sel_keywords = {"SELECT", "SEL"} - td_limit_keywords = {"TOP", "SAMPLE"} - statement = self._parsed[0] - - if not self._limit: - final_limit = new_limit - elif new_limit < self._limit: - final_limit = new_limit - else: - final_limit = self._limit - - str_statement = str(statement) - str_statement = str_statement.replace("\n", " ").replace("\r", "") - - tokens = str_statement.rstrip().split(" ") - tokens = [token for token in tokens if token] - - if limit_not_in_sql(str_statement, td_limit_keywords): - selects = [i for i, word in enumerate(tokens) if word in td_sel_keywords] - first_select = selects[0] - tokens.insert(first_select + 1, "TOP") - tokens.insert(first_select + 2, str(final_limit)) - - next_is_limit_token = False - new_tokens = [] - - for token in tokens: - if token.upper() in td_limit_keywords: - next_is_limit_token = True - elif next_is_limit_token: - if token.isdigit(): - token = str(final_limit) - next_is_limit_token = False - new_tokens.append(token) - - return " ".join(new_tokens) class TeradataEngineSpec(BaseEngineSpec): @@ -234,6 +25,9 @@ class TeradataEngineSpec(BaseEngineSpec): engine_name = "Teradata" limit_method = LimitMethod.WRAP_SQL max_column_name_length = 30 # since 14.10 this is 128 + allow_limit_clause = False + select_keywords = {"SELECT", "SEL"} + top_keywords = {"TOP", "SAMPLE"} _time_grain_expressions = { None: "{col}", @@ -253,29 +47,3 @@ class TeradataEngineSpec(BaseEngineSpec): "AT 0)) AT 0) + (({col} MOD 86400) * INTERVAL '00:00:01' " "HOUR TO SECOND) AS TIMESTAMP(0))" ) - - @classmethod - def apply_limit_to_sql( - cls, sql: str, limit: int, database: str = "Database", force: bool = False - ) -> str: - """ - Alters the SQL statement to apply a TOP clause - The function overwrites similar function in base.py because Teradata doesn't - support LIMIT syntax - :param sql: SQL query - :param limit: Maximum number of rows to be returned by the query - :param database: Database instance - :return: SQL query with limit clause - """ - - parsed_query = ParsedQueryTeradata(sql) - sql = parsed_query.set_or_update_query_limit_td(limit) - - return sql - - -def limit_not_in_sql(sql: str, limit_words: Set[str]) -> bool: - for limit_word in limit_words: - if limit_word in sql: - return False - return True diff --git a/superset/models/core.py b/superset/models/core.py index 05970fe774..7798ddf059 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -488,7 +488,9 @@ class Database( def apply_limit_to_sql( self, sql: str, limit: int = 1000, force: bool = False ) -> str: - return self.db_engine_spec.apply_limit_to_sql(sql, limit, self, force=force) + if self.db_engine_spec.allow_limit_clause: + return self.db_engine_spec.apply_limit_to_sql(sql, limit, self, force=force) + return self.db_engine_spec.apply_top_to_sql(sql, limit) def safe_sqlalchemy_uri(self) -> str: return self.sqlalchemy_uri diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 27b188355a..8fac419cf0 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -292,7 +292,8 @@ def apply_limit_if_exists( ) -> str: if query.limit and increased_limit: # We are fetching one more than the requested limit in order - # to test whether there are more rows than the limit. + # to test whether there are more rows than the limit. According to the DB + # Engine support it will choose top or limit parse # Later, the extra row will be dropped before sending # the results back to the user. sql = database.apply_limit_to_sql(sql, increased_limit, force=True) diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 34dbaa5531..b5b614cf25 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -18,7 +18,7 @@ import logging import re from dataclasses import dataclass from enum import Enum -from typing import List, Optional, Set +from typing import List, Optional, Set, Tuple from urllib import parse import sqlparse @@ -30,7 +30,16 @@ from sqlparse.sql import ( Token, TokenList, ) -from sqlparse.tokens import DDL, DML, Keyword, Name, Punctuation, String, Whitespace +from sqlparse.tokens import ( + CTE, + DDL, + DML, + Keyword, + Name, + Punctuation, + String, + Whitespace, +) from sqlparse.utils import imt from superset.exceptions import QueryClauseValidationException @@ -78,6 +87,58 @@ def _extract_limit_from_query(statement: TokenList) -> Optional[int]: return None +def extract_top_from_query( + statement: TokenList, top_keywords: Set[str] +) -> Optional[int]: + """ + Extract top clause value from SQL statement. + + :param statement: SQL statement + :param top_keywords: keywords that are considered as synonyms to TOP + :return: top value extracted from query, None if no top value present in statement + """ + + str_statement = str(statement) + str_statement = str_statement.replace("\n", " ").replace("\r", "") + token = str_statement.rstrip().split(" ") + token = [part for part in token if part] + top = None + for i, _ in enumerate(token): + if token[i].upper() in top_keywords and len(token) - 1 > i: + try: + top = int(token[i + 1]) + except ValueError: + top = None + break + return top + + +def get_cte_remainder_query(sql: str) -> Tuple[Optional[str], str]: + """ + parse the SQL and return the CTE and rest of the block to the caller + + :param sql: SQL query + :return: CTE and remainder block to the caller + + """ + cte: Optional[str] = None + remainder = sql + stmt = sqlparse.parse(sql)[0] + + # The first meaningful token for CTE will be with WITH + idx, token = stmt.token_next(-1, skip_ws=True, skip_cm=True) + if not (token and token.ttype == CTE): + return cte, remainder + idx, token = stmt.token_next(idx) + idx = stmt.token_index(token) + 1 + + # extract rest of the SQLs after CTE + remainder = "".join(str(token) for token in stmt.tokens[idx:]).strip() + cte = f"WITH {token.value}" + + return cte, remainder + + def strip_comments_from_sql(statement: str) -> str: """ Strips comments from a SQL statement, does a simple test first diff --git a/tests/unit_tests/db_engine_specs/test_mssql.py b/tests/unit_tests/db_engine_specs/test_mssql.py index 250b8158fa..5c8848280b 100644 --- a/tests/unit_tests/db_engine_specs/test_mssql.py +++ b/tests/unit_tests/db_engine_specs/test_mssql.py @@ -231,6 +231,37 @@ def test_cte_query_parsing( assert actual == expected +@pytest.mark.parametrize( + "original,expected,top", + [ + ("SEL TOP 1000 * FROM My_table", "SEL TOP 100 * FROM My_table", 100), + ("SEL TOP 1000 * FROM My_table;", "SEL TOP 100 * FROM My_table", 100), + ("SEL TOP 1000 * FROM My_table;", "SEL TOP 1000 * FROM My_table", 10000), + ("SEL TOP 1000 * FROM My_table;", "SEL TOP 1000 * FROM My_table", 1000), + ( + """with abc as (select * from test union select * from test1) +select TOP 100 * from currency""", + """WITH abc as (select * from test union select * from test1) +select TOP 100 * from currency""", + 1000, + ), + ("SELECT 1 as cnt", "SELECT TOP 10 1 as cnt", 10), + ( + "select TOP 1000 * from abc where id=1", + "select TOP 10 * from abc where id=1", + 10, + ), + ], +) +def test_top_query_parsing( + app_context: AppContext, original: TypeEngine, expected: str, top: int +) -> None: + from superset.db_engine_specs.mssql import MssqlEngineSpec + + actual = MssqlEngineSpec.apply_top_to_sql(original, top) + assert actual == expected + + def test_extract_errors(app_context: AppContext) -> None: """ Test that custom error messages are extracted correctly. diff --git a/tests/unit_tests/db_engine_specs/test_teradata.py b/tests/unit_tests/db_engine_specs/test_teradata.py index 8d9fc08c4a..11978737ab 100644 --- a/tests/unit_tests/db_engine_specs/test_teradata.py +++ b/tests/unit_tests/db_engine_specs/test_teradata.py @@ -15,73 +15,28 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=unused-argument, import-outside-toplevel, protected-access - +import pytest from flask.ctx import AppContext -def test_ParsedQueryTeradata_lower_limit(app_context: AppContext) -> None: +@pytest.mark.parametrize( + "limit,original,expected", + [ + (100, "SEL TOP 1000 * FROM My_table", "SEL TOP 100 * FROM My_table"), + (100, "SEL TOP 1000 * FROM My_table;", "SEL TOP 100 * FROM My_table"), + (10000, "SEL TOP 1000 * FROM My_table;", "SEL TOP 1000 * FROM My_table"), + (1000, "SEL TOP 1000 * FROM My_table;", "SEL TOP 1000 * FROM My_table"), + (100, "SELECT TOP 1000 * FROM My_table", "SELECT TOP 100 * FROM My_table"), + (100, "SEL SAMPLE 1000 * FROM My_table", "SEL SAMPLE 100 * FROM My_table"), + (10000, "SEL SAMPLE 1000 * FROM My_table", "SEL SAMPLE 1000 * FROM My_table"), + ], +) +def test_apply_top_to_sql_limit( + app_context: AppContext, limit: int, original: str, expected: str, +) -> None: """ - Test the custom ``ParsedQueryTeradata`` that calls ``_extract_limit_from_query_td(`` - - The CLass looks for Teradata limit keywords TOP and SAMPLE vs LIMIT in - other dialects. + Ensure limits are applied to the query correctly """ from superset.db_engine_specs.teradata import TeradataEngineSpec - sql = "SEL TOP 1000 * FROM My_table;" - limit = 100 - - assert str(TeradataEngineSpec.apply_limit_to_sql(sql, limit, "Database")) == ( - "SEL TOP 100 * FROM My_table" - ) - - -def test_ParsedQueryTeradata_higher_limit(app_context: AppContext) -> None: - """ - Test the custom ``ParsedQueryTeradata`` that calls ``_extract_limit_from_query_td(`` - - The CLass looks for Teradata limit keywords TOP and SAMPLE vs LIMIT in - other dialects. - """ - from superset.db_engine_specs.teradata import TeradataEngineSpec - - sql = "SEL TOP 1000 * FROM My_table;" - limit = 10000 - - assert str(TeradataEngineSpec.apply_limit_to_sql(sql, limit, "Database")) == ( - "SEL TOP 1000 * FROM My_table" - ) - - -def test_ParsedQueryTeradata_equal_limit(app_context: AppContext) -> None: - """ - Test the custom ``ParsedQueryTeradata`` that calls ``_extract_limit_from_query_td(`` - - The CLass looks for Teradata limit keywords TOP and SAMPLE vs LIMIT in - other dialects. - """ - from superset.db_engine_specs.teradata import TeradataEngineSpec - - sql = "SEL TOP 1000 * FROM My_table;" - limit = 1000 - - assert str(TeradataEngineSpec.apply_limit_to_sql(sql, limit, "Database")) == ( - "SEL TOP 1000 * FROM My_table" - ) - - -def test_ParsedQueryTeradata_no_limit(app_context: AppContext) -> None: - """ - Test the custom ``ParsedQueryTeradata`` that calls ``_extract_limit_from_query_td(`` - - The CLass looks for Teradata limit keywords TOP and SAMPLE vs LIMIT in - other dialects. - """ - from superset.db_engine_specs.teradata import TeradataEngineSpec - - sql = "SEL * FROM My_table;" - limit = 1000 - - assert str(TeradataEngineSpec.apply_limit_to_sql(sql, limit, "Database")) == ( - "SEL TOP 1000 * FROM My_table" - ) + assert TeradataEngineSpec.apply_top_to_sql(original, limit) == expected