mirror of https://github.com/apache/superset.git
feat: improve adhoc SQL validation (#19454)
* feat: improve adhoc SQL validation * Small changes * Add more unit tests
This commit is contained in:
parent
1a1322d3d9
commit
6828624f61
|
@ -899,7 +899,11 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
elif expression_type == utils.AdhocMetricExpressionType.SQL:
|
||||
tp = self.get_template_processor()
|
||||
expression = tp.process_template(cast(str, metric["sqlExpression"]))
|
||||
validate_adhoc_subquery(expression)
|
||||
expression = validate_adhoc_subquery(
|
||||
expression,
|
||||
self.database_id,
|
||||
self.schema,
|
||||
)
|
||||
try:
|
||||
expression = sanitize_clause(expression)
|
||||
except QueryClauseValidationException as ex:
|
||||
|
@ -928,7 +932,11 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
if template_processor and expression:
|
||||
expression = template_processor.process_template(expression)
|
||||
if expression:
|
||||
validate_adhoc_subquery(expression)
|
||||
expression = validate_adhoc_subquery(
|
||||
expression,
|
||||
self.database_id,
|
||||
self.schema,
|
||||
)
|
||||
try:
|
||||
expression = sanitize_clause(expression)
|
||||
except QueryClauseValidationException as ex:
|
||||
|
@ -982,9 +990,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
if is_alias_used_in_orderby(col):
|
||||
col.name = f"{col.name}__"
|
||||
|
||||
def _get_sqla_row_level_filters(
|
||||
def get_sqla_row_level_filters(
|
||||
self, template_processor: BaseTemplateProcessor
|
||||
) -> List[str]:
|
||||
) -> List[TextClause]:
|
||||
"""
|
||||
Return the appropriate row level security filters for
|
||||
this table and the current user.
|
||||
|
@ -992,7 +1000,6 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
:param BaseTemplateProcessor template_processor: The template
|
||||
processor to apply to the filters.
|
||||
:returns: A list of SQL clauses to be ANDed together.
|
||||
:rtype: List[str]
|
||||
"""
|
||||
all_filters: List[TextClause] = []
|
||||
filter_groups: Dict[Union[int, str], List[TextClause]] = defaultdict(list)
|
||||
|
@ -1145,6 +1152,12 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
col: Union[AdhocMetric, ColumnElement] = orig_col
|
||||
if isinstance(col, dict):
|
||||
col = cast(AdhocMetric, col)
|
||||
if col.get("sqlExpression"):
|
||||
col["sqlExpression"] = validate_adhoc_subquery(
|
||||
cast(str, col["sqlExpression"]),
|
||||
self.database_id,
|
||||
self.schema,
|
||||
)
|
||||
if utils.is_adhoc_metric(col):
|
||||
# add adhoc sort by column to columns_by_name if not exists
|
||||
col = self.adhoc_metric_to_sqla(col, columns_by_name)
|
||||
|
@ -1194,7 +1207,11 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
elif selected in columns_by_name:
|
||||
outer = columns_by_name[selected].get_sqla_col()
|
||||
else:
|
||||
validate_adhoc_subquery(selected)
|
||||
selected = validate_adhoc_subquery(
|
||||
selected,
|
||||
self.database_id,
|
||||
self.schema,
|
||||
)
|
||||
outer = literal_column(f"({selected})")
|
||||
outer = self.make_sqla_column_compatible(outer, selected)
|
||||
else:
|
||||
|
@ -1207,7 +1224,11 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
select_exprs.append(outer)
|
||||
elif columns:
|
||||
for selected in columns:
|
||||
validate_adhoc_subquery(selected)
|
||||
selected = validate_adhoc_subquery(
|
||||
selected,
|
||||
self.database_id,
|
||||
self.schema,
|
||||
)
|
||||
select_exprs.append(
|
||||
columns_by_name[selected].get_sqla_col()
|
||||
if selected in columns_by_name
|
||||
|
@ -1373,7 +1394,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
_("Invalid filter operation type: %(op)s", op=op)
|
||||
)
|
||||
if is_feature_enabled("ROW_LEVEL_SECURITY"):
|
||||
where_clause_and += self._get_sqla_row_level_filters(template_processor)
|
||||
where_clause_and += self.get_sqla_row_level_filters(template_processor)
|
||||
if extras:
|
||||
where = extras.get("where")
|
||||
if where:
|
||||
|
@ -1420,7 +1441,6 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
and db_engine_spec.allows_hidden_cc_in_orderby
|
||||
and col.name in [select_col.name for select_col in select_exprs]
|
||||
):
|
||||
validate_adhoc_subquery(str(col.expression))
|
||||
col = literal_column(col.name)
|
||||
direction = asc if ascending else desc
|
||||
qry = qry.order_by(direction(col))
|
||||
|
|
|
@ -33,7 +33,7 @@ from superset.exceptions import (
|
|||
)
|
||||
from superset.models.core import Database
|
||||
from superset.result_set import SupersetResultSet
|
||||
from superset.sql_parse import has_table_query, ParsedQuery, Table
|
||||
from superset.sql_parse import has_table_query, insert_rls, ParsedQuery, Table
|
||||
from superset.tables.models import Table as NewTable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -136,29 +136,39 @@ def get_virtual_table_metadata(dataset: "SqlaTable") -> List[Dict[str, str]]:
|
|||
return cols
|
||||
|
||||
|
||||
def validate_adhoc_subquery(raw_sql: str) -> None:
|
||||
def validate_adhoc_subquery(
|
||||
sql: str,
|
||||
database_id: int,
|
||||
default_schema: str,
|
||||
) -> str:
|
||||
"""
|
||||
Check if adhoc SQL contains sub-queries or nested sub-queries with table
|
||||
:param raw_sql: adhoc sql expression
|
||||
Check if adhoc SQL contains sub-queries or nested sub-queries with table.
|
||||
|
||||
If sub-queries are allowed, the adhoc SQL is modified to insert any applicable RLS
|
||||
predicates to it.
|
||||
|
||||
:param sql: adhoc sql expression
|
||||
:raise SupersetSecurityException if sql contains sub-queries or
|
||||
nested sub-queries with table
|
||||
"""
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from superset import is_feature_enabled
|
||||
|
||||
if is_feature_enabled("ALLOW_ADHOC_SUBQUERY"):
|
||||
return
|
||||
|
||||
for statement in sqlparse.parse(raw_sql):
|
||||
statements = []
|
||||
for statement in sqlparse.parse(sql):
|
||||
if has_table_query(statement):
|
||||
raise SupersetSecurityException(
|
||||
SupersetError(
|
||||
error_type=SupersetErrorType.ADHOC_SUBQUERY_NOT_ALLOWED_ERROR,
|
||||
message=_("Custom SQL fields cannot contain sub-queries."),
|
||||
level=ErrorLevel.ERROR,
|
||||
if not is_feature_enabled("ALLOW_ADHOC_SUBQUERY"):
|
||||
raise SupersetSecurityException(
|
||||
SupersetError(
|
||||
error_type=SupersetErrorType.ADHOC_SUBQUERY_NOT_ALLOWED_ERROR,
|
||||
message=_("Custom SQL fields cannot contain sub-queries."),
|
||||
level=ErrorLevel.ERROR,
|
||||
)
|
||||
)
|
||||
)
|
||||
return
|
||||
statement = insert_rls(statement, database_id, default_schema)
|
||||
statements.append(statement)
|
||||
|
||||
return ";\n".join(str(statement) for statement in statements)
|
||||
|
||||
|
||||
def load_or_create_tables( # pylint: disable=too-many-arguments
|
||||
|
|
|
@ -18,10 +18,11 @@ import logging
|
|||
import re
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Set, Tuple
|
||||
from typing import cast, List, Optional, Set, Tuple
|
||||
from urllib import parse
|
||||
|
||||
import sqlparse
|
||||
from sqlalchemy import and_
|
||||
from sqlparse.sql import (
|
||||
Identifier,
|
||||
IdentifierList,
|
||||
|
@ -283,7 +284,7 @@ class ParsedQuery:
|
|||
return statements
|
||||
|
||||
@staticmethod
|
||||
def _get_table(tlist: TokenList) -> Optional[Table]:
|
||||
def get_table(tlist: TokenList) -> Optional[Table]:
|
||||
"""
|
||||
Return the table if valid, i.e., conforms to the [[catalog.]schema.]table
|
||||
construct.
|
||||
|
@ -324,7 +325,7 @@ class ParsedQuery:
|
|||
"""
|
||||
# exclude subselects
|
||||
if "(" not in str(token_list):
|
||||
table = self._get_table(token_list)
|
||||
table = self.get_table(token_list)
|
||||
if table and not table.table.startswith(CTE_PREFIX):
|
||||
self._tables.add(table)
|
||||
return
|
||||
|
@ -500,7 +501,7 @@ def has_table_query(token_list: TokenList) -> bool:
|
|||
state = InsertRLSState.SCANNING
|
||||
for token in token_list.tokens:
|
||||
|
||||
# # Recurse into child token list
|
||||
# Recurse into child token list
|
||||
if isinstance(token, TokenList) and has_table_query(token):
|
||||
return True
|
||||
|
||||
|
@ -523,7 +524,7 @@ def has_table_query(token_list: TokenList) -> bool:
|
|||
|
||||
def add_table_name(rls: TokenList, table: str) -> None:
|
||||
"""
|
||||
Modify a RLS expression ensuring columns are fully qualified.
|
||||
Modify a RLS expression inplace ensuring columns are fully qualified.
|
||||
"""
|
||||
tokens = rls.tokens[:]
|
||||
while tokens:
|
||||
|
@ -539,45 +540,70 @@ def add_table_name(rls: TokenList, table: str) -> None:
|
|||
tokens.extend(token.tokens)
|
||||
|
||||
|
||||
def matches_table_name(candidate: Token, table: str) -> bool:
|
||||
def get_rls_for_table(
|
||||
candidate: Token,
|
||||
database_id: int,
|
||||
default_schema: Optional[str],
|
||||
) -> Optional[TokenList]:
|
||||
"""
|
||||
Returns if the token represents a reference to the table.
|
||||
|
||||
Tables can be fully qualified with periods.
|
||||
|
||||
Note that in theory a table should be represented as an identifier, but due to
|
||||
sqlparse's aggressive list of keywords (spanning multiple dialects) often it gets
|
||||
classified as a keyword.
|
||||
Given a table name, return any associated RLS predicates.
|
||||
"""
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from superset import db
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
|
||||
if not isinstance(candidate, Identifier):
|
||||
candidate = Identifier([Token(Name, candidate.value)])
|
||||
|
||||
target = sqlparse.parse(table)[0].tokens[0]
|
||||
if not isinstance(target, Identifier):
|
||||
target = Identifier([Token(Name, target.value)])
|
||||
table = ParsedQuery.get_table(candidate)
|
||||
if not table:
|
||||
return None
|
||||
|
||||
# match from right to left, splitting on the period, eg, schema.table == table
|
||||
for left, right in zip(candidate.tokens[::-1], target.tokens[::-1]):
|
||||
if left.value != right.value:
|
||||
return False
|
||||
dataset = (
|
||||
db.session.query(SqlaTable)
|
||||
.filter(
|
||||
and_(
|
||||
SqlaTable.database_id == database_id,
|
||||
SqlaTable.schema == (table.schema or default_schema),
|
||||
SqlaTable.table_name == table.table,
|
||||
)
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
if not dataset:
|
||||
return None
|
||||
|
||||
return True
|
||||
template_processor = dataset.get_template_processor()
|
||||
# pylint: disable=protected-access
|
||||
predicate = " AND ".join(
|
||||
str(filter_)
|
||||
for filter_ in dataset.get_sqla_row_level_filters(template_processor)
|
||||
)
|
||||
if not predicate:
|
||||
return None
|
||||
|
||||
rls = sqlparse.parse(predicate)[0]
|
||||
add_table_name(rls, str(dataset))
|
||||
|
||||
return rls
|
||||
|
||||
|
||||
def insert_rls(token_list: TokenList, table: str, rls: TokenList) -> TokenList:
|
||||
def insert_rls(
|
||||
token_list: TokenList,
|
||||
database_id: int,
|
||||
default_schema: Optional[str],
|
||||
) -> TokenList:
|
||||
"""
|
||||
Update a statement inplace applying an RLS associated with a given table.
|
||||
Update a statement inplace applying any associated RLS predicates.
|
||||
"""
|
||||
# make sure the identifier has the table name
|
||||
add_table_name(rls, table)
|
||||
|
||||
rls: Optional[TokenList] = None
|
||||
state = InsertRLSState.SCANNING
|
||||
for token in token_list.tokens:
|
||||
|
||||
# Recurse into child token list
|
||||
if isinstance(token, TokenList):
|
||||
i = token_list.tokens.index(token)
|
||||
token_list.tokens[i] = insert_rls(token, table, rls)
|
||||
token_list.tokens[i] = insert_rls(token, database_id, default_schema)
|
||||
|
||||
# Found a source keyword (FROM/JOIN)
|
||||
if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]):
|
||||
|
@ -587,12 +613,14 @@ def insert_rls(token_list: TokenList, table: str, rls: TokenList) -> TokenList:
|
|||
elif state == InsertRLSState.SEEN_SOURCE and (
|
||||
isinstance(token, Identifier) or token.ttype == Keyword
|
||||
):
|
||||
if matches_table_name(token, table):
|
||||
rls = get_rls_for_table(token, database_id, default_schema)
|
||||
if rls:
|
||||
state = InsertRLSState.FOUND_TABLE
|
||||
|
||||
# Found WHERE clause, insert RLS. Note that we insert it even it already exists,
|
||||
# to be on the safe side: it could be present in a clause like `1=1 OR RLS`.
|
||||
elif state == InsertRLSState.FOUND_TABLE and isinstance(token, Where):
|
||||
rls = cast(TokenList, rls)
|
||||
token.tokens[1:1] = [Token(Whitespace, " "), Token(Punctuation, "(")]
|
||||
token.tokens.extend(
|
||||
[
|
||||
|
|
|
@ -14,21 +14,24 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
# pylint: disable=invalid-name, too-many-lines
|
||||
# pylint: disable=invalid-name, redefined-outer-name, unused-argument, protected-access, too-many-lines
|
||||
|
||||
import unittest
|
||||
from typing import Set
|
||||
from typing import Optional, Set
|
||||
|
||||
import pytest
|
||||
import sqlparse
|
||||
from pytest_mock import MockerFixture
|
||||
from sqlalchemy import text
|
||||
from sqlparse.sql import Identifier, Token, TokenList
|
||||
from sqlparse.tokens import Name
|
||||
|
||||
from superset.exceptions import QueryClauseValidationException
|
||||
from superset.sql_parse import (
|
||||
add_table_name,
|
||||
get_rls_for_table,
|
||||
has_table_query,
|
||||
insert_rls,
|
||||
matches_table_name,
|
||||
ParsedQuery,
|
||||
sanitize_clause,
|
||||
strip_comments_from_sql,
|
||||
|
@ -1391,13 +1394,37 @@ def test_has_table_query(sql: str, expected: bool) -> None:
|
|||
),
|
||||
],
|
||||
)
|
||||
def test_insert_rls(sql: str, table: str, rls: str, expected: str) -> None:
|
||||
def test_insert_rls(
|
||||
mocker: MockerFixture, sql: str, table: str, rls: str, expected: str
|
||||
) -> None:
|
||||
"""
|
||||
Insert into a statement a given RLS condition associated with a table.
|
||||
"""
|
||||
statement = sqlparse.parse(sql)[0]
|
||||
condition = sqlparse.parse(rls)[0]
|
||||
assert str(insert_rls(statement, table, condition)).strip() == expected.strip()
|
||||
add_table_name(condition, table)
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def get_rls_for_table(
|
||||
candidate: Token, database_id: int, default_schema: str
|
||||
) -> Optional[TokenList]:
|
||||
"""
|
||||
Return the RLS ``condition`` if ``candidate`` matches ``table``.
|
||||
"""
|
||||
# compare ignoring schema
|
||||
for left, right in zip(str(candidate).split(".")[::-1], table.split(".")[::-1]):
|
||||
if left != right:
|
||||
return None
|
||||
return condition
|
||||
|
||||
mocker.patch("superset.sql_parse.get_rls_for_table", new=get_rls_for_table)
|
||||
|
||||
statement = sqlparse.parse(sql)[0]
|
||||
assert (
|
||||
str(
|
||||
insert_rls(token_list=statement, database_id=1, default_schema="my_schema")
|
||||
).strip()
|
||||
== expected.strip()
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -1415,16 +1442,29 @@ def test_add_table_name(rls: str, table: str, expected: str) -> None:
|
|||
assert str(condition) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"candidate,table,expected",
|
||||
[
|
||||
("table", "table", True),
|
||||
("schema.table", "table", True),
|
||||
("table", "schema.table", True),
|
||||
('schema."my table"', '"my table"', True),
|
||||
('schema."my.table"', '"my.table"', True),
|
||||
],
|
||||
)
|
||||
def test_matches_table_name(candidate: str, table: str, expected: bool) -> None:
|
||||
token = sqlparse.parse(candidate)[0].tokens[0]
|
||||
assert matches_table_name(token, table) == expected
|
||||
def test_get_rls_for_table(mocker: MockerFixture, app_context: None) -> None:
|
||||
"""
|
||||
Tests for ``get_rls_for_table``.
|
||||
"""
|
||||
candidate = Identifier([Token(Name, "some_table")])
|
||||
db = mocker.patch("superset.db")
|
||||
dataset = db.session.query().filter().one_or_none()
|
||||
dataset.__str__.return_value = "some_table"
|
||||
|
||||
dataset.get_sqla_row_level_filters.return_value = [text("organization_id = 1")]
|
||||
assert (
|
||||
str(get_rls_for_table(candidate, 1, "public"))
|
||||
== "some_table.organization_id = 1"
|
||||
)
|
||||
|
||||
dataset.get_sqla_row_level_filters.return_value = [
|
||||
text("organization_id = 1"),
|
||||
text("foo = 'bar'"),
|
||||
]
|
||||
assert (
|
||||
str(get_rls_for_table(candidate, 1, "public"))
|
||||
== "some_table.organization_id = 1 AND some_table.foo = 'bar'"
|
||||
)
|
||||
|
||||
dataset.get_sqla_row_level_filters.return_value = []
|
||||
assert get_rls_for_table(candidate, 1, "public") is None
|
||||
|
|
Loading…
Reference in New Issue