fix: avoid escaping bind-like params containing colons (#17419)

* fix: avoid escaping bind-like params containing colons

* fix query for mysql

* address comments
This commit is contained in:
Ville Brofeldt 2021-11-13 09:01:49 +02:00 committed by GitHub
parent aa8040ec9b
commit ad8a7c42f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 72 additions and 36 deletions

View File

@ -65,7 +65,7 @@ from sqlalchemy.engine.base import Connection
from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty, Session
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.sql import column, ColumnElement, literal_column, table, text
from sqlalchemy.sql import column, ColumnElement, literal_column, table
from sqlalchemy.sql.elements import ColumnClause, TextClause
from sqlalchemy.sql.expression import Label, Select, TextAsFrom
from sqlalchemy.sql.selectable import Alias, TableClause
@ -103,6 +103,16 @@ logger = logging.getLogger(__name__)
VIRTUAL_TABLE_ALIAS = "virtual_table"
def text(clause: str) -> TextClause:
"""
SQLALchemy wrapper to ensure text clauses are escaped properly
:param clause: clause potentially containing colons
:return: text clause with escaped colons
"""
return sa.text(clause.replace(":", "\\:"))
class SqlaQuery(NamedTuple):
applied_template_filters: List[str]
extra_cache_keys: List[Any]
@ -806,7 +816,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
raise QueryObjectValidationError(
_("Virtual dataset query must be read-only")
)
return TextAsFrom(sa.text(from_sql), []).alias(VIRTUAL_TABLE_ALIAS)
return TextAsFrom(text(from_sql), []).alias(VIRTUAL_TABLE_ALIAS)
def get_rendered_sql(
self, template_processor: Optional[BaseTemplateProcessor] = None
@ -827,8 +837,6 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
)
) from ex
sql = sqlparse.format(sql.strip("\t\r\n; "), strip_comments=True)
# we need to escape strings that SQLAlchemy interprets as bind parameters
sql = utils.escape_sqla_query_binds(sql)
if not sql:
raise QueryObjectValidationError(_("Virtual dataset query cannot be empty"))
if len(sqlparse.split(sql)) > 1:
@ -1286,7 +1294,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
msg=ex.message,
)
) from ex
where_clause_and += [sa.text("({})".format(where))]
where_clause_and += [text(f"({where})")]
having = extras.get("having")
if having:
try:
@ -1298,7 +1306,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
msg=ex.message,
)
) from ex
having_clause_and += [sa.text("({})".format(having))]
having_clause_and += [text(f"({having})")]
if apply_fetch_values_predicate and self.fetch_values_predicate:
qry = qry.where(self.get_fetch_values_predicate())
if granularity:

View File

@ -81,7 +81,6 @@ from sqlalchemy import event, exc, inspect, select, Text
from sqlalchemy.dialects.mysql import MEDIUMTEXT
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.type_api import Variant
from sqlalchemy.types import TEXT, TypeDecorator, TypeEngine
from typing_extensions import TypedDict, TypeGuard
@ -132,8 +131,6 @@ JS_MAX_INTEGER = 9007199254740991 # Largest int Java Script can handle 2^53-1
InputType = TypeVar("InputType")
BIND_PARAM_REGEX = TextClause._bind_params_regex # pylint: disable=protected-access
class LenientEnum(Enum):
"""Enums with a `get` method that convert a enum value to `Enum` if it is a
@ -1771,29 +1768,3 @@ def apply_max_row_limit(limit: int, max_limit: Optional[int] = None,) -> int:
if limit != 0:
return min(max_limit, limit)
return max_limit
def escape_sqla_query_binds(sql: str) -> str:
"""
Replace strings in a query that SQLAlchemy would otherwise interpret as
bind parameters.
:param sql: unescaped query string
:return: escaped query string
>>> escape_sqla_query_binds("select ':foo'")
"select '\\\\:foo'"
>>> escape_sqla_query_binds("select 'foo'::TIMESTAMP")
"select 'foo'::TIMESTAMP"
>>> escape_sqla_query_binds("select ':foo :bar'::TIMESTAMP")
"select '\\\\:foo \\\\:bar'::TIMESTAMP"
>>> escape_sqla_query_binds("select ':foo :foo :bar'::TIMESTAMP")
"select '\\\\:foo \\\\:foo \\\\:bar'::TIMESTAMP"
"""
matches = BIND_PARAM_REGEX.finditer(sql)
processed_binds = set()
for match in matches:
bind = match.group(0)
if bind not in processed_binds:
sql = sql.replace(bind, bind.replace(":", "\\:"))
processed_binds.add(bind)
return sql

View File

@ -20,7 +20,7 @@ import json
import unittest
from datetime import datetime
from io import BytesIO
from typing import Optional
from typing import Optional, List
from unittest import mock
from zipfile import is_zipfile, ZipFile
@ -42,6 +42,7 @@ from tests.integration_tests.fixtures.world_bank_dashboard import (
load_world_bank_dashboard_with_slices,
)
from tests.integration_tests.test_app import app
from superset import security_manager
from superset.charts.commands.data import ChartDataCommand
from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.errors import SupersetErrorType
@ -56,6 +57,7 @@ from superset.utils.core import (
get_example_database,
get_example_default_schema,
get_main_database,
AdhocMetricExpressionType,
)
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
@ -2033,3 +2035,58 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
self.assertEqual(
set(column for column in data[0].keys()), {"state", "name", "sum__num"}
)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_virtual_table_with_colons(self):
"""
Chart data API: test query with literal colon characters in query, metrics,
where clause and filters
"""
self.login(username="admin")
owner = self.get_user("admin").id
user = db.session.query(security_manager.user_model).get(owner)
table = SqlaTable(
table_name="virtual_table_1",
schema=get_example_default_schema(),
owners=[user],
database=get_example_database(),
sql="select ':foo' as foo, ':bar:' as bar, state, num from birth_names",
)
db.session.add(table)
db.session.commit()
table.fetch_metadata()
request_payload = get_query_context("birth_names")
request_payload["datasource"] = {
"type": "table",
"id": table.id,
}
request_payload["queries"][0]["columns"] = ["foo", "bar", "state"]
request_payload["queries"][0]["where"] = "':abc' != ':xyz:qwerty'"
request_payload["queries"][0]["orderby"] = None
request_payload["queries"][0]["metrics"] = [
{
"expressionType": AdhocMetricExpressionType.SQL,
"sqlExpression": "sum(case when state = ':asdf' then 0 else 1 end)",
"label": "count",
}
]
request_payload["queries"][0]["filters"] = [
{"col": "foo", "op": "!=", "val": ":qwerty:",}
]
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
db.session.delete(table)
db.session.commit()
assert rv.status_code == 200
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
data = result["data"]
assert {col for col in data[0].keys()} == {"foo", "bar", "state", "count"}
# make sure results and query parameters are unescaped
assert {row["foo"] for row in data} == {":foo"}
assert {row["bar"] for row in data} == {":bar:"}
assert "':asdf'" in result["query"]
assert "':xyz:qwerty'" in result["query"]
assert "':qwerty:'" in result["query"]