fix(mssql): support top syntax for limiting queries (#18746)

* SQL-TOP Fix For Database Engines

MSSQL is not supporting LIMIT syntax in SQLs. For limiting the rows, MSSQL having a different keyword TOP. Added fixes for handling the TOP and LIMIT clauses based on the database engines.

* Teradata code for top clause handling removed from teradata.py

Teradata code for top clause handling removed from teradata.py file, since we added generic section in base engine for the same.

* Changes to handle CTE along with TOP in complex SQLs

Added changes to handle TOP command in CTEs, for DB Engines which are not supporting inline CTEs.

* Test cases for TOP unit testing in MSSQL

Added multiple unit test cases for MSSQL top command handling and also along with CTEs

* Corrected the select_keywords name key in basengine

Corrected the select_keywords name key in basengine

* Changes based on as per review.

made the required corrections based on code review to keep good code readability and code cleanliness.

* Review changes to correct lint and typo issues

Made the changes according to the review comments.

* fix linting errors

* fix teradata tests

* add coverage

* lint

* Code cleanliness

Moved the top/limit flag check from sql_lab to core.

* Changed for code cleanliness

Changes for keeping code cleanliness

* Corrected lint issue

Corrected lint issue.

* Code cleanliness

Code cleanliness

Co-authored-by: Ville Brofeldt <ville.v.brofeldt@gmail.com>
This commit is contained in:
Sujith Kumar S 2022-02-21 13:28:39 +05:30 committed by GitHub
parent a29153778e
commit 7e51b200b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 196 additions and 302 deletions

View File

@ -300,6 +300,16 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
# If True, then it will allow in subquery ,
# if False it will allow as regular CTE
allows_cte_in_subquery = True
# Whether allow LIMIT clause in the SQL
# If True, then the database engine is allowed for LIMIT clause
# If False, then the database engine is allowed for TOP clause
allow_limit_clause = True
# This set will give keywords for select statements
# to consider for the engines with TOP SQL parsing
select_keywords: Set[str] = {"SELECT"}
# This set will give the keywords for data limit statements
# to consider for the engines with TOP SQL parsing
top_keywords: Set[str] = {"TOP"}
force_column_alias_quotes = False
arraysize = 0
@ -649,6 +659,71 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return sql
@classmethod
def apply_top_to_sql(cls, sql: str, limit: int) -> str:
"""
Alters the SQL statement to apply a TOP clause
:param limit: Maximum number of rows to be returned by the query
:param sql: SQL query
:return: SQL query with top clause
"""
cte = None
sql_remainder = None
sql = sql.strip(" \t\n;")
sql_statement = sqlparse.format(sql, strip_comments=True)
query_limit: Optional[int] = sql_parse.extract_top_from_query(
sql_statement, cls.top_keywords
)
if not limit:
final_limit = query_limit
elif int(query_limit or 0) < limit and query_limit is not None:
final_limit = query_limit
else:
final_limit = limit
if not cls.allows_cte_in_subquery:
cte, sql_remainder = sql_parse.get_cte_remainder_query(sql_statement)
if cte:
str_statement = str(sql_remainder)
cte = cte + "\n"
else:
cte = ""
str_statement = str(sql)
str_statement = str_statement.replace("\n", " ").replace("\r", "")
tokens = str_statement.rstrip().split(" ")
tokens = [token for token in tokens if token]
if cls.top_not_in_sql(str_statement):
selects = [
i
for i, word in enumerate(tokens)
if word.upper() in cls.select_keywords
]
first_select = selects[0]
tokens.insert(first_select + 1, "TOP")
tokens.insert(first_select + 2, str(final_limit))
next_is_limit_token = False
new_tokens = []
for token in tokens:
if token in cls.top_keywords:
next_is_limit_token = True
elif next_is_limit_token:
if token.isdigit():
token = str(final_limit)
next_is_limit_token = False
new_tokens.append(token)
sql = " ".join(new_tokens)
return cte + sql
@classmethod
def top_not_in_sql(cls, sql: str) -> bool:
for top_word in cls.top_keywords:
if top_word.upper() in sql.upper():
return False
return True
@classmethod
def get_limit_from_sql(cls, sql: str) -> Optional[int]:
"""

View File

@ -48,6 +48,7 @@ class MssqlEngineSpec(BaseEngineSpec):
limit_method = LimitMethod.WRAP_SQL
max_column_name_length = 128
allows_cte_in_subquery = False
allow_limit_clause = False
_time_grain_expressions = {
None: "{col}",

View File

@ -15,216 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from typing import Optional, Set
import sqlparse
from sqlparse.sql import (
Identifier,
IdentifierList,
Parenthesis,
remove_quotes,
Token,
TokenList,
)
from sqlparse.tokens import Keyword, Name, Punctuation, String, Whitespace
from sqlparse.utils import imt
from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
from superset.sql_parse import Table
PRECEDES_TABLE_NAME = {"FROM", "JOIN", "DESCRIBE", "WITH", "LEFT JOIN", "RIGHT JOIN"}
CTE_PREFIX = "CTE__"
JOIN = " JOIN"
def _extract_limit_from_query_td(statement: TokenList) -> Optional[int]:
td_limit_keywork = {"TOP", "SAMPLE"}
str_statement = str(statement)
str_statement = str_statement.replace("\n", " ").replace("\r", "")
token = str_statement.rstrip().split(" ")
token = [part for part in token if part]
limit = None
for i, _ in enumerate(token):
if token[i].upper() in td_limit_keywork and len(token) - 1 > i:
try:
limit = int(token[i + 1])
except ValueError:
limit = None
break
return limit
class ParsedQueryTeradata:
def __init__(
self, sql_statement: str, strip_comments: bool = False, uri_type: str = "None"
):
if strip_comments:
sql_statement = sqlparse.format(sql_statement, strip_comments=True)
self.sql: str = sql_statement
self._tables: Set[Table] = set()
self._alias_names: Set[str] = set()
self._limit: Optional[int] = None
self.uri_type: str = uri_type
self._parsed = sqlparse.parse(self.stripped())
for statement in self._parsed:
self._limit = _extract_limit_from_query_td(statement)
@property
def tables(self) -> Set[Table]:
if not self._tables:
for statement in self._parsed:
self._extract_from_token(statement)
self._tables = {
table for table in self._tables if str(table) not in self._alias_names
}
return self._tables
def stripped(self) -> str:
return self.sql.strip(" \t\n;")
def _extract_from_token(self, token: Token) -> None:
"""
<Identifier> store a list of subtokens and <IdentifierList> store lists of
subtoken list.
It extracts <IdentifierList> and <Identifier> from :param token: and loops
through all subtokens recursively. It finds table_name_preceding_token and
passes <IdentifierList> and <Identifier> to self._process_tokenlist to populate
self._tables.
:param token: instance of Token or child class, e.g. TokenList, to be processed
"""
if not hasattr(token, "tokens"):
return
table_name_preceding_token = False
for item in token.tokens:
if item.is_group and (
not self._is_identifier(item) or isinstance(item.tokens[0], Parenthesis)
):
self._extract_from_token(item)
if item.ttype in Keyword and (
item.normalized in PRECEDES_TABLE_NAME or item.normalized.endswith(JOIN)
):
table_name_preceding_token = True
continue
if item.ttype in Keyword:
table_name_preceding_token = False
continue
if table_name_preceding_token:
if isinstance(item, Identifier):
self._process_tokenlist(item)
elif isinstance(item, IdentifierList):
for item_list in item.get_identifiers():
if isinstance(item_list, TokenList):
self._process_tokenlist(item_list)
elif isinstance(item, IdentifierList):
if any(not self._is_identifier(ItemList) for ItemList in item.tokens):
self._extract_from_token(item)
@staticmethod
def _get_table(tlist: TokenList) -> Optional[Table]:
"""
Return the table if valid, i.e., conforms to the [[catalog.]schema.]table
construct.
:param tlist: The SQL tokens
:returns: The table if the name conforms
"""
# Strip the alias if present.
idx = len(tlist.tokens)
if tlist.has_alias():
ws_idx, _ = tlist.token_next_by(t=Whitespace)
if ws_idx != -1:
idx = ws_idx
tokens = tlist.tokens[:idx]
odd_token_number = len(tokens) in (1, 3, 5)
qualified_name_parts = all(
imt(token, t=[Name, String]) for token in tokens[::2]
)
dot_separators = all(imt(token, m=(Punctuation, ".")) for token in tokens[1::2])
if odd_token_number and qualified_name_parts and dot_separators:
return Table(*[remove_quotes(token.value) for token in tokens[::-2]])
return None
@staticmethod
def _is_identifier(token: Token) -> bool:
return isinstance(token, (IdentifierList, Identifier))
def _process_tokenlist(self, token_list: TokenList) -> None:
"""
Add table names to table set
:param token_list: TokenList to be processed
"""
# exclude subselects
if "(" not in str(token_list):
table = self._get_table(token_list)
if table and not table.table.startswith(CTE_PREFIX):
self._tables.add(table)
return
# store aliases
if token_list.has_alias():
self._alias_names.add(token_list.get_alias())
# some aliases are not parsed properly
if token_list.tokens[0].ttype == Name:
self._alias_names.add(token_list.tokens[0].value)
self._extract_from_token(token_list)
def set_or_update_query_limit_td(self, new_limit: int) -> str:
td_sel_keywords = {"SELECT", "SEL"}
td_limit_keywords = {"TOP", "SAMPLE"}
statement = self._parsed[0]
if not self._limit:
final_limit = new_limit
elif new_limit < self._limit:
final_limit = new_limit
else:
final_limit = self._limit
str_statement = str(statement)
str_statement = str_statement.replace("\n", " ").replace("\r", "")
tokens = str_statement.rstrip().split(" ")
tokens = [token for token in tokens if token]
if limit_not_in_sql(str_statement, td_limit_keywords):
selects = [i for i, word in enumerate(tokens) if word in td_sel_keywords]
first_select = selects[0]
tokens.insert(first_select + 1, "TOP")
tokens.insert(first_select + 2, str(final_limit))
next_is_limit_token = False
new_tokens = []
for token in tokens:
if token.upper() in td_limit_keywords:
next_is_limit_token = True
elif next_is_limit_token:
if token.isdigit():
token = str(final_limit)
next_is_limit_token = False
new_tokens.append(token)
return " ".join(new_tokens)
class TeradataEngineSpec(BaseEngineSpec):
@ -234,6 +25,9 @@ class TeradataEngineSpec(BaseEngineSpec):
engine_name = "Teradata"
limit_method = LimitMethod.WRAP_SQL
max_column_name_length = 30 # since 14.10 this is 128
allow_limit_clause = False
select_keywords = {"SELECT", "SEL"}
top_keywords = {"TOP", "SAMPLE"}
_time_grain_expressions = {
None: "{col}",
@ -253,29 +47,3 @@ class TeradataEngineSpec(BaseEngineSpec):
"AT 0)) AT 0) + (({col} MOD 86400) * INTERVAL '00:00:01' "
"HOUR TO SECOND) AS TIMESTAMP(0))"
)
@classmethod
def apply_limit_to_sql(
cls, sql: str, limit: int, database: str = "Database", force: bool = False
) -> str:
"""
Alters the SQL statement to apply a TOP clause
The function overwrites similar function in base.py because Teradata doesn't
support LIMIT syntax
:param sql: SQL query
:param limit: Maximum number of rows to be returned by the query
:param database: Database instance
:return: SQL query with limit clause
"""
parsed_query = ParsedQueryTeradata(sql)
sql = parsed_query.set_or_update_query_limit_td(limit)
return sql
def limit_not_in_sql(sql: str, limit_words: Set[str]) -> bool:
for limit_word in limit_words:
if limit_word in sql:
return False
return True

View File

@ -488,7 +488,9 @@ class Database(
def apply_limit_to_sql(
self, sql: str, limit: int = 1000, force: bool = False
) -> str:
return self.db_engine_spec.apply_limit_to_sql(sql, limit, self, force=force)
if self.db_engine_spec.allow_limit_clause:
return self.db_engine_spec.apply_limit_to_sql(sql, limit, self, force=force)
return self.db_engine_spec.apply_top_to_sql(sql, limit)
def safe_sqlalchemy_uri(self) -> str:
return self.sqlalchemy_uri

View File

@ -292,7 +292,8 @@ def apply_limit_if_exists(
) -> str:
if query.limit and increased_limit:
# We are fetching one more than the requested limit in order
# to test whether there are more rows than the limit.
# to test whether there are more rows than the limit. According to the DB
# Engine support it will choose top or limit parse
# Later, the extra row will be dropped before sending
# the results back to the user.
sql = database.apply_limit_to_sql(sql, increased_limit, force=True)

View File

@ -18,7 +18,7 @@ import logging
import re
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional, Set
from typing import List, Optional, Set, Tuple
from urllib import parse
import sqlparse
@ -30,7 +30,16 @@ from sqlparse.sql import (
Token,
TokenList,
)
from sqlparse.tokens import DDL, DML, Keyword, Name, Punctuation, String, Whitespace
from sqlparse.tokens import (
CTE,
DDL,
DML,
Keyword,
Name,
Punctuation,
String,
Whitespace,
)
from sqlparse.utils import imt
from superset.exceptions import QueryClauseValidationException
@ -78,6 +87,58 @@ def _extract_limit_from_query(statement: TokenList) -> Optional[int]:
return None
def extract_top_from_query(
statement: TokenList, top_keywords: Set[str]
) -> Optional[int]:
"""
Extract top clause value from SQL statement.
:param statement: SQL statement
:param top_keywords: keywords that are considered as synonyms to TOP
:return: top value extracted from query, None if no top value present in statement
"""
str_statement = str(statement)
str_statement = str_statement.replace("\n", " ").replace("\r", "")
token = str_statement.rstrip().split(" ")
token = [part for part in token if part]
top = None
for i, _ in enumerate(token):
if token[i].upper() in top_keywords and len(token) - 1 > i:
try:
top = int(token[i + 1])
except ValueError:
top = None
break
return top
def get_cte_remainder_query(sql: str) -> Tuple[Optional[str], str]:
"""
parse the SQL and return the CTE and rest of the block to the caller
:param sql: SQL query
:return: CTE and remainder block to the caller
"""
cte: Optional[str] = None
remainder = sql
stmt = sqlparse.parse(sql)[0]
# The first meaningful token for CTE will be with WITH
idx, token = stmt.token_next(-1, skip_ws=True, skip_cm=True)
if not (token and token.ttype == CTE):
return cte, remainder
idx, token = stmt.token_next(idx)
idx = stmt.token_index(token) + 1
# extract rest of the SQLs after CTE
remainder = "".join(str(token) for token in stmt.tokens[idx:]).strip()
cte = f"WITH {token.value}"
return cte, remainder
def strip_comments_from_sql(statement: str) -> str:
"""
Strips comments from a SQL statement, does a simple test first

View File

@ -231,6 +231,37 @@ def test_cte_query_parsing(
assert actual == expected
@pytest.mark.parametrize(
"original,expected,top",
[
("SEL TOP 1000 * FROM My_table", "SEL TOP 100 * FROM My_table", 100),
("SEL TOP 1000 * FROM My_table;", "SEL TOP 100 * FROM My_table", 100),
("SEL TOP 1000 * FROM My_table;", "SEL TOP 1000 * FROM My_table", 10000),
("SEL TOP 1000 * FROM My_table;", "SEL TOP 1000 * FROM My_table", 1000),
(
"""with abc as (select * from test union select * from test1)
select TOP 100 * from currency""",
"""WITH abc as (select * from test union select * from test1)
select TOP 100 * from currency""",
1000,
),
("SELECT 1 as cnt", "SELECT TOP 10 1 as cnt", 10),
(
"select TOP 1000 * from abc where id=1",
"select TOP 10 * from abc where id=1",
10,
),
],
)
def test_top_query_parsing(
app_context: AppContext, original: TypeEngine, expected: str, top: int
) -> None:
from superset.db_engine_specs.mssql import MssqlEngineSpec
actual = MssqlEngineSpec.apply_top_to_sql(original, top)
assert actual == expected
def test_extract_errors(app_context: AppContext) -> None:
"""
Test that custom error messages are extracted correctly.

View File

@ -15,73 +15,28 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
import pytest
from flask.ctx import AppContext
def test_ParsedQueryTeradata_lower_limit(app_context: AppContext) -> None:
@pytest.mark.parametrize(
"limit,original,expected",
[
(100, "SEL TOP 1000 * FROM My_table", "SEL TOP 100 * FROM My_table"),
(100, "SEL TOP 1000 * FROM My_table;", "SEL TOP 100 * FROM My_table"),
(10000, "SEL TOP 1000 * FROM My_table;", "SEL TOP 1000 * FROM My_table"),
(1000, "SEL TOP 1000 * FROM My_table;", "SEL TOP 1000 * FROM My_table"),
(100, "SELECT TOP 1000 * FROM My_table", "SELECT TOP 100 * FROM My_table"),
(100, "SEL SAMPLE 1000 * FROM My_table", "SEL SAMPLE 100 * FROM My_table"),
(10000, "SEL SAMPLE 1000 * FROM My_table", "SEL SAMPLE 1000 * FROM My_table"),
],
)
def test_apply_top_to_sql_limit(
app_context: AppContext, limit: int, original: str, expected: str,
) -> None:
"""
Test the custom ``ParsedQueryTeradata`` that calls ``_extract_limit_from_query_td(``
The CLass looks for Teradata limit keywords TOP and SAMPLE vs LIMIT in
other dialects.
Ensure limits are applied to the query correctly
"""
from superset.db_engine_specs.teradata import TeradataEngineSpec
sql = "SEL TOP 1000 * FROM My_table;"
limit = 100
assert str(TeradataEngineSpec.apply_limit_to_sql(sql, limit, "Database")) == (
"SEL TOP 100 * FROM My_table"
)
def test_ParsedQueryTeradata_higher_limit(app_context: AppContext) -> None:
"""
Test the custom ``ParsedQueryTeradata`` that calls ``_extract_limit_from_query_td(``
The CLass looks for Teradata limit keywords TOP and SAMPLE vs LIMIT in
other dialects.
"""
from superset.db_engine_specs.teradata import TeradataEngineSpec
sql = "SEL TOP 1000 * FROM My_table;"
limit = 10000
assert str(TeradataEngineSpec.apply_limit_to_sql(sql, limit, "Database")) == (
"SEL TOP 1000 * FROM My_table"
)
def test_ParsedQueryTeradata_equal_limit(app_context: AppContext) -> None:
"""
Test the custom ``ParsedQueryTeradata`` that calls ``_extract_limit_from_query_td(``
The CLass looks for Teradata limit keywords TOP and SAMPLE vs LIMIT in
other dialects.
"""
from superset.db_engine_specs.teradata import TeradataEngineSpec
sql = "SEL TOP 1000 * FROM My_table;"
limit = 1000
assert str(TeradataEngineSpec.apply_limit_to_sql(sql, limit, "Database")) == (
"SEL TOP 1000 * FROM My_table"
)
def test_ParsedQueryTeradata_no_limit(app_context: AppContext) -> None:
"""
Test the custom ``ParsedQueryTeradata`` that calls ``_extract_limit_from_query_td(``
The CLass looks for Teradata limit keywords TOP and SAMPLE vs LIMIT in
other dialects.
"""
from superset.db_engine_specs.teradata import TeradataEngineSpec
sql = "SEL * FROM My_table;"
limit = 1000
assert str(TeradataEngineSpec.apply_limit_to_sql(sql, limit, "Database")) == (
"SEL TOP 1000 * FROM My_table"
)
assert TeradataEngineSpec.apply_top_to_sql(original, limit) == expected