feat: improve logic in is_select (#17329)

* feat: improve logic in is_select

* Add more edge cases
This commit is contained in:
Beto Dealmeida 2021-11-02 17:30:12 -07:00 committed by GitHub
parent 9a4ab1026e
commit 93bafa0e6a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 70 additions and 4 deletions

View File

@ -29,7 +29,7 @@ from sqlparse.sql import (
Token,
TokenList,
)
from sqlparse.tokens import Keyword, Name, Punctuation, String, Whitespace
from sqlparse.tokens import DDL, DML, Keyword, Name, Punctuation, String, Whitespace
from sqlparse.utils import imt
RESULT_OPERATIONS = {"UNION", "INTERSECT", "EXCEPT", "SELECT"}
@ -133,7 +133,26 @@ class ParsedQuery:
def is_select(self) -> bool:
# make sure we strip comments; prevents a bug with coments in the CTE
parsed = sqlparse.parse(self.strip_comments())
return parsed[0].get_type() == "SELECT"
if parsed[0].get_type() == "SELECT":
return True
if parsed[0].get_type() != "UNKNOWN":
return False
# for `UNKNOWN`, check all DDL/DML explicitly: only `SELECT` DML is allowed,
# and no DDL is allowed
if any(token.ttype == DDL for token in parsed[0]) or any(
token.ttype == DML and token.value != "SELECT" for token in parsed[0]
):
return False
# return false on `EXPLAIN`, `SET`, `SHOW`, etc.
if parsed[0][0].ttype == Keyword:
return False
return any(
token.ttype == DML and token.value == "SELECT" for token in parsed[0]
)
def is_valid_ctas(self) -> bool:
parsed = sqlparse.parse(self.strip_comments())
@ -150,7 +169,7 @@ class ParsedQuery:
)
# Explain statements will only be the first statement
return statements_without_comments.startswith("EXPLAIN")
return statements_without_comments.upper().startswith("EXPLAIN")
def is_show(self) -> bool:
# Remove comments

View File

@ -15,10 +15,17 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
import sqlparse
from superset.sql_parse import ParsedQuery
def test_cte_with_comments():
def test_cte_with_comments_is_select():
"""
Some CTES with comments are not correctly identified as SELECTS.
"""
sql = ParsedQuery(
"""WITH blah AS
(SELECT * FROM core_dev.manager_team),
@ -44,3 +51,43 @@ SELECT * FROM blah
INNER JOIN blah2 ON blah2.team_id = blah.team_id"""
)
assert sql.is_select()
def test_cte_is_select():
"""
Some CTEs are not correctly identified as SELECTS.
"""
# `AS(` gets parsed as a function
sql = ParsedQuery(
"""WITH foo AS(
SELECT
FLOOR(__time TO WEEK) AS "week",
name,
COUNT(DISTINCT user_id) AS "unique_users"
FROM "druid"."my_table"
GROUP BY 1,2
)
SELECT
f.week,
f.name,
f.unique_users
FROM foo f"""
)
assert sql.is_select()
def test_unknown_select():
"""
Test that `is_select` works when sqlparse fails to identify the type.
"""
sql = "WITH foo AS(SELECT 1) SELECT 1"
assert sqlparse.parse(sql)[0].get_type() == "UNKNOWN"
assert ParsedQuery(sql).is_select()
sql = "WITH foo AS(SELECT 1) INSERT INTO my_table (a) VALUES (1)"
assert sqlparse.parse(sql)[0].get_type() == "UNKNOWN"
assert not ParsedQuery(sql).is_select()
sql = "WITH foo AS(SELECT 1) DELETE FROM my_table"
assert sqlparse.parse(sql)[0].get_type() == "UNKNOWN"
assert not ParsedQuery(sql).is_select()