mirror of https://github.com/apache/superset.git
feat: improve logic in is_select (#17329)
* feat: improve logic in is_select * Add more edge cases
This commit is contained in:
parent
9a4ab1026e
commit
93bafa0e6a
|
@ -29,7 +29,7 @@ from sqlparse.sql import (
|
|||
Token,
|
||||
TokenList,
|
||||
)
|
||||
from sqlparse.tokens import Keyword, Name, Punctuation, String, Whitespace
|
||||
from sqlparse.tokens import DDL, DML, Keyword, Name, Punctuation, String, Whitespace
|
||||
from sqlparse.utils import imt
|
||||
|
||||
RESULT_OPERATIONS = {"UNION", "INTERSECT", "EXCEPT", "SELECT"}
|
||||
|
@ -133,7 +133,26 @@ class ParsedQuery:
|
|||
def is_select(self) -> bool:
|
||||
# make sure we strip comments; prevents a bug with coments in the CTE
|
||||
parsed = sqlparse.parse(self.strip_comments())
|
||||
return parsed[0].get_type() == "SELECT"
|
||||
if parsed[0].get_type() == "SELECT":
|
||||
return True
|
||||
|
||||
if parsed[0].get_type() != "UNKNOWN":
|
||||
return False
|
||||
|
||||
# for `UNKNOWN`, check all DDL/DML explicitly: only `SELECT` DML is allowed,
|
||||
# and no DDL is allowed
|
||||
if any(token.ttype == DDL for token in parsed[0]) or any(
|
||||
token.ttype == DML and token.value != "SELECT" for token in parsed[0]
|
||||
):
|
||||
return False
|
||||
|
||||
# return false on `EXPLAIN`, `SET`, `SHOW`, etc.
|
||||
if parsed[0][0].ttype == Keyword:
|
||||
return False
|
||||
|
||||
return any(
|
||||
token.ttype == DML and token.value == "SELECT" for token in parsed[0]
|
||||
)
|
||||
|
||||
def is_valid_ctas(self) -> bool:
|
||||
parsed = sqlparse.parse(self.strip_comments())
|
||||
|
@ -150,7 +169,7 @@ class ParsedQuery:
|
|||
)
|
||||
|
||||
# Explain statements will only be the first statement
|
||||
return statements_without_comments.startswith("EXPLAIN")
|
||||
return statements_without_comments.upper().startswith("EXPLAIN")
|
||||
|
||||
def is_show(self) -> bool:
|
||||
# Remove comments
|
||||
|
|
|
@ -15,10 +15,17 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
import sqlparse
|
||||
|
||||
from superset.sql_parse import ParsedQuery
|
||||
|
||||
|
||||
def test_cte_with_comments():
|
||||
def test_cte_with_comments_is_select():
|
||||
"""
|
||||
Some CTES with comments are not correctly identified as SELECTS.
|
||||
"""
|
||||
sql = ParsedQuery(
|
||||
"""WITH blah AS
|
||||
(SELECT * FROM core_dev.manager_team),
|
||||
|
@ -44,3 +51,43 @@ SELECT * FROM blah
|
|||
INNER JOIN blah2 ON blah2.team_id = blah.team_id"""
|
||||
)
|
||||
assert sql.is_select()
|
||||
|
||||
|
||||
def test_cte_is_select():
|
||||
"""
|
||||
Some CTEs are not correctly identified as SELECTS.
|
||||
"""
|
||||
# `AS(` gets parsed as a function
|
||||
sql = ParsedQuery(
|
||||
"""WITH foo AS(
|
||||
SELECT
|
||||
FLOOR(__time TO WEEK) AS "week",
|
||||
name,
|
||||
COUNT(DISTINCT user_id) AS "unique_users"
|
||||
FROM "druid"."my_table"
|
||||
GROUP BY 1,2
|
||||
)
|
||||
SELECT
|
||||
f.week,
|
||||
f.name,
|
||||
f.unique_users
|
||||
FROM foo f"""
|
||||
)
|
||||
assert sql.is_select()
|
||||
|
||||
|
||||
def test_unknown_select():
|
||||
"""
|
||||
Test that `is_select` works when sqlparse fails to identify the type.
|
||||
"""
|
||||
sql = "WITH foo AS(SELECT 1) SELECT 1"
|
||||
assert sqlparse.parse(sql)[0].get_type() == "UNKNOWN"
|
||||
assert ParsedQuery(sql).is_select()
|
||||
|
||||
sql = "WITH foo AS(SELECT 1) INSERT INTO my_table (a) VALUES (1)"
|
||||
assert sqlparse.parse(sql)[0].get_type() == "UNKNOWN"
|
||||
assert not ParsedQuery(sql).is_select()
|
||||
|
||||
sql = "WITH foo AS(SELECT 1) DELETE FROM my_table"
|
||||
assert sqlparse.parse(sql)[0].get_type() == "UNKNOWN"
|
||||
assert not ParsedQuery(sql).is_select()
|
||||
|
|
Loading…
Reference in New Issue