fix(filter-indicator): show filters handled by jinja as applied (#17140)

This commit is contained in:
Ville Brofeldt 2021-10-18 19:28:05 +02:00 committed by GitHub
parent 565ee2318d
commit d7834f17e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 48 additions and 20 deletions

View File

@ -96,6 +96,7 @@ def _get_full(
datasource = _get_datasource(query_context, query_obj)
result_type = query_obj.result_type or query_context.result_type
payload = query_context.get_df_payload(query_obj, force_cached=force_cached)
applied_template_filters = payload.get("applied_template_filters", [])
df = payload["df"]
status = payload["status"]
if status != QueryStatus.FAILED:
@ -113,12 +114,14 @@ def _get_full(
datasource, query_obj.applied_time_extras
)
payload["applied_filters"] = [
{"column": col} for col in filter_columns if col in columns
{"column": col}
for col in filter_columns
if col in columns or col in applied_template_filters
] + applied_time_columns
payload["rejected_filters"] = [
{"reason": ExtraFiltersReasonType.COL_NOT_IN_DATASOURCE, "column": col}
for col in filter_columns
if col not in columns
if col not in columns and col not in applied_template_filters
] + rejected_time_columns
if result_type == ChartDataResultType.RESULTS and status != QueryStatus.FAILED:

View File

@ -485,6 +485,7 @@ class QueryContext:
"cached_dttm": cache.cache_dttm,
"cache_timeout": self.cache_timeout,
"df": cache.df,
"applied_template_filters": cache.applied_template_filters,
"annotation_data": cache.annotation_data,
"error": cache.error_message,
"is_cached": cache.is_cached,

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional
from flask_caching import Cache
from pandas import DataFrame
@ -51,6 +51,7 @@ class QueryCacheManager:
df: DataFrame = DataFrame(),
query: str = "",
annotation_data: Optional[Dict[str, Any]] = None,
applied_template_filters: Optional[List[str]] = None,
status: Optional[str] = None,
error_message: Optional[str] = None,
is_loaded: bool = False,
@ -62,6 +63,7 @@ class QueryCacheManager:
self.df = df
self.query = query
self.annotation_data = {} if annotation_data is None else annotation_data
self.applied_template_filters = applied_template_filters or []
self.status = status
self.error_message = error_message
@ -88,6 +90,7 @@ class QueryCacheManager:
try:
self.status = query_result.status
self.query = query_result.query
self.applied_template_filters = query_result.applied_template_filters
self.error_message = query_result.error_message
self.df = query_result.df
self.annotation_data = {} if annotation_data is None else annotation_data
@ -101,6 +104,7 @@ class QueryCacheManager:
value = {
"df": self.df,
"query": self.query,
"applied_template_filters": self.applied_template_filters,
"annotation_data": self.annotation_data,
}
if self.is_loaded and key and self.status != QueryStatus.FAILED:
@ -141,6 +145,9 @@ class QueryCacheManager:
query_cache.df = cache_value["df"]
query_cache.query = cache_value["query"]
query_cache.annotation_data = cache_value.get("annotation_data", {})
query_cache.applied_template_filters = cache_value.get(
"applied_template_filters", []
)
query_cache.status = QueryStatus.SUCCESS
query_cache.is_loaded = True
query_cache.is_cached = cache_value is not None

View File

@ -103,6 +103,7 @@ VIRTUAL_TABLE_ALIAS = "virtual_table"
class SqlaQuery(NamedTuple):
applied_template_filters: List[str]
extra_cache_keys: List[Any]
labels_expected: List[str]
prequeries: List[str]
@ -110,6 +111,7 @@ class SqlaQuery(NamedTuple):
class QueryStringExtended(NamedTuple):
applied_template_filters: Optional[List[str]]
labels_expected: List[str]
prequeries: List[str]
sql: str
@ -755,7 +757,10 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
sql = sqlparse.format(sql, reindent=True)
sql = self.mutate_query_from_config(sql)
return QueryStringExtended(
labels_expected=sqlaq.labels_expected, sql=sql, prequeries=sqlaq.prequeries
applied_template_filters=sqlaq.applied_template_filters,
labels_expected=sqlaq.labels_expected,
prequeries=sqlaq.prequeries,
sql=sql,
)
def get_query_str(self, query_obj: QueryObjectDict) -> str:
@ -978,7 +983,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
extra_cache_keys: List[Any] = []
template_kwargs["extra_cache_keys"] = extra_cache_keys
removed_filters: List[str] = []
applied_template_filters: List[str] = []
template_kwargs["removed_filters"] = removed_filters
template_kwargs["applied_filters"] = applied_template_filters
template_processor = self.get_template_processor(**template_kwargs)
db_engine_spec = self.db_engine_spec
prequeries: List[str] = []
@ -1394,6 +1401,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
labels_expected = [label]
return SqlaQuery(
applied_template_filters=applied_template_filters,
extra_cache_keys=extra_cache_keys,
labels_expected=labels_expected,
sqla_query=qry,
@ -1491,6 +1499,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
error_message = utils.error_msg_from_exception(ex)
return QueryResult(
applied_template_filters=query_str_ext.applied_template_filters,
status=status,
df=df,
duration=datetime.now() - qry_start_dttm,

View File

@ -96,10 +96,12 @@ class ExtraCache:
def __init__(
self,
extra_cache_keys: Optional[List[Any]] = None,
applied_filters: Optional[List[str]] = None,
removed_filters: Optional[List[str]] = None,
dialect: Optional[Dialect] = None,
):
self.extra_cache_keys = extra_cache_keys
self.applied_filters = applied_filters if applied_filters is not None else []
self.removed_filters = removed_filters if removed_filters is not None else []
self.dialect = dialect
@ -323,6 +325,9 @@ class ExtraCache:
if remove_filter:
if column not in self.removed_filters:
self.removed_filters.append(column)
if column not in self.applied_filters:
self.applied_filters.append(column)
if op in (
FilterOperator.IN.value,
FilterOperator.NOT_IN.value,
@ -408,6 +413,7 @@ class BaseTemplateProcessor:
table: Optional["SqlaTable"] = None,
extra_cache_keys: Optional[List[Any]] = None,
removed_filters: Optional[List[str]] = None,
applied_filters: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
self._database = database
@ -418,6 +424,7 @@ class BaseTemplateProcessor:
elif table:
self._schema = table.schema
self._extra_cache_keys = extra_cache_keys
self._applied_filters = applied_filters
self._removed_filters = removed_filters
self._context: Dict[str, Any] = {}
self._env = SandboxedEnvironment(undefined=DebugUndefined)
@ -446,6 +453,7 @@ class JinjaTemplateProcessor(BaseTemplateProcessor):
super().set_context(**kwargs)
extra_cache = ExtraCache(
extra_cache_keys=self._extra_cache_keys,
applied_filters=self._applied_filters,
removed_filters=self._removed_filters,
dialect=self._database.get_dialect(),
)

View File

@ -442,6 +442,7 @@ class QueryResult: # pylint: disable=too-few-public-methods
df: pd.DataFrame,
query: str,
duration: timedelta,
applied_template_filters: Optional[List[str]] = None,
status: str = QueryStatus.SUCCESS,
error_message: Optional[str] = None,
errors: Optional[List[Dict[str, Any]]] = None,
@ -449,6 +450,7 @@ class QueryResult: # pylint: disable=too-few-public-methods
self.df = df
self.query = query
self.duration = duration
self.applied_template_filters = applied_template_filters or []
self.status = status
self.error_message = error_message
self.errors = errors or []

View File

@ -102,11 +102,6 @@ METRIC_KEYS = [
"size",
]
# This regex is to get user defined filter column name, which is the first param in the
# filter_values function. See the definition of filter_values template:
# https://github.com/apache/superset/blob/24ad6063d736c1f38ad6f962e586b9b1a21946af/superset/jinja_context.py#L63
FILTER_VALUES_REGEX = re.compile(r"filter_values\(['\"](\w+)['\"]\,")
class BaseViz: # pylint: disable=too-many-public-methods
@ -143,6 +138,7 @@ class BaseViz: # pylint: disable=too-many-public-methods
self.status: Optional[str] = None
self.error_msg = ""
self.results: Optional[QueryResult] = None
self.applied_template_filters: List[str] = []
self.errors: List[Dict[str, Any]] = []
self.force = force
self._force_cached = force_cached
@ -270,6 +266,7 @@ class BaseViz: # pylint: disable=too-many-public-methods
# The datasource here can be different backend but the interface is common
self.results = self.datasource.query(query_obj)
self.applied_template_filters = self.results.applied_template_filters or []
self.query = self.results.query
self.status = self.results.status
self.errors = self.results.errors
@ -459,14 +456,7 @@ class BaseViz: # pylint: disable=too-many-public-methods
filters = self.form_data.get("filters", [])
filter_columns = [flt.get("col") for flt in filters]
columns = set(self.datasource.column_names)
filter_values_columns = []
# if using virtual datasource, check filter_values
if self.datasource.sql:
filter_values_columns = (
re.findall(FILTER_VALUES_REGEX, self.datasource.sql)
) or []
applied_template_filters = self.applied_template_filters or []
applied_time_extras = self.form_data.get("applied_time_extras", {})
applied_time_columns, rejected_time_columns = utils.get_time_filter_status(
self.datasource, applied_time_extras
@ -474,18 +464,18 @@ class BaseViz: # pylint: disable=too-many-public-methods
payload["applied_filters"] = [
{"column": col}
for col in filter_columns
if col in columns or col in filter_values_columns
if col in columns or col in applied_template_filters
] + applied_time_columns
payload["rejected_filters"] = [
{"reason": ExtraFiltersReasonType.COL_NOT_IN_DATASOURCE, "column": col}
for col in filter_columns
if col not in columns and col not in filter_values_columns
if col not in columns and col not in applied_template_filters
] + rejected_time_columns
if df is not None:
payload["colnames"] = list(df.columns)
return payload
def get_df_payload(
def get_df_payload( # pylint: disable=too-many-statements
self, query_obj: Optional[QueryObjectDict] = None, **kwargs: Any
) -> Dict[str, Any]:
"""Handles caching around the df payload retrieval"""
@ -504,6 +494,9 @@ class BaseViz: # pylint: disable=too-many-public-methods
try:
df = cache_value["df"]
self.query = cache_value["query"]
self.applied_template_filters = cache_value.get(
"applied_template_filters", []
)
self.status = QueryStatus.SUCCESS
is_loaded = True
stats_logger.incr("loaded_from_cache")

View File

@ -74,6 +74,7 @@ class TestJinja2Context(SupersetTestCase):
):
cache = ExtraCache()
self.assertEqual(cache.filter_values("name"), ["foo"])
self.assertEqual(cache.applied_filters, ["name"])
with app.test_request_context(
data={
@ -94,6 +95,7 @@ class TestJinja2Context(SupersetTestCase):
):
cache = ExtraCache()
self.assertEqual(cache.filter_values("name"), ["foo", "bar"])
self.assertEqual(cache.applied_filters, ["name"])
def test_get_filters_adhoc_filters(self) -> None:
with app.test_request_context(
@ -118,6 +120,7 @@ class TestJinja2Context(SupersetTestCase):
cache.get_filters("name"), [{"op": "IN", "col": "name", "val": ["foo"]}]
)
self.assertEqual(cache.removed_filters, list())
self.assertEqual(cache.applied_filters, ["name"])
with app.test_request_context(
data={
@ -166,6 +169,7 @@ class TestJinja2Context(SupersetTestCase):
[{"op": "IN", "col": "name", "val": ["foo", "bar"]}],
)
self.assertEqual(cache.removed_filters, ["name"])
self.assertEqual(cache.applied_filters, ["name"])
def test_filter_values_extra_filters(self) -> None:
with app.test_request_context(
@ -177,6 +181,7 @@ class TestJinja2Context(SupersetTestCase):
):
cache = ExtraCache()
self.assertEqual(cache.filter_values("name"), ["foo"])
self.assertEqual(cache.applied_filters, ["name"])
def test_url_param_default(self) -> None:
with app.test_request_context():