From 2e172d77cf87ef9b4850f5fa183a2b1815bb6e58 Mon Sep 17 00:00:00 2001 From: Maxime Beauchemin Date: Wed, 7 Feb 2018 14:49:19 -0800 Subject: [PATCH] Fix caching issues (#4316) --- superset/assets/package.json | 4 +- superset/models/core.py | 3 +- superset/views/core.py | 18 ++-- superset/viz.py | 171 ++++++++++++++++++++++++----------- 4 files changed, 132 insertions(+), 64 deletions(-) diff --git a/superset/assets/package.json b/superset/assets/package.json index c944ad2fa0..abc978c079 100644 --- a/superset/assets/package.json +++ b/superset/assets/package.json @@ -93,8 +93,8 @@ "react-sortable-hoc": "^0.6.7", "react-split-pane": "^0.1.66", "react-syntax-highlighter": "^5.7.0", - "react-virtualized": "^9.3.0", - "react-virtualized-select": "^2.4.0", + "react-virtualized": "9.3.0", + "react-virtualized-select": "2.4.0", "reactable": "^0.14.1", "redux": "^3.5.2", "redux-localstorage": "^0.4.1", diff --git a/superset/models/core.py b/superset/models/core.py index 142482bdbd..cfd6d75203 100644 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -230,7 +230,7 @@ class Slice(Model, AuditMixinNullable, ImportMixin): name = escape(self.slice_name) return Markup('{name}'.format(**locals())) - def get_viz(self): + def get_viz(self, force=False): """Creates :py:class:viz.BaseViz object from the url_params_multidict. :return: object of the 'viz_type' type that is taken from the @@ -246,6 +246,7 @@ class Slice(Model, AuditMixinNullable, ImportMixin): return viz_types[slice_params.get('viz_type')]( self.datasource, form_data=slice_params, + force=force, ) @classmethod diff --git a/superset/views/core.py b/superset/views/core.py index 44b6b2a5a4..7fb96a9bc1 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -954,7 +954,9 @@ class Superset(BaseSupersetView): slice_id=None, form_data=None, datasource_type=None, - datasource_id=None): + datasource_id=None, + force=False, + ): if slice_id: slc = ( db.session.query(models.Slice) @@ -969,6 +971,7 @@ class Superset(BaseSupersetView): viz_obj = viz.viz_types[viz_type]( datasource, form_data=form_data, + force=force, ) return viz_obj @@ -1017,7 +1020,9 @@ class Superset(BaseSupersetView): viz_obj = self.get_viz( datasource_type=datasource_type, datasource_id=datasource_id, - form_data=form_data) + form_data=form_data, + force=force, + ) except Exception as e: logging.exception(e) return json_error_response( @@ -1038,7 +1043,7 @@ class Superset(BaseSupersetView): return self.get_query_string_response(viz_obj) try: - payload = viz_obj.get_payload(force=force) + payload = viz_obj.get_payload() except Exception as e: logging.exception(e) return json_error_response(utils.error_msg_from_exception(e)) @@ -1082,9 +1087,10 @@ class Superset(BaseSupersetView): viz_obj = viz.viz_types['table']( datasource, form_data=form_data, + force=False, ) try: - payload = viz_obj.get_payload(force=False) + payload = viz_obj.get_payload() except Exception as e: logging.exception(e) return json_error_response(utils.error_msg_from_exception(e)) @@ -1876,8 +1882,8 @@ class Superset(BaseSupersetView): for slc in slices: try: - obj = slc.get_viz() - obj.get_json(force=True) + obj = slc.get_viz(force=True) + obj.get_json() except Exception as e: return json_error_response(utils.error_msg_from_exception(e)) return json_success(json.dumps( diff --git a/superset/viz.py b/superset/viz.py index ebd0c788b9..d66884a3b6 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -18,10 +18,9 @@ import logging import math import traceback import uuid -import zlib from dateutil import relativedelta as rdelta -from flask import request +from flask import escape, request from flask_babel import lazy_gettext as _ import geohash from markdown import markdown @@ -30,8 +29,8 @@ import pandas as pd from pandas.tseries.frequencies import to_offset import polyline import simplejson as json -from six import PY3, string_types, text_type -from six.moves import reduce +from six import string_types, text_type +from six.moves import cPickle as pkl, reduce from superset import app, cache, get_manifest_file, utils from superset.utils import DTTM_ALIAS, merge_extra_filters @@ -49,8 +48,9 @@ class BaseViz(object): credits = '' is_timeseries = False default_fillna = 0 + cache_type = 'df' - def __init__(self, datasource, form_data): + def __init__(self, datasource, form_data, force=False): if not datasource: raise Exception(_('Viz is missing a datasource')) self.datasource = datasource @@ -67,6 +67,40 @@ class BaseViz(object): self.status = None self.error_message = None + self.force = force + + # Keeping track of whether some data came from cache + # this is useful to trigerr the when + # in the cases where visualization have many queries + # (FilterBox for instance) + self._some_from_cache = False + self._any_cache_key = None + self._any_cached_dttm = None + + self.run_extra_queries() + + def run_extra_queries(self): + """Lyfecycle method to use when more than one query is needed + + In rare-ish cases, a visualization may need to execute multiple + queries. That is the case for FilterBox or for time comparison + in Line chart for instance. + + In those cases, we need to make sure these queries run before the + main `get_payload` method gets called, so that the overall caching + metadata can be right. The way it works here is that if any of + the previous `get_df_payload` calls hit the cache, the main + payload's metadata will reflect that. + + The multi-query support may need more work to become a first class + use case in the framework, and for the UI to reflect the subtleties + (show that only some of the queries were served from cache for + instance). In the meantime, since multi-query is rare, we treat + it with a bit of a hack. Note that the hack became necessary + when moving from caching the visualization's data itself, to caching + the underlying query(ies). + """ + pass def get_fillna_for_col(self, col): """Returns the value for use as filler for a specific Column.type""" @@ -225,9 +259,9 @@ class BaseViz(object): return self.datasource.database.cache_timeout return config.get('CACHE_DEFAULT_TIMEOUT') - def get_json(self, force=False): + def get_json(self): return json.dumps( - self.get_payload(force), + self.get_payload(), default=utils.json_int_dttm_ser, ignore_nan=True) def cache_key(self, query_obj): @@ -249,64 +283,73 @@ class BaseViz(object): json_data = self.json_dumps(cache_dict, sort_keys=True) return hashlib.md5(json_data.encode('utf-8')).hexdigest() - def get_payload(self, force=False): - """Handles caching around the json payload retrieval""" - query_obj = self.query_obj() + def get_payload(self, query_obj=None): + """Returns a payload of metadata and data""" + payload = self.get_df_payload(query_obj) + df = payload['df'] + if df is not None: + payload['data'] = self.get_data(df) + del payload['df'] + return payload + + def get_df_payload(self, query_obj=None): + """Handles caching around the df payload retrieval""" + if not query_obj: + query_obj = self.query_obj() cache_key = self.cache_key(query_obj) if query_obj else None - cached_dttm = None - data = None + logging.info('Cache key: {}'.format(cache_key)) + is_loaded = False stacktrace = None - rowcount = None - if cache_key and cache and not force: + df = None + cached_dttm = datetime.utcnow().isoformat().split('.')[0] + if cache_key and cache and not self.force: cache_value = cache.get(cache_key) if cache_value: stats_logger.incr('loaded_from_cache') - is_cached = True try: - cache_value = zlib.decompress(cache_value) - if PY3: - cache_value = cache_value.decode('utf-8') - cache_value = json.loads(cache_value) - data = cache_value['data'] - cached_dttm = cache_value['dttm'] + cache_value = pkl.loads(cache_value) + df = cache_value['df'] + is_loaded = True + self._any_cache_key = cache_key + self._any_cached_dttm = cache_value['dttm'] except Exception as e: + logging.exception(e) logging.error('Error reading cache: ' + utils.error_msg_from_exception(e)) - data = None logging.info('Serving from cache') - if not data: - stats_logger.incr('loaded_from_source') - is_cached = False + if query_obj and not is_loaded: try: df = self.get_df(query_obj) - if not self.error_message: - data = self.get_data(df) - rowcount = len(df.index) if df is not None else 0 + stats_logger.incr('loaded_from_source') + is_loaded = True except Exception as e: logging.exception(e) if not self.error_message: - self.error_message = str(e) + self.error_message = escape('{}'.format(e)) self.status = utils.QueryStatus.FAILED - data = None stacktrace = traceback.format_exc() if ( - data and + is_loaded and cache_key and cache and self.status != utils.QueryStatus.FAILED): - cached_dttm = datetime.utcnow().isoformat().split('.')[0] try: - cache_value = self.json_dumps({ - 'data': data, - 'dttm': cached_dttm, - }) - if PY3: - cache_value = bytes(cache_value, 'utf-8') + cache_value = dict( + dttm=cached_dttm, + df=df if df is not None else None, + ) + cache_value = pkl.dumps( + cache_value, protocol=pkl.HIGHEST_PROTOCOL) + + logging.info('Caching {} chars at key {}'.format( + len(cache_value), cache_key)) + + stats_logger.incr('set_cache_key') cache.set( cache_key, - zlib.compress(cache_value), + cache_value, timeout=self.cache_timeout) except Exception as e: # cache.set call can fail if the backend is down or if @@ -316,17 +359,17 @@ class BaseViz(object): cache.delete(cache_key) return { - 'cache_key': cache_key, - 'cached_dttm': cached_dttm, + 'cache_key': self._any_cache_key, + 'cached_dttm': self._any_cached_dttm, 'cache_timeout': self.cache_timeout, - 'data': data, + 'df': df, 'error': self.error_message, 'form_data': self.form_data, - 'is_cached': is_cached, + 'is_cached': self._any_cache_key is not None, 'query': self.query, 'status': self.status, 'stacktrace': stacktrace, - 'rowcount': rowcount, + 'rowcount': len(df.index) if df is not None else 0, } def json_dumps(self, obj, sort_keys=False): @@ -415,7 +458,11 @@ class TableViz(BaseViz): def get_data(self, df): fd = self.form_data - if not self.should_be_timeseries() and DTTM_ALIAS in df: + if ( + not self.should_be_timeseries() and + df is not None and + DTTM_ALIAS in df + ): del df[DTTM_ALIAS] # Sum up and compute percentages for all percent metrics @@ -1062,12 +1109,10 @@ class NVD3TimeSeriesViz(NVD3Viz): df = df[num_period_compare:] return df - def get_data(self, df): + def run_extra_queries(self): fd = self.form_data - df = self.process_data(df) - chart_data = self.to_series(df) - time_compare = fd.get('time_compare') + self.extra_chart_data = None if time_compare: query_object = self.query_obj() delta = utils.parse_human_timedelta(time_compare) @@ -1081,12 +1126,20 @@ class NVD3TimeSeriesViz(NVD3Viz): query_object['from_dttm'] -= delta query_object['to_dttm'] -= delta - df2 = self.get_df(query_object) + df2 = self.get_df_payload(query_object).get('df') df2[DTTM_ALIAS] += delta df2 = self.process_data(df2) - chart_data += self.to_series( + self.extra_chart_data = self.to_series( df2, classed='superset', title_suffix='---') + + def get_data(self, df): + df = self.process_data(df) + chart_data = self.to_series(df) + + if self.extra_chart_data: + chart_data += self.extra_chart_data chart_data = sorted(chart_data, key=lambda x: x['key']) + return chart_data @@ -1564,10 +1617,20 @@ class FilterBoxViz(BaseViz): verbose_name = _('Filters') is_timeseries = False credits = 'a Superset original' + cache_type = 'get_data' def query_obj(self): return None + def run_extra_queries(self): + qry = self.filter_query_obj() + filters = [g for g in self.form_data['groupby']] + self.dataframes = {} + for flt in filters: + qry['groupby'] = [flt] + df = self.get_df_payload(query_obj=qry).get('df') + self.dataframes[flt] = df + def filter_query_obj(self): qry = super(FilterBoxViz, self).query_obj() groupby = self.form_data.get('groupby') @@ -1578,12 +1641,10 @@ class FilterBoxViz(BaseViz): return qry def get_data(self, df): - qry = self.filter_query_obj() - filters = [g for g in self.form_data['groupby']] d = {} + filters = [g for g in self.form_data['groupby']] for flt in filters: - qry['groupby'] = [flt] - df = super(FilterBoxViz, self).get_df(qry) + df = self.dataframes[flt] d[flt] = [{ 'id': row[0], 'text': row[0],