mirror of
https://github.com/apache/superset.git
synced 2024-09-17 11:09:47 -04:00
fix: escape bind-like strings in virtual table query (#17111)
This commit is contained in:
parent
83a2f8346e
commit
434b5767c9
@ -65,8 +65,8 @@ from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty, S
|
|||||||
from sqlalchemy.orm.mapper import Mapper
|
from sqlalchemy.orm.mapper import Mapper
|
||||||
from sqlalchemy.schema import UniqueConstraint
|
from sqlalchemy.schema import UniqueConstraint
|
||||||
from sqlalchemy.sql import column, ColumnElement, literal_column, table, text
|
from sqlalchemy.sql import column, ColumnElement, literal_column, table, text
|
||||||
from sqlalchemy.sql.elements import ColumnClause
|
from sqlalchemy.sql.elements import ColumnClause, TextClause
|
||||||
from sqlalchemy.sql.expression import Label, Select, TextAsFrom, TextClause
|
from sqlalchemy.sql.expression import Label, Select, TextAsFrom
|
||||||
from sqlalchemy.sql.selectable import Alias, TableClause
|
from sqlalchemy.sql.selectable import Alias, TableClause
|
||||||
|
|
||||||
from superset import app, db, is_feature_enabled, security_manager
|
from superset import app, db, is_feature_enabled, security_manager
|
||||||
@ -809,6 +809,8 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||||||
)
|
)
|
||||||
) from ex
|
) from ex
|
||||||
sql = sqlparse.format(sql.strip("\t\r\n; "), strip_comments=True)
|
sql = sqlparse.format(sql.strip("\t\r\n; "), strip_comments=True)
|
||||||
|
# we need to escape strings that SQLAlchemy interprets as bind parameters
|
||||||
|
sql = utils.escape_sqla_query_binds(sql)
|
||||||
if not sql:
|
if not sql:
|
||||||
raise QueryObjectValidationError(_("Virtual dataset query cannot be empty"))
|
raise QueryObjectValidationError(_("Virtual dataset query cannot be empty"))
|
||||||
if len(sqlparse.split(sql)) > 1:
|
if len(sqlparse.split(sql)) > 1:
|
||||||
|
@ -80,6 +80,7 @@ from sqlalchemy import event, exc, select, Text
|
|||||||
from sqlalchemy.dialects.mysql import MEDIUMTEXT
|
from sqlalchemy.dialects.mysql import MEDIUMTEXT
|
||||||
from sqlalchemy.engine import Connection, Engine
|
from sqlalchemy.engine import Connection, Engine
|
||||||
from sqlalchemy.engine.reflection import Inspector
|
from sqlalchemy.engine.reflection import Inspector
|
||||||
|
from sqlalchemy.sql.elements import TextClause
|
||||||
from sqlalchemy.sql.type_api import Variant
|
from sqlalchemy.sql.type_api import Variant
|
||||||
from sqlalchemy.types import TEXT, TypeDecorator, TypeEngine
|
from sqlalchemy.types import TEXT, TypeDecorator, TypeEngine
|
||||||
from typing_extensions import TypedDict, TypeGuard
|
from typing_extensions import TypedDict, TypeGuard
|
||||||
@ -131,6 +132,8 @@ JS_MAX_INTEGER = 9007199254740991 # Largest int Java Script can handle 2^53-1
|
|||||||
|
|
||||||
InputType = TypeVar("InputType")
|
InputType = TypeVar("InputType")
|
||||||
|
|
||||||
|
BIND_PARAM_REGEX = TextClause._bind_params_regex # pylint: disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
class LenientEnum(Enum):
|
class LenientEnum(Enum):
|
||||||
"""Enums with a `get` method that convert a enum value to `Enum` if it is a
|
"""Enums with a `get` method that convert a enum value to `Enum` if it is a
|
||||||
@ -1784,3 +1787,29 @@ def apply_max_row_limit(limit: int, max_limit: Optional[int] = None,) -> int:
|
|||||||
if limit != 0:
|
if limit != 0:
|
||||||
return min(max_limit, limit)
|
return min(max_limit, limit)
|
||||||
return max_limit
|
return max_limit
|
||||||
|
|
||||||
|
|
||||||
|
def escape_sqla_query_binds(sql: str) -> str:
|
||||||
|
"""
|
||||||
|
Replace strings in a query that SQLAlchemy would otherwise interpret as
|
||||||
|
bind parameters.
|
||||||
|
|
||||||
|
:param sql: unescaped query string
|
||||||
|
:return: escaped query string
|
||||||
|
>>> escape_sqla_query_binds("select ':foo'")
|
||||||
|
"select '\\\\:foo'"
|
||||||
|
>>> escape_sqla_query_binds("select 'foo'::TIMESTAMP")
|
||||||
|
"select 'foo'::TIMESTAMP"
|
||||||
|
>>> escape_sqla_query_binds("select ':foo :bar'::TIMESTAMP")
|
||||||
|
"select '\\\\:foo \\\\:bar'::TIMESTAMP"
|
||||||
|
>>> escape_sqla_query_binds("select ':foo :foo :bar'::TIMESTAMP")
|
||||||
|
"select '\\\\:foo \\\\:foo \\\\:bar'::TIMESTAMP"
|
||||||
|
"""
|
||||||
|
matches = BIND_PARAM_REGEX.finditer(sql)
|
||||||
|
processed_binds = set()
|
||||||
|
for match in matches:
|
||||||
|
bind = match.group(0)
|
||||||
|
if bind not in processed_binds:
|
||||||
|
sql = sql.replace(bind, bind.replace(":", "\\:"))
|
||||||
|
processed_binds.add(bind)
|
||||||
|
return sql
|
||||||
|
Loading…
Reference in New Issue
Block a user