fix(sqla-query): order by aggregations in Presto and Hive (#13739)

This commit is contained in:
Jesse Yang 2021-04-01 18:10:17 -07:00 committed by GitHub
parent 762101018b
commit 4789074309
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 315 additions and 108 deletions

View File

@ -16,6 +16,7 @@
# under the License.
import json
import logging
import re
from collections import defaultdict, OrderedDict
from contextlib import closing
from dataclasses import dataclass, field # pylint: disable=wrong-import-order
@ -50,6 +51,7 @@ from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.sql import column, ColumnElement, literal_column, table, text
from sqlalchemy.sql.elements import ColumnClause
from sqlalchemy.sql.expression import Label, Select, TextAsFrom, TextClause
from sqlalchemy.sql.selectable import Alias, TableClause
from sqlalchemy.types import TypeEngine
from superset import app, db, is_feature_enabled, security_manager
@ -70,7 +72,7 @@ from superset.result_set import SupersetResultSet
from superset.sql_parse import ParsedQuery
from superset.typing import AdhocMetric, Metric, OrderBy, QueryObjectDict
from superset.utils import core as utils
from superset.utils.core import GenericDataType
from superset.utils.core import GenericDataType, remove_duplicates
config = app.config
metadata = Model.metadata # pylint: disable=no-member
@ -465,7 +467,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
database_id = Column(Integer, ForeignKey("dbs.id"), nullable=False)
fetch_values_predicate = Column(String(1000))
owners = relationship(owner_class, secondary=sqlatable_user, backref="tables")
database = relationship(
database: Database = relationship(
"Database",
backref=backref("tables", cascade="all, delete-orphan"),
foreign_keys=[database_id],
@ -507,22 +509,6 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
"MAX": sa.func.MAX,
}
def make_sqla_column_compatible(
self, sqla_col: Column, label: Optional[str] = None
) -> Column:
"""Takes a sqlalchemy column object and adds label info if supported by engine.
:param sqla_col: sqlalchemy column instance
:param label: alias/label that column is expected to have
:return: either a sql alchemy column or label instance if supported by engine
"""
label_expected = label or sqla_col.name
db_engine_spec = self.database.db_engine_spec
# add quotes to tables
if db_engine_spec.allows_alias_in_select:
label = db_engine_spec.make_label_compatible(label_expected)
sqla_col = sqla_col.label(label)
return sqla_col
def __repr__(self) -> str:
return self.name
@ -708,11 +694,10 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
def data(self) -> Dict[str, Any]:
data_ = super().data
if self.type == "table":
grains = self.database.grains() or []
if grains:
grains = [(g.duration, g.name) for g in grains]
data_["granularity_sqla"] = utils.choicify(self.dttm_cols)
data_["time_grain_sqla"] = grains
data_["time_grain_sqla"] = [
(g.duration, g.name) for g in self.database.grains() or []
]
data_["main_dttm_col"] = self.main_dttm_col
data_["fetch_values_predicate"] = self.fetch_values_predicate
data_["template_params"] = self.template_params
@ -800,7 +785,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
all_queries = query_str_ext.prequeries + [query_str_ext.sql]
return ";\n\n".join(all_queries) + ";"
def get_sqla_table(self) -> table:
def get_sqla_table(self) -> TableClause:
tbl = table(self.table_name)
if self.schema:
tbl.schema = self.schema
@ -808,7 +793,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
def get_from_clause(
self, template_processor: Optional[BaseTemplateProcessor] = None
) -> Union[table, TextAsFrom]:
) -> Union[TableClause, Alias]:
"""
Return where to select the columns and metrics from. Either a physical table
or a virtual table with it's own subquery.
@ -882,6 +867,51 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
return self.make_sqla_column_compatible(sqla_metric, label)
def make_sqla_column_compatible(
self, sqla_col: Column, label: Optional[str] = None
) -> Column:
"""Takes a sqlalchemy column object and adds label info if supported by engine.
:param sqla_col: sqlalchemy column instance
:param label: alias/label that column is expected to have
:return: either a sql alchemy column or label instance if supported by engine
"""
label_expected = label or sqla_col.name
db_engine_spec = self.database.db_engine_spec
# add quotes to tables
if db_engine_spec.allows_alias_in_select:
label = db_engine_spec.make_label_compatible(label_expected)
sqla_col = sqla_col.label(label)
return sqla_col
def make_orderby_compatible(
self, select_exprs: List[ColumnElement], orderby_exprs: List[ColumnElement]
) -> None:
"""
If needed, make sure aliases for selected columns are not used in
`ORDER BY`.
In some databases (e.g. Presto), `ORDER BY` clause is not able to
automatically pick the source column if a `SELECT` clause alias is named
the same as a source column. In this case, we update the SELECT alias to
another name to avoid the conflict.
"""
if self.database.db_engine_spec.allows_alias_to_source_column:
return
def is_alias_used_in_orderby(col: ColumnElement) -> bool:
if not isinstance(col, Label):
return False
regexp = re.compile(f"\\(.*\\b{re.escape(col.name)}\\b.*\\)", re.IGNORECASE)
return any(regexp.search(str(x)) for x in orderby_exprs)
# Iterate through selected columns, if column alias appears in orderby
# use another `alias`. The final output columns will still use the
# original names, because they are updated by `labels_expected` after
# querying.
for col in select_exprs:
if is_alias_used_in_orderby(col):
col.name = f"{col.name}__"
def _get_sqla_row_level_filters(
self, template_processor: BaseTemplateProcessor
) -> List[str]:
@ -995,9 +1025,8 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
# To ensure correct handling of the ORDER BY labeling we need to reference the
# metric instance if defined in the SELECT clause.
metrics_exprs_by_label = {
m.name: m for m in metrics_exprs # pylint: disable=protected-access
}
metrics_exprs_by_label = {m.name: m for m in metrics_exprs}
metrics_exprs_by_expr = {str(m): m for m in metrics_exprs}
# Since orderby may use adhoc metrics, too; we need to process them first
orderby_exprs: List[ColumnElement] = []
@ -1007,21 +1036,25 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
if utils.is_adhoc_metric(col):
# add adhoc sort by column to columns_by_name if not exists
col = self.adhoc_metric_to_sqla(col, columns_by_name)
# if the adhoc metric has been defined before
# use the existing instance.
col = metrics_exprs_by_expr.get(str(col), col)
need_groupby = True
elif col in columns_by_name:
col = columns_by_name[col].get_sqla_col()
elif col in metrics_exprs_by_label:
col = metrics_exprs_by_label[col]
need_groupby = True
elif col in metrics_by_name:
col = metrics_by_name[col].get_sqla_col()
need_groupby = True
elif col in metrics_exprs_by_label:
col = metrics_exprs_by_label[col]
if isinstance(col, ColumnElement):
orderby_exprs.append(col)
else:
# Could not convert a column reference to valid ColumnElement
raise QueryObjectValidationError(
_("Unknown column used in orderby: %(col)", col=orig_col)
_("Unknown column used in orderby: %(col)s", col=orig_col)
)
select_exprs: List[Union[Column, Label]] = []
@ -1093,11 +1126,21 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
dttm_col.get_time_filter(from_dttm, to_dttm, time_range_endpoints)
)
select_exprs += metrics_exprs
labels_expected = [c.name for c in select_exprs]
select_exprs = db_engine_spec.make_select_compatible(
groupby_exprs_with_timestamp.values(), select_exprs
# Always remove duplicates by column name, as sometimes `metrics_exprs`
# can have the same name as a groupby column (e.g. when users use
# raw columns as custom SQL adhoc metric).
select_exprs = remove_duplicates(
select_exprs + metrics_exprs, key=lambda x: x.name
)
# Expected output columns
labels_expected = [c.name for c in select_exprs]
# Order by columns are "hidden" columns, some databases require them
# always be present in SELECT if an aggregation function is used
if not db_engine_spec.allows_hidden_ordeby_agg:
select_exprs = remove_duplicates(select_exprs + orderby_exprs)
qry = sa.select(select_exprs)
tbl = self.get_from_clause(template_processor)
@ -1213,12 +1256,13 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
qry = qry.where(and_(*where_clause_and))
qry = qry.having(and_(*having_clause_and))
self.make_orderby_compatible(select_exprs, orderby_exprs)
for col, (orig_col, ascending) in zip(orderby_exprs, orderby):
if (
db_engine_spec.allows_alias_in_orderby
and col.name in metrics_exprs_by_label
):
col = Label(col.name, metrics_exprs_by_label[col.name])
if not db_engine_spec.allows_alias_in_orderby and isinstance(col, Label):
# if engine does not allow using SELECT alias in ORDER BY
# revert to the underlying column
col = col.element
direction = asc if ascending else desc
qry = qry.order_by(direction(col))
@ -1315,6 +1359,9 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
result.df, dimensions, groupby_exprs_sans_timestamp
)
qry = qry.where(top_groups)
qry = qry.select_from(tbl)
if is_rowcount:
if not db_engine_spec.allows_subqueries:
raise QueryObjectValidationError(
@ -1322,10 +1369,9 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
)
label = "rowcount"
col = self.make_sqla_column_compatible(literal_column("COUNT(*)"), label)
qry = select([col]).select_from(qry.select_from(tbl).alias("rowcount_qry"))
qry = select([col]).select_from(qry.alias("rowcount_qry"))
labels_expected = [label]
else:
qry = qry.select_from(tbl)
return SqlaQuery(
extra_cache_keys=extra_cache_keys,
labels_expected=labels_expected,

View File

@ -49,7 +49,7 @@ from sqlalchemy.engine.url import URL
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.orm import Session
from sqlalchemy.sql import quoted_name, text
from sqlalchemy.sql.expression import ColumnClause, ColumnElement, Select, TextAsFrom
from sqlalchemy.sql.expression import ColumnClause, Select, TextAsFrom
from sqlalchemy.types import String, TypeEngine, UnicodeText
from superset import app, security_manager, sql_parse
@ -137,7 +137,18 @@ class LimitMethod: # pylint: disable=too-few-public-methods
class BaseEngineSpec: # pylint: disable=too-many-public-methods
"""Abstract class for database engine specific configurations"""
"""Abstract class for database engine specific configurations
Attributes:
allows_alias_to_source_column: Whether the engine is able to pick the
source column for aggregation clauses
used in ORDER BY when a column in SELECT
has an alias that is the same as a source
column.
allows_hidden_orderby_agg: Whether the engine allows ORDER BY to
directly use aggregation clauses, without
having to add the same aggregation in SELECT.
"""
engine = "base" # str as defined in sqlalchemy.engine.engine
engine_aliases: Optional[Tuple[str]] = None
@ -241,6 +252,15 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
allows_alias_in_select = True
allows_alias_in_orderby = True
allows_sql_comments = True
# Whether ORDER BY clause can use aliases created in SELECT
# that are the same as a source column
allows_alias_to_source_column = True
# Whether ORDER BY clause must appear in SELECT
# if TRUE, then it doesn't have to.
allows_hidden_ordeby_agg = True
force_column_alias_quotes = False
arraysize = 0
max_column_name_length = 0
@ -441,20 +461,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
)
)
@classmethod
def make_select_compatible(
cls, groupby_exprs: Dict[str, ColumnElement], select_exprs: List[ColumnElement]
) -> List[ColumnElement]:
"""
Some databases will just return the group-by field into the select, but don't
allow the group-by field to be put into the select list.
:param groupby_exprs: mapping between column name and column object
:param select_exprs: all columns in the select clause
:return: columns to be included in the final select clause
"""
return select_exprs
@classmethod
def fetch_data(
cls, cursor: Any, limit: Optional[int] = None

View File

@ -81,6 +81,9 @@ class HiveEngineSpec(PrestoEngineSpec):
engine = "hive"
engine_name = "Apache Hive"
max_column_name_length = 767
allows_alias_to_source_column = True
allows_hidden_ordeby_agg = False
# pylint: disable=line-too-long
_time_grain_expressions = {
None: "{col}",

View File

@ -14,9 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Dict, List, Optional
from typing import Dict, Optional
from sqlalchemy.sql.expression import ColumnClause, ColumnElement
from sqlalchemy.sql.expression import ColumnClause
from superset.db_engine_specs.base import BaseEngineSpec, TimestampExpression
@ -112,9 +112,3 @@ class PinotEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
time_expr = f"DATETIMECONVERT({{col}}, '{tf}', '{tf}', '{granularity}')"
return TimestampExpression(time_expr, col)
@classmethod
def make_select_compatible(
cls, groupby_exprs: Dict[str, ColumnElement], select_exprs: List[ColumnElement]
) -> List[ColumnElement]:
return select_exprs

View File

@ -128,6 +128,7 @@ def get_children(column: Dict[str, str]) -> List[Dict[str, str]]:
class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-methods
engine = "presto"
engine_name = "Presto"
allows_alias_to_source_column = False
_time_grain_expressions = {
None: "{col}",

View File

@ -1631,6 +1631,22 @@ def find_duplicates(items: Iterable[InputType]) -> List[InputType]:
return [item for item, count in collections.Counter(items).items() if count > 1]
def remove_duplicates(
items: Iterable[InputType], key: Optional[Callable[[InputType], Any]] = None
) -> List[InputType]:
"""Remove duplicate items in an iterable."""
if not key:
return list(dict.fromkeys(items).keys())
seen = set()
result = []
for item in items:
item_key = key(item)
if item_key not in seen:
seen.add(item_key)
result.append(item)
return result
def normalize_dttm_col(
df: pd.DataFrame,
timestamp_format: Optional[str],

View File

@ -23,7 +23,7 @@ from sqlalchemy.engine import Engine
from tests.test_app import app
from superset import db
from superset.utils.core import get_example_database
from superset.utils.core import get_example_database, json_dumps_w_dates
CTAS_SCHEMA_NAME = "sqllab_test_db"
@ -73,13 +73,22 @@ def drop_from_schema(engine: Engine, schema_name: str):
def setup_presto_if_needed():
backend = app.config["SQLALCHEMY_EXAMPLES_URI"].split("://")[0]
database = get_example_database()
extra = database.get_extra()
if backend == "presto":
# decrease poll interval for tests
presto_poll_interval = app.config["PRESTO_POLL_INTERVAL"]
extra = f'{{"engine_params": {{"connect_args": {{"poll_interval": {presto_poll_interval}}}}}}}'
database = get_example_database()
database.extra = extra
db.session.commit()
extra = {
**extra,
"engine_params": {
"connect_args": {"poll_interval": app.config["PRESTO_POLL_INTERVAL"]}
},
}
else:
# remove `poll_interval` from databases that do not support it
extra = {**extra, "engine_params": {}}
database.extra = json_dumps_w_dates(extra)
db.session.commit()
if backend in {"presto", "hive"}:
database = get_example_database()

View File

@ -82,7 +82,10 @@ class TestExportDatabasesCommand(SupersetTestCase):
"schemas_allowed_for_csv_upload": [],
}
if backend() == "presto":
expected_extra = {"engine_params": {"connect_args": {"poll_interval": 0.1}}}
expected_extra = {
**expected_extra,
"engine_params": {"connect_args": {"poll_interval": 0.1}},
}
assert core_files.issubset(set(contents.keys()))

View File

@ -31,13 +31,13 @@ query_birth_names = {
},
"groupby": ["name"],
"metrics": [{"label": "sum__num"}],
"order_desc": True,
"orderby": [["sum__num", False]],
"row_limit": 100,
"granularity": "ds",
"time_range": "100 years ago : now",
"timeseries_limit": 0,
"timeseries_limit_metric": None,
"order_desc": True,
"filters": [
{"col": "gender", "op": "==", "val": "boy"},
{"col": "num", "op": "IS NOT NULL"},
@ -49,8 +49,57 @@ query_birth_names = {
}
QUERY_OBJECTS: Dict[str, Dict[str, object]] = {
"birth_names": {**query_birth_names, "is_timeseries": False,},
"birth_names:include_time": {**query_birth_names, "groupby": [DTTM_ALIAS, "name"],},
"birth_names": query_birth_names,
# `:suffix` are overrides only
"birth_names:include_time": {"groupby": [DTTM_ALIAS, "name"],},
"birth_names:orderby_dup_alias": {
"metrics": [
{
"expressionType": "SIMPLE",
"column": {"column_name": "num_girls", "type": "BIGINT(20)"},
"aggregate": "SUM",
"label": "num_girls",
},
{
"expressionType": "SIMPLE",
"column": {"column_name": "num_boys", "type": "BIGINT(20)"},
"aggregate": "SUM",
"label": "num_boys",
},
],
"orderby": [
[
{
"expressionType": "SIMPLE",
"column": {"column_name": "num_girls", "type": "BIGINT(20)"},
"aggregate": "SUM",
# the same underlying expression, but different label
"label": "SUM(num_girls)",
},
False,
],
# reference the ambiguous alias in SIMPLE metric
[
{
"expressionType": "SIMPLE",
"column": {"column_name": "num_boys", "type": "BIGINT(20)"},
"aggregate": "AVG",
"label": "AVG(num_boys)",
},
False,
],
# reference the ambiguous alias in CUSTOM SQL metric
[
{
"expressionType": "SQL",
"sqlExpression": "MAX(CASE WHEN num_boys > 0 THEN 1 ELSE 0 END)",
"label": "MAX(CASE WHEN...",
},
True,
],
],
},
"birth_names:only_orderby_has_metric": {"metrics": [],},
}
ANNOTATION_LAYERS = {
@ -150,7 +199,17 @@ def get_query_object(
) -> Dict[str, Any]:
if query_name not in QUERY_OBJECTS:
raise Exception(f"QueryObject fixture not defined for datasource: {query_name}")
query_object = copy.deepcopy(QUERY_OBJECTS[query_name])
obj = QUERY_OBJECTS[query_name]
# apply overrides
if ":" in query_name:
parent_query_name = query_name.split(":")[0]
obj = {
**QUERY_OBJECTS[parent_query_name],
**obj,
}
query_object = copy.deepcopy(obj)
if add_postprocessing_operations:
query_object["post_processing"] = _get_postprocessing_operation(query_name)
return query_object

View File

@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import re
from typing import Any, Dict
import pytest
@ -24,9 +25,9 @@ from superset.common.query_context import QueryContext
from superset.common.query_object import QueryObject
from superset.connectors.connector_registry import ConnectorRegistry
from superset.extensions import cache_manager
from superset.models.cache import CacheKey
from superset.utils.core import (
AdhocMetricExpressionType,
backend,
ChartDataResultFormat,
ChartDataResultType,
TimeRangeEndpoint,
@ -36,6 +37,17 @@ from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with
from tests.fixtures.query_context import get_query_context
def get_sql_text(payload: Dict[str, Any]) -> str:
payload["result_type"] = ChartDataResultType.QUERY.value
query_context = ChartDataQueryContextSchema().load(payload)
responses = query_context.get_payload()
assert len(responses) == 1
response = responses["queries"][0]
assert len(response) == 2
assert response["language"] == "sql"
return response["query"]
class TestQueryContext(SupersetTestCase):
def test_schema_deserialization(self):
"""
@ -301,14 +313,7 @@ class TestQueryContext(SupersetTestCase):
"""
self.login(username="admin")
payload = get_query_context("birth_names")
payload["result_type"] = ChartDataResultType.QUERY.value
query_context = ChartDataQueryContextSchema().load(payload)
responses = query_context.get_payload()
assert len(responses) == 1
response = responses["queries"][0]
assert len(response) == 2
sql_text = response["query"]
assert response["language"] == "sql"
sql_text = get_sql_text(payload)
assert "SELECT" in sql_text
assert re.search(r'[`"\[]?num[`"\]]? IS NOT NULL', sql_text)
assert re.search(
@ -318,37 +323,102 @@ class TestQueryContext(SupersetTestCase):
)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_fetch_values_predicate_in_query(self):
def test_handle_sort_by_metrics(self):
"""
Ensure that fetch values predicate is added to query
Should properly handle sort by metrics in various scenarios.
"""
self.login(username="admin")
payload = get_query_context("birth_names")
payload["result_type"] = ChartDataResultType.QUERY.value
payload["queries"][0]["apply_fetch_values_predicate"] = True
query_context = ChartDataQueryContextSchema().load(payload)
responses = query_context.get_payload()
assert len(responses) == 1
response = responses["queries"][0]
assert len(response) == 2
assert response["language"] == "sql"
assert "123 = 123" in response["query"]
sql_text = get_sql_text(get_query_context("birth_names"))
if backend() == "hive":
# should have no duplicate `SUM(num)`
assert "SUM(num) AS `sum__num`," not in sql_text
assert "SUM(num) AS `sum__num`" in sql_text
# the alias should be in ORDER BY
assert "ORDER BY `sum__num` DESC" in sql_text
else:
assert re.search(r'ORDER BY [`"\[]?sum__num[`"\]]? DESC', sql_text)
sql_text = get_sql_text(
get_query_context("birth_names:only_orderby_has_metric")
)
if backend() == "hive":
assert "SUM(num) AS `sum__num`," not in sql_text
assert "SUM(num) AS `sum__num`" in sql_text
assert "ORDER BY `sum__num` DESC" in sql_text
else:
assert re.search(
r'ORDER BY SUM\([`"\[]?num[`"\]]?\) DESC', sql_text, re.IGNORECASE
)
sql_text = get_sql_text(get_query_context("birth_names:orderby_dup_alias"))
# Check SELECT clauses
if backend() == "presto":
# presto cannot have ambiguous alias in order by, so selected column
# alias is renamed.
assert 'sum("num_boys") AS "num_boys__"' in sql_text
else:
assert re.search(
r'SUM\([`"\[]?num_boys[`"\]]?\) AS [`\"\[]?num_boys[`"\]]?',
sql_text,
re.IGNORECASE,
)
# Check ORDER BY clauses
if backend() == "hive":
# Hive must add additional SORT BY metrics to SELECT
assert re.search(
r"MAX\(CASE.*END\) AS `MAX\(CASE WHEN...`",
sql_text,
re.IGNORECASE | re.DOTALL,
)
# The additional column with the same expression but a different label
# as an existing metric should not be added
assert "sum(`num_girls`) AS `SUM(num_girls)`" not in sql_text
# Should reference all ORDER BY columns by aliases
assert "ORDER BY `num_girls` DESC," in sql_text
assert "`AVG(num_boys)` DESC," in sql_text
assert "`MAX(CASE WHEN...` ASC" in sql_text
else:
if backend() == "presto":
# since the selected `num_boys` is renamed to `num_boys__`
# it must be references as expression
assert re.search(
r'ORDER BY SUM\([`"\[]?num_girls[`"\]]?\) DESC',
sql_text,
re.IGNORECASE,
)
else:
# Should reference the adhoc metric by alias when possible
assert re.search(
r'ORDER BY [`"\[]?num_girls[`"\]]? DESC', sql_text, re.IGNORECASE,
)
# ORDER BY only columns should always be expressions
assert re.search(
r'AVG\([`"\[]?num_boys[`"\]]?\) DESC', sql_text, re.IGNORECASE,
)
assert re.search(
r"MAX\(CASE.*END\) ASC", sql_text, re.IGNORECASE | re.DOTALL
)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_fetch_values_predicate_not_in_query(self):
def test_fetch_values_predicate(self):
"""
Ensure that fetch values predicate is not added to query
Ensure that fetch values predicate is added to query if needed
"""
self.login(username="admin")
payload = get_query_context("birth_names")
payload["result_type"] = ChartDataResultType.QUERY.value
query_context = ChartDataQueryContextSchema().load(payload)
responses = query_context.get_payload()
assert len(responses) == 1
response = responses["queries"][0]
assert len(response) == 2
assert response["language"] == "sql"
assert "123 = 123" not in response["query"]
sql_text = get_sql_text(payload)
assert "123 = 123" not in sql_text
payload["queries"][0]["apply_fetch_values_predicate"] = True
sql_text = get_sql_text(payload)
assert "123 = 123" in sql_text
def test_query_object_unknown_fields(self):
"""