From 514eda82fbada573b99c5eba892f811ac50bb771 Mon Sep 17 00:00:00 2001 From: Maxime Beauchemin Date: Thu, 20 Jun 2024 14:30:03 -0700 Subject: [PATCH] fix: don't strip SQL comments in Explore - 2nd try (#28753) --- .../superset-python-integrationtest.yml | 2 +- superset/connectors/sqla/models.py | 2 +- superset/db_engine_specs/base.py | 5 +- superset/models/helpers.py | 8 +- tests/integration_tests/conftest.py | 81 ++++++++++++++----- tests/integration_tests/core_tests.py | 4 +- tests/integration_tests/datasource_tests.py | 8 +- .../integration_tests/query_context_tests.py | 31 +++++++ tests/integration_tests/sqla_models_tests.py | 2 +- 9 files changed, 109 insertions(+), 34 deletions(-) diff --git a/.github/workflows/superset-python-integrationtest.yml b/.github/workflows/superset-python-integrationtest.yml index 2569a471b3..3f43bef88c 100644 --- a/.github/workflows/superset-python-integrationtest.yml +++ b/.github/workflows/superset-python-integrationtest.yml @@ -24,7 +24,7 @@ jobs: mysql+mysqldb://superset:superset@127.0.0.1:13306/superset?charset=utf8mb4&binary_prefix=true services: mysql: - image: mysql:5.7 + image: mysql:8.0 env: MYSQL_ROOT_PASSWORD: root ports: diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 11ad95bb44..6d8d87a506 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -1450,7 +1450,7 @@ class SqlaTable( if not self.is_virtual: return self.get_sqla_table(), None - from_sql = self.get_rendered_sql(template_processor) + from_sql = self.get_rendered_sql(template_processor) + "\n" parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine) if not ( parsed_query.is_unknown() diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index a9a5cf7655..159c510fe9 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -1133,9 +1133,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods cte = None sql_remainder = None sql = sql.strip(" \t\n;") - sql_statement = sqlparse.format(sql, strip_comments=True) query_limit: int | None = sql_parse.extract_top_from_query( - sql_statement, cls.top_keywords + sql, cls.top_keywords ) if not limit: final_limit = query_limit @@ -1144,7 +1143,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods else: final_limit = limit if not cls.allows_cte_in_subquery: - cte, sql_remainder = sql_parse.get_cte_remainder_query(sql_statement) + cte, sql_remainder = sql_parse.get_cte_remainder_query(sql) if cte: str_statement = str(sql_remainder) cte = cte + "\n" diff --git a/superset/models/helpers.py b/superset/models/helpers.py index a044ff75e1..7b211f98b1 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -1070,8 +1070,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods """ Render sql with template engine (Jinja). """ - - sql = self.sql + sql = self.sql.strip("\t\r\n; ") if template_processor: try: sql = template_processor.process_template(sql) @@ -1083,13 +1082,12 @@ class ExploreMixin: # pylint: disable=too-many-public-methods ) ) from ex - script = SQLScript(sql.strip("\t\r\n; "), engine=self.db_engine_spec.engine) + script = SQLScript(sql, engine=self.db_engine_spec.engine) if len(script.statements) > 1: raise QueryObjectValidationError( _("Virtual dataset query cannot consist of multiple statements") ) - sql = script.statements[0].format(comments=False) if not sql: raise QueryObjectValidationError(_("Virtual dataset query cannot be empty")) return sql @@ -1106,7 +1104,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods CTE, the CTE is returned as the second value in the return tuple. """ - from_sql = self.get_rendered_sql(template_processor) + from_sql = self.get_rendered_sql(template_processor) + "\n" parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine) if not ( parsed_query.is_unknown() diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py index 84c5793105..f180da9aed 100644 --- a/tests/integration_tests/conftest.py +++ b/tests/integration_tests/conftest.py @@ -19,6 +19,7 @@ from __future__ import annotations import contextlib import functools import os +from textwrap import dedent from typing import Any, Callable, TYPE_CHECKING from unittest.mock import patch @@ -295,25 +296,67 @@ def virtual_dataset(): dataset = SqlaTable( table_name="virtual_dataset", sql=( - "SELECT 0 as col1, 'a' as col2, 1.0 as col3, NULL as col4, '2000-01-01 00:00:00' as col5, 1 as col6 " - "UNION ALL " - "SELECT 1, 'b', 1.1, NULL, '2000-01-02 00:00:00', NULL " - "UNION ALL " - "SELECT 2 as col1, 'c' as col2, 1.2, NULL, '2000-01-03 00:00:00', 3 " - "UNION ALL " - "SELECT 3 as col1, 'd' as col2, 1.3, NULL, '2000-01-04 00:00:00', 4 " - "UNION ALL " - "SELECT 4 as col1, 'e' as col2, 1.4, NULL, '2000-01-05 00:00:00', 5 " - "UNION ALL " - "SELECT 5 as col1, 'f' as col2, 1.5, NULL, '2000-01-06 00:00:00', 6 " - "UNION ALL " - "SELECT 6 as col1, 'g' as col2, 1.6, NULL, '2000-01-07 00:00:00', 7 " - "UNION ALL " - "SELECT 7 as col1, 'h' as col2, 1.7, NULL, '2000-01-08 00:00:00', 8 " - "UNION ALL " - "SELECT 8 as col1, 'i' as col2, 1.8, NULL, '2000-01-09 00:00:00', 9 " - "UNION ALL " - "SELECT 9 as col1, 'j' as col2, 1.9, NULL, '2000-01-10 00:00:00', 10" + dedent("""\ + SELECT 0 as col1, 'a' as col2, 1.0 as col3, NULL as col4, '2000-01-01 00:00:00' as col5, 1 as col6 + UNION ALL + SELECT 1, 'b', 1.1, NULL, '2000-01-02 00:00:00', NULL + UNION ALL + SELECT 2 as col1, 'c' as col2, 1.2, NULL, '2000-01-03 00:00:00', 3 + UNION ALL + SELECT 3 as col1, 'd' as col2, 1.3, NULL, '2000-01-04 00:00:00', 4 + UNION ALL + SELECT 4 as col1, 'e' as col2, 1.4, NULL, '2000-01-05 00:00:00', 5 + UNION ALL + SELECT 5 as col1, 'f' as col2, 1.5, NULL, '2000-01-06 00:00:00', 6 + UNION ALL + SELECT 6 as col1, 'g' as col2, 1.6, NULL, '2000-01-07 00:00:00', 7 + UNION ALL + SELECT 7 as col1, 'h' as col2, 1.7, NULL, '2000-01-08 00:00:00', 8 + UNION ALL + SELECT 8 as col1, 'i' as col2, 1.8, NULL, '2000-01-09 00:00:00', 9 + UNION ALL + SELECT 9 as col1, 'j' as col2, 1.9, NULL, '2000-01-10 00:00:00', 10 + """) + ), + database=get_example_database(), + ) + TableColumn(column_name="col1", type="INTEGER", table=dataset) + TableColumn(column_name="col2", type="VARCHAR(255)", table=dataset) + TableColumn(column_name="col3", type="DECIMAL(4,2)", table=dataset) + TableColumn(column_name="col4", type="VARCHAR(255)", table=dataset) + # Different database dialect datetime type is not consistent, so temporarily use varchar + TableColumn(column_name="col5", type="VARCHAR(255)", table=dataset) + TableColumn(column_name="col6", type="INTEGER", table=dataset) + + SqlMetric(metric_name="count", expression="count(*)", table=dataset) + db.session.add(dataset) + db.session.commit() + + yield dataset + + db.session.delete(dataset) + db.session.commit() + + +@pytest.fixture +def virtual_dataset_with_comments(): + from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn + + dataset = SqlaTable( + table_name="virtual_dataset_with_comments", + sql=( + dedent("""\ + --COMMENT + /*COMMENT*/ + WITH cte as (--COMMENT + SELECT 2 as col1, /*COMMENT*/'j' as col2, 1.9, NULL, '2000-01-10 00:00:00', 10 + ) + SELECT 0 as col1, 'a' as col2, 1.0 as col3, NULL as col4, '2000-01-01 00:00:00' as col5, 1 as col6 + \n /* COMMENT */ \n + UNION ALL/*COMMENT*/ + SELECT 1 as col1, 'f' as col2, 1.5, NULL, '2000-01-06 00:00:00', 6 --COMMENT + UNION ALL--COMMENT + SELECT * FROM cte --COMMENT""") ), database=get_example_database(), ) diff --git a/tests/integration_tests/core_tests.py b/tests/integration_tests/core_tests.py index d085beba78..9166d54958 100644 --- a/tests/integration_tests/core_tests.py +++ b/tests/integration_tests/core_tests.py @@ -539,7 +539,9 @@ class TestCore(SupersetTestCase): database=get_example_database(), ) rendered_query = str(table.get_from_clause()[0]) - self.assertEqual(clean_query, rendered_query) + assert "comment 1" in rendered_query + assert "comment 2" in rendered_query + assert "FROM tbl" in rendered_query def test_slice_payload_no_datasource(self): form_data = { diff --git a/tests/integration_tests/datasource_tests.py b/tests/integration_tests/datasource_tests.py index 1b7fcb733b..718b6d2d98 100644 --- a/tests/integration_tests/datasource_tests.py +++ b/tests/integration_tests/datasource_tests.py @@ -538,10 +538,12 @@ def test_get_samples(test_client, login_as_admin, virtual_dataset): assert "coltypes" in rv2.json["result"] assert "data" in rv2.json["result"] - eager_samples = virtual_dataset.database.get_df( - f"select * from ({virtual_dataset.sql}) as tbl" - f' limit {app.config["SAMPLES_ROW_LIMIT"]}' + sql = ( + f"select * from ({virtual_dataset.sql}) as tbl " + f'limit {app.config["SAMPLES_ROW_LIMIT"]}' ) + eager_samples = virtual_dataset.database.get_df(sql) + # the col3 is Decimal eager_samples["col3"] = eager_samples["col3"].apply(float) eager_samples = eager_samples.to_dict(orient="records") diff --git a/tests/integration_tests/query_context_tests.py b/tests/integration_tests/query_context_tests.py index 9c18b5e07c..2fcd6d2048 100644 --- a/tests/integration_tests/query_context_tests.py +++ b/tests/integration_tests/query_context_tests.py @@ -1168,3 +1168,34 @@ OFFSET 0 re.search(r"WHERE\n col6 >= .*2001-10-01", sqls[1]) and re.search(r"AND col6 < .*2002-10-01", sqls[1]) ) is not None + + +def test_virtual_dataset_with_comments(app_context, virtual_dataset_with_comments): + qc = QueryContextFactory().create( + datasource={ + "type": virtual_dataset_with_comments.type, + "id": virtual_dataset_with_comments.id, + }, + queries=[ + { + "columns": ["col1", "col2"], + "metrics": ["count"], + "post_processing": [ + { + "operation": "pivot", + "options": { + "aggregates": {"count": {"operator": "mean"}}, + "columns": ["col2"], + "index": ["col1"], + }, + }, + {"operation": "flatten"}, + ], + } + ], + result_type=ChartDataResultType.FULL, + force=True, + ) + query_object = qc.queries[0] + df = qc.get_df_payload(query_object)["df"] + assert len(df) == 3 diff --git a/tests/integration_tests/sqla_models_tests.py b/tests/integration_tests/sqla_models_tests.py index f0fa70bc02..f5569b1c83 100644 --- a/tests/integration_tests/sqla_models_tests.py +++ b/tests/integration_tests/sqla_models_tests.py @@ -253,7 +253,7 @@ class TestDatabaseModel(SupersetTestCase): query = table.database.compile_sqla_query(sqla_query.sqla_query) # assert virtual dataset - assert "SELECT\n 'user_abc' AS user,\n 'xyz_P1D' AS time_grain" in query + assert "SELECT 'user_abc' as user, 'xyz_P1D' as time_grain" in query # assert dataset calculated column assert "case when 'abc' = 'abc' then 'yes' else 'no' end" in query # assert adhoc column