mirror of https://github.com/apache/superset.git
fix: avoid escaping bind-like params containing colons (#17419)
* fix: avoid escaping bind-like params containing colons * fix query for mysql * address comments
This commit is contained in:
parent
aa8040ec9b
commit
ad8a7c42f9
|
@ -65,7 +65,7 @@ from sqlalchemy.engine.base import Connection
|
|||
from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty, Session
|
||||
from sqlalchemy.orm.mapper import Mapper
|
||||
from sqlalchemy.schema import UniqueConstraint
|
||||
from sqlalchemy.sql import column, ColumnElement, literal_column, table, text
|
||||
from sqlalchemy.sql import column, ColumnElement, literal_column, table
|
||||
from sqlalchemy.sql.elements import ColumnClause, TextClause
|
||||
from sqlalchemy.sql.expression import Label, Select, TextAsFrom
|
||||
from sqlalchemy.sql.selectable import Alias, TableClause
|
||||
|
@ -103,6 +103,16 @@ logger = logging.getLogger(__name__)
|
|||
VIRTUAL_TABLE_ALIAS = "virtual_table"
|
||||
|
||||
|
||||
def text(clause: str) -> TextClause:
|
||||
"""
|
||||
SQLALchemy wrapper to ensure text clauses are escaped properly
|
||||
|
||||
:param clause: clause potentially containing colons
|
||||
:return: text clause with escaped colons
|
||||
"""
|
||||
return sa.text(clause.replace(":", "\\:"))
|
||||
|
||||
|
||||
class SqlaQuery(NamedTuple):
|
||||
applied_template_filters: List[str]
|
||||
extra_cache_keys: List[Any]
|
||||
|
@ -806,7 +816,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
raise QueryObjectValidationError(
|
||||
_("Virtual dataset query must be read-only")
|
||||
)
|
||||
return TextAsFrom(sa.text(from_sql), []).alias(VIRTUAL_TABLE_ALIAS)
|
||||
return TextAsFrom(text(from_sql), []).alias(VIRTUAL_TABLE_ALIAS)
|
||||
|
||||
def get_rendered_sql(
|
||||
self, template_processor: Optional[BaseTemplateProcessor] = None
|
||||
|
@ -827,8 +837,6 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
)
|
||||
) from ex
|
||||
sql = sqlparse.format(sql.strip("\t\r\n; "), strip_comments=True)
|
||||
# we need to escape strings that SQLAlchemy interprets as bind parameters
|
||||
sql = utils.escape_sqla_query_binds(sql)
|
||||
if not sql:
|
||||
raise QueryObjectValidationError(_("Virtual dataset query cannot be empty"))
|
||||
if len(sqlparse.split(sql)) > 1:
|
||||
|
@ -1286,7 +1294,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
msg=ex.message,
|
||||
)
|
||||
) from ex
|
||||
where_clause_and += [sa.text("({})".format(where))]
|
||||
where_clause_and += [text(f"({where})")]
|
||||
having = extras.get("having")
|
||||
if having:
|
||||
try:
|
||||
|
@ -1298,7 +1306,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
msg=ex.message,
|
||||
)
|
||||
) from ex
|
||||
having_clause_and += [sa.text("({})".format(having))]
|
||||
having_clause_and += [text(f"({having})")]
|
||||
if apply_fetch_values_predicate and self.fetch_values_predicate:
|
||||
qry = qry.where(self.get_fetch_values_predicate())
|
||||
if granularity:
|
||||
|
|
|
@ -81,7 +81,6 @@ from sqlalchemy import event, exc, inspect, select, Text
|
|||
from sqlalchemy.dialects.mysql import MEDIUMTEXT
|
||||
from sqlalchemy.engine import Connection, Engine
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.sql.elements import TextClause
|
||||
from sqlalchemy.sql.type_api import Variant
|
||||
from sqlalchemy.types import TEXT, TypeDecorator, TypeEngine
|
||||
from typing_extensions import TypedDict, TypeGuard
|
||||
|
@ -132,8 +131,6 @@ JS_MAX_INTEGER = 9007199254740991 # Largest int Java Script can handle 2^53-1
|
|||
|
||||
InputType = TypeVar("InputType")
|
||||
|
||||
BIND_PARAM_REGEX = TextClause._bind_params_regex # pylint: disable=protected-access
|
||||
|
||||
|
||||
class LenientEnum(Enum):
|
||||
"""Enums with a `get` method that convert a enum value to `Enum` if it is a
|
||||
|
@ -1771,29 +1768,3 @@ def apply_max_row_limit(limit: int, max_limit: Optional[int] = None,) -> int:
|
|||
if limit != 0:
|
||||
return min(max_limit, limit)
|
||||
return max_limit
|
||||
|
||||
|
||||
def escape_sqla_query_binds(sql: str) -> str:
|
||||
"""
|
||||
Replace strings in a query that SQLAlchemy would otherwise interpret as
|
||||
bind parameters.
|
||||
|
||||
:param sql: unescaped query string
|
||||
:return: escaped query string
|
||||
>>> escape_sqla_query_binds("select ':foo'")
|
||||
"select '\\\\:foo'"
|
||||
>>> escape_sqla_query_binds("select 'foo'::TIMESTAMP")
|
||||
"select 'foo'::TIMESTAMP"
|
||||
>>> escape_sqla_query_binds("select ':foo :bar'::TIMESTAMP")
|
||||
"select '\\\\:foo \\\\:bar'::TIMESTAMP"
|
||||
>>> escape_sqla_query_binds("select ':foo :foo :bar'::TIMESTAMP")
|
||||
"select '\\\\:foo \\\\:foo \\\\:bar'::TIMESTAMP"
|
||||
"""
|
||||
matches = BIND_PARAM_REGEX.finditer(sql)
|
||||
processed_binds = set()
|
||||
for match in matches:
|
||||
bind = match.group(0)
|
||||
if bind not in processed_binds:
|
||||
sql = sql.replace(bind, bind.replace(":", "\\:"))
|
||||
processed_binds.add(bind)
|
||||
return sql
|
||||
|
|
|
@ -20,7 +20,7 @@ import json
|
|||
import unittest
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
from unittest import mock
|
||||
from zipfile import is_zipfile, ZipFile
|
||||
|
||||
|
@ -42,6 +42,7 @@ from tests.integration_tests.fixtures.world_bank_dashboard import (
|
|||
load_world_bank_dashboard_with_slices,
|
||||
)
|
||||
from tests.integration_tests.test_app import app
|
||||
from superset import security_manager
|
||||
from superset.charts.commands.data import ChartDataCommand
|
||||
from superset.connectors.sqla.models import SqlaTable, TableColumn
|
||||
from superset.errors import SupersetErrorType
|
||||
|
@ -56,6 +57,7 @@ from superset.utils.core import (
|
|||
get_example_database,
|
||||
get_example_default_schema,
|
||||
get_main_database,
|
||||
AdhocMetricExpressionType,
|
||||
)
|
||||
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
|
||||
|
||||
|
@ -2033,3 +2035,58 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
|
|||
self.assertEqual(
|
||||
set(column for column in data[0].keys()), {"state", "name", "sum__num"}
|
||||
)
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_chart_data_virtual_table_with_colons(self):
|
||||
"""
|
||||
Chart data API: test query with literal colon characters in query, metrics,
|
||||
where clause and filters
|
||||
"""
|
||||
self.login(username="admin")
|
||||
owner = self.get_user("admin").id
|
||||
user = db.session.query(security_manager.user_model).get(owner)
|
||||
|
||||
table = SqlaTable(
|
||||
table_name="virtual_table_1",
|
||||
schema=get_example_default_schema(),
|
||||
owners=[user],
|
||||
database=get_example_database(),
|
||||
sql="select ':foo' as foo, ':bar:' as bar, state, num from birth_names",
|
||||
)
|
||||
db.session.add(table)
|
||||
db.session.commit()
|
||||
table.fetch_metadata()
|
||||
|
||||
request_payload = get_query_context("birth_names")
|
||||
request_payload["datasource"] = {
|
||||
"type": "table",
|
||||
"id": table.id,
|
||||
}
|
||||
request_payload["queries"][0]["columns"] = ["foo", "bar", "state"]
|
||||
request_payload["queries"][0]["where"] = "':abc' != ':xyz:qwerty'"
|
||||
request_payload["queries"][0]["orderby"] = None
|
||||
request_payload["queries"][0]["metrics"] = [
|
||||
{
|
||||
"expressionType": AdhocMetricExpressionType.SQL,
|
||||
"sqlExpression": "sum(case when state = ':asdf' then 0 else 1 end)",
|
||||
"label": "count",
|
||||
}
|
||||
]
|
||||
request_payload["queries"][0]["filters"] = [
|
||||
{"col": "foo", "op": "!=", "val": ":qwerty:",}
|
||||
]
|
||||
|
||||
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
|
||||
db.session.delete(table)
|
||||
db.session.commit()
|
||||
assert rv.status_code == 200
|
||||
response_payload = json.loads(rv.data.decode("utf-8"))
|
||||
result = response_payload["result"][0]
|
||||
data = result["data"]
|
||||
assert {col for col in data[0].keys()} == {"foo", "bar", "state", "count"}
|
||||
# make sure results and query parameters are unescaped
|
||||
assert {row["foo"] for row in data} == {":foo"}
|
||||
assert {row["bar"] for row in data} == {":bar:"}
|
||||
assert "':asdf'" in result["query"]
|
||||
assert "':xyz:qwerty'" in result["query"]
|
||||
assert "':qwerty:'" in result["query"]
|
||||
|
|
Loading…
Reference in New Issue