diff --git a/docs/sqllab.rst b/docs/sqllab.rst index 5fe24ad973..c60123fa71 100644 --- a/docs/sqllab.rst +++ b/docs/sqllab.rst @@ -87,6 +87,8 @@ Superset's Jinja context: .. autofunction:: superset.jinja_context.filter_values +.. autofunction:: superset.jinja_context.CacheKeyWrapper.cache_key_wrapper + .. autoclass:: superset.jinja_context.PrestoTemplateProcessor :members: diff --git a/superset/common/query_context.py b/superset/common/query_context.py index 499240586e..e8f7f62419 100644 --- a/superset/common/query_context.py +++ b/superset/common/query_context.py @@ -152,8 +152,13 @@ class QueryContext: def get_df_payload(self, query_obj, **kwargs): """Handles caching around the df paylod retrieval""" + extra_cache_keys = self.datasource.get_extra_cache_keys(query_obj) cache_key = ( - query_obj.cache_key(datasource=self.datasource.uid, **kwargs) + query_obj.cache_key( + datasource=self.datasource.uid, + extra_cache_keys=extra_cache_keys, + **kwargs + ) if query_obj else None ) diff --git a/superset/common/query_object.py b/superset/common/query_object.py index f4130319c2..7d72aa5eb7 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -107,7 +107,7 @@ class QueryObject: def cache_key(self, **extra): """ - The cache key is made out of the key/values in `query_obj`, plus any + The cache key is made out of the key/values from to_dict(), plus any other key/values in `extra` We remove datetime bounds that are hard values, and replace them with the use-provided inputs to bounds, which may be time-relative (as in diff --git a/superset/connectors/base/models.py b/superset/connectors/base/models.py index e90d5d3fcd..da7fec5d6a 100644 --- a/superset/connectors/base/models.py +++ b/superset/connectors/base/models.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=C,R,W import json +from typing import Any, List from sqlalchemy import and_, Boolean, Column, Integer, String, Text from sqlalchemy.ext.declarative import declared_attr @@ -73,9 +74,9 @@ class BaseDatasource(AuditMixinNullable, ImportMixin): ) # placeholder for a relationship to a derivative of BaseColumn - columns = [] + columns: List[Any] = [] # placeholder for a relationship to a derivative of BaseMetric - metrics = [] + metrics: List[Any] = [] @property def uid(self): @@ -329,6 +330,12 @@ class BaseDatasource(AuditMixinNullable, ImportMixin): obj.get("columns"), self.columns, self.column_class, "column_name" ) + def get_extra_cache_keys(self, query_obj) -> List[Any]: + """ If a datasource needs to provide additional keys for calculation of + cache keys, those can be provided via this method + """ + return [] + class BaseColumn(AuditMixinNullable, ImportMixin): """Interface for column""" @@ -346,7 +353,7 @@ class BaseColumn(AuditMixinNullable, ImportMixin): is_dttm = None # [optional] Set this to support import/export functionality - export_fields = [] + export_fields: List[Any] = [] def __repr__(self): return self.column_name diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index c9eed47df0..dfedca3a86 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -18,7 +18,7 @@ from collections import namedtuple, OrderedDict from datetime import datetime import logging -from typing import Optional, Union +from typing import Any, List, Optional, Union from flask import escape, Markup from flask_appbuilder import Model @@ -61,7 +61,9 @@ from superset.utils import core as utils, import_datasource config = app.config metadata = Model.metadata # pylint: disable=no-member -SqlaQuery = namedtuple("SqlaQuery", ["sqla_query", "labels_expected"]) +SqlaQuery = namedtuple( + "SqlaQuery", ["sqla_query", "labels_expected", "extra_cache_keys"] +) QueryStringExtended = namedtuple("QueryStringExtended", ["sql", "labels_expected"]) @@ -618,6 +620,8 @@ class SqlaTable(Model, BaseDatasource): "columns": {col.column_name: col for col in self.columns}, } template_kwargs.update(self.template_params_dict) + extra_cache_keys: List[Any] = [] + template_kwargs["extra_cache_keys"] = extra_cache_keys template_processor = self.get_template_processor(**template_kwargs) db_engine_spec = self.database.db_engine_spec @@ -869,7 +873,9 @@ class SqlaTable(Model, BaseDatasource): qry = qry.where(top_groups) return SqlaQuery( - sqla_query=qry.select_from(tbl), labels_expected=labels_expected + sqla_query=qry.select_from(tbl), + labels_expected=labels_expected, + extra_cache_keys=extra_cache_keys, ) def _get_timeseries_orderby(self, timeseries_limit_metric, metrics_dict, cols): @@ -1058,6 +1064,10 @@ class SqlaTable(Model, BaseDatasource): def default_query(qry): return qry.filter_by(is_sqllab_view=False) + def get_extra_cache_keys(self, query_obj) -> List[Any]: + sqla_query = self.get_sqla_query(**query_obj) + return sqla_query.extra_cache_keys + sa.event.listen(SqlaTable, "after_insert", security_manager.set_perm) sa.event.listen(SqlaTable, "after_update", security_manager.set_perm) diff --git a/superset/jinja_context.py b/superset/jinja_context.py index cfb75937c4..97b4dfe30c 100644 --- a/superset/jinja_context.py +++ b/superset/jinja_context.py @@ -129,7 +129,43 @@ def filter_values(column: str, default: Optional[str] = None) -> List[str]: return [] -class BaseTemplateProcessor(object): +class CacheKeyWrapper: + """ Dummy class that exposes a method used to store additional values used in + calculation of query object cache keys""" + + def __init__(self, extra_cache_keys: Optional[List[Any]] = None): + self.extra_cache_keys = extra_cache_keys + + def cache_key_wrapper(self, key: Any) -> Any: + """ Adds values to a list that is added to the query object used for calculating + a cache key. + + This is needed if the following applies: + - Caching is enabled + - The query is dynamically generated using a jinja template + - A username or similar is used as a filter in the query + + Example when using a SQL query as a data source :: + + SELECT action, count(*) as times + FROM logs + WHERE logged_in_user = '{{ cache_key_wrapper(current_username()) }}' + GROUP BY action + + This will ensure that the query results that were cached by `user_1` will + **not** be seen by `user_2`, as the `cache_key` for the query will be + different. ``cache_key_wrapper`` can be used similarly for regular table data + sources by adding a `Custom SQL` filter. + + :param key: Any value that should be considered when calculating the cache key + :return: the original value ``key`` passed to the function + """ + if self.extra_cache_keys is not None: + self.extra_cache_keys.append(key) + return key + + +class BaseTemplateProcessor: """Base class for database-specific jinja context There's this bit of magic in ``process_template`` that instantiates only @@ -146,7 +182,14 @@ class BaseTemplateProcessor(object): engine: Optional[str] = None - def __init__(self, database=None, query=None, table=None, **kwargs): + def __init__( + self, + database=None, + query=None, + table=None, + extra_cache_keys: Optional[List[Any]] = None, + **kwargs + ): self.database = database self.query = query self.schema = None @@ -158,6 +201,7 @@ class BaseTemplateProcessor(object): "url_param": url_param, "current_user_id": current_user_id, "current_username": current_username, + "cache_key_wrapper": CacheKeyWrapper(extra_cache_keys).cache_key_wrapper, "filter_values": filter_values, "form_data": {}, } @@ -189,7 +233,9 @@ class PrestoTemplateProcessor(BaseTemplateProcessor): engine = "presto" @staticmethod - def _schema_table(table_name: str, schema: str) -> Tuple[str, str]: + def _schema_table( + table_name: str, schema: Optional[str] + ) -> Tuple[str, Optional[str]]: if "." in table_name: schema, table_name = table_name.split(".") return table_name, schema diff --git a/superset/viz.py b/superset/viz.py index d81e16e066..52075be02c 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -364,6 +364,7 @@ class BaseViz(object): cache_dict["time_range"] = self.form_data.get("time_range") cache_dict["datasource"] = self.datasource.uid + cache_dict["extra_cache_keys"] = self.datasource.get_extra_cache_keys(query_obj) json_data = self.json_dumps(cache_dict, sort_keys=True) return hashlib.md5(json_data.encode("utf-8")).hexdigest() diff --git a/tests/sqla_models_tests.py b/tests/sqla_models_tests.py index 93a285909e..f089628798 100644 --- a/tests/sqla_models_tests.py +++ b/tests/sqla_models_tests.py @@ -14,8 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from superset.connectors.sqla.models import TableColumn +from superset import db +from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.db_engine_specs.druid import DruidEngineSpec +from superset.utils.core import get_main_database from .base_tests import SupersetTestCase @@ -39,3 +41,20 @@ class DatabaseModelTestCase(SupersetTestCase): col = TableColumn(column_name="foo", type="STRING") self.assertEquals(col.is_time, False) + + def test_cache_key_wrapper(self): + query = "SELECT '{{ cache_key_wrapper('user_1') }}' as user" + table = SqlaTable(sql=query, database=get_main_database(db.session)) + query_obj = { + "granularity": None, + "from_dttm": None, + "to_dttm": None, + "groupby": ["user"], + "metrics": [], + "is_timeseries": False, + "filter": [], + "is_prequery": False, + "extras": {"where": "(user != '{{ cache_key_wrapper('user_2') }}')"}, + } + extra_cache_keys = table.get_extra_cache_keys(query_obj) + self.assertListEqual(extra_cache_keys, ["user_1", "user_2"])