mirror of https://github.com/apache/superset.git
fix(sqla-query): order by aggregations in Presto and Hive (#13739)
This commit is contained in:
parent
762101018b
commit
4789074309
|
@ -16,6 +16,7 @@
|
|||
# under the License.
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections import defaultdict, OrderedDict
|
||||
from contextlib import closing
|
||||
from dataclasses import dataclass, field # pylint: disable=wrong-import-order
|
||||
|
@ -50,6 +51,7 @@ from sqlalchemy.schema import UniqueConstraint
|
|||
from sqlalchemy.sql import column, ColumnElement, literal_column, table, text
|
||||
from sqlalchemy.sql.elements import ColumnClause
|
||||
from sqlalchemy.sql.expression import Label, Select, TextAsFrom, TextClause
|
||||
from sqlalchemy.sql.selectable import Alias, TableClause
|
||||
from sqlalchemy.types import TypeEngine
|
||||
|
||||
from superset import app, db, is_feature_enabled, security_manager
|
||||
|
@ -70,7 +72,7 @@ from superset.result_set import SupersetResultSet
|
|||
from superset.sql_parse import ParsedQuery
|
||||
from superset.typing import AdhocMetric, Metric, OrderBy, QueryObjectDict
|
||||
from superset.utils import core as utils
|
||||
from superset.utils.core import GenericDataType
|
||||
from superset.utils.core import GenericDataType, remove_duplicates
|
||||
|
||||
config = app.config
|
||||
metadata = Model.metadata # pylint: disable=no-member
|
||||
|
@ -465,7 +467,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
|
|||
database_id = Column(Integer, ForeignKey("dbs.id"), nullable=False)
|
||||
fetch_values_predicate = Column(String(1000))
|
||||
owners = relationship(owner_class, secondary=sqlatable_user, backref="tables")
|
||||
database = relationship(
|
||||
database: Database = relationship(
|
||||
"Database",
|
||||
backref=backref("tables", cascade="all, delete-orphan"),
|
||||
foreign_keys=[database_id],
|
||||
|
@ -507,22 +509,6 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
|
|||
"MAX": sa.func.MAX,
|
||||
}
|
||||
|
||||
def make_sqla_column_compatible(
|
||||
self, sqla_col: Column, label: Optional[str] = None
|
||||
) -> Column:
|
||||
"""Takes a sqlalchemy column object and adds label info if supported by engine.
|
||||
:param sqla_col: sqlalchemy column instance
|
||||
:param label: alias/label that column is expected to have
|
||||
:return: either a sql alchemy column or label instance if supported by engine
|
||||
"""
|
||||
label_expected = label or sqla_col.name
|
||||
db_engine_spec = self.database.db_engine_spec
|
||||
# add quotes to tables
|
||||
if db_engine_spec.allows_alias_in_select:
|
||||
label = db_engine_spec.make_label_compatible(label_expected)
|
||||
sqla_col = sqla_col.label(label)
|
||||
return sqla_col
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.name
|
||||
|
||||
|
@ -708,11 +694,10 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
|
|||
def data(self) -> Dict[str, Any]:
|
||||
data_ = super().data
|
||||
if self.type == "table":
|
||||
grains = self.database.grains() or []
|
||||
if grains:
|
||||
grains = [(g.duration, g.name) for g in grains]
|
||||
data_["granularity_sqla"] = utils.choicify(self.dttm_cols)
|
||||
data_["time_grain_sqla"] = grains
|
||||
data_["time_grain_sqla"] = [
|
||||
(g.duration, g.name) for g in self.database.grains() or []
|
||||
]
|
||||
data_["main_dttm_col"] = self.main_dttm_col
|
||||
data_["fetch_values_predicate"] = self.fetch_values_predicate
|
||||
data_["template_params"] = self.template_params
|
||||
|
@ -800,7 +785,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
|
|||
all_queries = query_str_ext.prequeries + [query_str_ext.sql]
|
||||
return ";\n\n".join(all_queries) + ";"
|
||||
|
||||
def get_sqla_table(self) -> table:
|
||||
def get_sqla_table(self) -> TableClause:
|
||||
tbl = table(self.table_name)
|
||||
if self.schema:
|
||||
tbl.schema = self.schema
|
||||
|
@ -808,7 +793,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
|
|||
|
||||
def get_from_clause(
|
||||
self, template_processor: Optional[BaseTemplateProcessor] = None
|
||||
) -> Union[table, TextAsFrom]:
|
||||
) -> Union[TableClause, Alias]:
|
||||
"""
|
||||
Return where to select the columns and metrics from. Either a physical table
|
||||
or a virtual table with it's own subquery.
|
||||
|
@ -882,6 +867,51 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
|
|||
|
||||
return self.make_sqla_column_compatible(sqla_metric, label)
|
||||
|
||||
def make_sqla_column_compatible(
|
||||
self, sqla_col: Column, label: Optional[str] = None
|
||||
) -> Column:
|
||||
"""Takes a sqlalchemy column object and adds label info if supported by engine.
|
||||
:param sqla_col: sqlalchemy column instance
|
||||
:param label: alias/label that column is expected to have
|
||||
:return: either a sql alchemy column or label instance if supported by engine
|
||||
"""
|
||||
label_expected = label or sqla_col.name
|
||||
db_engine_spec = self.database.db_engine_spec
|
||||
# add quotes to tables
|
||||
if db_engine_spec.allows_alias_in_select:
|
||||
label = db_engine_spec.make_label_compatible(label_expected)
|
||||
sqla_col = sqla_col.label(label)
|
||||
return sqla_col
|
||||
|
||||
def make_orderby_compatible(
|
||||
self, select_exprs: List[ColumnElement], orderby_exprs: List[ColumnElement]
|
||||
) -> None:
|
||||
"""
|
||||
If needed, make sure aliases for selected columns are not used in
|
||||
`ORDER BY`.
|
||||
|
||||
In some databases (e.g. Presto), `ORDER BY` clause is not able to
|
||||
automatically pick the source column if a `SELECT` clause alias is named
|
||||
the same as a source column. In this case, we update the SELECT alias to
|
||||
another name to avoid the conflict.
|
||||
"""
|
||||
if self.database.db_engine_spec.allows_alias_to_source_column:
|
||||
return
|
||||
|
||||
def is_alias_used_in_orderby(col: ColumnElement) -> bool:
|
||||
if not isinstance(col, Label):
|
||||
return False
|
||||
regexp = re.compile(f"\\(.*\\b{re.escape(col.name)}\\b.*\\)", re.IGNORECASE)
|
||||
return any(regexp.search(str(x)) for x in orderby_exprs)
|
||||
|
||||
# Iterate through selected columns, if column alias appears in orderby
|
||||
# use another `alias`. The final output columns will still use the
|
||||
# original names, because they are updated by `labels_expected` after
|
||||
# querying.
|
||||
for col in select_exprs:
|
||||
if is_alias_used_in_orderby(col):
|
||||
col.name = f"{col.name}__"
|
||||
|
||||
def _get_sqla_row_level_filters(
|
||||
self, template_processor: BaseTemplateProcessor
|
||||
) -> List[str]:
|
||||
|
@ -995,9 +1025,8 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
|
|||
|
||||
# To ensure correct handling of the ORDER BY labeling we need to reference the
|
||||
# metric instance if defined in the SELECT clause.
|
||||
metrics_exprs_by_label = {
|
||||
m.name: m for m in metrics_exprs # pylint: disable=protected-access
|
||||
}
|
||||
metrics_exprs_by_label = {m.name: m for m in metrics_exprs}
|
||||
metrics_exprs_by_expr = {str(m): m for m in metrics_exprs}
|
||||
|
||||
# Since orderby may use adhoc metrics, too; we need to process them first
|
||||
orderby_exprs: List[ColumnElement] = []
|
||||
|
@ -1007,21 +1036,25 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
|
|||
if utils.is_adhoc_metric(col):
|
||||
# add adhoc sort by column to columns_by_name if not exists
|
||||
col = self.adhoc_metric_to_sqla(col, columns_by_name)
|
||||
# if the adhoc metric has been defined before
|
||||
# use the existing instance.
|
||||
col = metrics_exprs_by_expr.get(str(col), col)
|
||||
need_groupby = True
|
||||
elif col in columns_by_name:
|
||||
col = columns_by_name[col].get_sqla_col()
|
||||
elif col in metrics_exprs_by_label:
|
||||
col = metrics_exprs_by_label[col]
|
||||
need_groupby = True
|
||||
elif col in metrics_by_name:
|
||||
col = metrics_by_name[col].get_sqla_col()
|
||||
need_groupby = True
|
||||
elif col in metrics_exprs_by_label:
|
||||
col = metrics_exprs_by_label[col]
|
||||
|
||||
if isinstance(col, ColumnElement):
|
||||
orderby_exprs.append(col)
|
||||
else:
|
||||
# Could not convert a column reference to valid ColumnElement
|
||||
raise QueryObjectValidationError(
|
||||
_("Unknown column used in orderby: %(col)", col=orig_col)
|
||||
_("Unknown column used in orderby: %(col)s", col=orig_col)
|
||||
)
|
||||
|
||||
select_exprs: List[Union[Column, Label]] = []
|
||||
|
@ -1093,11 +1126,21 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
|
|||
dttm_col.get_time_filter(from_dttm, to_dttm, time_range_endpoints)
|
||||
)
|
||||
|
||||
select_exprs += metrics_exprs
|
||||
labels_expected = [c.name for c in select_exprs]
|
||||
select_exprs = db_engine_spec.make_select_compatible(
|
||||
groupby_exprs_with_timestamp.values(), select_exprs
|
||||
# Always remove duplicates by column name, as sometimes `metrics_exprs`
|
||||
# can have the same name as a groupby column (e.g. when users use
|
||||
# raw columns as custom SQL adhoc metric).
|
||||
select_exprs = remove_duplicates(
|
||||
select_exprs + metrics_exprs, key=lambda x: x.name
|
||||
)
|
||||
|
||||
# Expected output columns
|
||||
labels_expected = [c.name for c in select_exprs]
|
||||
|
||||
# Order by columns are "hidden" columns, some databases require them
|
||||
# always be present in SELECT if an aggregation function is used
|
||||
if not db_engine_spec.allows_hidden_ordeby_agg:
|
||||
select_exprs = remove_duplicates(select_exprs + orderby_exprs)
|
||||
|
||||
qry = sa.select(select_exprs)
|
||||
|
||||
tbl = self.get_from_clause(template_processor)
|
||||
|
@ -1213,12 +1256,13 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
|
|||
qry = qry.where(and_(*where_clause_and))
|
||||
qry = qry.having(and_(*having_clause_and))
|
||||
|
||||
self.make_orderby_compatible(select_exprs, orderby_exprs)
|
||||
|
||||
for col, (orig_col, ascending) in zip(orderby_exprs, orderby):
|
||||
if (
|
||||
db_engine_spec.allows_alias_in_orderby
|
||||
and col.name in metrics_exprs_by_label
|
||||
):
|
||||
col = Label(col.name, metrics_exprs_by_label[col.name])
|
||||
if not db_engine_spec.allows_alias_in_orderby and isinstance(col, Label):
|
||||
# if engine does not allow using SELECT alias in ORDER BY
|
||||
# revert to the underlying column
|
||||
col = col.element
|
||||
direction = asc if ascending else desc
|
||||
qry = qry.order_by(direction(col))
|
||||
|
||||
|
@ -1315,6 +1359,9 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
|
|||
result.df, dimensions, groupby_exprs_sans_timestamp
|
||||
)
|
||||
qry = qry.where(top_groups)
|
||||
|
||||
qry = qry.select_from(tbl)
|
||||
|
||||
if is_rowcount:
|
||||
if not db_engine_spec.allows_subqueries:
|
||||
raise QueryObjectValidationError(
|
||||
|
@ -1322,10 +1369,9 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
|
|||
)
|
||||
label = "rowcount"
|
||||
col = self.make_sqla_column_compatible(literal_column("COUNT(*)"), label)
|
||||
qry = select([col]).select_from(qry.select_from(tbl).alias("rowcount_qry"))
|
||||
qry = select([col]).select_from(qry.alias("rowcount_qry"))
|
||||
labels_expected = [label]
|
||||
else:
|
||||
qry = qry.select_from(tbl)
|
||||
|
||||
return SqlaQuery(
|
||||
extra_cache_keys=extra_cache_keys,
|
||||
labels_expected=labels_expected,
|
||||
|
|
|
@ -49,7 +49,7 @@ from sqlalchemy.engine.url import URL
|
|||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import quoted_name, text
|
||||
from sqlalchemy.sql.expression import ColumnClause, ColumnElement, Select, TextAsFrom
|
||||
from sqlalchemy.sql.expression import ColumnClause, Select, TextAsFrom
|
||||
from sqlalchemy.types import String, TypeEngine, UnicodeText
|
||||
|
||||
from superset import app, security_manager, sql_parse
|
||||
|
@ -137,7 +137,18 @@ class LimitMethod: # pylint: disable=too-few-public-methods
|
|||
|
||||
|
||||
class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
"""Abstract class for database engine specific configurations"""
|
||||
"""Abstract class for database engine specific configurations
|
||||
|
||||
Attributes:
|
||||
allows_alias_to_source_column: Whether the engine is able to pick the
|
||||
source column for aggregation clauses
|
||||
used in ORDER BY when a column in SELECT
|
||||
has an alias that is the same as a source
|
||||
column.
|
||||
allows_hidden_orderby_agg: Whether the engine allows ORDER BY to
|
||||
directly use aggregation clauses, without
|
||||
having to add the same aggregation in SELECT.
|
||||
"""
|
||||
|
||||
engine = "base" # str as defined in sqlalchemy.engine.engine
|
||||
engine_aliases: Optional[Tuple[str]] = None
|
||||
|
@ -241,6 +252,15 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
allows_alias_in_select = True
|
||||
allows_alias_in_orderby = True
|
||||
allows_sql_comments = True
|
||||
|
||||
# Whether ORDER BY clause can use aliases created in SELECT
|
||||
# that are the same as a source column
|
||||
allows_alias_to_source_column = True
|
||||
|
||||
# Whether ORDER BY clause must appear in SELECT
|
||||
# if TRUE, then it doesn't have to.
|
||||
allows_hidden_ordeby_agg = True
|
||||
|
||||
force_column_alias_quotes = False
|
||||
arraysize = 0
|
||||
max_column_name_length = 0
|
||||
|
@ -441,20 +461,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def make_select_compatible(
|
||||
cls, groupby_exprs: Dict[str, ColumnElement], select_exprs: List[ColumnElement]
|
||||
) -> List[ColumnElement]:
|
||||
"""
|
||||
Some databases will just return the group-by field into the select, but don't
|
||||
allow the group-by field to be put into the select list.
|
||||
|
||||
:param groupby_exprs: mapping between column name and column object
|
||||
:param select_exprs: all columns in the select clause
|
||||
:return: columns to be included in the final select clause
|
||||
"""
|
||||
return select_exprs
|
||||
|
||||
@classmethod
|
||||
def fetch_data(
|
||||
cls, cursor: Any, limit: Optional[int] = None
|
||||
|
|
|
@ -81,6 +81,9 @@ class HiveEngineSpec(PrestoEngineSpec):
|
|||
engine = "hive"
|
||||
engine_name = "Apache Hive"
|
||||
max_column_name_length = 767
|
||||
allows_alias_to_source_column = True
|
||||
allows_hidden_ordeby_agg = False
|
||||
|
||||
# pylint: disable=line-too-long
|
||||
_time_grain_expressions = {
|
||||
None: "{col}",
|
||||
|
|
|
@ -14,9 +14,9 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, Optional
|
||||
|
||||
from sqlalchemy.sql.expression import ColumnClause, ColumnElement
|
||||
from sqlalchemy.sql.expression import ColumnClause
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec, TimestampExpression
|
||||
|
||||
|
@ -112,9 +112,3 @@ class PinotEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
|
|||
time_expr = f"DATETIMECONVERT({{col}}, '{tf}', '{tf}', '{granularity}')"
|
||||
|
||||
return TimestampExpression(time_expr, col)
|
||||
|
||||
@classmethod
|
||||
def make_select_compatible(
|
||||
cls, groupby_exprs: Dict[str, ColumnElement], select_exprs: List[ColumnElement]
|
||||
) -> List[ColumnElement]:
|
||||
return select_exprs
|
||||
|
|
|
@ -128,6 +128,7 @@ def get_children(column: Dict[str, str]) -> List[Dict[str, str]]:
|
|||
class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-methods
|
||||
engine = "presto"
|
||||
engine_name = "Presto"
|
||||
allows_alias_to_source_column = False
|
||||
|
||||
_time_grain_expressions = {
|
||||
None: "{col}",
|
||||
|
|
|
@ -1631,6 +1631,22 @@ def find_duplicates(items: Iterable[InputType]) -> List[InputType]:
|
|||
return [item for item, count in collections.Counter(items).items() if count > 1]
|
||||
|
||||
|
||||
def remove_duplicates(
|
||||
items: Iterable[InputType], key: Optional[Callable[[InputType], Any]] = None
|
||||
) -> List[InputType]:
|
||||
"""Remove duplicate items in an iterable."""
|
||||
if not key:
|
||||
return list(dict.fromkeys(items).keys())
|
||||
seen = set()
|
||||
result = []
|
||||
for item in items:
|
||||
item_key = key(item)
|
||||
if item_key not in seen:
|
||||
seen.add(item_key)
|
||||
result.append(item)
|
||||
return result
|
||||
|
||||
|
||||
def normalize_dttm_col(
|
||||
df: pd.DataFrame,
|
||||
timestamp_format: Optional[str],
|
||||
|
|
|
@ -23,7 +23,7 @@ from sqlalchemy.engine import Engine
|
|||
from tests.test_app import app
|
||||
|
||||
from superset import db
|
||||
from superset.utils.core import get_example_database
|
||||
from superset.utils.core import get_example_database, json_dumps_w_dates
|
||||
|
||||
|
||||
CTAS_SCHEMA_NAME = "sqllab_test_db"
|
||||
|
@ -73,13 +73,22 @@ def drop_from_schema(engine: Engine, schema_name: str):
|
|||
|
||||
def setup_presto_if_needed():
|
||||
backend = app.config["SQLALCHEMY_EXAMPLES_URI"].split("://")[0]
|
||||
database = get_example_database()
|
||||
extra = database.get_extra()
|
||||
|
||||
if backend == "presto":
|
||||
# decrease poll interval for tests
|
||||
presto_poll_interval = app.config["PRESTO_POLL_INTERVAL"]
|
||||
extra = f'{{"engine_params": {{"connect_args": {{"poll_interval": {presto_poll_interval}}}}}}}'
|
||||
database = get_example_database()
|
||||
database.extra = extra
|
||||
db.session.commit()
|
||||
extra = {
|
||||
**extra,
|
||||
"engine_params": {
|
||||
"connect_args": {"poll_interval": app.config["PRESTO_POLL_INTERVAL"]}
|
||||
},
|
||||
}
|
||||
else:
|
||||
# remove `poll_interval` from databases that do not support it
|
||||
extra = {**extra, "engine_params": {}}
|
||||
database.extra = json_dumps_w_dates(extra)
|
||||
db.session.commit()
|
||||
|
||||
if backend in {"presto", "hive"}:
|
||||
database = get_example_database()
|
||||
|
|
|
@ -82,7 +82,10 @@ class TestExportDatabasesCommand(SupersetTestCase):
|
|||
"schemas_allowed_for_csv_upload": [],
|
||||
}
|
||||
if backend() == "presto":
|
||||
expected_extra = {"engine_params": {"connect_args": {"poll_interval": 0.1}}}
|
||||
expected_extra = {
|
||||
**expected_extra,
|
||||
"engine_params": {"connect_args": {"poll_interval": 0.1}},
|
||||
}
|
||||
|
||||
assert core_files.issubset(set(contents.keys()))
|
||||
|
||||
|
|
|
@ -31,13 +31,13 @@ query_birth_names = {
|
|||
},
|
||||
"groupby": ["name"],
|
||||
"metrics": [{"label": "sum__num"}],
|
||||
"order_desc": True,
|
||||
"orderby": [["sum__num", False]],
|
||||
"row_limit": 100,
|
||||
"granularity": "ds",
|
||||
"time_range": "100 years ago : now",
|
||||
"timeseries_limit": 0,
|
||||
"timeseries_limit_metric": None,
|
||||
"order_desc": True,
|
||||
"filters": [
|
||||
{"col": "gender", "op": "==", "val": "boy"},
|
||||
{"col": "num", "op": "IS NOT NULL"},
|
||||
|
@ -49,8 +49,57 @@ query_birth_names = {
|
|||
}
|
||||
|
||||
QUERY_OBJECTS: Dict[str, Dict[str, object]] = {
|
||||
"birth_names": {**query_birth_names, "is_timeseries": False,},
|
||||
"birth_names:include_time": {**query_birth_names, "groupby": [DTTM_ALIAS, "name"],},
|
||||
"birth_names": query_birth_names,
|
||||
# `:suffix` are overrides only
|
||||
"birth_names:include_time": {"groupby": [DTTM_ALIAS, "name"],},
|
||||
"birth_names:orderby_dup_alias": {
|
||||
"metrics": [
|
||||
{
|
||||
"expressionType": "SIMPLE",
|
||||
"column": {"column_name": "num_girls", "type": "BIGINT(20)"},
|
||||
"aggregate": "SUM",
|
||||
"label": "num_girls",
|
||||
},
|
||||
{
|
||||
"expressionType": "SIMPLE",
|
||||
"column": {"column_name": "num_boys", "type": "BIGINT(20)"},
|
||||
"aggregate": "SUM",
|
||||
"label": "num_boys",
|
||||
},
|
||||
],
|
||||
"orderby": [
|
||||
[
|
||||
{
|
||||
"expressionType": "SIMPLE",
|
||||
"column": {"column_name": "num_girls", "type": "BIGINT(20)"},
|
||||
"aggregate": "SUM",
|
||||
# the same underlying expression, but different label
|
||||
"label": "SUM(num_girls)",
|
||||
},
|
||||
False,
|
||||
],
|
||||
# reference the ambiguous alias in SIMPLE metric
|
||||
[
|
||||
{
|
||||
"expressionType": "SIMPLE",
|
||||
"column": {"column_name": "num_boys", "type": "BIGINT(20)"},
|
||||
"aggregate": "AVG",
|
||||
"label": "AVG(num_boys)",
|
||||
},
|
||||
False,
|
||||
],
|
||||
# reference the ambiguous alias in CUSTOM SQL metric
|
||||
[
|
||||
{
|
||||
"expressionType": "SQL",
|
||||
"sqlExpression": "MAX(CASE WHEN num_boys > 0 THEN 1 ELSE 0 END)",
|
||||
"label": "MAX(CASE WHEN...",
|
||||
},
|
||||
True,
|
||||
],
|
||||
],
|
||||
},
|
||||
"birth_names:only_orderby_has_metric": {"metrics": [],},
|
||||
}
|
||||
|
||||
ANNOTATION_LAYERS = {
|
||||
|
@ -150,7 +199,17 @@ def get_query_object(
|
|||
) -> Dict[str, Any]:
|
||||
if query_name not in QUERY_OBJECTS:
|
||||
raise Exception(f"QueryObject fixture not defined for datasource: {query_name}")
|
||||
query_object = copy.deepcopy(QUERY_OBJECTS[query_name])
|
||||
obj = QUERY_OBJECTS[query_name]
|
||||
|
||||
# apply overrides
|
||||
if ":" in query_name:
|
||||
parent_query_name = query_name.split(":")[0]
|
||||
obj = {
|
||||
**QUERY_OBJECTS[parent_query_name],
|
||||
**obj,
|
||||
}
|
||||
|
||||
query_object = copy.deepcopy(obj)
|
||||
if add_postprocessing_operations:
|
||||
query_object["post_processing"] = _get_postprocessing_operation(query_name)
|
||||
return query_object
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import re
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -24,9 +25,9 @@ from superset.common.query_context import QueryContext
|
|||
from superset.common.query_object import QueryObject
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from superset.extensions import cache_manager
|
||||
from superset.models.cache import CacheKey
|
||||
from superset.utils.core import (
|
||||
AdhocMetricExpressionType,
|
||||
backend,
|
||||
ChartDataResultFormat,
|
||||
ChartDataResultType,
|
||||
TimeRangeEndpoint,
|
||||
|
@ -36,6 +37,17 @@ from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with
|
|||
from tests.fixtures.query_context import get_query_context
|
||||
|
||||
|
||||
def get_sql_text(payload: Dict[str, Any]) -> str:
|
||||
payload["result_type"] = ChartDataResultType.QUERY.value
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
responses = query_context.get_payload()
|
||||
assert len(responses) == 1
|
||||
response = responses["queries"][0]
|
||||
assert len(response) == 2
|
||||
assert response["language"] == "sql"
|
||||
return response["query"]
|
||||
|
||||
|
||||
class TestQueryContext(SupersetTestCase):
|
||||
def test_schema_deserialization(self):
|
||||
"""
|
||||
|
@ -301,14 +313,7 @@ class TestQueryContext(SupersetTestCase):
|
|||
"""
|
||||
self.login(username="admin")
|
||||
payload = get_query_context("birth_names")
|
||||
payload["result_type"] = ChartDataResultType.QUERY.value
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
responses = query_context.get_payload()
|
||||
assert len(responses) == 1
|
||||
response = responses["queries"][0]
|
||||
assert len(response) == 2
|
||||
sql_text = response["query"]
|
||||
assert response["language"] == "sql"
|
||||
sql_text = get_sql_text(payload)
|
||||
assert "SELECT" in sql_text
|
||||
assert re.search(r'[`"\[]?num[`"\]]? IS NOT NULL', sql_text)
|
||||
assert re.search(
|
||||
|
@ -318,37 +323,102 @@ class TestQueryContext(SupersetTestCase):
|
|||
)
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_fetch_values_predicate_in_query(self):
|
||||
def test_handle_sort_by_metrics(self):
|
||||
"""
|
||||
Ensure that fetch values predicate is added to query
|
||||
Should properly handle sort by metrics in various scenarios.
|
||||
"""
|
||||
self.login(username="admin")
|
||||
payload = get_query_context("birth_names")
|
||||
payload["result_type"] = ChartDataResultType.QUERY.value
|
||||
payload["queries"][0]["apply_fetch_values_predicate"] = True
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
responses = query_context.get_payload()
|
||||
assert len(responses) == 1
|
||||
response = responses["queries"][0]
|
||||
assert len(response) == 2
|
||||
assert response["language"] == "sql"
|
||||
assert "123 = 123" in response["query"]
|
||||
|
||||
sql_text = get_sql_text(get_query_context("birth_names"))
|
||||
if backend() == "hive":
|
||||
# should have no duplicate `SUM(num)`
|
||||
assert "SUM(num) AS `sum__num`," not in sql_text
|
||||
assert "SUM(num) AS `sum__num`" in sql_text
|
||||
# the alias should be in ORDER BY
|
||||
assert "ORDER BY `sum__num` DESC" in sql_text
|
||||
else:
|
||||
assert re.search(r'ORDER BY [`"\[]?sum__num[`"\]]? DESC', sql_text)
|
||||
|
||||
sql_text = get_sql_text(
|
||||
get_query_context("birth_names:only_orderby_has_metric")
|
||||
)
|
||||
if backend() == "hive":
|
||||
assert "SUM(num) AS `sum__num`," not in sql_text
|
||||
assert "SUM(num) AS `sum__num`" in sql_text
|
||||
assert "ORDER BY `sum__num` DESC" in sql_text
|
||||
else:
|
||||
assert re.search(
|
||||
r'ORDER BY SUM\([`"\[]?num[`"\]]?\) DESC', sql_text, re.IGNORECASE
|
||||
)
|
||||
|
||||
sql_text = get_sql_text(get_query_context("birth_names:orderby_dup_alias"))
|
||||
|
||||
# Check SELECT clauses
|
||||
if backend() == "presto":
|
||||
# presto cannot have ambiguous alias in order by, so selected column
|
||||
# alias is renamed.
|
||||
assert 'sum("num_boys") AS "num_boys__"' in sql_text
|
||||
else:
|
||||
assert re.search(
|
||||
r'SUM\([`"\[]?num_boys[`"\]]?\) AS [`\"\[]?num_boys[`"\]]?',
|
||||
sql_text,
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Check ORDER BY clauses
|
||||
if backend() == "hive":
|
||||
# Hive must add additional SORT BY metrics to SELECT
|
||||
assert re.search(
|
||||
r"MAX\(CASE.*END\) AS `MAX\(CASE WHEN...`",
|
||||
sql_text,
|
||||
re.IGNORECASE | re.DOTALL,
|
||||
)
|
||||
|
||||
# The additional column with the same expression but a different label
|
||||
# as an existing metric should not be added
|
||||
assert "sum(`num_girls`) AS `SUM(num_girls)`" not in sql_text
|
||||
|
||||
# Should reference all ORDER BY columns by aliases
|
||||
assert "ORDER BY `num_girls` DESC," in sql_text
|
||||
assert "`AVG(num_boys)` DESC," in sql_text
|
||||
assert "`MAX(CASE WHEN...` ASC" in sql_text
|
||||
else:
|
||||
if backend() == "presto":
|
||||
# since the selected `num_boys` is renamed to `num_boys__`
|
||||
# it must be references as expression
|
||||
assert re.search(
|
||||
r'ORDER BY SUM\([`"\[]?num_girls[`"\]]?\) DESC',
|
||||
sql_text,
|
||||
re.IGNORECASE,
|
||||
)
|
||||
else:
|
||||
# Should reference the adhoc metric by alias when possible
|
||||
assert re.search(
|
||||
r'ORDER BY [`"\[]?num_girls[`"\]]? DESC', sql_text, re.IGNORECASE,
|
||||
)
|
||||
|
||||
# ORDER BY only columns should always be expressions
|
||||
assert re.search(
|
||||
r'AVG\([`"\[]?num_boys[`"\]]?\) DESC', sql_text, re.IGNORECASE,
|
||||
)
|
||||
assert re.search(
|
||||
r"MAX\(CASE.*END\) ASC", sql_text, re.IGNORECASE | re.DOTALL
|
||||
)
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_fetch_values_predicate_not_in_query(self):
|
||||
def test_fetch_values_predicate(self):
|
||||
"""
|
||||
Ensure that fetch values predicate is not added to query
|
||||
Ensure that fetch values predicate is added to query if needed
|
||||
"""
|
||||
self.login(username="admin")
|
||||
|
||||
payload = get_query_context("birth_names")
|
||||
payload["result_type"] = ChartDataResultType.QUERY.value
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
responses = query_context.get_payload()
|
||||
assert len(responses) == 1
|
||||
response = responses["queries"][0]
|
||||
assert len(response) == 2
|
||||
assert response["language"] == "sql"
|
||||
assert "123 = 123" not in response["query"]
|
||||
sql_text = get_sql_text(payload)
|
||||
assert "123 = 123" not in sql_text
|
||||
|
||||
payload["queries"][0]["apply_fetch_values_predicate"] = True
|
||||
sql_text = get_sql_text(payload)
|
||||
assert "123 = 123" in sql_text
|
||||
|
||||
def test_query_object_unknown_fields(self):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue