feat: improve adhoc SQL validation (#19454)

* feat: improve adhoc SQL validation

* Small changes

* Add more unit tests
This commit is contained in:
Beto Dealmeida 2022-03-31 11:55:19 -07:00 committed by GitHub
parent 1a1322d3d9
commit 6828624f61
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 170 additions and 72 deletions

View File

@ -899,7 +899,11 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
elif expression_type == utils.AdhocMetricExpressionType.SQL:
tp = self.get_template_processor()
expression = tp.process_template(cast(str, metric["sqlExpression"]))
validate_adhoc_subquery(expression)
expression = validate_adhoc_subquery(
expression,
self.database_id,
self.schema,
)
try:
expression = sanitize_clause(expression)
except QueryClauseValidationException as ex:
@ -928,7 +932,11 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
if template_processor and expression:
expression = template_processor.process_template(expression)
if expression:
validate_adhoc_subquery(expression)
expression = validate_adhoc_subquery(
expression,
self.database_id,
self.schema,
)
try:
expression = sanitize_clause(expression)
except QueryClauseValidationException as ex:
@ -982,9 +990,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
if is_alias_used_in_orderby(col):
col.name = f"{col.name}__"
def _get_sqla_row_level_filters(
def get_sqla_row_level_filters(
self, template_processor: BaseTemplateProcessor
) -> List[str]:
) -> List[TextClause]:
"""
Return the appropriate row level security filters for
this table and the current user.
@ -992,7 +1000,6 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
:param BaseTemplateProcessor template_processor: The template
processor to apply to the filters.
:returns: A list of SQL clauses to be ANDed together.
:rtype: List[str]
"""
all_filters: List[TextClause] = []
filter_groups: Dict[Union[int, str], List[TextClause]] = defaultdict(list)
@ -1145,6 +1152,12 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
col: Union[AdhocMetric, ColumnElement] = orig_col
if isinstance(col, dict):
col = cast(AdhocMetric, col)
if col.get("sqlExpression"):
col["sqlExpression"] = validate_adhoc_subquery(
cast(str, col["sqlExpression"]),
self.database_id,
self.schema,
)
if utils.is_adhoc_metric(col):
# add adhoc sort by column to columns_by_name if not exists
col = self.adhoc_metric_to_sqla(col, columns_by_name)
@ -1194,7 +1207,11 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
elif selected in columns_by_name:
outer = columns_by_name[selected].get_sqla_col()
else:
validate_adhoc_subquery(selected)
selected = validate_adhoc_subquery(
selected,
self.database_id,
self.schema,
)
outer = literal_column(f"({selected})")
outer = self.make_sqla_column_compatible(outer, selected)
else:
@ -1207,7 +1224,11 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
select_exprs.append(outer)
elif columns:
for selected in columns:
validate_adhoc_subquery(selected)
selected = validate_adhoc_subquery(
selected,
self.database_id,
self.schema,
)
select_exprs.append(
columns_by_name[selected].get_sqla_col()
if selected in columns_by_name
@ -1373,7 +1394,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
_("Invalid filter operation type: %(op)s", op=op)
)
if is_feature_enabled("ROW_LEVEL_SECURITY"):
where_clause_and += self._get_sqla_row_level_filters(template_processor)
where_clause_and += self.get_sqla_row_level_filters(template_processor)
if extras:
where = extras.get("where")
if where:
@ -1420,7 +1441,6 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
and db_engine_spec.allows_hidden_cc_in_orderby
and col.name in [select_col.name for select_col in select_exprs]
):
validate_adhoc_subquery(str(col.expression))
col = literal_column(col.name)
direction = asc if ascending else desc
qry = qry.order_by(direction(col))

View File

@ -33,7 +33,7 @@ from superset.exceptions import (
)
from superset.models.core import Database
from superset.result_set import SupersetResultSet
from superset.sql_parse import has_table_query, ParsedQuery, Table
from superset.sql_parse import has_table_query, insert_rls, ParsedQuery, Table
from superset.tables.models import Table as NewTable
if TYPE_CHECKING:
@ -136,29 +136,39 @@ def get_virtual_table_metadata(dataset: "SqlaTable") -> List[Dict[str, str]]:
return cols
def validate_adhoc_subquery(raw_sql: str) -> None:
def validate_adhoc_subquery(
sql: str,
database_id: int,
default_schema: str,
) -> str:
"""
Check if adhoc SQL contains sub-queries or nested sub-queries with table
:param raw_sql: adhoc sql expression
Check if adhoc SQL contains sub-queries or nested sub-queries with table.
If sub-queries are allowed, the adhoc SQL is modified to insert any applicable RLS
predicates to it.
:param sql: adhoc sql expression
:raise SupersetSecurityException if sql contains sub-queries or
nested sub-queries with table
"""
# pylint: disable=import-outside-toplevel
from superset import is_feature_enabled
if is_feature_enabled("ALLOW_ADHOC_SUBQUERY"):
return
for statement in sqlparse.parse(raw_sql):
statements = []
for statement in sqlparse.parse(sql):
if has_table_query(statement):
raise SupersetSecurityException(
SupersetError(
error_type=SupersetErrorType.ADHOC_SUBQUERY_NOT_ALLOWED_ERROR,
message=_("Custom SQL fields cannot contain sub-queries."),
level=ErrorLevel.ERROR,
if not is_feature_enabled("ALLOW_ADHOC_SUBQUERY"):
raise SupersetSecurityException(
SupersetError(
error_type=SupersetErrorType.ADHOC_SUBQUERY_NOT_ALLOWED_ERROR,
message=_("Custom SQL fields cannot contain sub-queries."),
level=ErrorLevel.ERROR,
)
)
)
return
statement = insert_rls(statement, database_id, default_schema)
statements.append(statement)
return ";\n".join(str(statement) for statement in statements)
def load_or_create_tables( # pylint: disable=too-many-arguments

View File

@ -18,10 +18,11 @@ import logging
import re
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional, Set, Tuple
from typing import cast, List, Optional, Set, Tuple
from urllib import parse
import sqlparse
from sqlalchemy import and_
from sqlparse.sql import (
Identifier,
IdentifierList,
@ -283,7 +284,7 @@ class ParsedQuery:
return statements
@staticmethod
def _get_table(tlist: TokenList) -> Optional[Table]:
def get_table(tlist: TokenList) -> Optional[Table]:
"""
Return the table if valid, i.e., conforms to the [[catalog.]schema.]table
construct.
@ -324,7 +325,7 @@ class ParsedQuery:
"""
# exclude subselects
if "(" not in str(token_list):
table = self._get_table(token_list)
table = self.get_table(token_list)
if table and not table.table.startswith(CTE_PREFIX):
self._tables.add(table)
return
@ -500,7 +501,7 @@ def has_table_query(token_list: TokenList) -> bool:
state = InsertRLSState.SCANNING
for token in token_list.tokens:
# # Recurse into child token list
# Recurse into child token list
if isinstance(token, TokenList) and has_table_query(token):
return True
@ -523,7 +524,7 @@ def has_table_query(token_list: TokenList) -> bool:
def add_table_name(rls: TokenList, table: str) -> None:
"""
Modify a RLS expression ensuring columns are fully qualified.
Modify a RLS expression inplace ensuring columns are fully qualified.
"""
tokens = rls.tokens[:]
while tokens:
@ -539,45 +540,70 @@ def add_table_name(rls: TokenList, table: str) -> None:
tokens.extend(token.tokens)
def matches_table_name(candidate: Token, table: str) -> bool:
def get_rls_for_table(
candidate: Token,
database_id: int,
default_schema: Optional[str],
) -> Optional[TokenList]:
"""
Returns if the token represents a reference to the table.
Tables can be fully qualified with periods.
Note that in theory a table should be represented as an identifier, but due to
sqlparse's aggressive list of keywords (spanning multiple dialects) often it gets
classified as a keyword.
Given a table name, return any associated RLS predicates.
"""
# pylint: disable=import-outside-toplevel
from superset import db
from superset.connectors.sqla.models import SqlaTable
if not isinstance(candidate, Identifier):
candidate = Identifier([Token(Name, candidate.value)])
target = sqlparse.parse(table)[0].tokens[0]
if not isinstance(target, Identifier):
target = Identifier([Token(Name, target.value)])
table = ParsedQuery.get_table(candidate)
if not table:
return None
# match from right to left, splitting on the period, eg, schema.table == table
for left, right in zip(candidate.tokens[::-1], target.tokens[::-1]):
if left.value != right.value:
return False
dataset = (
db.session.query(SqlaTable)
.filter(
and_(
SqlaTable.database_id == database_id,
SqlaTable.schema == (table.schema or default_schema),
SqlaTable.table_name == table.table,
)
)
.one_or_none()
)
if not dataset:
return None
return True
template_processor = dataset.get_template_processor()
# pylint: disable=protected-access
predicate = " AND ".join(
str(filter_)
for filter_ in dataset.get_sqla_row_level_filters(template_processor)
)
if not predicate:
return None
rls = sqlparse.parse(predicate)[0]
add_table_name(rls, str(dataset))
return rls
def insert_rls(token_list: TokenList, table: str, rls: TokenList) -> TokenList:
def insert_rls(
token_list: TokenList,
database_id: int,
default_schema: Optional[str],
) -> TokenList:
"""
Update a statement inplace applying an RLS associated with a given table.
Update a statement inplace applying any associated RLS predicates.
"""
# make sure the identifier has the table name
add_table_name(rls, table)
rls: Optional[TokenList] = None
state = InsertRLSState.SCANNING
for token in token_list.tokens:
# Recurse into child token list
if isinstance(token, TokenList):
i = token_list.tokens.index(token)
token_list.tokens[i] = insert_rls(token, table, rls)
token_list.tokens[i] = insert_rls(token, database_id, default_schema)
# Found a source keyword (FROM/JOIN)
if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]):
@ -587,12 +613,14 @@ def insert_rls(token_list: TokenList, table: str, rls: TokenList) -> TokenList:
elif state == InsertRLSState.SEEN_SOURCE and (
isinstance(token, Identifier) or token.ttype == Keyword
):
if matches_table_name(token, table):
rls = get_rls_for_table(token, database_id, default_schema)
if rls:
state = InsertRLSState.FOUND_TABLE
# Found WHERE clause, insert RLS. Note that we insert it even it already exists,
# to be on the safe side: it could be present in a clause like `1=1 OR RLS`.
elif state == InsertRLSState.FOUND_TABLE and isinstance(token, Where):
rls = cast(TokenList, rls)
token.tokens[1:1] = [Token(Whitespace, " "), Token(Punctuation, "(")]
token.tokens.extend(
[

View File

@ -14,21 +14,24 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, too-many-lines
# pylint: disable=invalid-name, redefined-outer-name, unused-argument, protected-access, too-many-lines
import unittest
from typing import Set
from typing import Optional, Set
import pytest
import sqlparse
from pytest_mock import MockerFixture
from sqlalchemy import text
from sqlparse.sql import Identifier, Token, TokenList
from sqlparse.tokens import Name
from superset.exceptions import QueryClauseValidationException
from superset.sql_parse import (
add_table_name,
get_rls_for_table,
has_table_query,
insert_rls,
matches_table_name,
ParsedQuery,
sanitize_clause,
strip_comments_from_sql,
@ -1391,13 +1394,37 @@ def test_has_table_query(sql: str, expected: bool) -> None:
),
],
)
def test_insert_rls(sql: str, table: str, rls: str, expected: str) -> None:
def test_insert_rls(
mocker: MockerFixture, sql: str, table: str, rls: str, expected: str
) -> None:
"""
Insert into a statement a given RLS condition associated with a table.
"""
statement = sqlparse.parse(sql)[0]
condition = sqlparse.parse(rls)[0]
assert str(insert_rls(statement, table, condition)).strip() == expected.strip()
add_table_name(condition, table)
# pylint: disable=unused-argument
def get_rls_for_table(
candidate: Token, database_id: int, default_schema: str
) -> Optional[TokenList]:
"""
Return the RLS ``condition`` if ``candidate`` matches ``table``.
"""
# compare ignoring schema
for left, right in zip(str(candidate).split(".")[::-1], table.split(".")[::-1]):
if left != right:
return None
return condition
mocker.patch("superset.sql_parse.get_rls_for_table", new=get_rls_for_table)
statement = sqlparse.parse(sql)[0]
assert (
str(
insert_rls(token_list=statement, database_id=1, default_schema="my_schema")
).strip()
== expected.strip()
)
@pytest.mark.parametrize(
@ -1415,16 +1442,29 @@ def test_add_table_name(rls: str, table: str, expected: str) -> None:
assert str(condition) == expected
@pytest.mark.parametrize(
"candidate,table,expected",
[
("table", "table", True),
("schema.table", "table", True),
("table", "schema.table", True),
('schema."my table"', '"my table"', True),
('schema."my.table"', '"my.table"', True),
],
)
def test_matches_table_name(candidate: str, table: str, expected: bool) -> None:
token = sqlparse.parse(candidate)[0].tokens[0]
assert matches_table_name(token, table) == expected
def test_get_rls_for_table(mocker: MockerFixture, app_context: None) -> None:
"""
Tests for ``get_rls_for_table``.
"""
candidate = Identifier([Token(Name, "some_table")])
db = mocker.patch("superset.db")
dataset = db.session.query().filter().one_or_none()
dataset.__str__.return_value = "some_table"
dataset.get_sqla_row_level_filters.return_value = [text("organization_id = 1")]
assert (
str(get_rls_for_table(candidate, 1, "public"))
== "some_table.organization_id = 1"
)
dataset.get_sqla_row_level_filters.return_value = [
text("organization_id = 1"),
text("foo = 'bar'"),
]
assert (
str(get_rls_for_table(candidate, 1, "public"))
== "some_table.organization_id = 1 AND some_table.foo = 'bar'"
)
dataset.get_sqla_row_level_filters.return_value = []
assert get_rls_for_table(candidate, 1, "public") is None