[sql] Adding lighweight Table class (#9649)

Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
John Bodley 2020-04-30 08:38:02 -07:00 committed by GitHub
parent f7f60cc75d
commit 3b0f8e9c8a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 202 additions and 169 deletions

View File

@ -19,6 +19,7 @@ colorama==0.4.3 # via apache-superset (setup.py), flask-appbuilder
contextlib2==0.6.0.post1 # via apache-superset (setup.py)
croniter==0.3.31 # via apache-superset (setup.py)
cryptography==2.8 # via apache-superset (setup.py)
dataclasses==0.6 # via apache-superset (setup.py)
decorator==4.4.1 # via retry
defusedxml==0.6.0 # via python3-openid
flask-appbuilder==2.3.2 # via apache-superset (setup.py)

View File

@ -45,7 +45,7 @@ combine_as_imports = true
include_trailing_comma = true
line_length = 88
known_first_party = superset
known_third_party =alembic,apispec,backoff,bleach,celery,click,colorama,contextlib2,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parsedatetime,pathlib2,polyline,prison,pyarrow,pyhive,pytz,retry,selenium,setuptools,simplejson,sphinx_rtd_theme,sqlalchemy,sqlalchemy_utils,sqlparse,werkzeug,wtforms,wtforms_json,yaml
known_third_party =alembic,apispec,backoff,bleach,celery,click,colorama,contextlib2,croniter,cryptography,dataclasses,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parsedatetime,pathlib2,polyline,prison,pyarrow,pyhive,pytz,retry,selenium,setuptools,simplejson,sphinx_rtd_theme,sqlalchemy,sqlalchemy_utils,sqlparse,werkzeug,wtforms,wtforms_json,yaml
multi_line_output = 3
order_by_type = false

View File

@ -75,6 +75,7 @@ setup(
"contextlib2",
"croniter>=0.3.28",
"cryptography>=2.4.2",
"dataclasses<0.7",
"flask>=1.1.0, <2.0.0",
"flask-appbuilder>=2.3.2, <2.4.0",
"flask-caching",

View File

@ -50,6 +50,7 @@ if TYPE_CHECKING:
from superset.common.query_context import QueryContext
from superset.connectors.base.models import BaseDatasource
from superset.models.core import Database
from superset.sql_parse import Table
from superset.viz import BaseViz
logger = logging.getLogger(__name__)
@ -290,26 +291,23 @@ class SupersetSecurityManager(SecurityManager):
return conf.get("PERMISSION_INSTRUCTIONS_LINK")
def get_table_access_error_msg(self, tables: List[str]) -> str:
def get_table_access_error_msg(self, tables: Set["Table"]) -> str:
"""
Return the error message for the denied SQL tables.
Note the table names conform to the [[cluster.]schema.]table construct.
:param tables: The list of denied SQL table names
:param tables: The set of denied SQL tables
:returns: The error message
"""
quoted_tables = [f"`{t}`" for t in tables]
quoted_tables = [f"`{table}`" for table in tables]
return f"""You need access to the following tables: {", ".join(quoted_tables)},
`all_database_access` or `all_datasource_access` permission"""
def get_table_access_link(self, tables: List[str]) -> Optional[str]:
def get_table_access_link(self, tables: Set["Table"]) -> Optional[str]:
"""
Return the access link for the denied SQL tables.
Note the table names conform to the [[cluster.]schema.]table construct.
:param tables: The list of denied SQL table names
:param tables: The set of denied SQL tables
:returns: The access URL
"""
@ -318,23 +316,19 @@ class SupersetSecurityManager(SecurityManager):
return conf.get("PERMISSION_INSTRUCTIONS_LINK")
def can_access_datasource(
self, database: "Database", table_name: str, schema: Optional[str] = None
) -> bool:
return self._datasource_access_by_name(database, table_name, schema=schema)
def _datasource_access_by_name(
self, database: "Database", table_name: str, schema: Optional[str] = None
self, database: "Database", table: "Table", schema: Optional[str] = None
) -> bool:
"""
Return True if the user can access the SQL table, False otherwise.
:param database: The SQL database
:param table_name: The SQL table name
:param schema: The Superset schema
:param table: The SQL table
:param schema: The fallback SQL schema if not present in the table
:returns: Whether the use can access the SQL table
"""
from superset import db
from superset.connectors.sqla.models import SqlaTable
if self.database_access(database) or self.all_datasource_access():
return True
@ -343,74 +337,33 @@ class SupersetSecurityManager(SecurityManager):
if schema_perm and self.can_access("schema_access", schema_perm):
return True
datasources = ConnectorRegistry.query_datasources_by_name(
db.session, database, table_name, schema=schema
datasources = SqlaTable.query_datasources_by_name(
db.session, database, table.table, schema=table.schema or schema
)
for datasource in datasources:
if self.can_access("datasource_access", datasource.perm):
return True
return False
def _get_schema_and_table(
self, table_in_query: str, schema: str
) -> Tuple[str, str]:
def rejected_tables(
self, sql: str, database: "Database", schema: str
) -> Set["Table"]:
"""
Return the SQL schema/table tuple associated with the table extracted from the
SQL query.
Note the table name conforms to the [[cluster.]schema.]table construct.
:param table_in_query: The SQL table name
:param schema: The fallback SQL schema if not present in the table name
:returns: The SQL schema/table tuple
"""
table_name_pieces = table_in_query.split(".")
if len(table_name_pieces) == 3:
return tuple(table_name_pieces[1:]) # type: ignore
elif len(table_name_pieces) == 2:
return tuple(table_name_pieces) # type: ignore
return (schema, table_name_pieces[0])
def _datasource_access_by_fullname(
self, database: "Database", table_in_query: str, schema: str
) -> bool:
"""
Return True if the user can access the table extracted from the SQL query, False
otherwise.
Note the table name conforms to the [[cluster.]schema.]table construct.
:param database: The Superset database
:param table_in_query: The SQL table name
:param schema: The fallback SQL schema, i.e., if not present in the table name
:returns: Whether the user can access the SQL table
"""
table_schema, table_name = self._get_schema_and_table(table_in_query, schema)
return self._datasource_access_by_name(
database, table_name, schema=table_schema
)
def rejected_tables(self, sql: str, database: "Database", schema: str) -> List[str]:
"""
Return the list of rejected SQL table names.
Note the rejected table names conform to the [[cluster.]schema.]table construct.
Return the list of rejected SQL tables.
:param sql: The SQL statement
:param database: The SQL database
:param schema: The SQL database schema
:returns: The rejected table names
:returns: The rejected tables
"""
superset_query = sql_parse.ParsedQuery(sql)
query = sql_parse.ParsedQuery(sql)
return [
t
for t in superset_query.tables
if not self._datasource_access_by_fullname(database, t, schema)
]
return {
table
for table in query.tables
if not self.can_access_datasource(database, table, schema)
}
def get_public_role(self) -> Optional[Any]: # Optional[self.role_model]
from superset import conf
@ -493,7 +446,7 @@ class SupersetSecurityManager(SecurityManager):
.filter(or_(SqlaTable.perm.in_(perms)))
.distinct()
)
accessible_schemas.update([t.schema for t in tables])
accessible_schemas.update([table.schema for table in tables])
return [s for s in schemas if s in accessible_schemas]

View File

@ -16,8 +16,10 @@
# under the License.
import logging
from typing import List, Optional, Set
from urllib import parse
import sqlparse
from dataclasses import dataclass
from sqlparse.sql import (
Function,
Identifier,
@ -57,10 +59,32 @@ def _extract_limit_from_query(statement: TokenList) -> Optional[int]:
return None
@dataclass(eq=True, frozen=True)
class Table: # pylint: disable=too-few-public-methods
"""
A fully qualified SQL table conforming to [[catalog.]schema.]table.
"""
table: str
schema: Optional[str] = None
catalog: Optional[str] = None
def __str__(self) -> str:
"""
Return the fully qualified SQL table name.
"""
return ".".join(
parse.quote(part, safe="").replace(".", "%2E")
for part in [self.catalog, self.schema, self.table]
if part
)
class ParsedQuery:
def __init__(self, sql_statement: str):
self.sql: str = sql_statement
self._table_names: Set[str] = set()
self._tables: Set[Table] = set()
self._alias_names: Set[str] = set()
self._limit: Optional[int] = None
@ -70,12 +94,15 @@ class ParsedQuery:
self._limit = _extract_limit_from_query(statement)
@property
def tables(self) -> Set[str]:
if not self._table_names:
def tables(self) -> Set[Table]:
if not self._tables:
for statement in self._parsed:
self.__extract_from_token(statement)
self._table_names = self._table_names - self._alias_names
return self._table_names
self._extract_from_token(statement)
self._tables = {
table for table in self._tables if str(table) not in self._alias_names
}
return self._tables
@property
def limit(self) -> Optional[int]:
@ -105,13 +132,13 @@ class ParsedQuery:
return statements
@staticmethod
def __get_full_name(tlist: TokenList) -> Optional[str]:
def _get_table(tlist: TokenList) -> Optional[Table]:
"""
Return the full unquoted table name if valid, i.e., conforms to the following
[[cluster.]schema.]table construct.
Return the table if valid, i.e., conforms to the [[catalog.]schema.]table
construct.
:param tlist: The SQL tokens
:returns: The valid full table name
:returns: The table if the name conforms
"""
# Strip the alias if present.
@ -127,18 +154,18 @@ class ParsedQuery:
if (
len(tokens) in (1, 3, 5)
and all(imt(token, t=[Name, String]) for token in tokens[0::2])
and all(imt(token, t=[Name, String]) for token in tokens[::2])
and all(imt(token, m=(Punctuation, ".")) for token in tokens[1::2])
):
return ".".join([remove_quotes(token.value) for token in tokens[0::2]])
return Table(*[remove_quotes(token.value) for token in tokens[::-2]])
return None
@staticmethod
def __is_identifier(token: Token) -> bool:
def _is_identifier(token: Token) -> bool:
return isinstance(token, (IdentifierList, Identifier))
def __process_tokenlist(self, token_list: TokenList):
def _process_tokenlist(self, token_list: TokenList):
"""
Add table names to table set
@ -146,9 +173,9 @@ class ParsedQuery:
"""
# exclude subselects
if "(" not in str(token_list):
table_name = self.__get_full_name(token_list)
if table_name and not table_name.startswith(CTE_PREFIX):
self._table_names.add(table_name)
table = self._get_table(token_list)
if table and not table.table.startswith(CTE_PREFIX):
self._tables.add(table)
return
# store aliases
@ -158,7 +185,7 @@ class ParsedQuery:
# some aliases are not parsed properly
if token_list.tokens[0].ttype == Name:
self._alias_names.add(token_list.tokens[0].value)
self.__extract_from_token(token_list)
self._extract_from_token(token_list)
def as_create_table(
self,
@ -184,9 +211,9 @@ class ParsedQuery:
exec_sql += f"CREATE TABLE {full_table_name} AS \n{sql}"
return exec_sql
def __extract_from_token(self, token: Token): # pylint: disable=too-many-branches
def _extract_from_token(self, token: Token): # pylint: disable=too-many-branches
"""
Populate self._table_names from token
Populate self._tables from token
:param token: instance of Token or child class, e.g. TokenList, to be processed
"""
@ -196,8 +223,8 @@ class ParsedQuery:
table_name_preceding_token = False
for item in token.tokens:
if item.is_group and not self.__is_identifier(item):
self.__extract_from_token(item)
if item.is_group and not self._is_identifier(item):
self._extract_from_token(item)
if item.ttype in Keyword and (
item.normalized in PRECEDES_TABLE_NAME
@ -212,15 +239,15 @@ class ParsedQuery:
if table_name_preceding_token:
if isinstance(item, Identifier):
self.__process_tokenlist(item)
self._process_tokenlist(item)
elif isinstance(item, IdentifierList):
for token2 in item.get_identifiers():
if isinstance(token2, TokenList):
self.__process_tokenlist(token2)
self._process_tokenlist(token2)
elif isinstance(item, IdentifierList):
for token2 in item.tokens:
if not self.__is_identifier(token2):
self.__extract_from_token(item)
if not self._is_identifier(token2):
self._extract_from_token(item)
def set_or_update_query_limit(self, new_limit: int) -> str:
"""Returns the query with the specified limit.

View File

@ -85,7 +85,7 @@ from superset.security.analytics_db_safety import (
check_sqlalchemy_uri,
DBSecurityException,
)
from superset.sql_parse import ParsedQuery
from superset.sql_parse import ParsedQuery, Table
from superset.sql_validators import get_validator_by_name
from superset.utils import core as utils, dashboard_import_export
from superset.utils.dashboard_filter_scopes_converter import copy_filter_scopes
@ -2083,7 +2083,9 @@ class Superset(BaseSupersetView):
schema = utils.parse_js_uri_path_item(schema, eval_undefined=True)
table_name = utils.parse_js_uri_path_item(table_name)
# Check that the user can access the datasource
if not self.appbuilder.sm.can_access_datasource(database, table_name, schema):
if not self.appbuilder.sm.can_access_datasource(
database, Table(table_name, schema), schema
):
stats_logger.incr(
f"deprecated.{self.__class__.__name__}.select_star.permission_denied"
)

View File

@ -22,6 +22,7 @@ from flask import g
from flask_babel import lazy_gettext as _
from superset.models.core import Database
from superset.sql_parse import Table
from superset.utils.core import parse_js_uri_path_item
logger = logging.getLogger(__name__)
@ -45,7 +46,7 @@ def check_datasource_access(f):
return self.response_404()
# Check that the user can access the datasource
if not self.appbuilder.sm.can_access_datasource(
database, table_name_parsed, schema_name_parsed
database, Table(table_name_parsed, schema_name_parsed), schema_name_parsed
):
self.stats_logger.incr(
f"permisssion_denied_{self.__class__.__name__}.select_star"

View File

@ -16,90 +16,102 @@
# under the License.
import unittest
from superset import sql_parse
from superset.sql_parse import ParsedQuery, Table
class SupersetTestCase(unittest.TestCase):
def extract_tables(self, query):
sq = sql_parse.ParsedQuery(query)
return sq.tables
return ParsedQuery(query).tables
def test_table(self):
self.assertEqual(str(Table("tbname")), "tbname")
self.assertEqual(str(Table("tbname", "schemaname")), "schemaname.tbname")
self.assertEqual(
str(Table("tbname", "schemaname", "catalogname")),
"catalogname.schemaname.tbname",
)
self.assertEqual(
str(Table("tb.name", "schema/name", "catalog\name")),
"catalog%0Aame.schema%2Fname.tb%2Ename",
)
def test_simple_select(self):
query = "SELECT * FROM tbname"
self.assertEqual({"tbname"}, self.extract_tables(query))
self.assertEqual({Table("tbname")}, self.extract_tables(query))
query = "SELECT * FROM tbname foo"
self.assertEqual({"tbname"}, self.extract_tables(query))
self.assertEqual({Table("tbname")}, self.extract_tables(query))
query = "SELECT * FROM tbname AS foo"
self.assertEqual({"tbname"}, self.extract_tables(query))
self.assertEqual({Table("tbname")}, self.extract_tables(query))
# underscores
query = "SELECT * FROM tb_name"
self.assertEqual({"tb_name"}, self.extract_tables(query))
self.assertEqual({Table("tb_name")}, self.extract_tables(query))
# quotes
query = 'SELECT * FROM "tbname"'
self.assertEqual({"tbname"}, self.extract_tables(query))
self.assertEqual({Table("tbname")}, self.extract_tables(query))
# unicode encoding
query = 'SELECT * FROM "tb_name" WHERE city = "Lübeck"'
self.assertEqual({"tb_name"}, self.extract_tables(query))
self.assertEqual({Table("tb_name")}, self.extract_tables(query))
# schema
self.assertEqual(
{"schemaname.tbname"},
{Table("tbname", "schemaname")},
self.extract_tables("SELECT * FROM schemaname.tbname"),
)
self.assertEqual(
{"schemaname.tbname"},
{Table("tbname", "schemaname")},
self.extract_tables('SELECT * FROM "schemaname"."tbname"'),
)
self.assertEqual(
{"schemaname.tbname"},
{Table("tbname", "schemaname")},
self.extract_tables("SELECT * FROM schemaname.tbname foo"),
)
self.assertEqual(
{"schemaname.tbname"},
{Table("tbname", "schemaname")},
self.extract_tables("SELECT * FROM schemaname.tbname AS foo"),
)
# cluster
self.assertEqual(
{"clustername.schemaname.tbname"},
self.extract_tables("SELECT * FROM clustername.schemaname.tbname"),
{Table("tbname", "schemaname", "catalogname")},
self.extract_tables("SELECT * FROM catalogname.schemaname.tbname"),
)
# Ill-defined cluster/schema/table.
self.assertEqual(set(), self.extract_tables("SELECT * FROM schemaname."))
self.assertEqual(
set(), self.extract_tables("SELECT * FROM clustername.schemaname.")
set(), self.extract_tables("SELECT * FROM catalogname.schemaname.")
)
self.assertEqual(set(), self.extract_tables("SELECT * FROM clustername.."))
self.assertEqual(set(), self.extract_tables("SELECT * FROM catalogname.."))
self.assertEqual(
set(), self.extract_tables("SELECT * FROM clustername..tbname")
set(), self.extract_tables("SELECT * FROM catalogname..tbname")
)
# quotes
query = "SELECT field1, field2 FROM tb_name"
self.assertEqual({"tb_name"}, self.extract_tables(query))
self.assertEqual({Table("tb_name")}, self.extract_tables(query))
query = "SELECT t1.f1, t2.f2 FROM t1, t2"
self.assertEqual({"t1", "t2"}, self.extract_tables(query))
self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
def test_select_named_table(self):
query = "SELECT a.date, a.field FROM left_table a LIMIT 10"
self.assertEqual({"left_table"}, self.extract_tables(query))
self.assertEqual({Table("left_table")}, self.extract_tables(query))
def test_reverse_select(self):
query = "FROM t1 SELECT field"
self.assertEqual({"t1"}, self.extract_tables(query))
self.assertEqual({Table("t1")}, self.extract_tables(query))
def test_subselect(self):
query = """
@ -111,7 +123,9 @@ class SupersetTestCase(unittest.TestCase):
) sub, s2.t2
WHERE sub.resolution = 'NONE'
"""
self.assertEqual({"s1.t1", "s2.t2"}, self.extract_tables(query))
self.assertEqual(
{Table("t1", "s1"), Table("t2", "s2")}, self.extract_tables(query)
)
query = """
SELECT sub.*
@ -122,7 +136,7 @@ class SupersetTestCase(unittest.TestCase):
) sub
WHERE sub.resolution = 'NONE'
"""
self.assertEqual({"s1.t1"}, self.extract_tables(query))
self.assertEqual({Table("t1", "s1")}, self.extract_tables(query))
query = """
SELECT * FROM t1
@ -133,21 +147,24 @@ class SupersetTestCase(unittest.TestCase):
WHERE ROW(5*t2.s1,77)=
(SELECT 50,11*s1 FROM t4)));
"""
self.assertEqual({"t1", "t2", "t3", "t4"}, self.extract_tables(query))
self.assertEqual(
{Table("t1"), Table("t2"), Table("t3"), Table("t4")},
self.extract_tables(query),
)
def test_select_in_expression(self):
query = "SELECT f1, (SELECT count(1) FROM t2) FROM t1"
self.assertEqual({"t1", "t2"}, self.extract_tables(query))
self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
def test_union(self):
query = "SELECT * FROM t1 UNION SELECT * FROM t2"
self.assertEqual({"t1", "t2"}, self.extract_tables(query))
self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
query = "SELECT * FROM t1 UNION ALL SELECT * FROM t2"
self.assertEqual({"t1", "t2"}, self.extract_tables(query))
self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
query = "SELECT * FROM t1 INTERSECT ALL SELECT * FROM t2"
self.assertEqual({"t1", "t2"}, self.extract_tables(query))
self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
def test_select_from_values(self):
query = "SELECT * FROM VALUES (13, 42)"
@ -158,25 +175,25 @@ class SupersetTestCase(unittest.TestCase):
SELECT ARRAY[1, 2, 3] AS my_array
FROM t1 LIMIT 10
"""
self.assertEqual({"t1"}, self.extract_tables(query))
self.assertEqual({Table("t1")}, self.extract_tables(query))
def test_select_if(self):
query = """
SELECT IF(CARDINALITY(my_array) >= 3, my_array[3], NULL)
FROM t1 LIMIT 10
"""
self.assertEqual({"t1"}, self.extract_tables(query))
self.assertEqual({Table("t1")}, self.extract_tables(query))
# SHOW TABLES ((FROM | IN) qualifiedName)? (LIKE pattern=STRING)?
def test_show_tables(self):
query = "SHOW TABLES FROM s1 like '%order%'"
# TODO: figure out what should code do here
self.assertEqual({"s1"}, self.extract_tables(query))
self.assertEqual({Table("s1")}, self.extract_tables(query))
# SHOW COLUMNS (FROM | IN) qualifiedName
def test_show_columns(self):
query = "SHOW COLUMNS FROM t1"
self.assertEqual({"t1"}, self.extract_tables(query))
self.assertEqual({Table("t1")}, self.extract_tables(query))
def test_where_subquery(self):
query = """
@ -184,25 +201,25 @@ class SupersetTestCase(unittest.TestCase):
FROM t1
WHERE regionkey = (SELECT max(regionkey) FROM t2)
"""
self.assertEqual({"t1", "t2"}, self.extract_tables(query))
self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
query = """
SELECT name
FROM t1
WHERE regionkey IN (SELECT regionkey FROM t2)
"""
self.assertEqual({"t1", "t2"}, self.extract_tables(query))
self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
query = """
SELECT name
FROM t1
WHERE regionkey EXISTS (SELECT regionkey FROM t2)
"""
self.assertEqual({"t1", "t2"}, self.extract_tables(query))
self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
# DESCRIBE | DESC qualifiedName
def test_describe(self):
self.assertEqual({"t1"}, self.extract_tables("DESCRIBE t1"))
self.assertEqual({Table("t1")}, self.extract_tables("DESCRIBE t1"))
# SHOW PARTITIONS FROM qualifiedName (WHERE booleanExpression)?
# (ORDER BY sortItem (',' sortItem)*)? (LIMIT limit=(INTEGER_VALUE | ALL))?
@ -211,11 +228,11 @@ class SupersetTestCase(unittest.TestCase):
SHOW PARTITIONS FROM orders
WHERE ds >= '2013-01-01' ORDER BY ds DESC;
"""
self.assertEqual({"orders"}, self.extract_tables(query))
self.assertEqual({Table("orders")}, self.extract_tables(query))
def test_join(self):
query = "SELECT t1.*, t2.* FROM t1 JOIN t2 ON t1.a = t2.a;"
self.assertEqual({"t1", "t2"}, self.extract_tables(query))
self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
# subquery + join
query = """
@ -229,7 +246,9 @@ class SupersetTestCase(unittest.TestCase):
) b
ON a.date = b.date
"""
self.assertEqual({"left_table", "right_table"}, self.extract_tables(query))
self.assertEqual(
{Table("left_table"), Table("right_table")}, self.extract_tables(query)
)
query = """
SELECT a.date, b.name FROM
@ -242,7 +261,9 @@ class SupersetTestCase(unittest.TestCase):
) b
ON a.date = b.date
"""
self.assertEqual({"left_table", "right_table"}, self.extract_tables(query))
self.assertEqual(
{Table("left_table"), Table("right_table")}, self.extract_tables(query)
)
query = """
SELECT a.date, b.name FROM
@ -255,7 +276,9 @@ class SupersetTestCase(unittest.TestCase):
) b
ON a.date = b.date
"""
self.assertEqual({"left_table", "right_table"}, self.extract_tables(query))
self.assertEqual(
{Table("left_table"), Table("right_table")}, self.extract_tables(query)
)
query = """
SELECT a.date, b.name FROM
@ -268,7 +291,9 @@ class SupersetTestCase(unittest.TestCase):
) b
ON a.date = b.date
"""
self.assertEqual({"left_table", "right_table"}, self.extract_tables(query))
self.assertEqual(
{Table("left_table"), Table("right_table")}, self.extract_tables(query)
)
# TODO: add SEMI join support, SQL Parse does not handle it.
# query = """
@ -296,13 +321,16 @@ class SupersetTestCase(unittest.TestCase):
WHERE ROW(5*t3.s1,77)=
(SELECT 50,11*s1 FROM t4)));
"""
self.assertEqual({"t1", "t3", "t4", "t6"}, self.extract_tables(query))
self.assertEqual(
{Table("t1"), Table("t3"), Table("t4"), Table("t6")},
self.extract_tables(query),
)
query = """
SELECT * FROM (SELECT * FROM (SELECT * FROM (SELECT * FROM EmployeeS)
AS S1) AS S2) AS S3;
"""
self.assertEqual({"EmployeeS"}, self.extract_tables(query))
self.assertEqual({Table("EmployeeS")}, self.extract_tables(query))
def test_with(self):
query = """
@ -312,7 +340,9 @@ class SupersetTestCase(unittest.TestCase):
z AS (SELECT b AS c FROM t3)
SELECT c FROM z;
"""
self.assertEqual({"t1", "t2", "t3"}, self.extract_tables(query))
self.assertEqual(
{Table("t1"), Table("t2"), Table("t3")}, self.extract_tables(query)
)
query = """
WITH
@ -321,7 +351,7 @@ class SupersetTestCase(unittest.TestCase):
z AS (SELECT b AS c FROM y)
SELECT c FROM z;
"""
self.assertEqual({"t1"}, self.extract_tables(query))
self.assertEqual({Table("t1")}, self.extract_tables(query))
def test_reusing_aliases(self):
query = """
@ -329,22 +359,22 @@ class SupersetTestCase(unittest.TestCase):
q2 as ( select key from src where key = '5')
select * from (select key from q1) a;
"""
self.assertEqual({"src"}, self.extract_tables(query))
self.assertEqual({Table("src")}, self.extract_tables(query))
def test_multistatement(self):
query = "SELECT * FROM t1; SELECT * FROM t2"
self.assertEqual({"t1", "t2"}, self.extract_tables(query))
self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
query = "SELECT * FROM t1; SELECT * FROM t2;"
self.assertEqual({"t1", "t2"}, self.extract_tables(query))
self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
def test_update_not_select(self):
sql = sql_parse.ParsedQuery("UPDATE t1 SET col1 = NULL")
sql = ParsedQuery("UPDATE t1 SET col1 = NULL")
self.assertEqual(False, sql.is_select())
self.assertEqual(False, sql.is_readonly())
def test_explain(self):
sql = sql_parse.ParsedQuery("EXPLAIN SELECT 1")
sql = ParsedQuery("EXPLAIN SELECT 1")
self.assertEqual(True, sql.is_explain())
self.assertEqual(False, sql.is_select())
@ -367,7 +397,12 @@ class SupersetTestCase(unittest.TestCase):
ORDER BY "sum__m_example" DESC
LIMIT 10;"""
self.assertEqual(
{"my_l_table", "my_b_table", "my_t_table", "inner_table"},
{
Table("my_l_table"),
Table("my_b_table"),
Table("my_t_table"),
Table("inner_table"),
},
self.extract_tables(query),
)
@ -375,13 +410,19 @@ class SupersetTestCase(unittest.TestCase):
query = """SELECT *
FROM table_a AS a, table_b AS b, table_c as c
WHERE a.id = b.id and b.id = c.id"""
self.assertEqual({"table_a", "table_b", "table_c"}, self.extract_tables(query))
self.assertEqual(
{Table("table_a"), Table("table_b"), Table("table_c")},
self.extract_tables(query),
)
def test_mixed_from_clause(self):
query = """SELECT *
FROM table_a AS a, (select * from table_b) AS b, table_c as c
WHERE a.id = b.id and b.id = c.id"""
self.assertEqual({"table_a", "table_b", "table_c"}, self.extract_tables(query))
self.assertEqual(
{Table("table_a"), Table("table_b"), Table("table_c")},
self.extract_tables(query),
)
def test_nested_selects(self):
query = """
@ -389,13 +430,17 @@ class SupersetTestCase(unittest.TestCase):
from INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
"""
self.assertEqual({"INFORMATION_SCHEMA.COLUMNS"}, self.extract_tables(query))
self.assertEqual(
{Table("COLUMNS", "INFORMATION_SCHEMA")}, self.extract_tables(query)
)
query = """
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME)
from INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_NAME="bi_achivement_daily"),0x7e)));
"""
self.assertEqual({"INFORMATION_SCHEMA.COLUMNS"}, self.extract_tables(query))
self.assertEqual(
{Table("COLUMNS", "INFORMATION_SCHEMA")}, self.extract_tables(query)
)
def test_complex_extract_tables3(self):
query = """SELECT somecol AS somecol
@ -431,7 +476,10 @@ class SupersetTestCase(unittest.TestCase):
WHERE 2=2
GROUP BY last_col
LIMIT 50000;"""
self.assertEqual({"a", "b", "c", "d", "e", "f"}, self.extract_tables(query))
self.assertEqual(
{Table("a"), Table("b"), Table("c"), Table("d"), Table("e"), Table("f")},
self.extract_tables(query),
)
def test_complex_cte_with_prefix(self):
query = """
@ -446,23 +494,23 @@ class SupersetTestCase(unittest.TestCase):
GROUP BY SalesYear, SalesPersonID
ORDER BY SalesPersonID, SalesYear;
"""
self.assertEqual({"SalesOrderHeader"}, self.extract_tables(query))
self.assertEqual({Table("SalesOrderHeader")}, self.extract_tables(query))
def test_get_query_with_new_limit_comment(self):
sql = "SELECT * FROM birth_names -- SOME COMMENT"
parsed = sql_parse.ParsedQuery(sql)
parsed = ParsedQuery(sql)
newsql = parsed.set_or_update_query_limit(1000)
self.assertEqual(newsql, sql + "\nLIMIT 1000")
def test_get_query_with_new_limit_comment_with_limit(self):
sql = "SELECT * FROM birth_names -- SOME COMMENT WITH LIMIT 555"
parsed = sql_parse.ParsedQuery(sql)
parsed = ParsedQuery(sql)
newsql = parsed.set_or_update_query_limit(1000)
self.assertEqual(newsql, sql + "\nLIMIT 1000")
def test_get_query_with_new_limit_lower(self):
sql = "SELECT * FROM birth_names LIMIT 555"
parsed = sql_parse.ParsedQuery(sql)
parsed = ParsedQuery(sql)
newsql = parsed.set_or_update_query_limit(1000)
# not applied as new limit is higher
expected = "SELECT * FROM birth_names LIMIT 555"
@ -470,7 +518,7 @@ class SupersetTestCase(unittest.TestCase):
def test_get_query_with_new_limit_upper(self):
sql = "SELECT * FROM birth_names LIMIT 1555"
parsed = sql_parse.ParsedQuery(sql)
parsed = ParsedQuery(sql)
newsql = parsed.set_or_update_query_limit(1000)
# applied as new limit is lower
expected = "SELECT * FROM birth_names LIMIT 1000"
@ -481,7 +529,7 @@ class SupersetTestCase(unittest.TestCase):
SELECT * FROM birth_names;
SELECT * FROM birth_names LIMIT 1;
"""
parsed = sql_parse.ParsedQuery(multi_sql)
parsed = ParsedQuery(multi_sql)
statements = parsed.get_statements()
self.assertEqual(len(statements), 2)
expected = ["SELECT * FROM birth_names", "SELECT * FROM birth_names LIMIT 1"]
@ -494,7 +542,7 @@ class SupersetTestCase(unittest.TestCase):
SELECT * FROM birth_names;;;
SELECT * FROM birth_names LIMIT 1
"""
parsed = sql_parse.ParsedQuery(multi_sql)
parsed = ParsedQuery(multi_sql)
statements = parsed.get_statements()
self.assertEqual(len(statements), 4)
expected = [
@ -512,4 +560,4 @@ class SupersetTestCase(unittest.TestCase):
match AS (SELECT * FROM f)
SELECT * FROM match
"""
self.assertEqual({"foo"}, self.extract_tables(query))
self.assertEqual({Table("foo")}, self.extract_tables(query))