Add cache_key_wrapper to Jinja template processor (#7816)

This commit is contained in:
Ville Brofeldt 2019-07-20 19:12:35 +03:00 committed by Maxime Beauchemin
parent f570b459f2
commit 4568b2a532
8 changed files with 102 additions and 12 deletions

View File

@ -87,6 +87,8 @@ Superset's Jinja context:
.. autofunction:: superset.jinja_context.filter_values
.. autofunction:: superset.jinja_context.CacheKeyWrapper.cache_key_wrapper
.. autoclass:: superset.jinja_context.PrestoTemplateProcessor
:members:

View File

@ -152,8 +152,13 @@ class QueryContext:
def get_df_payload(self, query_obj, **kwargs):
"""Handles caching around the df paylod retrieval"""
extra_cache_keys = self.datasource.get_extra_cache_keys(query_obj)
cache_key = (
query_obj.cache_key(datasource=self.datasource.uid, **kwargs)
query_obj.cache_key(
datasource=self.datasource.uid,
extra_cache_keys=extra_cache_keys,
**kwargs
)
if query_obj
else None
)

View File

@ -107,7 +107,7 @@ class QueryObject:
def cache_key(self, **extra):
"""
The cache key is made out of the key/values in `query_obj`, plus any
The cache key is made out of the key/values from to_dict(), plus any
other key/values in `extra`
We remove datetime bounds that are hard values, and replace them with
the use-provided inputs to bounds, which may be time-relative (as in

View File

@ -16,6 +16,7 @@
# under the License.
# pylint: disable=C,R,W
import json
from typing import Any, List
from sqlalchemy import and_, Boolean, Column, Integer, String, Text
from sqlalchemy.ext.declarative import declared_attr
@ -73,9 +74,9 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
)
# placeholder for a relationship to a derivative of BaseColumn
columns = []
columns: List[Any] = []
# placeholder for a relationship to a derivative of BaseMetric
metrics = []
metrics: List[Any] = []
@property
def uid(self):
@ -329,6 +330,12 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
obj.get("columns"), self.columns, self.column_class, "column_name"
)
def get_extra_cache_keys(self, query_obj) -> List[Any]:
""" If a datasource needs to provide additional keys for calculation of
cache keys, those can be provided via this method
"""
return []
class BaseColumn(AuditMixinNullable, ImportMixin):
"""Interface for column"""
@ -346,7 +353,7 @@ class BaseColumn(AuditMixinNullable, ImportMixin):
is_dttm = None
# [optional] Set this to support import/export functionality
export_fields = []
export_fields: List[Any] = []
def __repr__(self):
return self.column_name

View File

@ -18,7 +18,7 @@
from collections import namedtuple, OrderedDict
from datetime import datetime
import logging
from typing import Optional, Union
from typing import Any, List, Optional, Union
from flask import escape, Markup
from flask_appbuilder import Model
@ -61,7 +61,9 @@ from superset.utils import core as utils, import_datasource
config = app.config
metadata = Model.metadata # pylint: disable=no-member
SqlaQuery = namedtuple("SqlaQuery", ["sqla_query", "labels_expected"])
SqlaQuery = namedtuple(
"SqlaQuery", ["sqla_query", "labels_expected", "extra_cache_keys"]
)
QueryStringExtended = namedtuple("QueryStringExtended", ["sql", "labels_expected"])
@ -618,6 +620,8 @@ class SqlaTable(Model, BaseDatasource):
"columns": {col.column_name: col for col in self.columns},
}
template_kwargs.update(self.template_params_dict)
extra_cache_keys: List[Any] = []
template_kwargs["extra_cache_keys"] = extra_cache_keys
template_processor = self.get_template_processor(**template_kwargs)
db_engine_spec = self.database.db_engine_spec
@ -869,7 +873,9 @@ class SqlaTable(Model, BaseDatasource):
qry = qry.where(top_groups)
return SqlaQuery(
sqla_query=qry.select_from(tbl), labels_expected=labels_expected
sqla_query=qry.select_from(tbl),
labels_expected=labels_expected,
extra_cache_keys=extra_cache_keys,
)
def _get_timeseries_orderby(self, timeseries_limit_metric, metrics_dict, cols):
@ -1058,6 +1064,10 @@ class SqlaTable(Model, BaseDatasource):
def default_query(qry):
return qry.filter_by(is_sqllab_view=False)
def get_extra_cache_keys(self, query_obj) -> List[Any]:
sqla_query = self.get_sqla_query(**query_obj)
return sqla_query.extra_cache_keys
sa.event.listen(SqlaTable, "after_insert", security_manager.set_perm)
sa.event.listen(SqlaTable, "after_update", security_manager.set_perm)

View File

@ -129,7 +129,43 @@ def filter_values(column: str, default: Optional[str] = None) -> List[str]:
return []
class BaseTemplateProcessor(object):
class CacheKeyWrapper:
""" Dummy class that exposes a method used to store additional values used in
calculation of query object cache keys"""
def __init__(self, extra_cache_keys: Optional[List[Any]] = None):
self.extra_cache_keys = extra_cache_keys
def cache_key_wrapper(self, key: Any) -> Any:
""" Adds values to a list that is added to the query object used for calculating
a cache key.
This is needed if the following applies:
- Caching is enabled
- The query is dynamically generated using a jinja template
- A username or similar is used as a filter in the query
Example when using a SQL query as a data source ::
SELECT action, count(*) as times
FROM logs
WHERE logged_in_user = '{{ cache_key_wrapper(current_username()) }}'
GROUP BY action
This will ensure that the query results that were cached by `user_1` will
**not** be seen by `user_2`, as the `cache_key` for the query will be
different. ``cache_key_wrapper`` can be used similarly for regular table data
sources by adding a `Custom SQL` filter.
:param key: Any value that should be considered when calculating the cache key
:return: the original value ``key`` passed to the function
"""
if self.extra_cache_keys is not None:
self.extra_cache_keys.append(key)
return key
class BaseTemplateProcessor:
"""Base class for database-specific jinja context
There's this bit of magic in ``process_template`` that instantiates only
@ -146,7 +182,14 @@ class BaseTemplateProcessor(object):
engine: Optional[str] = None
def __init__(self, database=None, query=None, table=None, **kwargs):
def __init__(
self,
database=None,
query=None,
table=None,
extra_cache_keys: Optional[List[Any]] = None,
**kwargs
):
self.database = database
self.query = query
self.schema = None
@ -158,6 +201,7 @@ class BaseTemplateProcessor(object):
"url_param": url_param,
"current_user_id": current_user_id,
"current_username": current_username,
"cache_key_wrapper": CacheKeyWrapper(extra_cache_keys).cache_key_wrapper,
"filter_values": filter_values,
"form_data": {},
}
@ -189,7 +233,9 @@ class PrestoTemplateProcessor(BaseTemplateProcessor):
engine = "presto"
@staticmethod
def _schema_table(table_name: str, schema: str) -> Tuple[str, str]:
def _schema_table(
table_name: str, schema: Optional[str]
) -> Tuple[str, Optional[str]]:
if "." in table_name:
schema, table_name = table_name.split(".")
return table_name, schema

View File

@ -364,6 +364,7 @@ class BaseViz(object):
cache_dict["time_range"] = self.form_data.get("time_range")
cache_dict["datasource"] = self.datasource.uid
cache_dict["extra_cache_keys"] = self.datasource.get_extra_cache_keys(query_obj)
json_data = self.json_dumps(cache_dict, sort_keys=True)
return hashlib.md5(json_data.encode("utf-8")).hexdigest()

View File

@ -14,8 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset.connectors.sqla.models import TableColumn
from superset import db
from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.db_engine_specs.druid import DruidEngineSpec
from superset.utils.core import get_main_database
from .base_tests import SupersetTestCase
@ -39,3 +41,20 @@ class DatabaseModelTestCase(SupersetTestCase):
col = TableColumn(column_name="foo", type="STRING")
self.assertEquals(col.is_time, False)
def test_cache_key_wrapper(self):
query = "SELECT '{{ cache_key_wrapper('user_1') }}' as user"
table = SqlaTable(sql=query, database=get_main_database(db.session))
query_obj = {
"granularity": None,
"from_dttm": None,
"to_dttm": None,
"groupby": ["user"],
"metrics": [],
"is_timeseries": False,
"filter": [],
"is_prequery": False,
"extras": {"where": "(user != '{{ cache_key_wrapper('user_2') }}')"},
}
extra_cache_keys = table.get_extra_cache_keys(query_obj)
self.assertListEqual(extra_cache_keys, ["user_1", "user_2"])