mirror of https://github.com/apache/superset.git
Add cache_key_wrapper to Jinja template processor (#7816)
This commit is contained in:
parent
f570b459f2
commit
4568b2a532
|
@ -87,6 +87,8 @@ Superset's Jinja context:
|
|||
|
||||
.. autofunction:: superset.jinja_context.filter_values
|
||||
|
||||
.. autofunction:: superset.jinja_context.CacheKeyWrapper.cache_key_wrapper
|
||||
|
||||
.. autoclass:: superset.jinja_context.PrestoTemplateProcessor
|
||||
:members:
|
||||
|
||||
|
|
|
@ -152,8 +152,13 @@ class QueryContext:
|
|||
|
||||
def get_df_payload(self, query_obj, **kwargs):
|
||||
"""Handles caching around the df paylod retrieval"""
|
||||
extra_cache_keys = self.datasource.get_extra_cache_keys(query_obj)
|
||||
cache_key = (
|
||||
query_obj.cache_key(datasource=self.datasource.uid, **kwargs)
|
||||
query_obj.cache_key(
|
||||
datasource=self.datasource.uid,
|
||||
extra_cache_keys=extra_cache_keys,
|
||||
**kwargs
|
||||
)
|
||||
if query_obj
|
||||
else None
|
||||
)
|
||||
|
|
|
@ -107,7 +107,7 @@ class QueryObject:
|
|||
|
||||
def cache_key(self, **extra):
|
||||
"""
|
||||
The cache key is made out of the key/values in `query_obj`, plus any
|
||||
The cache key is made out of the key/values from to_dict(), plus any
|
||||
other key/values in `extra`
|
||||
We remove datetime bounds that are hard values, and replace them with
|
||||
the use-provided inputs to bounds, which may be time-relative (as in
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
# under the License.
|
||||
# pylint: disable=C,R,W
|
||||
import json
|
||||
from typing import Any, List
|
||||
|
||||
from sqlalchemy import and_, Boolean, Column, Integer, String, Text
|
||||
from sqlalchemy.ext.declarative import declared_attr
|
||||
|
@ -73,9 +74,9 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
|
|||
)
|
||||
|
||||
# placeholder for a relationship to a derivative of BaseColumn
|
||||
columns = []
|
||||
columns: List[Any] = []
|
||||
# placeholder for a relationship to a derivative of BaseMetric
|
||||
metrics = []
|
||||
metrics: List[Any] = []
|
||||
|
||||
@property
|
||||
def uid(self):
|
||||
|
@ -329,6 +330,12 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
|
|||
obj.get("columns"), self.columns, self.column_class, "column_name"
|
||||
)
|
||||
|
||||
def get_extra_cache_keys(self, query_obj) -> List[Any]:
|
||||
""" If a datasource needs to provide additional keys for calculation of
|
||||
cache keys, those can be provided via this method
|
||||
"""
|
||||
return []
|
||||
|
||||
|
||||
class BaseColumn(AuditMixinNullable, ImportMixin):
|
||||
"""Interface for column"""
|
||||
|
@ -346,7 +353,7 @@ class BaseColumn(AuditMixinNullable, ImportMixin):
|
|||
is_dttm = None
|
||||
|
||||
# [optional] Set this to support import/export functionality
|
||||
export_fields = []
|
||||
export_fields: List[Any] = []
|
||||
|
||||
def __repr__(self):
|
||||
return self.column_name
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
from collections import namedtuple, OrderedDict
|
||||
from datetime import datetime
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
from flask import escape, Markup
|
||||
from flask_appbuilder import Model
|
||||
|
@ -61,7 +61,9 @@ from superset.utils import core as utils, import_datasource
|
|||
config = app.config
|
||||
metadata = Model.metadata # pylint: disable=no-member
|
||||
|
||||
SqlaQuery = namedtuple("SqlaQuery", ["sqla_query", "labels_expected"])
|
||||
SqlaQuery = namedtuple(
|
||||
"SqlaQuery", ["sqla_query", "labels_expected", "extra_cache_keys"]
|
||||
)
|
||||
QueryStringExtended = namedtuple("QueryStringExtended", ["sql", "labels_expected"])
|
||||
|
||||
|
||||
|
@ -618,6 +620,8 @@ class SqlaTable(Model, BaseDatasource):
|
|||
"columns": {col.column_name: col for col in self.columns},
|
||||
}
|
||||
template_kwargs.update(self.template_params_dict)
|
||||
extra_cache_keys: List[Any] = []
|
||||
template_kwargs["extra_cache_keys"] = extra_cache_keys
|
||||
template_processor = self.get_template_processor(**template_kwargs)
|
||||
db_engine_spec = self.database.db_engine_spec
|
||||
|
||||
|
@ -869,7 +873,9 @@ class SqlaTable(Model, BaseDatasource):
|
|||
qry = qry.where(top_groups)
|
||||
|
||||
return SqlaQuery(
|
||||
sqla_query=qry.select_from(tbl), labels_expected=labels_expected
|
||||
sqla_query=qry.select_from(tbl),
|
||||
labels_expected=labels_expected,
|
||||
extra_cache_keys=extra_cache_keys,
|
||||
)
|
||||
|
||||
def _get_timeseries_orderby(self, timeseries_limit_metric, metrics_dict, cols):
|
||||
|
@ -1058,6 +1064,10 @@ class SqlaTable(Model, BaseDatasource):
|
|||
def default_query(qry):
|
||||
return qry.filter_by(is_sqllab_view=False)
|
||||
|
||||
def get_extra_cache_keys(self, query_obj) -> List[Any]:
|
||||
sqla_query = self.get_sqla_query(**query_obj)
|
||||
return sqla_query.extra_cache_keys
|
||||
|
||||
|
||||
sa.event.listen(SqlaTable, "after_insert", security_manager.set_perm)
|
||||
sa.event.listen(SqlaTable, "after_update", security_manager.set_perm)
|
||||
|
|
|
@ -129,7 +129,43 @@ def filter_values(column: str, default: Optional[str] = None) -> List[str]:
|
|||
return []
|
||||
|
||||
|
||||
class BaseTemplateProcessor(object):
|
||||
class CacheKeyWrapper:
|
||||
""" Dummy class that exposes a method used to store additional values used in
|
||||
calculation of query object cache keys"""
|
||||
|
||||
def __init__(self, extra_cache_keys: Optional[List[Any]] = None):
|
||||
self.extra_cache_keys = extra_cache_keys
|
||||
|
||||
def cache_key_wrapper(self, key: Any) -> Any:
|
||||
""" Adds values to a list that is added to the query object used for calculating
|
||||
a cache key.
|
||||
|
||||
This is needed if the following applies:
|
||||
- Caching is enabled
|
||||
- The query is dynamically generated using a jinja template
|
||||
- A username or similar is used as a filter in the query
|
||||
|
||||
Example when using a SQL query as a data source ::
|
||||
|
||||
SELECT action, count(*) as times
|
||||
FROM logs
|
||||
WHERE logged_in_user = '{{ cache_key_wrapper(current_username()) }}'
|
||||
GROUP BY action
|
||||
|
||||
This will ensure that the query results that were cached by `user_1` will
|
||||
**not** be seen by `user_2`, as the `cache_key` for the query will be
|
||||
different. ``cache_key_wrapper`` can be used similarly for regular table data
|
||||
sources by adding a `Custom SQL` filter.
|
||||
|
||||
:param key: Any value that should be considered when calculating the cache key
|
||||
:return: the original value ``key`` passed to the function
|
||||
"""
|
||||
if self.extra_cache_keys is not None:
|
||||
self.extra_cache_keys.append(key)
|
||||
return key
|
||||
|
||||
|
||||
class BaseTemplateProcessor:
|
||||
"""Base class for database-specific jinja context
|
||||
|
||||
There's this bit of magic in ``process_template`` that instantiates only
|
||||
|
@ -146,7 +182,14 @@ class BaseTemplateProcessor(object):
|
|||
|
||||
engine: Optional[str] = None
|
||||
|
||||
def __init__(self, database=None, query=None, table=None, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
database=None,
|
||||
query=None,
|
||||
table=None,
|
||||
extra_cache_keys: Optional[List[Any]] = None,
|
||||
**kwargs
|
||||
):
|
||||
self.database = database
|
||||
self.query = query
|
||||
self.schema = None
|
||||
|
@ -158,6 +201,7 @@ class BaseTemplateProcessor(object):
|
|||
"url_param": url_param,
|
||||
"current_user_id": current_user_id,
|
||||
"current_username": current_username,
|
||||
"cache_key_wrapper": CacheKeyWrapper(extra_cache_keys).cache_key_wrapper,
|
||||
"filter_values": filter_values,
|
||||
"form_data": {},
|
||||
}
|
||||
|
@ -189,7 +233,9 @@ class PrestoTemplateProcessor(BaseTemplateProcessor):
|
|||
engine = "presto"
|
||||
|
||||
@staticmethod
|
||||
def _schema_table(table_name: str, schema: str) -> Tuple[str, str]:
|
||||
def _schema_table(
|
||||
table_name: str, schema: Optional[str]
|
||||
) -> Tuple[str, Optional[str]]:
|
||||
if "." in table_name:
|
||||
schema, table_name = table_name.split(".")
|
||||
return table_name, schema
|
||||
|
|
|
@ -364,6 +364,7 @@ class BaseViz(object):
|
|||
|
||||
cache_dict["time_range"] = self.form_data.get("time_range")
|
||||
cache_dict["datasource"] = self.datasource.uid
|
||||
cache_dict["extra_cache_keys"] = self.datasource.get_extra_cache_keys(query_obj)
|
||||
json_data = self.json_dumps(cache_dict, sort_keys=True)
|
||||
return hashlib.md5(json_data.encode("utf-8")).hexdigest()
|
||||
|
||||
|
|
|
@ -14,8 +14,10 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from superset.connectors.sqla.models import TableColumn
|
||||
from superset import db
|
||||
from superset.connectors.sqla.models import SqlaTable, TableColumn
|
||||
from superset.db_engine_specs.druid import DruidEngineSpec
|
||||
from superset.utils.core import get_main_database
|
||||
from .base_tests import SupersetTestCase
|
||||
|
||||
|
||||
|
@ -39,3 +41,20 @@ class DatabaseModelTestCase(SupersetTestCase):
|
|||
|
||||
col = TableColumn(column_name="foo", type="STRING")
|
||||
self.assertEquals(col.is_time, False)
|
||||
|
||||
def test_cache_key_wrapper(self):
|
||||
query = "SELECT '{{ cache_key_wrapper('user_1') }}' as user"
|
||||
table = SqlaTable(sql=query, database=get_main_database(db.session))
|
||||
query_obj = {
|
||||
"granularity": None,
|
||||
"from_dttm": None,
|
||||
"to_dttm": None,
|
||||
"groupby": ["user"],
|
||||
"metrics": [],
|
||||
"is_timeseries": False,
|
||||
"filter": [],
|
||||
"is_prequery": False,
|
||||
"extras": {"where": "(user != '{{ cache_key_wrapper('user_2') }}')"},
|
||||
}
|
||||
extra_cache_keys = table.get_extra_cache_keys(query_obj)
|
||||
self.assertListEqual(extra_cache_keys, ["user_1", "user_2"])
|
||||
|
|
Loading…
Reference in New Issue