From af7cdeba4d7eb38c51467a5987b2b3321cf23c69 Mon Sep 17 00:00:00 2001 From: Jeff Niu Date: Sun, 17 Dec 2017 10:35:00 -0800 Subject: [PATCH] [Feature] enhanced memoized on get_sqla_engine and other functions (#3530) * added watch to memoized * added unit tests for memoized * code style changes --- superset/models/core.py | 7 ++-- superset/utils.py | 36 ++++++++++++++----- tests/utils_tests.py | 76 ++++++++++++++++++++++++++++++++++++++++- 3 files changed, 108 insertions(+), 11 deletions(-) diff --git a/superset/models/core.py b/superset/models/core.py index 2c6e8b015f..396db2dc32 100644 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -629,6 +629,8 @@ class Database(Model, AuditMixinNullable, ImportMixin): effective_username = g.user.username return effective_username + @utils.memoized( + watch=('impersonate_user', 'sqlalchemy_uri_decrypted', 'extra')) def get_sqla_engine(self, schema=None, nullpool=False, user_name=None): extra = self.get_extra() url = make_url(self.sqlalchemy_uri_decrypted) @@ -662,10 +664,10 @@ class Database(Model, AuditMixinNullable, ImportMixin): return create_engine(url, **params) def get_reserved_words(self): - return self.get_sqla_engine().dialect.preparer.reserved_words + return self.get_dialect().preparer.reserved_words def get_quoter(self): - return self.get_sqla_engine().dialect.identifier_preparer.quote + return self.get_dialect().identifier_preparer.quote def get_df(self, sql, schema): sql = sql.strip().strip(';') @@ -813,6 +815,7 @@ class Database(Model, AuditMixinNullable, ImportMixin): return engine.has_table( table.table_name, table.schema or None) + @utils.memoized def get_dialect(self): sqla_url = url.make_url(self.sqlalchemy_uri_decrypted) return sqla_url.get_dialect()() diff --git a/superset/utils.py b/superset/utils.py index bae330b4af..afe2f419ad 100644 --- a/superset/utils.py +++ b/superset/utils.py @@ -91,38 +91,58 @@ def flasher(msg, severity=None): logging.info(msg) -class memoized(object): # noqa +class _memoized(object): # noqa """Decorator that caches a function's return value each time it is called If called later with the same arguments, the cached value is returned, and not re-evaluated. + + Define ``watch`` as a tuple of attribute names if this Decorator + should account for instance variable changes. """ - def __init__(self, func): + def __init__(self, func, watch=()): self.func = func self.cache = {} + self.is_method = False + self.watch = watch - def __call__(self, *args): + def __call__(self, *args, **kwargs): + key = [args, frozenset(kwargs.items())] + if self.is_method: + key.append(tuple([getattr(args[0], v, None) for v in self.watch])) + key = tuple(key) + if key in self.cache: + return self.cache[key] try: - return self.cache[args] - except KeyError: - value = self.func(*args) - self.cache[args] = value + value = self.func(*args, **kwargs) + self.cache[key] = value return value except TypeError: # uncachable -- for instance, passing a list as an argument. # Better to not cache than to blow up entirely. - return self.func(*args) + return self.func(*args, **kwargs) def __repr__(self): """Return the function's docstring.""" return self.func.__doc__ def __get__(self, obj, objtype): + if not self.is_method: + self.is_method = True """Support instance methods.""" return functools.partial(self.__call__, obj) +def memoized(func=None, watch=None): + if func: + return _memoized(func) + else: + def wrapper(f): + return _memoized(f, watch) + return wrapper + + def js_string_to_python(item): return None if item in ('null', 'undefined') else item diff --git a/tests/utils_tests.py b/tests/utils_tests.py index f6d1901d12..46d476632b 100644 --- a/tests/utils_tests.py +++ b/tests/utils_tests.py @@ -8,7 +8,7 @@ import numpy from superset.utils import ( base_json_conv, datetime_f, json_int_dttm_ser, json_iso_dttm_ser, - JSONEncodedDict, merge_extra_filters, parse_human_timedelta, + JSONEncodedDict, memoized, merge_extra_filters, parse_human_timedelta, SupersetException, validate_json, zlib_compress, zlib_decompress_to_string, ) @@ -219,3 +219,77 @@ class UtilsTestCase(unittest.TestCase): invalid = '{"a": 5, "b": [1, 5, ["g", "h]]}' with self.assertRaises(SupersetException): validate_json(invalid) + + def test_memoized_on_functions(self): + watcher = {'val': 0} + + @memoized + def test_function(a, b, c): + watcher['val'] += 1 + return a * b * c + result1 = test_function(1, 2, 3) + result2 = test_function(1, 2, 3) + self.assertEquals(result1, result2) + self.assertEquals(watcher['val'], 1) + + def test_memoized_on_methods(self): + + class test_class: + def __init__(self, num): + self.num = num + self.watcher = 0 + + @memoized + def test_method(self, a, b, c): + self.watcher += 1 + return a * b * c * self.num + + instance = test_class(5) + result1 = instance.test_method(1, 2, 3) + result2 = instance.test_method(1, 2, 3) + self.assertEquals(result1, result2) + self.assertEquals(instance.watcher, 1) + instance.num = 10 + self.assertEquals(result2, instance.test_method(1, 2, 3)) + + def test_memoized_on_methods_with_watches(self): + + class test_class: + def __init__(self, x, y): + self.x = x + self.y = y + self.watcher = 0 + + @memoized(watch=('x', 'y')) + def test_method(self, a, b, c): + self.watcher += 1 + return a * b * c * self.x * self.y + + instance = test_class(3, 12) + result1 = instance.test_method(1, 2, 3) + result2 = instance.test_method(1, 2, 3) + self.assertEquals(result1, result2) + self.assertEquals(instance.watcher, 1) + result3 = instance.test_method(2, 3, 4) + self.assertEquals(instance.watcher, 2) + result4 = instance.test_method(2, 3, 4) + self.assertEquals(instance.watcher, 2) + self.assertEquals(result3, result4) + self.assertNotEqual(result3, result1) + instance.x = 1 + result5 = instance.test_method(2, 3, 4) + self.assertEqual(instance.watcher, 3) + self.assertNotEqual(result5, result4) + result6 = instance.test_method(2, 3, 4) + self.assertEqual(instance.watcher, 3) + self.assertEqual(result6, result5) + instance.x = 10 + instance.y = 10 + result7 = instance.test_method(2, 3, 4) + self.assertEqual(instance.watcher, 4) + self.assertNotEqual(result7, result6) + instance.x = 3 + instance.y = 12 + result8 = instance.test_method(1, 2, 3) + self.assertEqual(instance.watcher, 4) + self.assertEqual(result1, result8)