mirror of https://github.com/apache/superset.git
[sql] Adding lighweight Table class (#9649)
Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
parent
f7f60cc75d
commit
3b0f8e9c8a
|
@ -19,6 +19,7 @@ colorama==0.4.3 # via apache-superset (setup.py), flask-appbuilder
|
|||
contextlib2==0.6.0.post1 # via apache-superset (setup.py)
|
||||
croniter==0.3.31 # via apache-superset (setup.py)
|
||||
cryptography==2.8 # via apache-superset (setup.py)
|
||||
dataclasses==0.6 # via apache-superset (setup.py)
|
||||
decorator==4.4.1 # via retry
|
||||
defusedxml==0.6.0 # via python3-openid
|
||||
flask-appbuilder==2.3.2 # via apache-superset (setup.py)
|
||||
|
|
|
@ -45,7 +45,7 @@ combine_as_imports = true
|
|||
include_trailing_comma = true
|
||||
line_length = 88
|
||||
known_first_party = superset
|
||||
known_third_party =alembic,apispec,backoff,bleach,celery,click,colorama,contextlib2,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parsedatetime,pathlib2,polyline,prison,pyarrow,pyhive,pytz,retry,selenium,setuptools,simplejson,sphinx_rtd_theme,sqlalchemy,sqlalchemy_utils,sqlparse,werkzeug,wtforms,wtforms_json,yaml
|
||||
known_third_party =alembic,apispec,backoff,bleach,celery,click,colorama,contextlib2,croniter,cryptography,dataclasses,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parsedatetime,pathlib2,polyline,prison,pyarrow,pyhive,pytz,retry,selenium,setuptools,simplejson,sphinx_rtd_theme,sqlalchemy,sqlalchemy_utils,sqlparse,werkzeug,wtforms,wtforms_json,yaml
|
||||
multi_line_output = 3
|
||||
order_by_type = false
|
||||
|
||||
|
|
1
setup.py
1
setup.py
|
@ -75,6 +75,7 @@ setup(
|
|||
"contextlib2",
|
||||
"croniter>=0.3.28",
|
||||
"cryptography>=2.4.2",
|
||||
"dataclasses<0.7",
|
||||
"flask>=1.1.0, <2.0.0",
|
||||
"flask-appbuilder>=2.3.2, <2.4.0",
|
||||
"flask-caching",
|
||||
|
|
|
@ -50,6 +50,7 @@ if TYPE_CHECKING:
|
|||
from superset.common.query_context import QueryContext
|
||||
from superset.connectors.base.models import BaseDatasource
|
||||
from superset.models.core import Database
|
||||
from superset.sql_parse import Table
|
||||
from superset.viz import BaseViz
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -290,26 +291,23 @@ class SupersetSecurityManager(SecurityManager):
|
|||
|
||||
return conf.get("PERMISSION_INSTRUCTIONS_LINK")
|
||||
|
||||
def get_table_access_error_msg(self, tables: List[str]) -> str:
|
||||
def get_table_access_error_msg(self, tables: Set["Table"]) -> str:
|
||||
"""
|
||||
Return the error message for the denied SQL tables.
|
||||
|
||||
Note the table names conform to the [[cluster.]schema.]table construct.
|
||||
|
||||
:param tables: The list of denied SQL table names
|
||||
:param tables: The set of denied SQL tables
|
||||
:returns: The error message
|
||||
"""
|
||||
quoted_tables = [f"`{t}`" for t in tables]
|
||||
|
||||
quoted_tables = [f"`{table}`" for table in tables]
|
||||
return f"""You need access to the following tables: {", ".join(quoted_tables)},
|
||||
`all_database_access` or `all_datasource_access` permission"""
|
||||
|
||||
def get_table_access_link(self, tables: List[str]) -> Optional[str]:
|
||||
def get_table_access_link(self, tables: Set["Table"]) -> Optional[str]:
|
||||
"""
|
||||
Return the access link for the denied SQL tables.
|
||||
|
||||
Note the table names conform to the [[cluster.]schema.]table construct.
|
||||
|
||||
:param tables: The list of denied SQL table names
|
||||
:param tables: The set of denied SQL tables
|
||||
:returns: The access URL
|
||||
"""
|
||||
|
||||
|
@ -318,23 +316,19 @@ class SupersetSecurityManager(SecurityManager):
|
|||
return conf.get("PERMISSION_INSTRUCTIONS_LINK")
|
||||
|
||||
def can_access_datasource(
|
||||
self, database: "Database", table_name: str, schema: Optional[str] = None
|
||||
) -> bool:
|
||||
return self._datasource_access_by_name(database, table_name, schema=schema)
|
||||
|
||||
def _datasource_access_by_name(
|
||||
self, database: "Database", table_name: str, schema: Optional[str] = None
|
||||
self, database: "Database", table: "Table", schema: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Return True if the user can access the SQL table, False otherwise.
|
||||
|
||||
:param database: The SQL database
|
||||
:param table_name: The SQL table name
|
||||
:param schema: The Superset schema
|
||||
:param table: The SQL table
|
||||
:param schema: The fallback SQL schema if not present in the table
|
||||
:returns: Whether the use can access the SQL table
|
||||
"""
|
||||
|
||||
from superset import db
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
|
||||
if self.database_access(database) or self.all_datasource_access():
|
||||
return True
|
||||
|
@ -343,74 +337,33 @@ class SupersetSecurityManager(SecurityManager):
|
|||
if schema_perm and self.can_access("schema_access", schema_perm):
|
||||
return True
|
||||
|
||||
datasources = ConnectorRegistry.query_datasources_by_name(
|
||||
db.session, database, table_name, schema=schema
|
||||
datasources = SqlaTable.query_datasources_by_name(
|
||||
db.session, database, table.table, schema=table.schema or schema
|
||||
)
|
||||
for datasource in datasources:
|
||||
if self.can_access("datasource_access", datasource.perm):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _get_schema_and_table(
|
||||
self, table_in_query: str, schema: str
|
||||
) -> Tuple[str, str]:
|
||||
def rejected_tables(
|
||||
self, sql: str, database: "Database", schema: str
|
||||
) -> Set["Table"]:
|
||||
"""
|
||||
Return the SQL schema/table tuple associated with the table extracted from the
|
||||
SQL query.
|
||||
|
||||
Note the table name conforms to the [[cluster.]schema.]table construct.
|
||||
|
||||
:param table_in_query: The SQL table name
|
||||
:param schema: The fallback SQL schema if not present in the table name
|
||||
:returns: The SQL schema/table tuple
|
||||
"""
|
||||
|
||||
table_name_pieces = table_in_query.split(".")
|
||||
if len(table_name_pieces) == 3:
|
||||
return tuple(table_name_pieces[1:]) # type: ignore
|
||||
elif len(table_name_pieces) == 2:
|
||||
return tuple(table_name_pieces) # type: ignore
|
||||
return (schema, table_name_pieces[0])
|
||||
|
||||
def _datasource_access_by_fullname(
|
||||
self, database: "Database", table_in_query: str, schema: str
|
||||
) -> bool:
|
||||
"""
|
||||
Return True if the user can access the table extracted from the SQL query, False
|
||||
otherwise.
|
||||
|
||||
Note the table name conforms to the [[cluster.]schema.]table construct.
|
||||
|
||||
:param database: The Superset database
|
||||
:param table_in_query: The SQL table name
|
||||
:param schema: The fallback SQL schema, i.e., if not present in the table name
|
||||
:returns: Whether the user can access the SQL table
|
||||
"""
|
||||
|
||||
table_schema, table_name = self._get_schema_and_table(table_in_query, schema)
|
||||
return self._datasource_access_by_name(
|
||||
database, table_name, schema=table_schema
|
||||
)
|
||||
|
||||
def rejected_tables(self, sql: str, database: "Database", schema: str) -> List[str]:
|
||||
"""
|
||||
Return the list of rejected SQL table names.
|
||||
|
||||
Note the rejected table names conform to the [[cluster.]schema.]table construct.
|
||||
Return the list of rejected SQL tables.
|
||||
|
||||
:param sql: The SQL statement
|
||||
:param database: The SQL database
|
||||
:param schema: The SQL database schema
|
||||
:returns: The rejected table names
|
||||
:returns: The rejected tables
|
||||
"""
|
||||
|
||||
superset_query = sql_parse.ParsedQuery(sql)
|
||||
query = sql_parse.ParsedQuery(sql)
|
||||
|
||||
return [
|
||||
t
|
||||
for t in superset_query.tables
|
||||
if not self._datasource_access_by_fullname(database, t, schema)
|
||||
]
|
||||
return {
|
||||
table
|
||||
for table in query.tables
|
||||
if not self.can_access_datasource(database, table, schema)
|
||||
}
|
||||
|
||||
def get_public_role(self) -> Optional[Any]: # Optional[self.role_model]
|
||||
from superset import conf
|
||||
|
@ -493,7 +446,7 @@ class SupersetSecurityManager(SecurityManager):
|
|||
.filter(or_(SqlaTable.perm.in_(perms)))
|
||||
.distinct()
|
||||
)
|
||||
accessible_schemas.update([t.schema for t in tables])
|
||||
accessible_schemas.update([table.schema for table in tables])
|
||||
|
||||
return [s for s in schemas if s in accessible_schemas]
|
||||
|
||||
|
|
|
@ -16,8 +16,10 @@
|
|||
# under the License.
|
||||
import logging
|
||||
from typing import List, Optional, Set
|
||||
from urllib import parse
|
||||
|
||||
import sqlparse
|
||||
from dataclasses import dataclass
|
||||
from sqlparse.sql import (
|
||||
Function,
|
||||
Identifier,
|
||||
|
@ -57,10 +59,32 @@ def _extract_limit_from_query(statement: TokenList) -> Optional[int]:
|
|||
return None
|
||||
|
||||
|
||||
@dataclass(eq=True, frozen=True)
|
||||
class Table: # pylint: disable=too-few-public-methods
|
||||
"""
|
||||
A fully qualified SQL table conforming to [[catalog.]schema.]table.
|
||||
"""
|
||||
|
||||
table: str
|
||||
schema: Optional[str] = None
|
||||
catalog: Optional[str] = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
Return the fully qualified SQL table name.
|
||||
"""
|
||||
|
||||
return ".".join(
|
||||
parse.quote(part, safe="").replace(".", "%2E")
|
||||
for part in [self.catalog, self.schema, self.table]
|
||||
if part
|
||||
)
|
||||
|
||||
|
||||
class ParsedQuery:
|
||||
def __init__(self, sql_statement: str):
|
||||
self.sql: str = sql_statement
|
||||
self._table_names: Set[str] = set()
|
||||
self._tables: Set[Table] = set()
|
||||
self._alias_names: Set[str] = set()
|
||||
self._limit: Optional[int] = None
|
||||
|
||||
|
@ -70,12 +94,15 @@ class ParsedQuery:
|
|||
self._limit = _extract_limit_from_query(statement)
|
||||
|
||||
@property
|
||||
def tables(self) -> Set[str]:
|
||||
if not self._table_names:
|
||||
def tables(self) -> Set[Table]:
|
||||
if not self._tables:
|
||||
for statement in self._parsed:
|
||||
self.__extract_from_token(statement)
|
||||
self._table_names = self._table_names - self._alias_names
|
||||
return self._table_names
|
||||
self._extract_from_token(statement)
|
||||
|
||||
self._tables = {
|
||||
table for table in self._tables if str(table) not in self._alias_names
|
||||
}
|
||||
return self._tables
|
||||
|
||||
@property
|
||||
def limit(self) -> Optional[int]:
|
||||
|
@ -105,13 +132,13 @@ class ParsedQuery:
|
|||
return statements
|
||||
|
||||
@staticmethod
|
||||
def __get_full_name(tlist: TokenList) -> Optional[str]:
|
||||
def _get_table(tlist: TokenList) -> Optional[Table]:
|
||||
"""
|
||||
Return the full unquoted table name if valid, i.e., conforms to the following
|
||||
[[cluster.]schema.]table construct.
|
||||
Return the table if valid, i.e., conforms to the [[catalog.]schema.]table
|
||||
construct.
|
||||
|
||||
:param tlist: The SQL tokens
|
||||
:returns: The valid full table name
|
||||
:returns: The table if the name conforms
|
||||
"""
|
||||
|
||||
# Strip the alias if present.
|
||||
|
@ -127,18 +154,18 @@ class ParsedQuery:
|
|||
|
||||
if (
|
||||
len(tokens) in (1, 3, 5)
|
||||
and all(imt(token, t=[Name, String]) for token in tokens[0::2])
|
||||
and all(imt(token, t=[Name, String]) for token in tokens[::2])
|
||||
and all(imt(token, m=(Punctuation, ".")) for token in tokens[1::2])
|
||||
):
|
||||
return ".".join([remove_quotes(token.value) for token in tokens[0::2]])
|
||||
return Table(*[remove_quotes(token.value) for token in tokens[::-2]])
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def __is_identifier(token: Token) -> bool:
|
||||
def _is_identifier(token: Token) -> bool:
|
||||
return isinstance(token, (IdentifierList, Identifier))
|
||||
|
||||
def __process_tokenlist(self, token_list: TokenList):
|
||||
def _process_tokenlist(self, token_list: TokenList):
|
||||
"""
|
||||
Add table names to table set
|
||||
|
||||
|
@ -146,9 +173,9 @@ class ParsedQuery:
|
|||
"""
|
||||
# exclude subselects
|
||||
if "(" not in str(token_list):
|
||||
table_name = self.__get_full_name(token_list)
|
||||
if table_name and not table_name.startswith(CTE_PREFIX):
|
||||
self._table_names.add(table_name)
|
||||
table = self._get_table(token_list)
|
||||
if table and not table.table.startswith(CTE_PREFIX):
|
||||
self._tables.add(table)
|
||||
return
|
||||
|
||||
# store aliases
|
||||
|
@ -158,7 +185,7 @@ class ParsedQuery:
|
|||
# some aliases are not parsed properly
|
||||
if token_list.tokens[0].ttype == Name:
|
||||
self._alias_names.add(token_list.tokens[0].value)
|
||||
self.__extract_from_token(token_list)
|
||||
self._extract_from_token(token_list)
|
||||
|
||||
def as_create_table(
|
||||
self,
|
||||
|
@ -184,9 +211,9 @@ class ParsedQuery:
|
|||
exec_sql += f"CREATE TABLE {full_table_name} AS \n{sql}"
|
||||
return exec_sql
|
||||
|
||||
def __extract_from_token(self, token: Token): # pylint: disable=too-many-branches
|
||||
def _extract_from_token(self, token: Token): # pylint: disable=too-many-branches
|
||||
"""
|
||||
Populate self._table_names from token
|
||||
Populate self._tables from token
|
||||
|
||||
:param token: instance of Token or child class, e.g. TokenList, to be processed
|
||||
"""
|
||||
|
@ -196,8 +223,8 @@ class ParsedQuery:
|
|||
table_name_preceding_token = False
|
||||
|
||||
for item in token.tokens:
|
||||
if item.is_group and not self.__is_identifier(item):
|
||||
self.__extract_from_token(item)
|
||||
if item.is_group and not self._is_identifier(item):
|
||||
self._extract_from_token(item)
|
||||
|
||||
if item.ttype in Keyword and (
|
||||
item.normalized in PRECEDES_TABLE_NAME
|
||||
|
@ -212,15 +239,15 @@ class ParsedQuery:
|
|||
|
||||
if table_name_preceding_token:
|
||||
if isinstance(item, Identifier):
|
||||
self.__process_tokenlist(item)
|
||||
self._process_tokenlist(item)
|
||||
elif isinstance(item, IdentifierList):
|
||||
for token2 in item.get_identifiers():
|
||||
if isinstance(token2, TokenList):
|
||||
self.__process_tokenlist(token2)
|
||||
self._process_tokenlist(token2)
|
||||
elif isinstance(item, IdentifierList):
|
||||
for token2 in item.tokens:
|
||||
if not self.__is_identifier(token2):
|
||||
self.__extract_from_token(item)
|
||||
if not self._is_identifier(token2):
|
||||
self._extract_from_token(item)
|
||||
|
||||
def set_or_update_query_limit(self, new_limit: int) -> str:
|
||||
"""Returns the query with the specified limit.
|
||||
|
|
|
@ -85,7 +85,7 @@ from superset.security.analytics_db_safety import (
|
|||
check_sqlalchemy_uri,
|
||||
DBSecurityException,
|
||||
)
|
||||
from superset.sql_parse import ParsedQuery
|
||||
from superset.sql_parse import ParsedQuery, Table
|
||||
from superset.sql_validators import get_validator_by_name
|
||||
from superset.utils import core as utils, dashboard_import_export
|
||||
from superset.utils.dashboard_filter_scopes_converter import copy_filter_scopes
|
||||
|
@ -2083,7 +2083,9 @@ class Superset(BaseSupersetView):
|
|||
schema = utils.parse_js_uri_path_item(schema, eval_undefined=True)
|
||||
table_name = utils.parse_js_uri_path_item(table_name)
|
||||
# Check that the user can access the datasource
|
||||
if not self.appbuilder.sm.can_access_datasource(database, table_name, schema):
|
||||
if not self.appbuilder.sm.can_access_datasource(
|
||||
database, Table(table_name, schema), schema
|
||||
):
|
||||
stats_logger.incr(
|
||||
f"deprecated.{self.__class__.__name__}.select_star.permission_denied"
|
||||
)
|
||||
|
|
|
@ -22,6 +22,7 @@ from flask import g
|
|||
from flask_babel import lazy_gettext as _
|
||||
|
||||
from superset.models.core import Database
|
||||
from superset.sql_parse import Table
|
||||
from superset.utils.core import parse_js_uri_path_item
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -45,7 +46,7 @@ def check_datasource_access(f):
|
|||
return self.response_404()
|
||||
# Check that the user can access the datasource
|
||||
if not self.appbuilder.sm.can_access_datasource(
|
||||
database, table_name_parsed, schema_name_parsed
|
||||
database, Table(table_name_parsed, schema_name_parsed), schema_name_parsed
|
||||
):
|
||||
self.stats_logger.incr(
|
||||
f"permisssion_denied_{self.__class__.__name__}.select_star"
|
||||
|
|
|
@ -16,90 +16,102 @@
|
|||
# under the License.
|
||||
import unittest
|
||||
|
||||
from superset import sql_parse
|
||||
from superset.sql_parse import ParsedQuery, Table
|
||||
|
||||
|
||||
class SupersetTestCase(unittest.TestCase):
|
||||
def extract_tables(self, query):
|
||||
sq = sql_parse.ParsedQuery(query)
|
||||
return sq.tables
|
||||
return ParsedQuery(query).tables
|
||||
|
||||
def test_table(self):
|
||||
self.assertEqual(str(Table("tbname")), "tbname")
|
||||
self.assertEqual(str(Table("tbname", "schemaname")), "schemaname.tbname")
|
||||
|
||||
self.assertEqual(
|
||||
str(Table("tbname", "schemaname", "catalogname")),
|
||||
"catalogname.schemaname.tbname",
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
str(Table("tb.name", "schema/name", "catalog\name")),
|
||||
"catalog%0Aame.schema%2Fname.tb%2Ename",
|
||||
)
|
||||
|
||||
def test_simple_select(self):
|
||||
query = "SELECT * FROM tbname"
|
||||
self.assertEqual({"tbname"}, self.extract_tables(query))
|
||||
self.assertEqual({Table("tbname")}, self.extract_tables(query))
|
||||
|
||||
query = "SELECT * FROM tbname foo"
|
||||
self.assertEqual({"tbname"}, self.extract_tables(query))
|
||||
self.assertEqual({Table("tbname")}, self.extract_tables(query))
|
||||
|
||||
query = "SELECT * FROM tbname AS foo"
|
||||
self.assertEqual({"tbname"}, self.extract_tables(query))
|
||||
self.assertEqual({Table("tbname")}, self.extract_tables(query))
|
||||
|
||||
# underscores
|
||||
query = "SELECT * FROM tb_name"
|
||||
self.assertEqual({"tb_name"}, self.extract_tables(query))
|
||||
self.assertEqual({Table("tb_name")}, self.extract_tables(query))
|
||||
|
||||
# quotes
|
||||
query = 'SELECT * FROM "tbname"'
|
||||
self.assertEqual({"tbname"}, self.extract_tables(query))
|
||||
self.assertEqual({Table("tbname")}, self.extract_tables(query))
|
||||
|
||||
# unicode encoding
|
||||
query = 'SELECT * FROM "tb_name" WHERE city = "Lübeck"'
|
||||
self.assertEqual({"tb_name"}, self.extract_tables(query))
|
||||
self.assertEqual({Table("tb_name")}, self.extract_tables(query))
|
||||
|
||||
# schema
|
||||
self.assertEqual(
|
||||
{"schemaname.tbname"},
|
||||
{Table("tbname", "schemaname")},
|
||||
self.extract_tables("SELECT * FROM schemaname.tbname"),
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
{"schemaname.tbname"},
|
||||
{Table("tbname", "schemaname")},
|
||||
self.extract_tables('SELECT * FROM "schemaname"."tbname"'),
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
{"schemaname.tbname"},
|
||||
{Table("tbname", "schemaname")},
|
||||
self.extract_tables("SELECT * FROM schemaname.tbname foo"),
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
{"schemaname.tbname"},
|
||||
{Table("tbname", "schemaname")},
|
||||
self.extract_tables("SELECT * FROM schemaname.tbname AS foo"),
|
||||
)
|
||||
|
||||
# cluster
|
||||
self.assertEqual(
|
||||
{"clustername.schemaname.tbname"},
|
||||
self.extract_tables("SELECT * FROM clustername.schemaname.tbname"),
|
||||
{Table("tbname", "schemaname", "catalogname")},
|
||||
self.extract_tables("SELECT * FROM catalogname.schemaname.tbname"),
|
||||
)
|
||||
|
||||
# Ill-defined cluster/schema/table.
|
||||
self.assertEqual(set(), self.extract_tables("SELECT * FROM schemaname."))
|
||||
|
||||
self.assertEqual(
|
||||
set(), self.extract_tables("SELECT * FROM clustername.schemaname.")
|
||||
set(), self.extract_tables("SELECT * FROM catalogname.schemaname.")
|
||||
)
|
||||
|
||||
self.assertEqual(set(), self.extract_tables("SELECT * FROM clustername.."))
|
||||
self.assertEqual(set(), self.extract_tables("SELECT * FROM catalogname.."))
|
||||
|
||||
self.assertEqual(
|
||||
set(), self.extract_tables("SELECT * FROM clustername..tbname")
|
||||
set(), self.extract_tables("SELECT * FROM catalogname..tbname")
|
||||
)
|
||||
|
||||
# quotes
|
||||
query = "SELECT field1, field2 FROM tb_name"
|
||||
self.assertEqual({"tb_name"}, self.extract_tables(query))
|
||||
self.assertEqual({Table("tb_name")}, self.extract_tables(query))
|
||||
|
||||
query = "SELECT t1.f1, t2.f2 FROM t1, t2"
|
||||
self.assertEqual({"t1", "t2"}, self.extract_tables(query))
|
||||
self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
|
||||
|
||||
def test_select_named_table(self):
|
||||
query = "SELECT a.date, a.field FROM left_table a LIMIT 10"
|
||||
self.assertEqual({"left_table"}, self.extract_tables(query))
|
||||
self.assertEqual({Table("left_table")}, self.extract_tables(query))
|
||||
|
||||
def test_reverse_select(self):
|
||||
query = "FROM t1 SELECT field"
|
||||
self.assertEqual({"t1"}, self.extract_tables(query))
|
||||
self.assertEqual({Table("t1")}, self.extract_tables(query))
|
||||
|
||||
def test_subselect(self):
|
||||
query = """
|
||||
|
@ -111,7 +123,9 @@ class SupersetTestCase(unittest.TestCase):
|
|||
) sub, s2.t2
|
||||
WHERE sub.resolution = 'NONE'
|
||||
"""
|
||||
self.assertEqual({"s1.t1", "s2.t2"}, self.extract_tables(query))
|
||||
self.assertEqual(
|
||||
{Table("t1", "s1"), Table("t2", "s2")}, self.extract_tables(query)
|
||||
)
|
||||
|
||||
query = """
|
||||
SELECT sub.*
|
||||
|
@ -122,7 +136,7 @@ class SupersetTestCase(unittest.TestCase):
|
|||
) sub
|
||||
WHERE sub.resolution = 'NONE'
|
||||
"""
|
||||
self.assertEqual({"s1.t1"}, self.extract_tables(query))
|
||||
self.assertEqual({Table("t1", "s1")}, self.extract_tables(query))
|
||||
|
||||
query = """
|
||||
SELECT * FROM t1
|
||||
|
@ -133,21 +147,24 @@ class SupersetTestCase(unittest.TestCase):
|
|||
WHERE ROW(5*t2.s1,77)=
|
||||
(SELECT 50,11*s1 FROM t4)));
|
||||
"""
|
||||
self.assertEqual({"t1", "t2", "t3", "t4"}, self.extract_tables(query))
|
||||
self.assertEqual(
|
||||
{Table("t1"), Table("t2"), Table("t3"), Table("t4")},
|
||||
self.extract_tables(query),
|
||||
)
|
||||
|
||||
def test_select_in_expression(self):
|
||||
query = "SELECT f1, (SELECT count(1) FROM t2) FROM t1"
|
||||
self.assertEqual({"t1", "t2"}, self.extract_tables(query))
|
||||
self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
|
||||
|
||||
def test_union(self):
|
||||
query = "SELECT * FROM t1 UNION SELECT * FROM t2"
|
||||
self.assertEqual({"t1", "t2"}, self.extract_tables(query))
|
||||
self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
|
||||
|
||||
query = "SELECT * FROM t1 UNION ALL SELECT * FROM t2"
|
||||
self.assertEqual({"t1", "t2"}, self.extract_tables(query))
|
||||
self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
|
||||
|
||||
query = "SELECT * FROM t1 INTERSECT ALL SELECT * FROM t2"
|
||||
self.assertEqual({"t1", "t2"}, self.extract_tables(query))
|
||||
self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
|
||||
|
||||
def test_select_from_values(self):
|
||||
query = "SELECT * FROM VALUES (13, 42)"
|
||||
|
@ -158,25 +175,25 @@ class SupersetTestCase(unittest.TestCase):
|
|||
SELECT ARRAY[1, 2, 3] AS my_array
|
||||
FROM t1 LIMIT 10
|
||||
"""
|
||||
self.assertEqual({"t1"}, self.extract_tables(query))
|
||||
self.assertEqual({Table("t1")}, self.extract_tables(query))
|
||||
|
||||
def test_select_if(self):
|
||||
query = """
|
||||
SELECT IF(CARDINALITY(my_array) >= 3, my_array[3], NULL)
|
||||
FROM t1 LIMIT 10
|
||||
"""
|
||||
self.assertEqual({"t1"}, self.extract_tables(query))
|
||||
self.assertEqual({Table("t1")}, self.extract_tables(query))
|
||||
|
||||
# SHOW TABLES ((FROM | IN) qualifiedName)? (LIKE pattern=STRING)?
|
||||
def test_show_tables(self):
|
||||
query = "SHOW TABLES FROM s1 like '%order%'"
|
||||
# TODO: figure out what should code do here
|
||||
self.assertEqual({"s1"}, self.extract_tables(query))
|
||||
self.assertEqual({Table("s1")}, self.extract_tables(query))
|
||||
|
||||
# SHOW COLUMNS (FROM | IN) qualifiedName
|
||||
def test_show_columns(self):
|
||||
query = "SHOW COLUMNS FROM t1"
|
||||
self.assertEqual({"t1"}, self.extract_tables(query))
|
||||
self.assertEqual({Table("t1")}, self.extract_tables(query))
|
||||
|
||||
def test_where_subquery(self):
|
||||
query = """
|
||||
|
@ -184,25 +201,25 @@ class SupersetTestCase(unittest.TestCase):
|
|||
FROM t1
|
||||
WHERE regionkey = (SELECT max(regionkey) FROM t2)
|
||||
"""
|
||||
self.assertEqual({"t1", "t2"}, self.extract_tables(query))
|
||||
self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
|
||||
|
||||
query = """
|
||||
SELECT name
|
||||
FROM t1
|
||||
WHERE regionkey IN (SELECT regionkey FROM t2)
|
||||
"""
|
||||
self.assertEqual({"t1", "t2"}, self.extract_tables(query))
|
||||
self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
|
||||
|
||||
query = """
|
||||
SELECT name
|
||||
FROM t1
|
||||
WHERE regionkey EXISTS (SELECT regionkey FROM t2)
|
||||
"""
|
||||
self.assertEqual({"t1", "t2"}, self.extract_tables(query))
|
||||
self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
|
||||
|
||||
# DESCRIBE | DESC qualifiedName
|
||||
def test_describe(self):
|
||||
self.assertEqual({"t1"}, self.extract_tables("DESCRIBE t1"))
|
||||
self.assertEqual({Table("t1")}, self.extract_tables("DESCRIBE t1"))
|
||||
|
||||
# SHOW PARTITIONS FROM qualifiedName (WHERE booleanExpression)?
|
||||
# (ORDER BY sortItem (',' sortItem)*)? (LIMIT limit=(INTEGER_VALUE | ALL))?
|
||||
|
@ -211,11 +228,11 @@ class SupersetTestCase(unittest.TestCase):
|
|||
SHOW PARTITIONS FROM orders
|
||||
WHERE ds >= '2013-01-01' ORDER BY ds DESC;
|
||||
"""
|
||||
self.assertEqual({"orders"}, self.extract_tables(query))
|
||||
self.assertEqual({Table("orders")}, self.extract_tables(query))
|
||||
|
||||
def test_join(self):
|
||||
query = "SELECT t1.*, t2.* FROM t1 JOIN t2 ON t1.a = t2.a;"
|
||||
self.assertEqual({"t1", "t2"}, self.extract_tables(query))
|
||||
self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
|
||||
|
||||
# subquery + join
|
||||
query = """
|
||||
|
@ -229,7 +246,9 @@ class SupersetTestCase(unittest.TestCase):
|
|||
) b
|
||||
ON a.date = b.date
|
||||
"""
|
||||
self.assertEqual({"left_table", "right_table"}, self.extract_tables(query))
|
||||
self.assertEqual(
|
||||
{Table("left_table"), Table("right_table")}, self.extract_tables(query)
|
||||
)
|
||||
|
||||
query = """
|
||||
SELECT a.date, b.name FROM
|
||||
|
@ -242,7 +261,9 @@ class SupersetTestCase(unittest.TestCase):
|
|||
) b
|
||||
ON a.date = b.date
|
||||
"""
|
||||
self.assertEqual({"left_table", "right_table"}, self.extract_tables(query))
|
||||
self.assertEqual(
|
||||
{Table("left_table"), Table("right_table")}, self.extract_tables(query)
|
||||
)
|
||||
|
||||
query = """
|
||||
SELECT a.date, b.name FROM
|
||||
|
@ -255,7 +276,9 @@ class SupersetTestCase(unittest.TestCase):
|
|||
) b
|
||||
ON a.date = b.date
|
||||
"""
|
||||
self.assertEqual({"left_table", "right_table"}, self.extract_tables(query))
|
||||
self.assertEqual(
|
||||
{Table("left_table"), Table("right_table")}, self.extract_tables(query)
|
||||
)
|
||||
|
||||
query = """
|
||||
SELECT a.date, b.name FROM
|
||||
|
@ -268,7 +291,9 @@ class SupersetTestCase(unittest.TestCase):
|
|||
) b
|
||||
ON a.date = b.date
|
||||
"""
|
||||
self.assertEqual({"left_table", "right_table"}, self.extract_tables(query))
|
||||
self.assertEqual(
|
||||
{Table("left_table"), Table("right_table")}, self.extract_tables(query)
|
||||
)
|
||||
|
||||
# TODO: add SEMI join support, SQL Parse does not handle it.
|
||||
# query = """
|
||||
|
@ -296,13 +321,16 @@ class SupersetTestCase(unittest.TestCase):
|
|||
WHERE ROW(5*t3.s1,77)=
|
||||
(SELECT 50,11*s1 FROM t4)));
|
||||
"""
|
||||
self.assertEqual({"t1", "t3", "t4", "t6"}, self.extract_tables(query))
|
||||
self.assertEqual(
|
||||
{Table("t1"), Table("t3"), Table("t4"), Table("t6")},
|
||||
self.extract_tables(query),
|
||||
)
|
||||
|
||||
query = """
|
||||
SELECT * FROM (SELECT * FROM (SELECT * FROM (SELECT * FROM EmployeeS)
|
||||
AS S1) AS S2) AS S3;
|
||||
"""
|
||||
self.assertEqual({"EmployeeS"}, self.extract_tables(query))
|
||||
self.assertEqual({Table("EmployeeS")}, self.extract_tables(query))
|
||||
|
||||
def test_with(self):
|
||||
query = """
|
||||
|
@ -312,7 +340,9 @@ class SupersetTestCase(unittest.TestCase):
|
|||
z AS (SELECT b AS c FROM t3)
|
||||
SELECT c FROM z;
|
||||
"""
|
||||
self.assertEqual({"t1", "t2", "t3"}, self.extract_tables(query))
|
||||
self.assertEqual(
|
||||
{Table("t1"), Table("t2"), Table("t3")}, self.extract_tables(query)
|
||||
)
|
||||
|
||||
query = """
|
||||
WITH
|
||||
|
@ -321,7 +351,7 @@ class SupersetTestCase(unittest.TestCase):
|
|||
z AS (SELECT b AS c FROM y)
|
||||
SELECT c FROM z;
|
||||
"""
|
||||
self.assertEqual({"t1"}, self.extract_tables(query))
|
||||
self.assertEqual({Table("t1")}, self.extract_tables(query))
|
||||
|
||||
def test_reusing_aliases(self):
|
||||
query = """
|
||||
|
@ -329,22 +359,22 @@ class SupersetTestCase(unittest.TestCase):
|
|||
q2 as ( select key from src where key = '5')
|
||||
select * from (select key from q1) a;
|
||||
"""
|
||||
self.assertEqual({"src"}, self.extract_tables(query))
|
||||
self.assertEqual({Table("src")}, self.extract_tables(query))
|
||||
|
||||
def test_multistatement(self):
|
||||
query = "SELECT * FROM t1; SELECT * FROM t2"
|
||||
self.assertEqual({"t1", "t2"}, self.extract_tables(query))
|
||||
self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
|
||||
|
||||
query = "SELECT * FROM t1; SELECT * FROM t2;"
|
||||
self.assertEqual({"t1", "t2"}, self.extract_tables(query))
|
||||
self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
|
||||
|
||||
def test_update_not_select(self):
|
||||
sql = sql_parse.ParsedQuery("UPDATE t1 SET col1 = NULL")
|
||||
sql = ParsedQuery("UPDATE t1 SET col1 = NULL")
|
||||
self.assertEqual(False, sql.is_select())
|
||||
self.assertEqual(False, sql.is_readonly())
|
||||
|
||||
def test_explain(self):
|
||||
sql = sql_parse.ParsedQuery("EXPLAIN SELECT 1")
|
||||
sql = ParsedQuery("EXPLAIN SELECT 1")
|
||||
|
||||
self.assertEqual(True, sql.is_explain())
|
||||
self.assertEqual(False, sql.is_select())
|
||||
|
@ -367,7 +397,12 @@ class SupersetTestCase(unittest.TestCase):
|
|||
ORDER BY "sum__m_example" DESC
|
||||
LIMIT 10;"""
|
||||
self.assertEqual(
|
||||
{"my_l_table", "my_b_table", "my_t_table", "inner_table"},
|
||||
{
|
||||
Table("my_l_table"),
|
||||
Table("my_b_table"),
|
||||
Table("my_t_table"),
|
||||
Table("inner_table"),
|
||||
},
|
||||
self.extract_tables(query),
|
||||
)
|
||||
|
||||
|
@ -375,13 +410,19 @@ class SupersetTestCase(unittest.TestCase):
|
|||
query = """SELECT *
|
||||
FROM table_a AS a, table_b AS b, table_c as c
|
||||
WHERE a.id = b.id and b.id = c.id"""
|
||||
self.assertEqual({"table_a", "table_b", "table_c"}, self.extract_tables(query))
|
||||
self.assertEqual(
|
||||
{Table("table_a"), Table("table_b"), Table("table_c")},
|
||||
self.extract_tables(query),
|
||||
)
|
||||
|
||||
def test_mixed_from_clause(self):
|
||||
query = """SELECT *
|
||||
FROM table_a AS a, (select * from table_b) AS b, table_c as c
|
||||
WHERE a.id = b.id and b.id = c.id"""
|
||||
self.assertEqual({"table_a", "table_b", "table_c"}, self.extract_tables(query))
|
||||
self.assertEqual(
|
||||
{Table("table_a"), Table("table_b"), Table("table_c")},
|
||||
self.extract_tables(query),
|
||||
)
|
||||
|
||||
def test_nested_selects(self):
|
||||
query = """
|
||||
|
@ -389,13 +430,17 @@ class SupersetTestCase(unittest.TestCase):
|
|||
from INFORMATION_SCHEMA.COLUMNS
|
||||
WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
|
||||
"""
|
||||
self.assertEqual({"INFORMATION_SCHEMA.COLUMNS"}, self.extract_tables(query))
|
||||
self.assertEqual(
|
||||
{Table("COLUMNS", "INFORMATION_SCHEMA")}, self.extract_tables(query)
|
||||
)
|
||||
query = """
|
||||
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME)
|
||||
from INFORMATION_SCHEMA.COLUMNS
|
||||
WHERE TABLE_NAME="bi_achivement_daily"),0x7e)));
|
||||
"""
|
||||
self.assertEqual({"INFORMATION_SCHEMA.COLUMNS"}, self.extract_tables(query))
|
||||
self.assertEqual(
|
||||
{Table("COLUMNS", "INFORMATION_SCHEMA")}, self.extract_tables(query)
|
||||
)
|
||||
|
||||
def test_complex_extract_tables3(self):
|
||||
query = """SELECT somecol AS somecol
|
||||
|
@ -431,7 +476,10 @@ class SupersetTestCase(unittest.TestCase):
|
|||
WHERE 2=2
|
||||
GROUP BY last_col
|
||||
LIMIT 50000;"""
|
||||
self.assertEqual({"a", "b", "c", "d", "e", "f"}, self.extract_tables(query))
|
||||
self.assertEqual(
|
||||
{Table("a"), Table("b"), Table("c"), Table("d"), Table("e"), Table("f")},
|
||||
self.extract_tables(query),
|
||||
)
|
||||
|
||||
def test_complex_cte_with_prefix(self):
|
||||
query = """
|
||||
|
@ -446,23 +494,23 @@ class SupersetTestCase(unittest.TestCase):
|
|||
GROUP BY SalesYear, SalesPersonID
|
||||
ORDER BY SalesPersonID, SalesYear;
|
||||
"""
|
||||
self.assertEqual({"SalesOrderHeader"}, self.extract_tables(query))
|
||||
self.assertEqual({Table("SalesOrderHeader")}, self.extract_tables(query))
|
||||
|
||||
def test_get_query_with_new_limit_comment(self):
|
||||
sql = "SELECT * FROM birth_names -- SOME COMMENT"
|
||||
parsed = sql_parse.ParsedQuery(sql)
|
||||
parsed = ParsedQuery(sql)
|
||||
newsql = parsed.set_or_update_query_limit(1000)
|
||||
self.assertEqual(newsql, sql + "\nLIMIT 1000")
|
||||
|
||||
def test_get_query_with_new_limit_comment_with_limit(self):
|
||||
sql = "SELECT * FROM birth_names -- SOME COMMENT WITH LIMIT 555"
|
||||
parsed = sql_parse.ParsedQuery(sql)
|
||||
parsed = ParsedQuery(sql)
|
||||
newsql = parsed.set_or_update_query_limit(1000)
|
||||
self.assertEqual(newsql, sql + "\nLIMIT 1000")
|
||||
|
||||
def test_get_query_with_new_limit_lower(self):
|
||||
sql = "SELECT * FROM birth_names LIMIT 555"
|
||||
parsed = sql_parse.ParsedQuery(sql)
|
||||
parsed = ParsedQuery(sql)
|
||||
newsql = parsed.set_or_update_query_limit(1000)
|
||||
# not applied as new limit is higher
|
||||
expected = "SELECT * FROM birth_names LIMIT 555"
|
||||
|
@ -470,7 +518,7 @@ class SupersetTestCase(unittest.TestCase):
|
|||
|
||||
def test_get_query_with_new_limit_upper(self):
|
||||
sql = "SELECT * FROM birth_names LIMIT 1555"
|
||||
parsed = sql_parse.ParsedQuery(sql)
|
||||
parsed = ParsedQuery(sql)
|
||||
newsql = parsed.set_or_update_query_limit(1000)
|
||||
# applied as new limit is lower
|
||||
expected = "SELECT * FROM birth_names LIMIT 1000"
|
||||
|
@ -481,7 +529,7 @@ class SupersetTestCase(unittest.TestCase):
|
|||
SELECT * FROM birth_names;
|
||||
SELECT * FROM birth_names LIMIT 1;
|
||||
"""
|
||||
parsed = sql_parse.ParsedQuery(multi_sql)
|
||||
parsed = ParsedQuery(multi_sql)
|
||||
statements = parsed.get_statements()
|
||||
self.assertEqual(len(statements), 2)
|
||||
expected = ["SELECT * FROM birth_names", "SELECT * FROM birth_names LIMIT 1"]
|
||||
|
@ -494,7 +542,7 @@ class SupersetTestCase(unittest.TestCase):
|
|||
SELECT * FROM birth_names;;;
|
||||
SELECT * FROM birth_names LIMIT 1
|
||||
"""
|
||||
parsed = sql_parse.ParsedQuery(multi_sql)
|
||||
parsed = ParsedQuery(multi_sql)
|
||||
statements = parsed.get_statements()
|
||||
self.assertEqual(len(statements), 4)
|
||||
expected = [
|
||||
|
@ -512,4 +560,4 @@ class SupersetTestCase(unittest.TestCase):
|
|||
match AS (SELECT * FROM f)
|
||||
SELECT * FROM match
|
||||
"""
|
||||
self.assertEqual({"foo"}, self.extract_tables(query))
|
||||
self.assertEqual({Table("foo")}, self.extract_tables(query))
|
||||
|
|
Loading…
Reference in New Issue