From 434b5767c910d984e2b39655999f96afd00b84a6 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt <33317356+villebro@users.noreply.github.com> Date: Thu, 14 Oct 2021 17:59:28 +0200 Subject: [PATCH] fix: escape bind-like strings in virtual table query (#17111) --- superset/connectors/sqla/models.py | 6 ++++-- superset/utils/core.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index a27a4a1631..0f6b705278 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -65,8 +65,8 @@ from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty, S from sqlalchemy.orm.mapper import Mapper from sqlalchemy.schema import UniqueConstraint from sqlalchemy.sql import column, ColumnElement, literal_column, table, text -from sqlalchemy.sql.elements import ColumnClause -from sqlalchemy.sql.expression import Label, Select, TextAsFrom, TextClause +from sqlalchemy.sql.elements import ColumnClause, TextClause +from sqlalchemy.sql.expression import Label, Select, TextAsFrom from sqlalchemy.sql.selectable import Alias, TableClause from superset import app, db, is_feature_enabled, security_manager @@ -809,6 +809,8 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho ) ) from ex sql = sqlparse.format(sql.strip("\t\r\n; "), strip_comments=True) + # we need to escape strings that SQLAlchemy interprets as bind parameters + sql = utils.escape_sqla_query_binds(sql) if not sql: raise QueryObjectValidationError(_("Virtual dataset query cannot be empty")) if len(sqlparse.split(sql)) > 1: diff --git a/superset/utils/core.py b/superset/utils/core.py index 71d59348c1..9f30c64d91 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -80,6 +80,7 @@ from sqlalchemy import event, exc, select, Text from sqlalchemy.dialects.mysql import MEDIUMTEXT from sqlalchemy.engine import Connection, Engine from sqlalchemy.engine.reflection import Inspector +from sqlalchemy.sql.elements import TextClause from sqlalchemy.sql.type_api import Variant from sqlalchemy.types import TEXT, TypeDecorator, TypeEngine from typing_extensions import TypedDict, TypeGuard @@ -131,6 +132,8 @@ JS_MAX_INTEGER = 9007199254740991 # Largest int Java Script can handle 2^53-1 InputType = TypeVar("InputType") +BIND_PARAM_REGEX = TextClause._bind_params_regex # pylint: disable=protected-access + class LenientEnum(Enum): """Enums with a `get` method that convert a enum value to `Enum` if it is a @@ -1784,3 +1787,29 @@ def apply_max_row_limit(limit: int, max_limit: Optional[int] = None,) -> int: if limit != 0: return min(max_limit, limit) return max_limit + + +def escape_sqla_query_binds(sql: str) -> str: + """ + Replace strings in a query that SQLAlchemy would otherwise interpret as + bind parameters. + + :param sql: unescaped query string + :return: escaped query string + >>> escape_sqla_query_binds("select ':foo'") + "select '\\\\:foo'" + >>> escape_sqla_query_binds("select 'foo'::TIMESTAMP") + "select 'foo'::TIMESTAMP" + >>> escape_sqla_query_binds("select ':foo :bar'::TIMESTAMP") + "select '\\\\:foo \\\\:bar'::TIMESTAMP" + >>> escape_sqla_query_binds("select ':foo :foo :bar'::TIMESTAMP") + "select '\\\\:foo \\\\:foo \\\\:bar'::TIMESTAMP" + """ + matches = BIND_PARAM_REGEX.finditer(sql) + processed_binds = set() + for match in matches: + bind = match.group(0) + if bind not in processed_binds: + sql = sql.replace(bind, bind.replace(":", "\\:")) + processed_binds.add(bind) + return sql