fix: avoid escaping bind-like params containing colons (#17419)

* fix: avoid escaping bind-like params containing colons

* fix query for mysql

* address comments
This commit is contained in:
Ville Brofeldt 2021-11-13 09:01:49 +02:00 committed by GitHub
parent aa8040ec9b
commit ad8a7c42f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 72 additions and 36 deletions

View File

@ -65,7 +65,7 @@ from sqlalchemy.engine.base import Connection
from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty, Session from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty, Session
from sqlalchemy.orm.mapper import Mapper from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.schema import UniqueConstraint from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.sql import column, ColumnElement, literal_column, table, text from sqlalchemy.sql import column, ColumnElement, literal_column, table
from sqlalchemy.sql.elements import ColumnClause, TextClause from sqlalchemy.sql.elements import ColumnClause, TextClause
from sqlalchemy.sql.expression import Label, Select, TextAsFrom from sqlalchemy.sql.expression import Label, Select, TextAsFrom
from sqlalchemy.sql.selectable import Alias, TableClause from sqlalchemy.sql.selectable import Alias, TableClause
@ -103,6 +103,16 @@ logger = logging.getLogger(__name__)
VIRTUAL_TABLE_ALIAS = "virtual_table" VIRTUAL_TABLE_ALIAS = "virtual_table"
def text(clause: str) -> TextClause:
"""
SQLALchemy wrapper to ensure text clauses are escaped properly
:param clause: clause potentially containing colons
:return: text clause with escaped colons
"""
return sa.text(clause.replace(":", "\\:"))
class SqlaQuery(NamedTuple): class SqlaQuery(NamedTuple):
applied_template_filters: List[str] applied_template_filters: List[str]
extra_cache_keys: List[Any] extra_cache_keys: List[Any]
@ -806,7 +816,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
raise QueryObjectValidationError( raise QueryObjectValidationError(
_("Virtual dataset query must be read-only") _("Virtual dataset query must be read-only")
) )
return TextAsFrom(sa.text(from_sql), []).alias(VIRTUAL_TABLE_ALIAS) return TextAsFrom(text(from_sql), []).alias(VIRTUAL_TABLE_ALIAS)
def get_rendered_sql( def get_rendered_sql(
self, template_processor: Optional[BaseTemplateProcessor] = None self, template_processor: Optional[BaseTemplateProcessor] = None
@ -827,8 +837,6 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
) )
) from ex ) from ex
sql = sqlparse.format(sql.strip("\t\r\n; "), strip_comments=True) sql = sqlparse.format(sql.strip("\t\r\n; "), strip_comments=True)
# we need to escape strings that SQLAlchemy interprets as bind parameters
sql = utils.escape_sqla_query_binds(sql)
if not sql: if not sql:
raise QueryObjectValidationError(_("Virtual dataset query cannot be empty")) raise QueryObjectValidationError(_("Virtual dataset query cannot be empty"))
if len(sqlparse.split(sql)) > 1: if len(sqlparse.split(sql)) > 1:
@ -1286,7 +1294,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
msg=ex.message, msg=ex.message,
) )
) from ex ) from ex
where_clause_and += [sa.text("({})".format(where))] where_clause_and += [text(f"({where})")]
having = extras.get("having") having = extras.get("having")
if having: if having:
try: try:
@ -1298,7 +1306,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
msg=ex.message, msg=ex.message,
) )
) from ex ) from ex
having_clause_and += [sa.text("({})".format(having))] having_clause_and += [text(f"({having})")]
if apply_fetch_values_predicate and self.fetch_values_predicate: if apply_fetch_values_predicate and self.fetch_values_predicate:
qry = qry.where(self.get_fetch_values_predicate()) qry = qry.where(self.get_fetch_values_predicate())
if granularity: if granularity:

View File

@ -81,7 +81,6 @@ from sqlalchemy import event, exc, inspect, select, Text
from sqlalchemy.dialects.mysql import MEDIUMTEXT from sqlalchemy.dialects.mysql import MEDIUMTEXT
from sqlalchemy.engine import Connection, Engine from sqlalchemy.engine import Connection, Engine
from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.type_api import Variant from sqlalchemy.sql.type_api import Variant
from sqlalchemy.types import TEXT, TypeDecorator, TypeEngine from sqlalchemy.types import TEXT, TypeDecorator, TypeEngine
from typing_extensions import TypedDict, TypeGuard from typing_extensions import TypedDict, TypeGuard
@ -132,8 +131,6 @@ JS_MAX_INTEGER = 9007199254740991 # Largest int Java Script can handle 2^53-1
InputType = TypeVar("InputType") InputType = TypeVar("InputType")
BIND_PARAM_REGEX = TextClause._bind_params_regex # pylint: disable=protected-access
class LenientEnum(Enum): class LenientEnum(Enum):
"""Enums with a `get` method that convert a enum value to `Enum` if it is a """Enums with a `get` method that convert a enum value to `Enum` if it is a
@ -1771,29 +1768,3 @@ def apply_max_row_limit(limit: int, max_limit: Optional[int] = None,) -> int:
if limit != 0: if limit != 0:
return min(max_limit, limit) return min(max_limit, limit)
return max_limit return max_limit
def escape_sqla_query_binds(sql: str) -> str:
"""
Replace strings in a query that SQLAlchemy would otherwise interpret as
bind parameters.
:param sql: unescaped query string
:return: escaped query string
>>> escape_sqla_query_binds("select ':foo'")
"select '\\\\:foo'"
>>> escape_sqla_query_binds("select 'foo'::TIMESTAMP")
"select 'foo'::TIMESTAMP"
>>> escape_sqla_query_binds("select ':foo :bar'::TIMESTAMP")
"select '\\\\:foo \\\\:bar'::TIMESTAMP"
>>> escape_sqla_query_binds("select ':foo :foo :bar'::TIMESTAMP")
"select '\\\\:foo \\\\:foo \\\\:bar'::TIMESTAMP"
"""
matches = BIND_PARAM_REGEX.finditer(sql)
processed_binds = set()
for match in matches:
bind = match.group(0)
if bind not in processed_binds:
sql = sql.replace(bind, bind.replace(":", "\\:"))
processed_binds.add(bind)
return sql

View File

@ -20,7 +20,7 @@ import json
import unittest import unittest
from datetime import datetime from datetime import datetime
from io import BytesIO from io import BytesIO
from typing import Optional from typing import Optional, List
from unittest import mock from unittest import mock
from zipfile import is_zipfile, ZipFile from zipfile import is_zipfile, ZipFile
@ -42,6 +42,7 @@ from tests.integration_tests.fixtures.world_bank_dashboard import (
load_world_bank_dashboard_with_slices, load_world_bank_dashboard_with_slices,
) )
from tests.integration_tests.test_app import app from tests.integration_tests.test_app import app
from superset import security_manager
from superset.charts.commands.data import ChartDataCommand from superset.charts.commands.data import ChartDataCommand
from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.errors import SupersetErrorType from superset.errors import SupersetErrorType
@ -56,6 +57,7 @@ from superset.utils.core import (
get_example_database, get_example_database,
get_example_default_schema, get_example_default_schema,
get_main_database, get_main_database,
AdhocMetricExpressionType,
) )
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
@ -2033,3 +2035,58 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
self.assertEqual( self.assertEqual(
set(column for column in data[0].keys()), {"state", "name", "sum__num"} set(column for column in data[0].keys()), {"state", "name", "sum__num"}
) )
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_virtual_table_with_colons(self):
"""
Chart data API: test query with literal colon characters in query, metrics,
where clause and filters
"""
self.login(username="admin")
owner = self.get_user("admin").id
user = db.session.query(security_manager.user_model).get(owner)
table = SqlaTable(
table_name="virtual_table_1",
schema=get_example_default_schema(),
owners=[user],
database=get_example_database(),
sql="select ':foo' as foo, ':bar:' as bar, state, num from birth_names",
)
db.session.add(table)
db.session.commit()
table.fetch_metadata()
request_payload = get_query_context("birth_names")
request_payload["datasource"] = {
"type": "table",
"id": table.id,
}
request_payload["queries"][0]["columns"] = ["foo", "bar", "state"]
request_payload["queries"][0]["where"] = "':abc' != ':xyz:qwerty'"
request_payload["queries"][0]["orderby"] = None
request_payload["queries"][0]["metrics"] = [
{
"expressionType": AdhocMetricExpressionType.SQL,
"sqlExpression": "sum(case when state = ':asdf' then 0 else 1 end)",
"label": "count",
}
]
request_payload["queries"][0]["filters"] = [
{"col": "foo", "op": "!=", "val": ":qwerty:",}
]
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
db.session.delete(table)
db.session.commit()
assert rv.status_code == 200
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
data = result["data"]
assert {col for col in data[0].keys()} == {"foo", "bar", "state", "count"}
# make sure results and query parameters are unescaped
assert {row["foo"] for row in data} == {":foo"}
assert {row["bar"] for row in data} == {":bar:"}
assert "':asdf'" in result["query"]
assert "':xyz:qwerty'" in result["query"]
assert "':qwerty:'" in result["query"]