diff --git a/requirements.txt b/requirements.txt index dd8b28a65e..61895fd4a3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,6 +19,7 @@ colorama==0.4.3 # via apache-superset (setup.py), flask-appbuilder contextlib2==0.6.0.post1 # via apache-superset (setup.py) croniter==0.3.31 # via apache-superset (setup.py) cryptography==2.8 # via apache-superset (setup.py) +dataclasses==0.6 # via apache-superset (setup.py) decorator==4.4.1 # via retry defusedxml==0.6.0 # via python3-openid flask-appbuilder==2.3.2 # via apache-superset (setup.py) diff --git a/setup.cfg b/setup.cfg index 9469118b2d..28a3ab3189 100644 --- a/setup.cfg +++ b/setup.cfg @@ -45,7 +45,7 @@ combine_as_imports = true include_trailing_comma = true line_length = 88 known_first_party = superset -known_third_party =alembic,apispec,backoff,bleach,celery,click,colorama,contextlib2,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parsedatetime,pathlib2,polyline,prison,pyarrow,pyhive,pytz,retry,selenium,setuptools,simplejson,sphinx_rtd_theme,sqlalchemy,sqlalchemy_utils,sqlparse,werkzeug,wtforms,wtforms_json,yaml +known_third_party =alembic,apispec,backoff,bleach,celery,click,colorama,contextlib2,croniter,cryptography,dataclasses,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parsedatetime,pathlib2,polyline,prison,pyarrow,pyhive,pytz,retry,selenium,setuptools,simplejson,sphinx_rtd_theme,sqlalchemy,sqlalchemy_utils,sqlparse,werkzeug,wtforms,wtforms_json,yaml multi_line_output = 3 order_by_type = false diff --git a/setup.py b/setup.py index 6c1484914e..e3e3cdc14c 100644 --- a/setup.py +++ b/setup.py @@ -75,6 +75,7 @@ setup( "contextlib2", "croniter>=0.3.28", "cryptography>=2.4.2", + "dataclasses<0.7", "flask>=1.1.0, <2.0.0", "flask-appbuilder>=2.3.2, <2.4.0", "flask-caching", diff --git a/superset/security/manager.py b/superset/security/manager.py index e3b4b1de9f..fac90685ab 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -50,6 +50,7 @@ if TYPE_CHECKING: from superset.common.query_context import QueryContext from superset.connectors.base.models import BaseDatasource from superset.models.core import Database + from superset.sql_parse import Table from superset.viz import BaseViz logger = logging.getLogger(__name__) @@ -290,26 +291,23 @@ class SupersetSecurityManager(SecurityManager): return conf.get("PERMISSION_INSTRUCTIONS_LINK") - def get_table_access_error_msg(self, tables: List[str]) -> str: + def get_table_access_error_msg(self, tables: Set["Table"]) -> str: """ Return the error message for the denied SQL tables. - Note the table names conform to the [[cluster.]schema.]table construct. - - :param tables: The list of denied SQL table names + :param tables: The set of denied SQL tables :returns: The error message """ - quoted_tables = [f"`{t}`" for t in tables] + + quoted_tables = [f"`{table}`" for table in tables] return f"""You need access to the following tables: {", ".join(quoted_tables)}, `all_database_access` or `all_datasource_access` permission""" - def get_table_access_link(self, tables: List[str]) -> Optional[str]: + def get_table_access_link(self, tables: Set["Table"]) -> Optional[str]: """ Return the access link for the denied SQL tables. - Note the table names conform to the [[cluster.]schema.]table construct. - - :param tables: The list of denied SQL table names + :param tables: The set of denied SQL tables :returns: The access URL """ @@ -318,23 +316,19 @@ class SupersetSecurityManager(SecurityManager): return conf.get("PERMISSION_INSTRUCTIONS_LINK") def can_access_datasource( - self, database: "Database", table_name: str, schema: Optional[str] = None - ) -> bool: - return self._datasource_access_by_name(database, table_name, schema=schema) - - def _datasource_access_by_name( - self, database: "Database", table_name: str, schema: Optional[str] = None + self, database: "Database", table: "Table", schema: Optional[str] = None ) -> bool: """ Return True if the user can access the SQL table, False otherwise. :param database: The SQL database - :param table_name: The SQL table name - :param schema: The Superset schema + :param table: The SQL table + :param schema: The fallback SQL schema if not present in the table :returns: Whether the use can access the SQL table """ from superset import db + from superset.connectors.sqla.models import SqlaTable if self.database_access(database) or self.all_datasource_access(): return True @@ -343,74 +337,33 @@ class SupersetSecurityManager(SecurityManager): if schema_perm and self.can_access("schema_access", schema_perm): return True - datasources = ConnectorRegistry.query_datasources_by_name( - db.session, database, table_name, schema=schema + datasources = SqlaTable.query_datasources_by_name( + db.session, database, table.table, schema=table.schema or schema ) for datasource in datasources: if self.can_access("datasource_access", datasource.perm): return True return False - def _get_schema_and_table( - self, table_in_query: str, schema: str - ) -> Tuple[str, str]: + def rejected_tables( + self, sql: str, database: "Database", schema: str + ) -> Set["Table"]: """ - Return the SQL schema/table tuple associated with the table extracted from the - SQL query. - - Note the table name conforms to the [[cluster.]schema.]table construct. - - :param table_in_query: The SQL table name - :param schema: The fallback SQL schema if not present in the table name - :returns: The SQL schema/table tuple - """ - - table_name_pieces = table_in_query.split(".") - if len(table_name_pieces) == 3: - return tuple(table_name_pieces[1:]) # type: ignore - elif len(table_name_pieces) == 2: - return tuple(table_name_pieces) # type: ignore - return (schema, table_name_pieces[0]) - - def _datasource_access_by_fullname( - self, database: "Database", table_in_query: str, schema: str - ) -> bool: - """ - Return True if the user can access the table extracted from the SQL query, False - otherwise. - - Note the table name conforms to the [[cluster.]schema.]table construct. - - :param database: The Superset database - :param table_in_query: The SQL table name - :param schema: The fallback SQL schema, i.e., if not present in the table name - :returns: Whether the user can access the SQL table - """ - - table_schema, table_name = self._get_schema_and_table(table_in_query, schema) - return self._datasource_access_by_name( - database, table_name, schema=table_schema - ) - - def rejected_tables(self, sql: str, database: "Database", schema: str) -> List[str]: - """ - Return the list of rejected SQL table names. - - Note the rejected table names conform to the [[cluster.]schema.]table construct. + Return the list of rejected SQL tables. :param sql: The SQL statement :param database: The SQL database :param schema: The SQL database schema - :returns: The rejected table names + :returns: The rejected tables """ - superset_query = sql_parse.ParsedQuery(sql) + query = sql_parse.ParsedQuery(sql) - return [ - t - for t in superset_query.tables - if not self._datasource_access_by_fullname(database, t, schema) - ] + return { + table + for table in query.tables + if not self.can_access_datasource(database, table, schema) + } def get_public_role(self) -> Optional[Any]: # Optional[self.role_model] from superset import conf @@ -493,7 +446,7 @@ class SupersetSecurityManager(SecurityManager): .filter(or_(SqlaTable.perm.in_(perms))) .distinct() ) - accessible_schemas.update([t.schema for t in tables]) + accessible_schemas.update([table.schema for table in tables]) return [s for s in schemas if s in accessible_schemas] diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 8cac2ffdab..34747e1fb9 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -16,8 +16,10 @@ # under the License. import logging from typing import List, Optional, Set +from urllib import parse import sqlparse +from dataclasses import dataclass from sqlparse.sql import ( Function, Identifier, @@ -57,10 +59,32 @@ def _extract_limit_from_query(statement: TokenList) -> Optional[int]: return None +@dataclass(eq=True, frozen=True) +class Table: # pylint: disable=too-few-public-methods + """ + A fully qualified SQL table conforming to [[catalog.]schema.]table. + """ + + table: str + schema: Optional[str] = None + catalog: Optional[str] = None + + def __str__(self) -> str: + """ + Return the fully qualified SQL table name. + """ + + return ".".join( + parse.quote(part, safe="").replace(".", "%2E") + for part in [self.catalog, self.schema, self.table] + if part + ) + + class ParsedQuery: def __init__(self, sql_statement: str): self.sql: str = sql_statement - self._table_names: Set[str] = set() + self._tables: Set[Table] = set() self._alias_names: Set[str] = set() self._limit: Optional[int] = None @@ -70,12 +94,15 @@ class ParsedQuery: self._limit = _extract_limit_from_query(statement) @property - def tables(self) -> Set[str]: - if not self._table_names: + def tables(self) -> Set[Table]: + if not self._tables: for statement in self._parsed: - self.__extract_from_token(statement) - self._table_names = self._table_names - self._alias_names - return self._table_names + self._extract_from_token(statement) + + self._tables = { + table for table in self._tables if str(table) not in self._alias_names + } + return self._tables @property def limit(self) -> Optional[int]: @@ -105,13 +132,13 @@ class ParsedQuery: return statements @staticmethod - def __get_full_name(tlist: TokenList) -> Optional[str]: + def _get_table(tlist: TokenList) -> Optional[Table]: """ - Return the full unquoted table name if valid, i.e., conforms to the following - [[cluster.]schema.]table construct. + Return the table if valid, i.e., conforms to the [[catalog.]schema.]table + construct. :param tlist: The SQL tokens - :returns: The valid full table name + :returns: The table if the name conforms """ # Strip the alias if present. @@ -127,18 +154,18 @@ class ParsedQuery: if ( len(tokens) in (1, 3, 5) - and all(imt(token, t=[Name, String]) for token in tokens[0::2]) + and all(imt(token, t=[Name, String]) for token in tokens[::2]) and all(imt(token, m=(Punctuation, ".")) for token in tokens[1::2]) ): - return ".".join([remove_quotes(token.value) for token in tokens[0::2]]) + return Table(*[remove_quotes(token.value) for token in tokens[::-2]]) return None @staticmethod - def __is_identifier(token: Token) -> bool: + def _is_identifier(token: Token) -> bool: return isinstance(token, (IdentifierList, Identifier)) - def __process_tokenlist(self, token_list: TokenList): + def _process_tokenlist(self, token_list: TokenList): """ Add table names to table set @@ -146,9 +173,9 @@ class ParsedQuery: """ # exclude subselects if "(" not in str(token_list): - table_name = self.__get_full_name(token_list) - if table_name and not table_name.startswith(CTE_PREFIX): - self._table_names.add(table_name) + table = self._get_table(token_list) + if table and not table.table.startswith(CTE_PREFIX): + self._tables.add(table) return # store aliases @@ -158,7 +185,7 @@ class ParsedQuery: # some aliases are not parsed properly if token_list.tokens[0].ttype == Name: self._alias_names.add(token_list.tokens[0].value) - self.__extract_from_token(token_list) + self._extract_from_token(token_list) def as_create_table( self, @@ -184,9 +211,9 @@ class ParsedQuery: exec_sql += f"CREATE TABLE {full_table_name} AS \n{sql}" return exec_sql - def __extract_from_token(self, token: Token): # pylint: disable=too-many-branches + def _extract_from_token(self, token: Token): # pylint: disable=too-many-branches """ - Populate self._table_names from token + Populate self._tables from token :param token: instance of Token or child class, e.g. TokenList, to be processed """ @@ -196,8 +223,8 @@ class ParsedQuery: table_name_preceding_token = False for item in token.tokens: - if item.is_group and not self.__is_identifier(item): - self.__extract_from_token(item) + if item.is_group and not self._is_identifier(item): + self._extract_from_token(item) if item.ttype in Keyword and ( item.normalized in PRECEDES_TABLE_NAME @@ -212,15 +239,15 @@ class ParsedQuery: if table_name_preceding_token: if isinstance(item, Identifier): - self.__process_tokenlist(item) + self._process_tokenlist(item) elif isinstance(item, IdentifierList): for token2 in item.get_identifiers(): if isinstance(token2, TokenList): - self.__process_tokenlist(token2) + self._process_tokenlist(token2) elif isinstance(item, IdentifierList): for token2 in item.tokens: - if not self.__is_identifier(token2): - self.__extract_from_token(item) + if not self._is_identifier(token2): + self._extract_from_token(item) def set_or_update_query_limit(self, new_limit: int) -> str: """Returns the query with the specified limit. diff --git a/superset/views/core.py b/superset/views/core.py index 9688b89b9d..641a68b115 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -85,7 +85,7 @@ from superset.security.analytics_db_safety import ( check_sqlalchemy_uri, DBSecurityException, ) -from superset.sql_parse import ParsedQuery +from superset.sql_parse import ParsedQuery, Table from superset.sql_validators import get_validator_by_name from superset.utils import core as utils, dashboard_import_export from superset.utils.dashboard_filter_scopes_converter import copy_filter_scopes @@ -2083,7 +2083,9 @@ class Superset(BaseSupersetView): schema = utils.parse_js_uri_path_item(schema, eval_undefined=True) table_name = utils.parse_js_uri_path_item(table_name) # Check that the user can access the datasource - if not self.appbuilder.sm.can_access_datasource(database, table_name, schema): + if not self.appbuilder.sm.can_access_datasource( + database, Table(table_name, schema), schema + ): stats_logger.incr( f"deprecated.{self.__class__.__name__}.select_star.permission_denied" ) diff --git a/superset/views/database/decorators.py b/superset/views/database/decorators.py index 3dd0e2acd7..322b42047f 100644 --- a/superset/views/database/decorators.py +++ b/superset/views/database/decorators.py @@ -22,6 +22,7 @@ from flask import g from flask_babel import lazy_gettext as _ from superset.models.core import Database +from superset.sql_parse import Table from superset.utils.core import parse_js_uri_path_item logger = logging.getLogger(__name__) @@ -45,7 +46,7 @@ def check_datasource_access(f): return self.response_404() # Check that the user can access the datasource if not self.appbuilder.sm.can_access_datasource( - database, table_name_parsed, schema_name_parsed + database, Table(table_name_parsed, schema_name_parsed), schema_name_parsed ): self.stats_logger.incr( f"permisssion_denied_{self.__class__.__name__}.select_star" diff --git a/tests/sql_parse_tests.py b/tests/sql_parse_tests.py index 46e54ffaff..d0ee5d18c5 100644 --- a/tests/sql_parse_tests.py +++ b/tests/sql_parse_tests.py @@ -16,90 +16,102 @@ # under the License. import unittest -from superset import sql_parse +from superset.sql_parse import ParsedQuery, Table class SupersetTestCase(unittest.TestCase): def extract_tables(self, query): - sq = sql_parse.ParsedQuery(query) - return sq.tables + return ParsedQuery(query).tables + + def test_table(self): + self.assertEqual(str(Table("tbname")), "tbname") + self.assertEqual(str(Table("tbname", "schemaname")), "schemaname.tbname") + + self.assertEqual( + str(Table("tbname", "schemaname", "catalogname")), + "catalogname.schemaname.tbname", + ) + + self.assertEqual( + str(Table("tb.name", "schema/name", "catalog\name")), + "catalog%0Aame.schema%2Fname.tb%2Ename", + ) def test_simple_select(self): query = "SELECT * FROM tbname" - self.assertEqual({"tbname"}, self.extract_tables(query)) + self.assertEqual({Table("tbname")}, self.extract_tables(query)) query = "SELECT * FROM tbname foo" - self.assertEqual({"tbname"}, self.extract_tables(query)) + self.assertEqual({Table("tbname")}, self.extract_tables(query)) query = "SELECT * FROM tbname AS foo" - self.assertEqual({"tbname"}, self.extract_tables(query)) + self.assertEqual({Table("tbname")}, self.extract_tables(query)) # underscores query = "SELECT * FROM tb_name" - self.assertEqual({"tb_name"}, self.extract_tables(query)) + self.assertEqual({Table("tb_name")}, self.extract_tables(query)) # quotes query = 'SELECT * FROM "tbname"' - self.assertEqual({"tbname"}, self.extract_tables(query)) + self.assertEqual({Table("tbname")}, self.extract_tables(query)) # unicode encoding query = 'SELECT * FROM "tb_name" WHERE city = "Lübeck"' - self.assertEqual({"tb_name"}, self.extract_tables(query)) + self.assertEqual({Table("tb_name")}, self.extract_tables(query)) # schema self.assertEqual( - {"schemaname.tbname"}, + {Table("tbname", "schemaname")}, self.extract_tables("SELECT * FROM schemaname.tbname"), ) self.assertEqual( - {"schemaname.tbname"}, + {Table("tbname", "schemaname")}, self.extract_tables('SELECT * FROM "schemaname"."tbname"'), ) self.assertEqual( - {"schemaname.tbname"}, + {Table("tbname", "schemaname")}, self.extract_tables("SELECT * FROM schemaname.tbname foo"), ) self.assertEqual( - {"schemaname.tbname"}, + {Table("tbname", "schemaname")}, self.extract_tables("SELECT * FROM schemaname.tbname AS foo"), ) - # cluster self.assertEqual( - {"clustername.schemaname.tbname"}, - self.extract_tables("SELECT * FROM clustername.schemaname.tbname"), + {Table("tbname", "schemaname", "catalogname")}, + self.extract_tables("SELECT * FROM catalogname.schemaname.tbname"), ) # Ill-defined cluster/schema/table. self.assertEqual(set(), self.extract_tables("SELECT * FROM schemaname.")) self.assertEqual( - set(), self.extract_tables("SELECT * FROM clustername.schemaname.") + set(), self.extract_tables("SELECT * FROM catalogname.schemaname.") ) - self.assertEqual(set(), self.extract_tables("SELECT * FROM clustername..")) + self.assertEqual(set(), self.extract_tables("SELECT * FROM catalogname..")) self.assertEqual( - set(), self.extract_tables("SELECT * FROM clustername..tbname") + set(), self.extract_tables("SELECT * FROM catalogname..tbname") ) # quotes query = "SELECT field1, field2 FROM tb_name" - self.assertEqual({"tb_name"}, self.extract_tables(query)) + self.assertEqual({Table("tb_name")}, self.extract_tables(query)) query = "SELECT t1.f1, t2.f2 FROM t1, t2" - self.assertEqual({"t1", "t2"}, self.extract_tables(query)) + self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query)) def test_select_named_table(self): query = "SELECT a.date, a.field FROM left_table a LIMIT 10" - self.assertEqual({"left_table"}, self.extract_tables(query)) + self.assertEqual({Table("left_table")}, self.extract_tables(query)) def test_reverse_select(self): query = "FROM t1 SELECT field" - self.assertEqual({"t1"}, self.extract_tables(query)) + self.assertEqual({Table("t1")}, self.extract_tables(query)) def test_subselect(self): query = """ @@ -111,7 +123,9 @@ class SupersetTestCase(unittest.TestCase): ) sub, s2.t2 WHERE sub.resolution = 'NONE' """ - self.assertEqual({"s1.t1", "s2.t2"}, self.extract_tables(query)) + self.assertEqual( + {Table("t1", "s1"), Table("t2", "s2")}, self.extract_tables(query) + ) query = """ SELECT sub.* @@ -122,7 +136,7 @@ class SupersetTestCase(unittest.TestCase): ) sub WHERE sub.resolution = 'NONE' """ - self.assertEqual({"s1.t1"}, self.extract_tables(query)) + self.assertEqual({Table("t1", "s1")}, self.extract_tables(query)) query = """ SELECT * FROM t1 @@ -133,21 +147,24 @@ class SupersetTestCase(unittest.TestCase): WHERE ROW(5*t2.s1,77)= (SELECT 50,11*s1 FROM t4))); """ - self.assertEqual({"t1", "t2", "t3", "t4"}, self.extract_tables(query)) + self.assertEqual( + {Table("t1"), Table("t2"), Table("t3"), Table("t4")}, + self.extract_tables(query), + ) def test_select_in_expression(self): query = "SELECT f1, (SELECT count(1) FROM t2) FROM t1" - self.assertEqual({"t1", "t2"}, self.extract_tables(query)) + self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query)) def test_union(self): query = "SELECT * FROM t1 UNION SELECT * FROM t2" - self.assertEqual({"t1", "t2"}, self.extract_tables(query)) + self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query)) query = "SELECT * FROM t1 UNION ALL SELECT * FROM t2" - self.assertEqual({"t1", "t2"}, self.extract_tables(query)) + self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query)) query = "SELECT * FROM t1 INTERSECT ALL SELECT * FROM t2" - self.assertEqual({"t1", "t2"}, self.extract_tables(query)) + self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query)) def test_select_from_values(self): query = "SELECT * FROM VALUES (13, 42)" @@ -158,25 +175,25 @@ class SupersetTestCase(unittest.TestCase): SELECT ARRAY[1, 2, 3] AS my_array FROM t1 LIMIT 10 """ - self.assertEqual({"t1"}, self.extract_tables(query)) + self.assertEqual({Table("t1")}, self.extract_tables(query)) def test_select_if(self): query = """ SELECT IF(CARDINALITY(my_array) >= 3, my_array[3], NULL) FROM t1 LIMIT 10 """ - self.assertEqual({"t1"}, self.extract_tables(query)) + self.assertEqual({Table("t1")}, self.extract_tables(query)) # SHOW TABLES ((FROM | IN) qualifiedName)? (LIKE pattern=STRING)? def test_show_tables(self): query = "SHOW TABLES FROM s1 like '%order%'" # TODO: figure out what should code do here - self.assertEqual({"s1"}, self.extract_tables(query)) + self.assertEqual({Table("s1")}, self.extract_tables(query)) # SHOW COLUMNS (FROM | IN) qualifiedName def test_show_columns(self): query = "SHOW COLUMNS FROM t1" - self.assertEqual({"t1"}, self.extract_tables(query)) + self.assertEqual({Table("t1")}, self.extract_tables(query)) def test_where_subquery(self): query = """ @@ -184,25 +201,25 @@ class SupersetTestCase(unittest.TestCase): FROM t1 WHERE regionkey = (SELECT max(regionkey) FROM t2) """ - self.assertEqual({"t1", "t2"}, self.extract_tables(query)) + self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query)) query = """ SELECT name FROM t1 WHERE regionkey IN (SELECT regionkey FROM t2) """ - self.assertEqual({"t1", "t2"}, self.extract_tables(query)) + self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query)) query = """ SELECT name FROM t1 WHERE regionkey EXISTS (SELECT regionkey FROM t2) """ - self.assertEqual({"t1", "t2"}, self.extract_tables(query)) + self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query)) # DESCRIBE | DESC qualifiedName def test_describe(self): - self.assertEqual({"t1"}, self.extract_tables("DESCRIBE t1")) + self.assertEqual({Table("t1")}, self.extract_tables("DESCRIBE t1")) # SHOW PARTITIONS FROM qualifiedName (WHERE booleanExpression)? # (ORDER BY sortItem (',' sortItem)*)? (LIMIT limit=(INTEGER_VALUE | ALL))? @@ -211,11 +228,11 @@ class SupersetTestCase(unittest.TestCase): SHOW PARTITIONS FROM orders WHERE ds >= '2013-01-01' ORDER BY ds DESC; """ - self.assertEqual({"orders"}, self.extract_tables(query)) + self.assertEqual({Table("orders")}, self.extract_tables(query)) def test_join(self): query = "SELECT t1.*, t2.* FROM t1 JOIN t2 ON t1.a = t2.a;" - self.assertEqual({"t1", "t2"}, self.extract_tables(query)) + self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query)) # subquery + join query = """ @@ -229,7 +246,9 @@ class SupersetTestCase(unittest.TestCase): ) b ON a.date = b.date """ - self.assertEqual({"left_table", "right_table"}, self.extract_tables(query)) + self.assertEqual( + {Table("left_table"), Table("right_table")}, self.extract_tables(query) + ) query = """ SELECT a.date, b.name FROM @@ -242,7 +261,9 @@ class SupersetTestCase(unittest.TestCase): ) b ON a.date = b.date """ - self.assertEqual({"left_table", "right_table"}, self.extract_tables(query)) + self.assertEqual( + {Table("left_table"), Table("right_table")}, self.extract_tables(query) + ) query = """ SELECT a.date, b.name FROM @@ -255,7 +276,9 @@ class SupersetTestCase(unittest.TestCase): ) b ON a.date = b.date """ - self.assertEqual({"left_table", "right_table"}, self.extract_tables(query)) + self.assertEqual( + {Table("left_table"), Table("right_table")}, self.extract_tables(query) + ) query = """ SELECT a.date, b.name FROM @@ -268,7 +291,9 @@ class SupersetTestCase(unittest.TestCase): ) b ON a.date = b.date """ - self.assertEqual({"left_table", "right_table"}, self.extract_tables(query)) + self.assertEqual( + {Table("left_table"), Table("right_table")}, self.extract_tables(query) + ) # TODO: add SEMI join support, SQL Parse does not handle it. # query = """ @@ -296,13 +321,16 @@ class SupersetTestCase(unittest.TestCase): WHERE ROW(5*t3.s1,77)= (SELECT 50,11*s1 FROM t4))); """ - self.assertEqual({"t1", "t3", "t4", "t6"}, self.extract_tables(query)) + self.assertEqual( + {Table("t1"), Table("t3"), Table("t4"), Table("t6")}, + self.extract_tables(query), + ) query = """ SELECT * FROM (SELECT * FROM (SELECT * FROM (SELECT * FROM EmployeeS) AS S1) AS S2) AS S3; """ - self.assertEqual({"EmployeeS"}, self.extract_tables(query)) + self.assertEqual({Table("EmployeeS")}, self.extract_tables(query)) def test_with(self): query = """ @@ -312,7 +340,9 @@ class SupersetTestCase(unittest.TestCase): z AS (SELECT b AS c FROM t3) SELECT c FROM z; """ - self.assertEqual({"t1", "t2", "t3"}, self.extract_tables(query)) + self.assertEqual( + {Table("t1"), Table("t2"), Table("t3")}, self.extract_tables(query) + ) query = """ WITH @@ -321,7 +351,7 @@ class SupersetTestCase(unittest.TestCase): z AS (SELECT b AS c FROM y) SELECT c FROM z; """ - self.assertEqual({"t1"}, self.extract_tables(query)) + self.assertEqual({Table("t1")}, self.extract_tables(query)) def test_reusing_aliases(self): query = """ @@ -329,22 +359,22 @@ class SupersetTestCase(unittest.TestCase): q2 as ( select key from src where key = '5') select * from (select key from q1) a; """ - self.assertEqual({"src"}, self.extract_tables(query)) + self.assertEqual({Table("src")}, self.extract_tables(query)) def test_multistatement(self): query = "SELECT * FROM t1; SELECT * FROM t2" - self.assertEqual({"t1", "t2"}, self.extract_tables(query)) + self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query)) query = "SELECT * FROM t1; SELECT * FROM t2;" - self.assertEqual({"t1", "t2"}, self.extract_tables(query)) + self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query)) def test_update_not_select(self): - sql = sql_parse.ParsedQuery("UPDATE t1 SET col1 = NULL") + sql = ParsedQuery("UPDATE t1 SET col1 = NULL") self.assertEqual(False, sql.is_select()) self.assertEqual(False, sql.is_readonly()) def test_explain(self): - sql = sql_parse.ParsedQuery("EXPLAIN SELECT 1") + sql = ParsedQuery("EXPLAIN SELECT 1") self.assertEqual(True, sql.is_explain()) self.assertEqual(False, sql.is_select()) @@ -367,7 +397,12 @@ class SupersetTestCase(unittest.TestCase): ORDER BY "sum__m_example" DESC LIMIT 10;""" self.assertEqual( - {"my_l_table", "my_b_table", "my_t_table", "inner_table"}, + { + Table("my_l_table"), + Table("my_b_table"), + Table("my_t_table"), + Table("inner_table"), + }, self.extract_tables(query), ) @@ -375,13 +410,19 @@ class SupersetTestCase(unittest.TestCase): query = """SELECT * FROM table_a AS a, table_b AS b, table_c as c WHERE a.id = b.id and b.id = c.id""" - self.assertEqual({"table_a", "table_b", "table_c"}, self.extract_tables(query)) + self.assertEqual( + {Table("table_a"), Table("table_b"), Table("table_c")}, + self.extract_tables(query), + ) def test_mixed_from_clause(self): query = """SELECT * FROM table_a AS a, (select * from table_b) AS b, table_c as c WHERE a.id = b.id and b.id = c.id""" - self.assertEqual({"table_a", "table_b", "table_c"}, self.extract_tables(query)) + self.assertEqual( + {Table("table_a"), Table("table_b"), Table("table_c")}, + self.extract_tables(query), + ) def test_nested_selects(self): query = """ @@ -389,13 +430,17 @@ class SupersetTestCase(unittest.TestCase): from INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA like "%bi%"),0x7e))); """ - self.assertEqual({"INFORMATION_SCHEMA.COLUMNS"}, self.extract_tables(query)) + self.assertEqual( + {Table("COLUMNS", "INFORMATION_SCHEMA")}, self.extract_tables(query) + ) query = """ select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME) from INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME="bi_achivement_daily"),0x7e))); """ - self.assertEqual({"INFORMATION_SCHEMA.COLUMNS"}, self.extract_tables(query)) + self.assertEqual( + {Table("COLUMNS", "INFORMATION_SCHEMA")}, self.extract_tables(query) + ) def test_complex_extract_tables3(self): query = """SELECT somecol AS somecol @@ -431,7 +476,10 @@ class SupersetTestCase(unittest.TestCase): WHERE 2=2 GROUP BY last_col LIMIT 50000;""" - self.assertEqual({"a", "b", "c", "d", "e", "f"}, self.extract_tables(query)) + self.assertEqual( + {Table("a"), Table("b"), Table("c"), Table("d"), Table("e"), Table("f")}, + self.extract_tables(query), + ) def test_complex_cte_with_prefix(self): query = """ @@ -446,23 +494,23 @@ class SupersetTestCase(unittest.TestCase): GROUP BY SalesYear, SalesPersonID ORDER BY SalesPersonID, SalesYear; """ - self.assertEqual({"SalesOrderHeader"}, self.extract_tables(query)) + self.assertEqual({Table("SalesOrderHeader")}, self.extract_tables(query)) def test_get_query_with_new_limit_comment(self): sql = "SELECT * FROM birth_names -- SOME COMMENT" - parsed = sql_parse.ParsedQuery(sql) + parsed = ParsedQuery(sql) newsql = parsed.set_or_update_query_limit(1000) self.assertEqual(newsql, sql + "\nLIMIT 1000") def test_get_query_with_new_limit_comment_with_limit(self): sql = "SELECT * FROM birth_names -- SOME COMMENT WITH LIMIT 555" - parsed = sql_parse.ParsedQuery(sql) + parsed = ParsedQuery(sql) newsql = parsed.set_or_update_query_limit(1000) self.assertEqual(newsql, sql + "\nLIMIT 1000") def test_get_query_with_new_limit_lower(self): sql = "SELECT * FROM birth_names LIMIT 555" - parsed = sql_parse.ParsedQuery(sql) + parsed = ParsedQuery(sql) newsql = parsed.set_or_update_query_limit(1000) # not applied as new limit is higher expected = "SELECT * FROM birth_names LIMIT 555" @@ -470,7 +518,7 @@ class SupersetTestCase(unittest.TestCase): def test_get_query_with_new_limit_upper(self): sql = "SELECT * FROM birth_names LIMIT 1555" - parsed = sql_parse.ParsedQuery(sql) + parsed = ParsedQuery(sql) newsql = parsed.set_or_update_query_limit(1000) # applied as new limit is lower expected = "SELECT * FROM birth_names LIMIT 1000" @@ -481,7 +529,7 @@ class SupersetTestCase(unittest.TestCase): SELECT * FROM birth_names; SELECT * FROM birth_names LIMIT 1; """ - parsed = sql_parse.ParsedQuery(multi_sql) + parsed = ParsedQuery(multi_sql) statements = parsed.get_statements() self.assertEqual(len(statements), 2) expected = ["SELECT * FROM birth_names", "SELECT * FROM birth_names LIMIT 1"] @@ -494,7 +542,7 @@ class SupersetTestCase(unittest.TestCase): SELECT * FROM birth_names;;; SELECT * FROM birth_names LIMIT 1 """ - parsed = sql_parse.ParsedQuery(multi_sql) + parsed = ParsedQuery(multi_sql) statements = parsed.get_statements() self.assertEqual(len(statements), 4) expected = [ @@ -512,4 +560,4 @@ class SupersetTestCase(unittest.TestCase): match AS (SELECT * FROM f) SELECT * FROM match """ - self.assertEqual({"foo"}, self.extract_tables(query)) + self.assertEqual({Table("foo")}, self.extract_tables(query))