chore(sqla): assert query is single read-only statement (#11236)

This commit is contained in:
Ville Brofeldt 2020-10-12 15:11:43 +03:00 committed by GitHub
parent e647286393
commit 9f3d089655
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 51 additions and 0 deletions

View File

@ -64,6 +64,7 @@ from superset.jinja_context import (
from superset.models.annotations import Annotation
from superset.models.core import Database
from superset.models.helpers import AuditMixinNullable, QueryResult
from superset.sql_parse import ParsedQuery
from superset.typing import Metric, QueryObjectDict
from superset.utils import core as utils, import_datasource
@ -755,6 +756,14 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
)
from_sql = sqlparse.format(from_sql, strip_comments=True)
if len(sqlparse.split(from_sql)) > 1:
raise QueryObjectValidationError(
_("Virtual dataset query cannot consist of multiple statements")
)
if not ParsedQuery(from_sql).is_readonly():
raise QueryObjectValidationError(
_("Virtual dataset query must be read-only")
)
return TextAsFrom(sa.text(from_sql), []).alias("expr_qry")
return self.get_sqla_table()

View File

@ -187,3 +187,45 @@ class TestDatabaseModel(SupersetTestCase):
if get_example_database().backend != "presto":
with pytest.raises(QueryObjectValidationError):
table.get_sqla_query(**query_obj)
def test_multiple_sql_statements_raises_exception(self):
base_query_obj = {
"granularity": None,
"from_dttm": None,
"to_dttm": None,
"groupby": ["grp"],
"metrics": [],
"is_timeseries": False,
"filter": [],
}
table = SqlaTable(
table_name="test_has_extra_cache_keys_table",
sql="SELECT 'foo' as grp, 1 as num; SELECT 'bar' as grp, 2 as num",
database=get_example_database(),
)
query_obj = dict(**base_query_obj, extras={})
with pytest.raises(QueryObjectValidationError):
table.get_sqla_query(**query_obj)
def test_dml_statement_raises_exception(self):
base_query_obj = {
"granularity": None,
"from_dttm": None,
"to_dttm": None,
"groupby": ["grp"],
"metrics": [],
"is_timeseries": False,
"filter": [],
}
table = SqlaTable(
table_name="test_has_extra_cache_keys_table",
sql="DELETE FROM foo",
database=get_example_database(),
)
query_obj = dict(**base_query_obj, extras={})
with pytest.raises(QueryObjectValidationError):
table.get_sqla_query(**query_obj)