mirror of https://github.com/apache/superset.git
[fix] Fixing SQL parsing issue (#7374)
This commit is contained in:
parent
ee78fd7b3d
commit
fb627ba376
|
@ -18,7 +18,7 @@
|
|||
import logging
|
||||
|
||||
import sqlparse
|
||||
from sqlparse.sql import Identifier, IdentifierList
|
||||
from sqlparse.sql import Identifier, IdentifierList, Token, TokenList
|
||||
from sqlparse.tokens import Keyword, Name
|
||||
|
||||
RESULT_OPERATIONS = {'UNION', 'INTERSECT', 'EXCEPT', 'SELECT'}
|
||||
|
@ -75,32 +75,32 @@ class ParsedQuery(object):
|
|||
return statements
|
||||
|
||||
@staticmethod
|
||||
def __get_full_name(identifier):
|
||||
if len(identifier.tokens) > 2 and identifier.tokens[1].value == '.':
|
||||
return '{}.{}'.format(identifier.tokens[0].value,
|
||||
identifier.tokens[2].value)
|
||||
return identifier.get_real_name()
|
||||
def __get_full_name(tlist: TokenList):
|
||||
if len(tlist.tokens) > 2 and tlist.tokens[1].value == '.':
|
||||
return '{}.{}'.format(tlist.tokens[0].value,
|
||||
tlist.tokens[2].value)
|
||||
return tlist.get_real_name()
|
||||
|
||||
@staticmethod
|
||||
def __is_identifier(token):
|
||||
def __is_identifier(token: Token):
|
||||
return isinstance(token, (IdentifierList, Identifier))
|
||||
|
||||
def __process_identifier(self, identifier):
|
||||
def __process_tokenlist(self, tlist: TokenList):
|
||||
# exclude subselects
|
||||
if '(' not in str(identifier):
|
||||
table_name = self.__get_full_name(identifier)
|
||||
if '(' not in str(tlist):
|
||||
table_name = self.__get_full_name(tlist)
|
||||
if table_name and not table_name.startswith(CTE_PREFIX):
|
||||
self._table_names.add(table_name)
|
||||
return
|
||||
|
||||
# store aliases
|
||||
if hasattr(identifier, 'get_alias'):
|
||||
self._alias_names.add(identifier.get_alias())
|
||||
if hasattr(identifier, 'tokens'):
|
||||
# some aliases are not parsed properly
|
||||
if identifier.tokens[0].ttype == Name:
|
||||
self._alias_names.add(identifier.tokens[0].value)
|
||||
self.__extract_from_token(identifier)
|
||||
if tlist.has_alias():
|
||||
self._alias_names.add(tlist.get_alias())
|
||||
|
||||
# some aliases are not parsed properly
|
||||
if tlist.tokens[0].ttype == Name:
|
||||
self._alias_names.add(tlist.tokens[0].value)
|
||||
self.__extract_from_token(tlist)
|
||||
|
||||
def as_create_table(self, table_name, overwrite=False):
|
||||
"""Reformats the query into the create table as query.
|
||||
|
@ -144,10 +144,11 @@ class ParsedQuery(object):
|
|||
|
||||
if table_name_preceding_token:
|
||||
if isinstance(item, Identifier):
|
||||
self.__process_identifier(item)
|
||||
self.__process_tokenlist(item)
|
||||
elif isinstance(item, IdentifierList):
|
||||
for token in item.get_identifiers():
|
||||
self.__process_identifier(token)
|
||||
if isinstance(token, TokenList):
|
||||
self.__process_tokenlist(token)
|
||||
elif isinstance(item, IdentifierList):
|
||||
for token in item.tokens:
|
||||
if not self.__is_identifier(token):
|
||||
|
|
|
@ -462,3 +462,12 @@ class SupersetTestCase(unittest.TestCase):
|
|||
'SELECT * FROM ab_user LIMIT 1',
|
||||
]
|
||||
self.assertEquals(statements, expected)
|
||||
|
||||
def test_identifier_list_with_keyword_as_alias(self):
|
||||
query = """
|
||||
WITH
|
||||
f AS (SELECT * FROM foo),
|
||||
match AS (SELECT * FROM f)
|
||||
SELECT * FROM match
|
||||
"""
|
||||
self.assertEquals({'foo'}, self.extract_tables(query))
|
||||
|
|
Loading…
Reference in New Issue