mirror of https://github.com/apache/superset.git
fix(mssql): support cte in virtual tables (#18567)
* Fix for handling regular CTE queries with MSSQL,#8074 * Moved the get_cte_query function from mssql.py to base.py for using irrespetcive of dbengine * Fix for handling regular CTE queries with MSSQL,#8074 * Moved the get_cte_query function from mssql.py to base.py for using irrespetcive of dbengine * Unit test added for the db engine CTE SQL parsing. Unit test added for the db engine CTE SQL parsing. Removed additional spaces from the CTE parsing SQL generation. * implement in sqla model * lint + cleanup Co-authored-by: Ville Brofeldt <ville.v.brofeldt@gmail.com>
This commit is contained in:
parent
00eb6b1f57
commit
b8aef10098
|
@ -77,7 +77,7 @@ from superset.connectors.sqla.utils import (
|
|||
get_physical_table_metadata,
|
||||
get_virtual_table_metadata,
|
||||
)
|
||||
from superset.db_engine_specs.base import BaseEngineSpec, TimestampExpression
|
||||
from superset.db_engine_specs.base import BaseEngineSpec, CTE_ALIAS, TimestampExpression
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.jinja_context import (
|
||||
BaseTemplateProcessor,
|
||||
|
@ -107,6 +107,7 @@ VIRTUAL_TABLE_ALIAS = "virtual_table"
|
|||
|
||||
class SqlaQuery(NamedTuple):
|
||||
applied_template_filters: List[str]
|
||||
cte: Optional[str]
|
||||
extra_cache_keys: List[Any]
|
||||
labels_expected: List[str]
|
||||
prequeries: List[str]
|
||||
|
@ -562,6 +563,19 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
def __repr__(self) -> str:
|
||||
return self.name
|
||||
|
||||
@staticmethod
|
||||
def _apply_cte(sql: str, cte: Optional[str]) -> str:
|
||||
"""
|
||||
Append a CTE before the SELECT statement if defined
|
||||
|
||||
:param sql: SELECT statement
|
||||
:param cte: CTE statement
|
||||
:return:
|
||||
"""
|
||||
if cte:
|
||||
sql = f"{cte}\n{sql}"
|
||||
return sql
|
||||
|
||||
@property
|
||||
def db_engine_spec(self) -> Type[BaseEngineSpec]:
|
||||
return self.database.db_engine_spec
|
||||
|
@ -743,12 +757,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
cols = {col.column_name: col for col in self.columns}
|
||||
target_col = cols[column_name]
|
||||
tp = self.get_template_processor()
|
||||
tbl, cte = self.get_from_clause(tp)
|
||||
|
||||
qry = (
|
||||
select([target_col.get_sqla_col()])
|
||||
.select_from(self.get_from_clause(tp))
|
||||
.distinct()
|
||||
)
|
||||
qry = select([target_col.get_sqla_col()]).select_from(tbl).distinct()
|
||||
if limit:
|
||||
qry = qry.limit(limit)
|
||||
|
||||
|
@ -756,7 +767,8 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
qry = qry.where(self.get_fetch_values_predicate())
|
||||
|
||||
engine = self.database.get_sqla_engine()
|
||||
sql = "{}".format(qry.compile(engine, compile_kwargs={"literal_binds": True}))
|
||||
sql = qry.compile(engine, compile_kwargs={"literal_binds": True})
|
||||
sql = self._apply_cte(sql, cte)
|
||||
sql = self.mutate_query_from_config(sql)
|
||||
|
||||
df = pd.read_sql_query(sql=sql, con=engine)
|
||||
|
@ -778,6 +790,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
def get_query_str_extended(self, query_obj: QueryObjectDict) -> QueryStringExtended:
|
||||
sqlaq = self.get_sqla_query(**query_obj)
|
||||
sql = self.database.compile_sqla_query(sqlaq.sqla_query)
|
||||
sql = self._apply_cte(sql, sqlaq.cte)
|
||||
sql = sqlparse.format(sql, reindent=True)
|
||||
sql = self.mutate_query_from_config(sql)
|
||||
return QueryStringExtended(
|
||||
|
@ -800,13 +813,14 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
|
||||
def get_from_clause(
|
||||
self, template_processor: Optional[BaseTemplateProcessor] = None
|
||||
) -> Union[TableClause, Alias]:
|
||||
) -> Tuple[Union[TableClause, Alias], Optional[str]]:
|
||||
"""
|
||||
Return where to select the columns and metrics from. Either a physical table
|
||||
or a virtual table with it's own subquery.
|
||||
or a virtual table with it's own subquery. If the FROM is referencing a
|
||||
CTE, the CTE is returned as the second value in the return tuple.
|
||||
"""
|
||||
if not self.is_virtual:
|
||||
return self.get_sqla_table()
|
||||
return self.get_sqla_table(), None
|
||||
|
||||
from_sql = self.get_rendered_sql(template_processor)
|
||||
parsed_query = ParsedQuery(from_sql)
|
||||
|
@ -817,7 +831,15 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
raise QueryObjectValidationError(
|
||||
_("Virtual dataset query must be read-only")
|
||||
)
|
||||
return TextAsFrom(self.text(from_sql), []).alias(VIRTUAL_TABLE_ALIAS)
|
||||
|
||||
cte = self.db_engine_spec.get_cte_query(from_sql)
|
||||
from_clause = (
|
||||
table(CTE_ALIAS)
|
||||
if cte
|
||||
else TextAsFrom(self.text(from_sql), []).alias(VIRTUAL_TABLE_ALIAS)
|
||||
)
|
||||
|
||||
return from_clause, cte
|
||||
|
||||
def get_rendered_sql(
|
||||
self, template_processor: Optional[BaseTemplateProcessor] = None
|
||||
|
@ -1224,7 +1246,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
|
||||
qry = sa.select(select_exprs)
|
||||
|
||||
tbl = self.get_from_clause(template_processor)
|
||||
tbl, cte = self.get_from_clause(template_processor)
|
||||
|
||||
if groupby_all_columns:
|
||||
qry = qry.group_by(*groupby_all_columns.values())
|
||||
|
@ -1491,6 +1513,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
|
||||
return SqlaQuery(
|
||||
applied_template_filters=applied_template_filters,
|
||||
cte=cte,
|
||||
extra_cache_keys=extra_cache_keys,
|
||||
labels_expected=labels_expected,
|
||||
sqla_query=qry,
|
||||
|
|
|
@ -54,6 +54,7 @@ from sqlalchemy.orm import Session
|
|||
from sqlalchemy.sql import quoted_name, text
|
||||
from sqlalchemy.sql.expression import ColumnClause, Select, TextAsFrom, TextClause
|
||||
from sqlalchemy.types import TypeEngine
|
||||
from sqlparse.tokens import CTE
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from superset import security_manager, sql_parse
|
||||
|
@ -80,6 +81,9 @@ ColumnTypeMapping = Tuple[
|
|||
logger = logging.getLogger()
|
||||
|
||||
|
||||
CTE_ALIAS = "__cte"
|
||||
|
||||
|
||||
class TimeGrain(NamedTuple):
|
||||
name: str # TODO: redundant field, remove
|
||||
label: str
|
||||
|
@ -292,6 +296,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
# But for backward compatibility, False by default
|
||||
allows_hidden_cc_in_orderby = False
|
||||
|
||||
# Whether allow CTE as subquery or regular CTE
|
||||
# If True, then it will allow in subquery ,
|
||||
# if False it will allow as regular CTE
|
||||
allows_cte_in_subquery = True
|
||||
|
||||
force_column_alias_quotes = False
|
||||
arraysize = 0
|
||||
max_column_name_length = 0
|
||||
|
@ -663,6 +672,31 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
parsed_query = sql_parse.ParsedQuery(sql)
|
||||
return parsed_query.set_or_update_query_limit(limit)
|
||||
|
||||
@classmethod
|
||||
def get_cte_query(cls, sql: str) -> Optional[str]:
|
||||
"""
|
||||
Convert the input CTE based SQL to the SQL for virtual table conversion
|
||||
|
||||
:param sql: SQL query
|
||||
:return: CTE with the main select query aliased as `__cte`
|
||||
|
||||
"""
|
||||
if not cls.allows_cte_in_subquery:
|
||||
stmt = sqlparse.parse(sql)[0]
|
||||
|
||||
# The first meaningful token for CTE will be with WITH
|
||||
idx, token = stmt.token_next(-1, skip_ws=True, skip_cm=True)
|
||||
if not (token and token.ttype == CTE):
|
||||
return None
|
||||
idx, token = stmt.token_next(idx)
|
||||
idx = stmt.token_index(token) + 1
|
||||
|
||||
# extract rest of the SQLs after CTE
|
||||
remainder = "".join(str(token) for token in stmt.tokens[idx:]).strip()
|
||||
return f"WITH {token.value},\n{CTE_ALIAS} AS (\n{remainder}\n)"
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def df_to_sql(
|
||||
cls,
|
||||
|
|
|
@ -47,6 +47,7 @@ class MssqlEngineSpec(BaseEngineSpec):
|
|||
engine_name = "Microsoft SQL Server"
|
||||
limit_method = LimitMethod.WRAP_SQL
|
||||
max_column_name_length = 128
|
||||
allows_cte_in_subquery = False
|
||||
|
||||
_time_grain_expressions = {
|
||||
None: "{col}",
|
||||
|
|
|
@ -984,7 +984,7 @@ class TestCore(SupersetTestCase):
|
|||
sql=commented_query,
|
||||
database=get_example_database(),
|
||||
)
|
||||
rendered_query = str(table.get_from_clause())
|
||||
rendered_query = str(table.get_from_clause()[0])
|
||||
self.assertEqual(clean_query, rendered_query)
|
||||
|
||||
def test_slice_payload_no_datasource(self):
|
||||
|
|
|
@ -16,7 +16,11 @@
|
|||
# under the License.
|
||||
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
|
||||
|
||||
from textwrap import dedent
|
||||
|
||||
import pytest
|
||||
from flask.ctx import AppContext
|
||||
from sqlalchemy.types import TypeEngine
|
||||
|
||||
|
||||
def test_get_text_clause_with_colon(app_context: AppContext) -> None:
|
||||
|
@ -56,3 +60,42 @@ def test_parse_sql_multi_statement(app_context: AppContext) -> None:
|
|||
"SELECT foo FROM tbl1",
|
||||
"SELECT bar FROM tbl2",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"original,expected",
|
||||
[
|
||||
(
|
||||
dedent(
|
||||
"""
|
||||
with currency as
|
||||
(
|
||||
select 'INR' as cur
|
||||
)
|
||||
select * from currency
|
||||
"""
|
||||
),
|
||||
None,
|
||||
),
|
||||
("SELECT 1 as cnt", None,),
|
||||
(
|
||||
dedent(
|
||||
"""
|
||||
select 'INR' as cur
|
||||
union
|
||||
select 'AUD' as cur
|
||||
union
|
||||
select 'USD' as cur
|
||||
"""
|
||||
),
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_cte_query_parsing(
|
||||
app_context: AppContext, original: TypeEngine, expected: str
|
||||
) -> None:
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
actual = BaseEngineSpec.get_cte_query(original)
|
||||
assert actual == expected
|
||||
|
|
|
@ -180,6 +180,57 @@ def test_column_datatype_to_string(
|
|||
assert actual == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"original,expected",
|
||||
[
|
||||
(
|
||||
dedent(
|
||||
"""
|
||||
with currency as (
|
||||
select 'INR' as cur
|
||||
),
|
||||
currency_2 as (
|
||||
select 'EUR' as cur
|
||||
)
|
||||
select * from currency union all select * from currency_2
|
||||
"""
|
||||
),
|
||||
dedent(
|
||||
"""WITH currency as (
|
||||
select 'INR' as cur
|
||||
),
|
||||
currency_2 as (
|
||||
select 'EUR' as cur
|
||||
),
|
||||
__cte AS (
|
||||
select * from currency union all select * from currency_2
|
||||
)"""
|
||||
),
|
||||
),
|
||||
("SELECT 1 as cnt", None,),
|
||||
(
|
||||
dedent(
|
||||
"""
|
||||
select 'INR' as cur
|
||||
union
|
||||
select 'AUD' as cur
|
||||
union
|
||||
select 'USD' as cur
|
||||
"""
|
||||
),
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_cte_query_parsing(
|
||||
app_context: AppContext, original: TypeEngine, expected: str
|
||||
) -> None:
|
||||
from superset.db_engine_specs.mssql import MssqlEngineSpec
|
||||
|
||||
actual = MssqlEngineSpec.get_cte_query(original)
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_extract_errors(app_context: AppContext) -> None:
|
||||
"""
|
||||
Test that custom error messages are extracted correctly.
|
||||
|
|
Loading…
Reference in New Issue