[fix] Fixing SQL parsing issue (#7374)

This commit is contained in:
John Bodley 2019-05-01 22:07:01 -07:00 committed by GitHub
parent ee78fd7b3d
commit fb627ba376
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 19 deletions

View File

@ -18,7 +18,7 @@
import logging import logging
import sqlparse import sqlparse
from sqlparse.sql import Identifier, IdentifierList from sqlparse.sql import Identifier, IdentifierList, Token, TokenList
from sqlparse.tokens import Keyword, Name from sqlparse.tokens import Keyword, Name
RESULT_OPERATIONS = {'UNION', 'INTERSECT', 'EXCEPT', 'SELECT'} RESULT_OPERATIONS = {'UNION', 'INTERSECT', 'EXCEPT', 'SELECT'}
@ -75,32 +75,32 @@ class ParsedQuery(object):
return statements return statements
@staticmethod @staticmethod
def __get_full_name(identifier): def __get_full_name(tlist: TokenList):
if len(identifier.tokens) > 2 and identifier.tokens[1].value == '.': if len(tlist.tokens) > 2 and tlist.tokens[1].value == '.':
return '{}.{}'.format(identifier.tokens[0].value, return '{}.{}'.format(tlist.tokens[0].value,
identifier.tokens[2].value) tlist.tokens[2].value)
return identifier.get_real_name() return tlist.get_real_name()
@staticmethod @staticmethod
def __is_identifier(token): def __is_identifier(token: Token):
return isinstance(token, (IdentifierList, Identifier)) return isinstance(token, (IdentifierList, Identifier))
def __process_identifier(self, identifier): def __process_tokenlist(self, tlist: TokenList):
# exclude subselects # exclude subselects
if '(' not in str(identifier): if '(' not in str(tlist):
table_name = self.__get_full_name(identifier) table_name = self.__get_full_name(tlist)
if table_name and not table_name.startswith(CTE_PREFIX): if table_name and not table_name.startswith(CTE_PREFIX):
self._table_names.add(table_name) self._table_names.add(table_name)
return return
# store aliases # store aliases
if hasattr(identifier, 'get_alias'): if tlist.has_alias():
self._alias_names.add(identifier.get_alias()) self._alias_names.add(tlist.get_alias())
if hasattr(identifier, 'tokens'):
# some aliases are not parsed properly # some aliases are not parsed properly
if identifier.tokens[0].ttype == Name: if tlist.tokens[0].ttype == Name:
self._alias_names.add(identifier.tokens[0].value) self._alias_names.add(tlist.tokens[0].value)
self.__extract_from_token(identifier) self.__extract_from_token(tlist)
def as_create_table(self, table_name, overwrite=False): def as_create_table(self, table_name, overwrite=False):
"""Reformats the query into the create table as query. """Reformats the query into the create table as query.
@ -144,10 +144,11 @@ class ParsedQuery(object):
if table_name_preceding_token: if table_name_preceding_token:
if isinstance(item, Identifier): if isinstance(item, Identifier):
self.__process_identifier(item) self.__process_tokenlist(item)
elif isinstance(item, IdentifierList): elif isinstance(item, IdentifierList):
for token in item.get_identifiers(): for token in item.get_identifiers():
self.__process_identifier(token) if isinstance(token, TokenList):
self.__process_tokenlist(token)
elif isinstance(item, IdentifierList): elif isinstance(item, IdentifierList):
for token in item.tokens: for token in item.tokens:
if not self.__is_identifier(token): if not self.__is_identifier(token):

View File

@ -462,3 +462,12 @@ class SupersetTestCase(unittest.TestCase):
'SELECT * FROM ab_user LIMIT 1', 'SELECT * FROM ab_user LIMIT 1',
] ]
self.assertEquals(statements, expected) self.assertEquals(statements, expected)
def test_identifier_list_with_keyword_as_alias(self):
query = """
WITH
f AS (SELECT * FROM foo),
match AS (SELECT * FROM f)
SELECT * FROM match
"""
self.assertEquals({'foo'}, self.extract_tables(query))