# -*- coding: utf-8 -*- from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals import unittest from superset import sql_parse class SupersetTestCase(unittest.TestCase): def extract_tables(self, query): sq = sql_parse.SupersetQuery(query) return sq.tables def test_simple_select(self): query = 'SELECT * FROM tbname' self.assertEquals({'tbname'}, self.extract_tables(query)) # underscores query = 'SELECT * FROM tb_name' self.assertEquals({'tb_name'}, self.extract_tables(query)) # quotes query = 'SELECT * FROM "tbname"' self.assertEquals({'tbname'}, self.extract_tables(query)) # unicode encoding query = 'SELECT * FROM "tb_name" WHERE city = "Lübeck"' self.assertEquals({'tb_name'}, self.extract_tables(query)) # schema self.assertEquals( {'schemaname.tbname'}, self.extract_tables('SELECT * FROM schemaname.tbname')) # quotes query = 'SELECT field1, field2 FROM tb_name' self.assertEquals({'tb_name'}, self.extract_tables(query)) query = 'SELECT t1.f1, t2.f2 FROM t1, t2' self.assertEquals({'t1', 't2'}, self.extract_tables(query)) def test_select_named_table(self): query = 'SELECT a.date, a.field FROM left_table a LIMIT 10' self.assertEquals( {'left_table'}, self.extract_tables(query)) def test_reverse_select(self): query = 'FROM t1 SELECT field' self.assertEquals({'t1'}, self.extract_tables(query)) def test_subselect(self): query = """ SELECT sub.* FROM ( SELECT * FROM s1.t1 WHERE day_of_week = 'Friday' ) sub, s2.t2 WHERE sub.resolution = 'NONE' """ self.assertEquals({'s1.t1', 's2.t2'}, self.extract_tables(query)) query = """ SELECT sub.* FROM ( SELECT * FROM s1.t1 WHERE day_of_week = 'Friday' ) sub WHERE sub.resolution = 'NONE' """ self.assertEquals({'s1.t1'}, self.extract_tables(query)) query = """ SELECT * FROM t1 WHERE s11 > ANY (SELECT COUNT(*) /* no hint */ FROM t2 WHERE NOT EXISTS (SELECT * FROM t3 WHERE ROW(5*t2.s1,77)= (SELECT 50,11*s1 FROM t4))); """ self.assertEquals({'t1', 't2', 't3', 't4'}, self.extract_tables(query)) def test_select_in_expression(self): query = 'SELECT f1, (SELECT count(1) FROM t2) FROM t1' self.assertEquals({'t1', 't2'}, self.extract_tables(query)) def test_union(self): query = 'SELECT * FROM t1 UNION SELECT * FROM t2' self.assertEquals({'t1', 't2'}, self.extract_tables(query)) query = 'SELECT * FROM t1 UNION ALL SELECT * FROM t2' self.assertEquals({'t1', 't2'}, self.extract_tables(query)) query = 'SELECT * FROM t1 INTERSECT ALL SELECT * FROM t2' self.assertEquals({'t1', 't2'}, self.extract_tables(query)) def test_select_from_values(self): query = 'SELECT * FROM VALUES (13, 42)' self.assertFalse(self.extract_tables(query)) def test_select_array(self): query = """ SELECT ARRAY[1, 2, 3] AS my_array FROM t1 LIMIT 10 """ self.assertEquals({'t1'}, self.extract_tables(query)) def test_select_if(self): query = """ SELECT IF(CARDINALITY(my_array) >= 3, my_array[3], NULL) FROM t1 LIMIT 10 """ self.assertEquals({'t1'}, self.extract_tables(query)) # SHOW TABLES ((FROM | IN) qualifiedName)? (LIKE pattern=STRING)? def test_show_tables(self): query = "SHOW TABLES FROM s1 like '%order%'" # TODO: figure out what should code do here self.assertEquals({'s1'}, self.extract_tables(query)) # SHOW COLUMNS (FROM | IN) qualifiedName def test_show_columns(self): query = 'SHOW COLUMNS FROM t1' self.assertEquals({'t1'}, self.extract_tables(query)) def test_where_subquery(self): query = """ SELECT name FROM t1 WHERE regionkey = (SELECT max(regionkey) FROM t2) """ self.assertEquals({'t1', 't2'}, self.extract_tables(query)) query = """ SELECT name FROM t1 WHERE regionkey IN (SELECT regionkey FROM t2) """ self.assertEquals({'t1', 't2'}, self.extract_tables(query)) query = """ SELECT name FROM t1 WHERE regionkey EXISTS (SELECT regionkey FROM t2) """ self.assertEquals({'t1', 't2'}, self.extract_tables(query)) # DESCRIBE | DESC qualifiedName def test_describe(self): self.assertEquals({'t1'}, self.extract_tables('DESCRIBE t1')) self.assertEquals({'t1'}, self.extract_tables('DESC t1')) # SHOW PARTITIONS FROM qualifiedName (WHERE booleanExpression)? # (ORDER BY sortItem (',' sortItem)*)? (LIMIT limit=(INTEGER_VALUE | ALL))? def test_show_partitions(self): query = """ SHOW PARTITIONS FROM orders WHERE ds >= '2013-01-01' ORDER BY ds DESC; """ self.assertEquals({'orders'}, self.extract_tables(query)) def test_join(self): query = 'SELECT t1.*, t2.* FROM t1 JOIN t2 ON t1.a = t2.a;' self.assertEquals({'t1', 't2'}, self.extract_tables(query)) # subquery + join query = """ SELECT a.date, b.name FROM left_table a JOIN ( SELECT CAST((b.year) as VARCHAR) date, name FROM right_table ) b ON a.date = b.date """ self.assertEquals({'left_table', 'right_table'}, self.extract_tables(query)) query = """ SELECT a.date, b.name FROM left_table a LEFT INNER JOIN ( SELECT CAST((b.year) as VARCHAR) date, name FROM right_table ) b ON a.date = b.date """ self.assertEquals({'left_table', 'right_table'}, self.extract_tables(query)) query = """ SELECT a.date, b.name FROM left_table a RIGHT OUTER JOIN ( SELECT CAST((b.year) as VARCHAR) date, name FROM right_table ) b ON a.date = b.date """ self.assertEquals({'left_table', 'right_table'}, self.extract_tables(query)) query = """ SELECT a.date, b.name FROM left_table a FULL OUTER JOIN ( SELECT CAST((b.year) as VARCHAR) date, name FROM right_table ) b ON a.date = b.date """ self.assertEquals({'left_table', 'right_table'}, self.extract_tables(query)) # TODO: add SEMI join support, SQL Parse does not handle it. # query = """ # SELECT a.date, b.name FROM # left_table a # LEFT SEMI JOIN ( # SELECT # CAST((b.year) as VARCHAR) date, # name # FROM right_table # ) b # ON a.date = b.date # """ # self.assertEquals({'left_table', 'right_table'}, # sql_parse.extract_tables(query)) def test_combinations(self): query = """ SELECT * FROM t1 WHERE s11 > ANY (SELECT * FROM t1 UNION ALL SELECT * FROM ( SELECT t6.*, t3.* FROM t6 JOIN t3 ON t6.a = t3.a) tmp_join WHERE NOT EXISTS (SELECT * FROM t3 WHERE ROW(5*t3.s1,77)= (SELECT 50,11*s1 FROM t4))); """ self.assertEquals({'t1', 't3', 't4', 't6'}, self.extract_tables(query)) query = """ SELECT * FROM (SELECT * FROM (SELECT * FROM (SELECT * FROM EmployeeS) AS S1) AS S2) AS S3; """ self.assertEquals({'EmployeeS'}, self.extract_tables(query)) def test_with(self): query = """ WITH x AS (SELECT a FROM t1), y AS (SELECT a AS b FROM t2), z AS (SELECT b AS c FROM t3) SELECT c FROM z; """ self.assertEquals({'t1', 't2', 't3'}, self.extract_tables(query)) query = """ WITH x AS (SELECT a FROM t1), y AS (SELECT a AS b FROM x), z AS (SELECT b AS c FROM y) SELECT c FROM z; """ self.assertEquals({'t1'}, self.extract_tables(query)) def test_reusing_aliases(self): query = """ with q1 as ( select key from q2 where key = '5'), q2 as ( select key from src where key = '5') select * from (select key from q1) a; """ self.assertEquals({'src'}, self.extract_tables(query)) def multistatement(self): query = 'SELECT * FROM t1; SELECT * FROM t2' self.assertEquals({'t1', 't2'}, self.extract_tables(query)) query = 'SELECT * FROM t1; SELECT * FROM t2;' self.assertEquals({'t1', 't2'}, self.extract_tables(query))