[fix] Fixing SQL parsing issue (#7374)

This commit is contained in:
John Bodley 2019-05-01 22:07:01 -07:00 committed by GitHub
parent ee78fd7b3d
commit fb627ba376
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 19 deletions

View File

@ -18,7 +18,7 @@
import logging
import sqlparse
from sqlparse.sql import Identifier, IdentifierList
from sqlparse.sql import Identifier, IdentifierList, Token, TokenList
from sqlparse.tokens import Keyword, Name
RESULT_OPERATIONS = {'UNION', 'INTERSECT', 'EXCEPT', 'SELECT'}
@ -75,32 +75,32 @@ class ParsedQuery(object):
return statements
@staticmethod
def __get_full_name(identifier):
if len(identifier.tokens) > 2 and identifier.tokens[1].value == '.':
return '{}.{}'.format(identifier.tokens[0].value,
identifier.tokens[2].value)
return identifier.get_real_name()
def __get_full_name(tlist: TokenList):
if len(tlist.tokens) > 2 and tlist.tokens[1].value == '.':
return '{}.{}'.format(tlist.tokens[0].value,
tlist.tokens[2].value)
return tlist.get_real_name()
@staticmethod
def __is_identifier(token):
def __is_identifier(token: Token):
return isinstance(token, (IdentifierList, Identifier))
def __process_identifier(self, identifier):
def __process_tokenlist(self, tlist: TokenList):
# exclude subselects
if '(' not in str(identifier):
table_name = self.__get_full_name(identifier)
if '(' not in str(tlist):
table_name = self.__get_full_name(tlist)
if table_name and not table_name.startswith(CTE_PREFIX):
self._table_names.add(table_name)
return
# store aliases
if hasattr(identifier, 'get_alias'):
self._alias_names.add(identifier.get_alias())
if hasattr(identifier, 'tokens'):
# some aliases are not parsed properly
if identifier.tokens[0].ttype == Name:
self._alias_names.add(identifier.tokens[0].value)
self.__extract_from_token(identifier)
if tlist.has_alias():
self._alias_names.add(tlist.get_alias())
# some aliases are not parsed properly
if tlist.tokens[0].ttype == Name:
self._alias_names.add(tlist.tokens[0].value)
self.__extract_from_token(tlist)
def as_create_table(self, table_name, overwrite=False):
"""Reformats the query into the create table as query.
@ -144,10 +144,11 @@ class ParsedQuery(object):
if table_name_preceding_token:
if isinstance(item, Identifier):
self.__process_identifier(item)
self.__process_tokenlist(item)
elif isinstance(item, IdentifierList):
for token in item.get_identifiers():
self.__process_identifier(token)
if isinstance(token, TokenList):
self.__process_tokenlist(token)
elif isinstance(item, IdentifierList):
for token in item.tokens:
if not self.__is_identifier(token):

View File

@ -462,3 +462,12 @@ class SupersetTestCase(unittest.TestCase):
'SELECT * FROM ab_user LIMIT 1',
]
self.assertEquals(statements, expected)
def test_identifier_list_with_keyword_as_alias(self):
query = """
WITH
f AS (SELECT * FROM foo),
match AS (SELECT * FROM f)
SELECT * FROM match
"""
self.assertEquals({'foo'}, self.extract_tables(query))