From 9f3d089655e314b1d1c6a547b573bb63dd0ab71b Mon Sep 17 00:00:00 2001 From: Ville Brofeldt <33317356+villebro@users.noreply.github.com> Date: Mon, 12 Oct 2020 15:11:43 +0300 Subject: [PATCH] chore(sqla): assert query is single read-only statement (#11236) --- superset/connectors/sqla/models.py | 9 +++++++ tests/sqla_models_tests.py | 42 ++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index df07ccad67..934da17c73 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -64,6 +64,7 @@ from superset.jinja_context import ( from superset.models.annotations import Annotation from superset.models.core import Database from superset.models.helpers import AuditMixinNullable, QueryResult +from superset.sql_parse import ParsedQuery from superset.typing import Metric, QueryObjectDict from superset.utils import core as utils, import_datasource @@ -755,6 +756,14 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at ) from_sql = sqlparse.format(from_sql, strip_comments=True) + if len(sqlparse.split(from_sql)) > 1: + raise QueryObjectValidationError( + _("Virtual dataset query cannot consist of multiple statements") + ) + if not ParsedQuery(from_sql).is_readonly(): + raise QueryObjectValidationError( + _("Virtual dataset query must be read-only") + ) return TextAsFrom(sa.text(from_sql), []).alias("expr_qry") return self.get_sqla_table() diff --git a/tests/sqla_models_tests.py b/tests/sqla_models_tests.py index 9e1884503d..de46e1c555 100644 --- a/tests/sqla_models_tests.py +++ b/tests/sqla_models_tests.py @@ -187,3 +187,45 @@ class TestDatabaseModel(SupersetTestCase): if get_example_database().backend != "presto": with pytest.raises(QueryObjectValidationError): table.get_sqla_query(**query_obj) + + def test_multiple_sql_statements_raises_exception(self): + base_query_obj = { + "granularity": None, + "from_dttm": None, + "to_dttm": None, + "groupby": ["grp"], + "metrics": [], + "is_timeseries": False, + "filter": [], + } + + table = SqlaTable( + table_name="test_has_extra_cache_keys_table", + sql="SELECT 'foo' as grp, 1 as num; SELECT 'bar' as grp, 2 as num", + database=get_example_database(), + ) + + query_obj = dict(**base_query_obj, extras={}) + with pytest.raises(QueryObjectValidationError): + table.get_sqla_query(**query_obj) + + def test_dml_statement_raises_exception(self): + base_query_obj = { + "granularity": None, + "from_dttm": None, + "to_dttm": None, + "groupby": ["grp"], + "metrics": [], + "is_timeseries": False, + "filter": [], + } + + table = SqlaTable( + table_name="test_has_extra_cache_keys_table", + sql="DELETE FROM foo", + database=get_example_database(), + ) + + query_obj = dict(**base_query_obj, extras={}) + with pytest.raises(QueryObjectValidationError): + table.get_sqla_query(**query_obj)