fix(mssql): support cte in virtual tables (#18567)

* Fix for handling regular CTE queries with MSSQL,#8074

* Moved the get_cte_query function from mssql.py to base.py for using irrespetcive of dbengine

* Fix for handling regular CTE queries with MSSQL,#8074

* Moved the get_cte_query function from mssql.py to base.py for using irrespetcive of dbengine

* Unit test added for the db engine CTE SQL parsing.

Unit test added for the db engine CTE SQL parsing.  Removed additional spaces from the CTE parsing SQL generation.

* implement in sqla model

* lint + cleanup

Co-authored-by: Ville Brofeldt <ville.v.brofeldt@gmail.com>
This commit is contained in:
Sujith Kumar S 2022-02-10 13:58:05 +05:30 committed by GitHub
parent 00eb6b1f57
commit b8aef10098
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 165 additions and 13 deletions

View File

@ -77,7 +77,7 @@ from superset.connectors.sqla.utils import (
get_physical_table_metadata,
get_virtual_table_metadata,
)
from superset.db_engine_specs.base import BaseEngineSpec, TimestampExpression
from superset.db_engine_specs.base import BaseEngineSpec, CTE_ALIAS, TimestampExpression
from superset.exceptions import QueryObjectValidationError
from superset.jinja_context import (
BaseTemplateProcessor,
@ -107,6 +107,7 @@ VIRTUAL_TABLE_ALIAS = "virtual_table"
class SqlaQuery(NamedTuple):
applied_template_filters: List[str]
cte: Optional[str]
extra_cache_keys: List[Any]
labels_expected: List[str]
prequeries: List[str]
@ -562,6 +563,19 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
def __repr__(self) -> str:
return self.name
@staticmethod
def _apply_cte(sql: str, cte: Optional[str]) -> str:
"""
Append a CTE before the SELECT statement if defined
:param sql: SELECT statement
:param cte: CTE statement
:return:
"""
if cte:
sql = f"{cte}\n{sql}"
return sql
@property
def db_engine_spec(self) -> Type[BaseEngineSpec]:
return self.database.db_engine_spec
@ -743,12 +757,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
cols = {col.column_name: col for col in self.columns}
target_col = cols[column_name]
tp = self.get_template_processor()
tbl, cte = self.get_from_clause(tp)
qry = (
select([target_col.get_sqla_col()])
.select_from(self.get_from_clause(tp))
.distinct()
)
qry = select([target_col.get_sqla_col()]).select_from(tbl).distinct()
if limit:
qry = qry.limit(limit)
@ -756,7 +767,8 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
qry = qry.where(self.get_fetch_values_predicate())
engine = self.database.get_sqla_engine()
sql = "{}".format(qry.compile(engine, compile_kwargs={"literal_binds": True}))
sql = qry.compile(engine, compile_kwargs={"literal_binds": True})
sql = self._apply_cte(sql, cte)
sql = self.mutate_query_from_config(sql)
df = pd.read_sql_query(sql=sql, con=engine)
@ -778,6 +790,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
def get_query_str_extended(self, query_obj: QueryObjectDict) -> QueryStringExtended:
sqlaq = self.get_sqla_query(**query_obj)
sql = self.database.compile_sqla_query(sqlaq.sqla_query)
sql = self._apply_cte(sql, sqlaq.cte)
sql = sqlparse.format(sql, reindent=True)
sql = self.mutate_query_from_config(sql)
return QueryStringExtended(
@ -800,13 +813,14 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
def get_from_clause(
self, template_processor: Optional[BaseTemplateProcessor] = None
) -> Union[TableClause, Alias]:
) -> Tuple[Union[TableClause, Alias], Optional[str]]:
"""
Return where to select the columns and metrics from. Either a physical table
or a virtual table with it's own subquery.
or a virtual table with it's own subquery. If the FROM is referencing a
CTE, the CTE is returned as the second value in the return tuple.
"""
if not self.is_virtual:
return self.get_sqla_table()
return self.get_sqla_table(), None
from_sql = self.get_rendered_sql(template_processor)
parsed_query = ParsedQuery(from_sql)
@ -817,7 +831,15 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
raise QueryObjectValidationError(
_("Virtual dataset query must be read-only")
)
return TextAsFrom(self.text(from_sql), []).alias(VIRTUAL_TABLE_ALIAS)
cte = self.db_engine_spec.get_cte_query(from_sql)
from_clause = (
table(CTE_ALIAS)
if cte
else TextAsFrom(self.text(from_sql), []).alias(VIRTUAL_TABLE_ALIAS)
)
return from_clause, cte
def get_rendered_sql(
self, template_processor: Optional[BaseTemplateProcessor] = None
@ -1224,7 +1246,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
qry = sa.select(select_exprs)
tbl = self.get_from_clause(template_processor)
tbl, cte = self.get_from_clause(template_processor)
if groupby_all_columns:
qry = qry.group_by(*groupby_all_columns.values())
@ -1491,6 +1513,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
return SqlaQuery(
applied_template_filters=applied_template_filters,
cte=cte,
extra_cache_keys=extra_cache_keys,
labels_expected=labels_expected,
sqla_query=qry,

View File

@ -54,6 +54,7 @@ from sqlalchemy.orm import Session
from sqlalchemy.sql import quoted_name, text
from sqlalchemy.sql.expression import ColumnClause, Select, TextAsFrom, TextClause
from sqlalchemy.types import TypeEngine
from sqlparse.tokens import CTE
from typing_extensions import TypedDict
from superset import security_manager, sql_parse
@ -80,6 +81,9 @@ ColumnTypeMapping = Tuple[
logger = logging.getLogger()
CTE_ALIAS = "__cte"
class TimeGrain(NamedTuple):
name: str # TODO: redundant field, remove
label: str
@ -292,6 +296,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
# But for backward compatibility, False by default
allows_hidden_cc_in_orderby = False
# Whether allow CTE as subquery or regular CTE
# If True, then it will allow in subquery ,
# if False it will allow as regular CTE
allows_cte_in_subquery = True
force_column_alias_quotes = False
arraysize = 0
max_column_name_length = 0
@ -663,6 +672,31 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
parsed_query = sql_parse.ParsedQuery(sql)
return parsed_query.set_or_update_query_limit(limit)
@classmethod
def get_cte_query(cls, sql: str) -> Optional[str]:
"""
Convert the input CTE based SQL to the SQL for virtual table conversion
:param sql: SQL query
:return: CTE with the main select query aliased as `__cte`
"""
if not cls.allows_cte_in_subquery:
stmt = sqlparse.parse(sql)[0]
# The first meaningful token for CTE will be with WITH
idx, token = stmt.token_next(-1, skip_ws=True, skip_cm=True)
if not (token and token.ttype == CTE):
return None
idx, token = stmt.token_next(idx)
idx = stmt.token_index(token) + 1
# extract rest of the SQLs after CTE
remainder = "".join(str(token) for token in stmt.tokens[idx:]).strip()
return f"WITH {token.value},\n{CTE_ALIAS} AS (\n{remainder}\n)"
return None
@classmethod
def df_to_sql(
cls,

View File

@ -47,6 +47,7 @@ class MssqlEngineSpec(BaseEngineSpec):
engine_name = "Microsoft SQL Server"
limit_method = LimitMethod.WRAP_SQL
max_column_name_length = 128
allows_cte_in_subquery = False
_time_grain_expressions = {
None: "{col}",

View File

@ -984,7 +984,7 @@ class TestCore(SupersetTestCase):
sql=commented_query,
database=get_example_database(),
)
rendered_query = str(table.get_from_clause())
rendered_query = str(table.get_from_clause()[0])
self.assertEqual(clean_query, rendered_query)
def test_slice_payload_no_datasource(self):

View File

@ -16,7 +16,11 @@
# under the License.
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
from textwrap import dedent
import pytest
from flask.ctx import AppContext
from sqlalchemy.types import TypeEngine
def test_get_text_clause_with_colon(app_context: AppContext) -> None:
@ -56,3 +60,42 @@ def test_parse_sql_multi_statement(app_context: AppContext) -> None:
"SELECT foo FROM tbl1",
"SELECT bar FROM tbl2",
]
@pytest.mark.parametrize(
"original,expected",
[
(
dedent(
"""
with currency as
(
select 'INR' as cur
)
select * from currency
"""
),
None,
),
("SELECT 1 as cnt", None,),
(
dedent(
"""
select 'INR' as cur
union
select 'AUD' as cur
union
select 'USD' as cur
"""
),
None,
),
],
)
def test_cte_query_parsing(
app_context: AppContext, original: TypeEngine, expected: str
) -> None:
from superset.db_engine_specs.base import BaseEngineSpec
actual = BaseEngineSpec.get_cte_query(original)
assert actual == expected

View File

@ -180,6 +180,57 @@ def test_column_datatype_to_string(
assert actual == expected
@pytest.mark.parametrize(
"original,expected",
[
(
dedent(
"""
with currency as (
select 'INR' as cur
),
currency_2 as (
select 'EUR' as cur
)
select * from currency union all select * from currency_2
"""
),
dedent(
"""WITH currency as (
select 'INR' as cur
),
currency_2 as (
select 'EUR' as cur
),
__cte AS (
select * from currency union all select * from currency_2
)"""
),
),
("SELECT 1 as cnt", None,),
(
dedent(
"""
select 'INR' as cur
union
select 'AUD' as cur
union
select 'USD' as cur
"""
),
None,
),
],
)
def test_cte_query_parsing(
app_context: AppContext, original: TypeEngine, expected: str
) -> None:
from superset.db_engine_specs.mssql import MssqlEngineSpec
actual = MssqlEngineSpec.get_cte_query(original)
assert actual == expected
def test_extract_errors(app_context: AppContext) -> None:
"""
Test that custom error messages are extracted correctly.