[sql] Adding lighweight Table class (#9649)

Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
John Bodley 2020-04-30 08:38:02 -07:00 committed by GitHub
parent f7f60cc75d
commit 3b0f8e9c8a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 202 additions and 169 deletions

View File

@ -19,6 +19,7 @@ colorama==0.4.3 # via apache-superset (setup.py), flask-appbuilder
contextlib2==0.6.0.post1 # via apache-superset (setup.py) contextlib2==0.6.0.post1 # via apache-superset (setup.py)
croniter==0.3.31 # via apache-superset (setup.py) croniter==0.3.31 # via apache-superset (setup.py)
cryptography==2.8 # via apache-superset (setup.py) cryptography==2.8 # via apache-superset (setup.py)
dataclasses==0.6 # via apache-superset (setup.py)
decorator==4.4.1 # via retry decorator==4.4.1 # via retry
defusedxml==0.6.0 # via python3-openid defusedxml==0.6.0 # via python3-openid
flask-appbuilder==2.3.2 # via apache-superset (setup.py) flask-appbuilder==2.3.2 # via apache-superset (setup.py)

View File

@ -45,7 +45,7 @@ combine_as_imports = true
include_trailing_comma = true include_trailing_comma = true
line_length = 88 line_length = 88
known_first_party = superset known_first_party = superset
known_third_party =alembic,apispec,backoff,bleach,celery,click,colorama,contextlib2,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parsedatetime,pathlib2,polyline,prison,pyarrow,pyhive,pytz,retry,selenium,setuptools,simplejson,sphinx_rtd_theme,sqlalchemy,sqlalchemy_utils,sqlparse,werkzeug,wtforms,wtforms_json,yaml known_third_party =alembic,apispec,backoff,bleach,celery,click,colorama,contextlib2,croniter,cryptography,dataclasses,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parsedatetime,pathlib2,polyline,prison,pyarrow,pyhive,pytz,retry,selenium,setuptools,simplejson,sphinx_rtd_theme,sqlalchemy,sqlalchemy_utils,sqlparse,werkzeug,wtforms,wtforms_json,yaml
multi_line_output = 3 multi_line_output = 3
order_by_type = false order_by_type = false

View File

@ -75,6 +75,7 @@ setup(
"contextlib2", "contextlib2",
"croniter>=0.3.28", "croniter>=0.3.28",
"cryptography>=2.4.2", "cryptography>=2.4.2",
"dataclasses<0.7",
"flask>=1.1.0, <2.0.0", "flask>=1.1.0, <2.0.0",
"flask-appbuilder>=2.3.2, <2.4.0", "flask-appbuilder>=2.3.2, <2.4.0",
"flask-caching", "flask-caching",

View File

@ -50,6 +50,7 @@ if TYPE_CHECKING:
from superset.common.query_context import QueryContext from superset.common.query_context import QueryContext
from superset.connectors.base.models import BaseDatasource from superset.connectors.base.models import BaseDatasource
from superset.models.core import Database from superset.models.core import Database
from superset.sql_parse import Table
from superset.viz import BaseViz from superset.viz import BaseViz
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -290,26 +291,23 @@ class SupersetSecurityManager(SecurityManager):
return conf.get("PERMISSION_INSTRUCTIONS_LINK") return conf.get("PERMISSION_INSTRUCTIONS_LINK")
def get_table_access_error_msg(self, tables: List[str]) -> str: def get_table_access_error_msg(self, tables: Set["Table"]) -> str:
""" """
Return the error message for the denied SQL tables. Return the error message for the denied SQL tables.
Note the table names conform to the [[cluster.]schema.]table construct. :param tables: The set of denied SQL tables
:param tables: The list of denied SQL table names
:returns: The error message :returns: The error message
""" """
quoted_tables = [f"`{t}`" for t in tables]
quoted_tables = [f"`{table}`" for table in tables]
return f"""You need access to the following tables: {", ".join(quoted_tables)}, return f"""You need access to the following tables: {", ".join(quoted_tables)},
`all_database_access` or `all_datasource_access` permission""" `all_database_access` or `all_datasource_access` permission"""
def get_table_access_link(self, tables: List[str]) -> Optional[str]: def get_table_access_link(self, tables: Set["Table"]) -> Optional[str]:
""" """
Return the access link for the denied SQL tables. Return the access link for the denied SQL tables.
Note the table names conform to the [[cluster.]schema.]table construct. :param tables: The set of denied SQL tables
:param tables: The list of denied SQL table names
:returns: The access URL :returns: The access URL
""" """
@ -318,23 +316,19 @@ class SupersetSecurityManager(SecurityManager):
return conf.get("PERMISSION_INSTRUCTIONS_LINK") return conf.get("PERMISSION_INSTRUCTIONS_LINK")
def can_access_datasource( def can_access_datasource(
self, database: "Database", table_name: str, schema: Optional[str] = None self, database: "Database", table: "Table", schema: Optional[str] = None
) -> bool:
return self._datasource_access_by_name(database, table_name, schema=schema)
def _datasource_access_by_name(
self, database: "Database", table_name: str, schema: Optional[str] = None
) -> bool: ) -> bool:
""" """
Return True if the user can access the SQL table, False otherwise. Return True if the user can access the SQL table, False otherwise.
:param database: The SQL database :param database: The SQL database
:param table_name: The SQL table name :param table: The SQL table
:param schema: The Superset schema :param schema: The fallback SQL schema if not present in the table
:returns: Whether the use can access the SQL table :returns: Whether the use can access the SQL table
""" """
from superset import db from superset import db
from superset.connectors.sqla.models import SqlaTable
if self.database_access(database) or self.all_datasource_access(): if self.database_access(database) or self.all_datasource_access():
return True return True
@ -343,74 +337,33 @@ class SupersetSecurityManager(SecurityManager):
if schema_perm and self.can_access("schema_access", schema_perm): if schema_perm and self.can_access("schema_access", schema_perm):
return True return True
datasources = ConnectorRegistry.query_datasources_by_name( datasources = SqlaTable.query_datasources_by_name(
db.session, database, table_name, schema=schema db.session, database, table.table, schema=table.schema or schema
) )
for datasource in datasources: for datasource in datasources:
if self.can_access("datasource_access", datasource.perm): if self.can_access("datasource_access", datasource.perm):
return True return True
return False return False
def _get_schema_and_table( def rejected_tables(
self, table_in_query: str, schema: str self, sql: str, database: "Database", schema: str
) -> Tuple[str, str]: ) -> Set["Table"]:
""" """
Return the SQL schema/table tuple associated with the table extracted from the Return the list of rejected SQL tables.
SQL query.
Note the table name conforms to the [[cluster.]schema.]table construct.
:param table_in_query: The SQL table name
:param schema: The fallback SQL schema if not present in the table name
:returns: The SQL schema/table tuple
"""
table_name_pieces = table_in_query.split(".")
if len(table_name_pieces) == 3:
return tuple(table_name_pieces[1:]) # type: ignore
elif len(table_name_pieces) == 2:
return tuple(table_name_pieces) # type: ignore
return (schema, table_name_pieces[0])
def _datasource_access_by_fullname(
self, database: "Database", table_in_query: str, schema: str
) -> bool:
"""
Return True if the user can access the table extracted from the SQL query, False
otherwise.
Note the table name conforms to the [[cluster.]schema.]table construct.
:param database: The Superset database
:param table_in_query: The SQL table name
:param schema: The fallback SQL schema, i.e., if not present in the table name
:returns: Whether the user can access the SQL table
"""
table_schema, table_name = self._get_schema_and_table(table_in_query, schema)
return self._datasource_access_by_name(
database, table_name, schema=table_schema
)
def rejected_tables(self, sql: str, database: "Database", schema: str) -> List[str]:
"""
Return the list of rejected SQL table names.
Note the rejected table names conform to the [[cluster.]schema.]table construct.
:param sql: The SQL statement :param sql: The SQL statement
:param database: The SQL database :param database: The SQL database
:param schema: The SQL database schema :param schema: The SQL database schema
:returns: The rejected table names :returns: The rejected tables
""" """
superset_query = sql_parse.ParsedQuery(sql) query = sql_parse.ParsedQuery(sql)
return [ return {
t table
for t in superset_query.tables for table in query.tables
if not self._datasource_access_by_fullname(database, t, schema) if not self.can_access_datasource(database, table, schema)
] }
def get_public_role(self) -> Optional[Any]: # Optional[self.role_model] def get_public_role(self) -> Optional[Any]: # Optional[self.role_model]
from superset import conf from superset import conf
@ -493,7 +446,7 @@ class SupersetSecurityManager(SecurityManager):
.filter(or_(SqlaTable.perm.in_(perms))) .filter(or_(SqlaTable.perm.in_(perms)))
.distinct() .distinct()
) )
accessible_schemas.update([t.schema for t in tables]) accessible_schemas.update([table.schema for table in tables])
return [s for s in schemas if s in accessible_schemas] return [s for s in schemas if s in accessible_schemas]

View File

@ -16,8 +16,10 @@
# under the License. # under the License.
import logging import logging
from typing import List, Optional, Set from typing import List, Optional, Set
from urllib import parse
import sqlparse import sqlparse
from dataclasses import dataclass
from sqlparse.sql import ( from sqlparse.sql import (
Function, Function,
Identifier, Identifier,
@ -57,10 +59,32 @@ def _extract_limit_from_query(statement: TokenList) -> Optional[int]:
return None return None
@dataclass(eq=True, frozen=True)
class Table: # pylint: disable=too-few-public-methods
"""
A fully qualified SQL table conforming to [[catalog.]schema.]table.
"""
table: str
schema: Optional[str] = None
catalog: Optional[str] = None
def __str__(self) -> str:
"""
Return the fully qualified SQL table name.
"""
return ".".join(
parse.quote(part, safe="").replace(".", "%2E")
for part in [self.catalog, self.schema, self.table]
if part
)
class ParsedQuery: class ParsedQuery:
def __init__(self, sql_statement: str): def __init__(self, sql_statement: str):
self.sql: str = sql_statement self.sql: str = sql_statement
self._table_names: Set[str] = set() self._tables: Set[Table] = set()
self._alias_names: Set[str] = set() self._alias_names: Set[str] = set()
self._limit: Optional[int] = None self._limit: Optional[int] = None
@ -70,12 +94,15 @@ class ParsedQuery:
self._limit = _extract_limit_from_query(statement) self._limit = _extract_limit_from_query(statement)
@property @property
def tables(self) -> Set[str]: def tables(self) -> Set[Table]:
if not self._table_names: if not self._tables:
for statement in self._parsed: for statement in self._parsed:
self.__extract_from_token(statement) self._extract_from_token(statement)
self._table_names = self._table_names - self._alias_names
return self._table_names self._tables = {
table for table in self._tables if str(table) not in self._alias_names
}
return self._tables
@property @property
def limit(self) -> Optional[int]: def limit(self) -> Optional[int]:
@ -105,13 +132,13 @@ class ParsedQuery:
return statements return statements
@staticmethod @staticmethod
def __get_full_name(tlist: TokenList) -> Optional[str]: def _get_table(tlist: TokenList) -> Optional[Table]:
""" """
Return the full unquoted table name if valid, i.e., conforms to the following Return the table if valid, i.e., conforms to the [[catalog.]schema.]table
[[cluster.]schema.]table construct. construct.
:param tlist: The SQL tokens :param tlist: The SQL tokens
:returns: The valid full table name :returns: The table if the name conforms
""" """
# Strip the alias if present. # Strip the alias if present.
@ -127,18 +154,18 @@ class ParsedQuery:
if ( if (
len(tokens) in (1, 3, 5) len(tokens) in (1, 3, 5)
and all(imt(token, t=[Name, String]) for token in tokens[0::2]) and all(imt(token, t=[Name, String]) for token in tokens[::2])
and all(imt(token, m=(Punctuation, ".")) for token in tokens[1::2]) and all(imt(token, m=(Punctuation, ".")) for token in tokens[1::2])
): ):
return ".".join([remove_quotes(token.value) for token in tokens[0::2]]) return Table(*[remove_quotes(token.value) for token in tokens[::-2]])
return None return None
@staticmethod @staticmethod
def __is_identifier(token: Token) -> bool: def _is_identifier(token: Token) -> bool:
return isinstance(token, (IdentifierList, Identifier)) return isinstance(token, (IdentifierList, Identifier))
def __process_tokenlist(self, token_list: TokenList): def _process_tokenlist(self, token_list: TokenList):
""" """
Add table names to table set Add table names to table set
@ -146,9 +173,9 @@ class ParsedQuery:
""" """
# exclude subselects # exclude subselects
if "(" not in str(token_list): if "(" not in str(token_list):
table_name = self.__get_full_name(token_list) table = self._get_table(token_list)
if table_name and not table_name.startswith(CTE_PREFIX): if table and not table.table.startswith(CTE_PREFIX):
self._table_names.add(table_name) self._tables.add(table)
return return
# store aliases # store aliases
@ -158,7 +185,7 @@ class ParsedQuery:
# some aliases are not parsed properly # some aliases are not parsed properly
if token_list.tokens[0].ttype == Name: if token_list.tokens[0].ttype == Name:
self._alias_names.add(token_list.tokens[0].value) self._alias_names.add(token_list.tokens[0].value)
self.__extract_from_token(token_list) self._extract_from_token(token_list)
def as_create_table( def as_create_table(
self, self,
@ -184,9 +211,9 @@ class ParsedQuery:
exec_sql += f"CREATE TABLE {full_table_name} AS \n{sql}" exec_sql += f"CREATE TABLE {full_table_name} AS \n{sql}"
return exec_sql return exec_sql
def __extract_from_token(self, token: Token): # pylint: disable=too-many-branches def _extract_from_token(self, token: Token): # pylint: disable=too-many-branches
""" """
Populate self._table_names from token Populate self._tables from token
:param token: instance of Token or child class, e.g. TokenList, to be processed :param token: instance of Token or child class, e.g. TokenList, to be processed
""" """
@ -196,8 +223,8 @@ class ParsedQuery:
table_name_preceding_token = False table_name_preceding_token = False
for item in token.tokens: for item in token.tokens:
if item.is_group and not self.__is_identifier(item): if item.is_group and not self._is_identifier(item):
self.__extract_from_token(item) self._extract_from_token(item)
if item.ttype in Keyword and ( if item.ttype in Keyword and (
item.normalized in PRECEDES_TABLE_NAME item.normalized in PRECEDES_TABLE_NAME
@ -212,15 +239,15 @@ class ParsedQuery:
if table_name_preceding_token: if table_name_preceding_token:
if isinstance(item, Identifier): if isinstance(item, Identifier):
self.__process_tokenlist(item) self._process_tokenlist(item)
elif isinstance(item, IdentifierList): elif isinstance(item, IdentifierList):
for token2 in item.get_identifiers(): for token2 in item.get_identifiers():
if isinstance(token2, TokenList): if isinstance(token2, TokenList):
self.__process_tokenlist(token2) self._process_tokenlist(token2)
elif isinstance(item, IdentifierList): elif isinstance(item, IdentifierList):
for token2 in item.tokens: for token2 in item.tokens:
if not self.__is_identifier(token2): if not self._is_identifier(token2):
self.__extract_from_token(item) self._extract_from_token(item)
def set_or_update_query_limit(self, new_limit: int) -> str: def set_or_update_query_limit(self, new_limit: int) -> str:
"""Returns the query with the specified limit. """Returns the query with the specified limit.

View File

@ -85,7 +85,7 @@ from superset.security.analytics_db_safety import (
check_sqlalchemy_uri, check_sqlalchemy_uri,
DBSecurityException, DBSecurityException,
) )
from superset.sql_parse import ParsedQuery from superset.sql_parse import ParsedQuery, Table
from superset.sql_validators import get_validator_by_name from superset.sql_validators import get_validator_by_name
from superset.utils import core as utils, dashboard_import_export from superset.utils import core as utils, dashboard_import_export
from superset.utils.dashboard_filter_scopes_converter import copy_filter_scopes from superset.utils.dashboard_filter_scopes_converter import copy_filter_scopes
@ -2083,7 +2083,9 @@ class Superset(BaseSupersetView):
schema = utils.parse_js_uri_path_item(schema, eval_undefined=True) schema = utils.parse_js_uri_path_item(schema, eval_undefined=True)
table_name = utils.parse_js_uri_path_item(table_name) table_name = utils.parse_js_uri_path_item(table_name)
# Check that the user can access the datasource # Check that the user can access the datasource
if not self.appbuilder.sm.can_access_datasource(database, table_name, schema): if not self.appbuilder.sm.can_access_datasource(
database, Table(table_name, schema), schema
):
stats_logger.incr( stats_logger.incr(
f"deprecated.{self.__class__.__name__}.select_star.permission_denied" f"deprecated.{self.__class__.__name__}.select_star.permission_denied"
) )

View File

@ -22,6 +22,7 @@ from flask import g
from flask_babel import lazy_gettext as _ from flask_babel import lazy_gettext as _
from superset.models.core import Database from superset.models.core import Database
from superset.sql_parse import Table
from superset.utils.core import parse_js_uri_path_item from superset.utils.core import parse_js_uri_path_item
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -45,7 +46,7 @@ def check_datasource_access(f):
return self.response_404() return self.response_404()
# Check that the user can access the datasource # Check that the user can access the datasource
if not self.appbuilder.sm.can_access_datasource( if not self.appbuilder.sm.can_access_datasource(
database, table_name_parsed, schema_name_parsed database, Table(table_name_parsed, schema_name_parsed), schema_name_parsed
): ):
self.stats_logger.incr( self.stats_logger.incr(
f"permisssion_denied_{self.__class__.__name__}.select_star" f"permisssion_denied_{self.__class__.__name__}.select_star"

View File

@ -16,90 +16,102 @@
# under the License. # under the License.
import unittest import unittest
from superset import sql_parse from superset.sql_parse import ParsedQuery, Table
class SupersetTestCase(unittest.TestCase): class SupersetTestCase(unittest.TestCase):
def extract_tables(self, query): def extract_tables(self, query):
sq = sql_parse.ParsedQuery(query) return ParsedQuery(query).tables
return sq.tables
def test_table(self):
self.assertEqual(str(Table("tbname")), "tbname")
self.assertEqual(str(Table("tbname", "schemaname")), "schemaname.tbname")
self.assertEqual(
str(Table("tbname", "schemaname", "catalogname")),
"catalogname.schemaname.tbname",
)
self.assertEqual(
str(Table("tb.name", "schema/name", "catalog\name")),
"catalog%0Aame.schema%2Fname.tb%2Ename",
)
def test_simple_select(self): def test_simple_select(self):
query = "SELECT * FROM tbname" query = "SELECT * FROM tbname"
self.assertEqual({"tbname"}, self.extract_tables(query)) self.assertEqual({Table("tbname")}, self.extract_tables(query))
query = "SELECT * FROM tbname foo" query = "SELECT * FROM tbname foo"
self.assertEqual({"tbname"}, self.extract_tables(query)) self.assertEqual({Table("tbname")}, self.extract_tables(query))
query = "SELECT * FROM tbname AS foo" query = "SELECT * FROM tbname AS foo"
self.assertEqual({"tbname"}, self.extract_tables(query)) self.assertEqual({Table("tbname")}, self.extract_tables(query))
# underscores # underscores
query = "SELECT * FROM tb_name" query = "SELECT * FROM tb_name"
self.assertEqual({"tb_name"}, self.extract_tables(query)) self.assertEqual({Table("tb_name")}, self.extract_tables(query))
# quotes # quotes
query = 'SELECT * FROM "tbname"' query = 'SELECT * FROM "tbname"'
self.assertEqual({"tbname"}, self.extract_tables(query)) self.assertEqual({Table("tbname")}, self.extract_tables(query))
# unicode encoding # unicode encoding
query = 'SELECT * FROM "tb_name" WHERE city = "Lübeck"' query = 'SELECT * FROM "tb_name" WHERE city = "Lübeck"'
self.assertEqual({"tb_name"}, self.extract_tables(query)) self.assertEqual({Table("tb_name")}, self.extract_tables(query))
# schema # schema
self.assertEqual( self.assertEqual(
{"schemaname.tbname"}, {Table("tbname", "schemaname")},
self.extract_tables("SELECT * FROM schemaname.tbname"), self.extract_tables("SELECT * FROM schemaname.tbname"),
) )
self.assertEqual( self.assertEqual(
{"schemaname.tbname"}, {Table("tbname", "schemaname")},
self.extract_tables('SELECT * FROM "schemaname"."tbname"'), self.extract_tables('SELECT * FROM "schemaname"."tbname"'),
) )
self.assertEqual( self.assertEqual(
{"schemaname.tbname"}, {Table("tbname", "schemaname")},
self.extract_tables("SELECT * FROM schemaname.tbname foo"), self.extract_tables("SELECT * FROM schemaname.tbname foo"),
) )
self.assertEqual( self.assertEqual(
{"schemaname.tbname"}, {Table("tbname", "schemaname")},
self.extract_tables("SELECT * FROM schemaname.tbname AS foo"), self.extract_tables("SELECT * FROM schemaname.tbname AS foo"),
) )
# cluster
self.assertEqual( self.assertEqual(
{"clustername.schemaname.tbname"}, {Table("tbname", "schemaname", "catalogname")},
self.extract_tables("SELECT * FROM clustername.schemaname.tbname"), self.extract_tables("SELECT * FROM catalogname.schemaname.tbname"),
) )
# Ill-defined cluster/schema/table. # Ill-defined cluster/schema/table.
self.assertEqual(set(), self.extract_tables("SELECT * FROM schemaname.")) self.assertEqual(set(), self.extract_tables("SELECT * FROM schemaname."))
self.assertEqual( self.assertEqual(
set(), self.extract_tables("SELECT * FROM clustername.schemaname.") set(), self.extract_tables("SELECT * FROM catalogname.schemaname.")
) )
self.assertEqual(set(), self.extract_tables("SELECT * FROM clustername..")) self.assertEqual(set(), self.extract_tables("SELECT * FROM catalogname.."))
self.assertEqual( self.assertEqual(
set(), self.extract_tables("SELECT * FROM clustername..tbname") set(), self.extract_tables("SELECT * FROM catalogname..tbname")
) )
# quotes # quotes
query = "SELECT field1, field2 FROM tb_name" query = "SELECT field1, field2 FROM tb_name"
self.assertEqual({"tb_name"}, self.extract_tables(query)) self.assertEqual({Table("tb_name")}, self.extract_tables(query))
query = "SELECT t1.f1, t2.f2 FROM t1, t2" query = "SELECT t1.f1, t2.f2 FROM t1, t2"
self.assertEqual({"t1", "t2"}, self.extract_tables(query)) self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
def test_select_named_table(self): def test_select_named_table(self):
query = "SELECT a.date, a.field FROM left_table a LIMIT 10" query = "SELECT a.date, a.field FROM left_table a LIMIT 10"
self.assertEqual({"left_table"}, self.extract_tables(query)) self.assertEqual({Table("left_table")}, self.extract_tables(query))
def test_reverse_select(self): def test_reverse_select(self):
query = "FROM t1 SELECT field" query = "FROM t1 SELECT field"
self.assertEqual({"t1"}, self.extract_tables(query)) self.assertEqual({Table("t1")}, self.extract_tables(query))
def test_subselect(self): def test_subselect(self):
query = """ query = """
@ -111,7 +123,9 @@ class SupersetTestCase(unittest.TestCase):
) sub, s2.t2 ) sub, s2.t2
WHERE sub.resolution = 'NONE' WHERE sub.resolution = 'NONE'
""" """
self.assertEqual({"s1.t1", "s2.t2"}, self.extract_tables(query)) self.assertEqual(
{Table("t1", "s1"), Table("t2", "s2")}, self.extract_tables(query)
)
query = """ query = """
SELECT sub.* SELECT sub.*
@ -122,7 +136,7 @@ class SupersetTestCase(unittest.TestCase):
) sub ) sub
WHERE sub.resolution = 'NONE' WHERE sub.resolution = 'NONE'
""" """
self.assertEqual({"s1.t1"}, self.extract_tables(query)) self.assertEqual({Table("t1", "s1")}, self.extract_tables(query))
query = """ query = """
SELECT * FROM t1 SELECT * FROM t1
@ -133,21 +147,24 @@ class SupersetTestCase(unittest.TestCase):
WHERE ROW(5*t2.s1,77)= WHERE ROW(5*t2.s1,77)=
(SELECT 50,11*s1 FROM t4))); (SELECT 50,11*s1 FROM t4)));
""" """
self.assertEqual({"t1", "t2", "t3", "t4"}, self.extract_tables(query)) self.assertEqual(
{Table("t1"), Table("t2"), Table("t3"), Table("t4")},
self.extract_tables(query),
)
def test_select_in_expression(self): def test_select_in_expression(self):
query = "SELECT f1, (SELECT count(1) FROM t2) FROM t1" query = "SELECT f1, (SELECT count(1) FROM t2) FROM t1"
self.assertEqual({"t1", "t2"}, self.extract_tables(query)) self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
def test_union(self): def test_union(self):
query = "SELECT * FROM t1 UNION SELECT * FROM t2" query = "SELECT * FROM t1 UNION SELECT * FROM t2"
self.assertEqual({"t1", "t2"}, self.extract_tables(query)) self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
query = "SELECT * FROM t1 UNION ALL SELECT * FROM t2" query = "SELECT * FROM t1 UNION ALL SELECT * FROM t2"
self.assertEqual({"t1", "t2"}, self.extract_tables(query)) self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
query = "SELECT * FROM t1 INTERSECT ALL SELECT * FROM t2" query = "SELECT * FROM t1 INTERSECT ALL SELECT * FROM t2"
self.assertEqual({"t1", "t2"}, self.extract_tables(query)) self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
def test_select_from_values(self): def test_select_from_values(self):
query = "SELECT * FROM VALUES (13, 42)" query = "SELECT * FROM VALUES (13, 42)"
@ -158,25 +175,25 @@ class SupersetTestCase(unittest.TestCase):
SELECT ARRAY[1, 2, 3] AS my_array SELECT ARRAY[1, 2, 3] AS my_array
FROM t1 LIMIT 10 FROM t1 LIMIT 10
""" """
self.assertEqual({"t1"}, self.extract_tables(query)) self.assertEqual({Table("t1")}, self.extract_tables(query))
def test_select_if(self): def test_select_if(self):
query = """ query = """
SELECT IF(CARDINALITY(my_array) >= 3, my_array[3], NULL) SELECT IF(CARDINALITY(my_array) >= 3, my_array[3], NULL)
FROM t1 LIMIT 10 FROM t1 LIMIT 10
""" """
self.assertEqual({"t1"}, self.extract_tables(query)) self.assertEqual({Table("t1")}, self.extract_tables(query))
# SHOW TABLES ((FROM | IN) qualifiedName)? (LIKE pattern=STRING)? # SHOW TABLES ((FROM | IN) qualifiedName)? (LIKE pattern=STRING)?
def test_show_tables(self): def test_show_tables(self):
query = "SHOW TABLES FROM s1 like '%order%'" query = "SHOW TABLES FROM s1 like '%order%'"
# TODO: figure out what should code do here # TODO: figure out what should code do here
self.assertEqual({"s1"}, self.extract_tables(query)) self.assertEqual({Table("s1")}, self.extract_tables(query))
# SHOW COLUMNS (FROM | IN) qualifiedName # SHOW COLUMNS (FROM | IN) qualifiedName
def test_show_columns(self): def test_show_columns(self):
query = "SHOW COLUMNS FROM t1" query = "SHOW COLUMNS FROM t1"
self.assertEqual({"t1"}, self.extract_tables(query)) self.assertEqual({Table("t1")}, self.extract_tables(query))
def test_where_subquery(self): def test_where_subquery(self):
query = """ query = """
@ -184,25 +201,25 @@ class SupersetTestCase(unittest.TestCase):
FROM t1 FROM t1
WHERE regionkey = (SELECT max(regionkey) FROM t2) WHERE regionkey = (SELECT max(regionkey) FROM t2)
""" """
self.assertEqual({"t1", "t2"}, self.extract_tables(query)) self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
query = """ query = """
SELECT name SELECT name
FROM t1 FROM t1
WHERE regionkey IN (SELECT regionkey FROM t2) WHERE regionkey IN (SELECT regionkey FROM t2)
""" """
self.assertEqual({"t1", "t2"}, self.extract_tables(query)) self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
query = """ query = """
SELECT name SELECT name
FROM t1 FROM t1
WHERE regionkey EXISTS (SELECT regionkey FROM t2) WHERE regionkey EXISTS (SELECT regionkey FROM t2)
""" """
self.assertEqual({"t1", "t2"}, self.extract_tables(query)) self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
# DESCRIBE | DESC qualifiedName # DESCRIBE | DESC qualifiedName
def test_describe(self): def test_describe(self):
self.assertEqual({"t1"}, self.extract_tables("DESCRIBE t1")) self.assertEqual({Table("t1")}, self.extract_tables("DESCRIBE t1"))
# SHOW PARTITIONS FROM qualifiedName (WHERE booleanExpression)? # SHOW PARTITIONS FROM qualifiedName (WHERE booleanExpression)?
# (ORDER BY sortItem (',' sortItem)*)? (LIMIT limit=(INTEGER_VALUE | ALL))? # (ORDER BY sortItem (',' sortItem)*)? (LIMIT limit=(INTEGER_VALUE | ALL))?
@ -211,11 +228,11 @@ class SupersetTestCase(unittest.TestCase):
SHOW PARTITIONS FROM orders SHOW PARTITIONS FROM orders
WHERE ds >= '2013-01-01' ORDER BY ds DESC; WHERE ds >= '2013-01-01' ORDER BY ds DESC;
""" """
self.assertEqual({"orders"}, self.extract_tables(query)) self.assertEqual({Table("orders")}, self.extract_tables(query))
def test_join(self): def test_join(self):
query = "SELECT t1.*, t2.* FROM t1 JOIN t2 ON t1.a = t2.a;" query = "SELECT t1.*, t2.* FROM t1 JOIN t2 ON t1.a = t2.a;"
self.assertEqual({"t1", "t2"}, self.extract_tables(query)) self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
# subquery + join # subquery + join
query = """ query = """
@ -229,7 +246,9 @@ class SupersetTestCase(unittest.TestCase):
) b ) b
ON a.date = b.date ON a.date = b.date
""" """
self.assertEqual({"left_table", "right_table"}, self.extract_tables(query)) self.assertEqual(
{Table("left_table"), Table("right_table")}, self.extract_tables(query)
)
query = """ query = """
SELECT a.date, b.name FROM SELECT a.date, b.name FROM
@ -242,7 +261,9 @@ class SupersetTestCase(unittest.TestCase):
) b ) b
ON a.date = b.date ON a.date = b.date
""" """
self.assertEqual({"left_table", "right_table"}, self.extract_tables(query)) self.assertEqual(
{Table("left_table"), Table("right_table")}, self.extract_tables(query)
)
query = """ query = """
SELECT a.date, b.name FROM SELECT a.date, b.name FROM
@ -255,7 +276,9 @@ class SupersetTestCase(unittest.TestCase):
) b ) b
ON a.date = b.date ON a.date = b.date
""" """
self.assertEqual({"left_table", "right_table"}, self.extract_tables(query)) self.assertEqual(
{Table("left_table"), Table("right_table")}, self.extract_tables(query)
)
query = """ query = """
SELECT a.date, b.name FROM SELECT a.date, b.name FROM
@ -268,7 +291,9 @@ class SupersetTestCase(unittest.TestCase):
) b ) b
ON a.date = b.date ON a.date = b.date
""" """
self.assertEqual({"left_table", "right_table"}, self.extract_tables(query)) self.assertEqual(
{Table("left_table"), Table("right_table")}, self.extract_tables(query)
)
# TODO: add SEMI join support, SQL Parse does not handle it. # TODO: add SEMI join support, SQL Parse does not handle it.
# query = """ # query = """
@ -296,13 +321,16 @@ class SupersetTestCase(unittest.TestCase):
WHERE ROW(5*t3.s1,77)= WHERE ROW(5*t3.s1,77)=
(SELECT 50,11*s1 FROM t4))); (SELECT 50,11*s1 FROM t4)));
""" """
self.assertEqual({"t1", "t3", "t4", "t6"}, self.extract_tables(query)) self.assertEqual(
{Table("t1"), Table("t3"), Table("t4"), Table("t6")},
self.extract_tables(query),
)
query = """ query = """
SELECT * FROM (SELECT * FROM (SELECT * FROM (SELECT * FROM EmployeeS) SELECT * FROM (SELECT * FROM (SELECT * FROM (SELECT * FROM EmployeeS)
AS S1) AS S2) AS S3; AS S1) AS S2) AS S3;
""" """
self.assertEqual({"EmployeeS"}, self.extract_tables(query)) self.assertEqual({Table("EmployeeS")}, self.extract_tables(query))
def test_with(self): def test_with(self):
query = """ query = """
@ -312,7 +340,9 @@ class SupersetTestCase(unittest.TestCase):
z AS (SELECT b AS c FROM t3) z AS (SELECT b AS c FROM t3)
SELECT c FROM z; SELECT c FROM z;
""" """
self.assertEqual({"t1", "t2", "t3"}, self.extract_tables(query)) self.assertEqual(
{Table("t1"), Table("t2"), Table("t3")}, self.extract_tables(query)
)
query = """ query = """
WITH WITH
@ -321,7 +351,7 @@ class SupersetTestCase(unittest.TestCase):
z AS (SELECT b AS c FROM y) z AS (SELECT b AS c FROM y)
SELECT c FROM z; SELECT c FROM z;
""" """
self.assertEqual({"t1"}, self.extract_tables(query)) self.assertEqual({Table("t1")}, self.extract_tables(query))
def test_reusing_aliases(self): def test_reusing_aliases(self):
query = """ query = """
@ -329,22 +359,22 @@ class SupersetTestCase(unittest.TestCase):
q2 as ( select key from src where key = '5') q2 as ( select key from src where key = '5')
select * from (select key from q1) a; select * from (select key from q1) a;
""" """
self.assertEqual({"src"}, self.extract_tables(query)) self.assertEqual({Table("src")}, self.extract_tables(query))
def test_multistatement(self): def test_multistatement(self):
query = "SELECT * FROM t1; SELECT * FROM t2" query = "SELECT * FROM t1; SELECT * FROM t2"
self.assertEqual({"t1", "t2"}, self.extract_tables(query)) self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
query = "SELECT * FROM t1; SELECT * FROM t2;" query = "SELECT * FROM t1; SELECT * FROM t2;"
self.assertEqual({"t1", "t2"}, self.extract_tables(query)) self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
def test_update_not_select(self): def test_update_not_select(self):
sql = sql_parse.ParsedQuery("UPDATE t1 SET col1 = NULL") sql = ParsedQuery("UPDATE t1 SET col1 = NULL")
self.assertEqual(False, sql.is_select()) self.assertEqual(False, sql.is_select())
self.assertEqual(False, sql.is_readonly()) self.assertEqual(False, sql.is_readonly())
def test_explain(self): def test_explain(self):
sql = sql_parse.ParsedQuery("EXPLAIN SELECT 1") sql = ParsedQuery("EXPLAIN SELECT 1")
self.assertEqual(True, sql.is_explain()) self.assertEqual(True, sql.is_explain())
self.assertEqual(False, sql.is_select()) self.assertEqual(False, sql.is_select())
@ -367,7 +397,12 @@ class SupersetTestCase(unittest.TestCase):
ORDER BY "sum__m_example" DESC ORDER BY "sum__m_example" DESC
LIMIT 10;""" LIMIT 10;"""
self.assertEqual( self.assertEqual(
{"my_l_table", "my_b_table", "my_t_table", "inner_table"}, {
Table("my_l_table"),
Table("my_b_table"),
Table("my_t_table"),
Table("inner_table"),
},
self.extract_tables(query), self.extract_tables(query),
) )
@ -375,13 +410,19 @@ class SupersetTestCase(unittest.TestCase):
query = """SELECT * query = """SELECT *
FROM table_a AS a, table_b AS b, table_c as c FROM table_a AS a, table_b AS b, table_c as c
WHERE a.id = b.id and b.id = c.id""" WHERE a.id = b.id and b.id = c.id"""
self.assertEqual({"table_a", "table_b", "table_c"}, self.extract_tables(query)) self.assertEqual(
{Table("table_a"), Table("table_b"), Table("table_c")},
self.extract_tables(query),
)
def test_mixed_from_clause(self): def test_mixed_from_clause(self):
query = """SELECT * query = """SELECT *
FROM table_a AS a, (select * from table_b) AS b, table_c as c FROM table_a AS a, (select * from table_b) AS b, table_c as c
WHERE a.id = b.id and b.id = c.id""" WHERE a.id = b.id and b.id = c.id"""
self.assertEqual({"table_a", "table_b", "table_c"}, self.extract_tables(query)) self.assertEqual(
{Table("table_a"), Table("table_b"), Table("table_c")},
self.extract_tables(query),
)
def test_nested_selects(self): def test_nested_selects(self):
query = """ query = """
@ -389,13 +430,17 @@ class SupersetTestCase(unittest.TestCase):
from INFORMATION_SCHEMA.COLUMNS from INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA like "%bi%"),0x7e))); WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
""" """
self.assertEqual({"INFORMATION_SCHEMA.COLUMNS"}, self.extract_tables(query)) self.assertEqual(
{Table("COLUMNS", "INFORMATION_SCHEMA")}, self.extract_tables(query)
)
query = """ query = """
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME) select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME)
from INFORMATION_SCHEMA.COLUMNS from INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_NAME="bi_achivement_daily"),0x7e))); WHERE TABLE_NAME="bi_achivement_daily"),0x7e)));
""" """
self.assertEqual({"INFORMATION_SCHEMA.COLUMNS"}, self.extract_tables(query)) self.assertEqual(
{Table("COLUMNS", "INFORMATION_SCHEMA")}, self.extract_tables(query)
)
def test_complex_extract_tables3(self): def test_complex_extract_tables3(self):
query = """SELECT somecol AS somecol query = """SELECT somecol AS somecol
@ -431,7 +476,10 @@ class SupersetTestCase(unittest.TestCase):
WHERE 2=2 WHERE 2=2
GROUP BY last_col GROUP BY last_col
LIMIT 50000;""" LIMIT 50000;"""
self.assertEqual({"a", "b", "c", "d", "e", "f"}, self.extract_tables(query)) self.assertEqual(
{Table("a"), Table("b"), Table("c"), Table("d"), Table("e"), Table("f")},
self.extract_tables(query),
)
def test_complex_cte_with_prefix(self): def test_complex_cte_with_prefix(self):
query = """ query = """
@ -446,23 +494,23 @@ class SupersetTestCase(unittest.TestCase):
GROUP BY SalesYear, SalesPersonID GROUP BY SalesYear, SalesPersonID
ORDER BY SalesPersonID, SalesYear; ORDER BY SalesPersonID, SalesYear;
""" """
self.assertEqual({"SalesOrderHeader"}, self.extract_tables(query)) self.assertEqual({Table("SalesOrderHeader")}, self.extract_tables(query))
def test_get_query_with_new_limit_comment(self): def test_get_query_with_new_limit_comment(self):
sql = "SELECT * FROM birth_names -- SOME COMMENT" sql = "SELECT * FROM birth_names -- SOME COMMENT"
parsed = sql_parse.ParsedQuery(sql) parsed = ParsedQuery(sql)
newsql = parsed.set_or_update_query_limit(1000) newsql = parsed.set_or_update_query_limit(1000)
self.assertEqual(newsql, sql + "\nLIMIT 1000") self.assertEqual(newsql, sql + "\nLIMIT 1000")
def test_get_query_with_new_limit_comment_with_limit(self): def test_get_query_with_new_limit_comment_with_limit(self):
sql = "SELECT * FROM birth_names -- SOME COMMENT WITH LIMIT 555" sql = "SELECT * FROM birth_names -- SOME COMMENT WITH LIMIT 555"
parsed = sql_parse.ParsedQuery(sql) parsed = ParsedQuery(sql)
newsql = parsed.set_or_update_query_limit(1000) newsql = parsed.set_or_update_query_limit(1000)
self.assertEqual(newsql, sql + "\nLIMIT 1000") self.assertEqual(newsql, sql + "\nLIMIT 1000")
def test_get_query_with_new_limit_lower(self): def test_get_query_with_new_limit_lower(self):
sql = "SELECT * FROM birth_names LIMIT 555" sql = "SELECT * FROM birth_names LIMIT 555"
parsed = sql_parse.ParsedQuery(sql) parsed = ParsedQuery(sql)
newsql = parsed.set_or_update_query_limit(1000) newsql = parsed.set_or_update_query_limit(1000)
# not applied as new limit is higher # not applied as new limit is higher
expected = "SELECT * FROM birth_names LIMIT 555" expected = "SELECT * FROM birth_names LIMIT 555"
@ -470,7 +518,7 @@ class SupersetTestCase(unittest.TestCase):
def test_get_query_with_new_limit_upper(self): def test_get_query_with_new_limit_upper(self):
sql = "SELECT * FROM birth_names LIMIT 1555" sql = "SELECT * FROM birth_names LIMIT 1555"
parsed = sql_parse.ParsedQuery(sql) parsed = ParsedQuery(sql)
newsql = parsed.set_or_update_query_limit(1000) newsql = parsed.set_or_update_query_limit(1000)
# applied as new limit is lower # applied as new limit is lower
expected = "SELECT * FROM birth_names LIMIT 1000" expected = "SELECT * FROM birth_names LIMIT 1000"
@ -481,7 +529,7 @@ class SupersetTestCase(unittest.TestCase):
SELECT * FROM birth_names; SELECT * FROM birth_names;
SELECT * FROM birth_names LIMIT 1; SELECT * FROM birth_names LIMIT 1;
""" """
parsed = sql_parse.ParsedQuery(multi_sql) parsed = ParsedQuery(multi_sql)
statements = parsed.get_statements() statements = parsed.get_statements()
self.assertEqual(len(statements), 2) self.assertEqual(len(statements), 2)
expected = ["SELECT * FROM birth_names", "SELECT * FROM birth_names LIMIT 1"] expected = ["SELECT * FROM birth_names", "SELECT * FROM birth_names LIMIT 1"]
@ -494,7 +542,7 @@ class SupersetTestCase(unittest.TestCase):
SELECT * FROM birth_names;;; SELECT * FROM birth_names;;;
SELECT * FROM birth_names LIMIT 1 SELECT * FROM birth_names LIMIT 1
""" """
parsed = sql_parse.ParsedQuery(multi_sql) parsed = ParsedQuery(multi_sql)
statements = parsed.get_statements() statements = parsed.get_statements()
self.assertEqual(len(statements), 4) self.assertEqual(len(statements), 4)
expected = [ expected = [
@ -512,4 +560,4 @@ class SupersetTestCase(unittest.TestCase):
match AS (SELECT * FROM f) match AS (SELECT * FROM f)
SELECT * FROM match SELECT * FROM match
""" """
self.assertEqual({"foo"}, self.extract_tables(query)) self.assertEqual({Table("foo")}, self.extract_tables(query))