[Feature] enhanced memoized on get_sqla_engine and other functions (#3530)

* added watch to memoized

* added unit tests for memoized

* code style changes
This commit is contained in:
Jeff Niu 2017-12-17 10:35:00 -08:00 committed by Maxime Beauchemin
parent 500e6256c0
commit af7cdeba4d
3 changed files with 108 additions and 11 deletions

View File

@ -629,6 +629,8 @@ class Database(Model, AuditMixinNullable, ImportMixin):
effective_username = g.user.username
return effective_username
@utils.memoized(
watch=('impersonate_user', 'sqlalchemy_uri_decrypted', 'extra'))
def get_sqla_engine(self, schema=None, nullpool=False, user_name=None):
extra = self.get_extra()
url = make_url(self.sqlalchemy_uri_decrypted)
@ -662,10 +664,10 @@ class Database(Model, AuditMixinNullable, ImportMixin):
return create_engine(url, **params)
def get_reserved_words(self):
return self.get_sqla_engine().dialect.preparer.reserved_words
return self.get_dialect().preparer.reserved_words
def get_quoter(self):
return self.get_sqla_engine().dialect.identifier_preparer.quote
return self.get_dialect().identifier_preparer.quote
def get_df(self, sql, schema):
sql = sql.strip().strip(';')
@ -813,6 +815,7 @@ class Database(Model, AuditMixinNullable, ImportMixin):
return engine.has_table(
table.table_name, table.schema or None)
@utils.memoized
def get_dialect(self):
sqla_url = url.make_url(self.sqlalchemy_uri_decrypted)
return sqla_url.get_dialect()()

View File

@ -91,38 +91,58 @@ def flasher(msg, severity=None):
logging.info(msg)
class memoized(object): # noqa
class _memoized(object): # noqa
"""Decorator that caches a function's return value each time it is called
If called later with the same arguments, the cached value is returned, and
not re-evaluated.
Define ``watch`` as a tuple of attribute names if this Decorator
should account for instance variable changes.
"""
def __init__(self, func):
def __init__(self, func, watch=()):
self.func = func
self.cache = {}
self.is_method = False
self.watch = watch
def __call__(self, *args):
def __call__(self, *args, **kwargs):
key = [args, frozenset(kwargs.items())]
if self.is_method:
key.append(tuple([getattr(args[0], v, None) for v in self.watch]))
key = tuple(key)
if key in self.cache:
return self.cache[key]
try:
return self.cache[args]
except KeyError:
value = self.func(*args)
self.cache[args] = value
value = self.func(*args, **kwargs)
self.cache[key] = value
return value
except TypeError:
# uncachable -- for instance, passing a list as an argument.
# Better to not cache than to blow up entirely.
return self.func(*args)
return self.func(*args, **kwargs)
def __repr__(self):
"""Return the function's docstring."""
return self.func.__doc__
def __get__(self, obj, objtype):
if not self.is_method:
self.is_method = True
"""Support instance methods."""
return functools.partial(self.__call__, obj)
def memoized(func=None, watch=None):
if func:
return _memoized(func)
else:
def wrapper(f):
return _memoized(f, watch)
return wrapper
def js_string_to_python(item):
return None if item in ('null', 'undefined') else item

View File

@ -8,7 +8,7 @@ import numpy
from superset.utils import (
base_json_conv, datetime_f, json_int_dttm_ser, json_iso_dttm_ser,
JSONEncodedDict, merge_extra_filters, parse_human_timedelta,
JSONEncodedDict, memoized, merge_extra_filters, parse_human_timedelta,
SupersetException, validate_json, zlib_compress, zlib_decompress_to_string,
)
@ -219,3 +219,77 @@ class UtilsTestCase(unittest.TestCase):
invalid = '{"a": 5, "b": [1, 5, ["g", "h]]}'
with self.assertRaises(SupersetException):
validate_json(invalid)
def test_memoized_on_functions(self):
watcher = {'val': 0}
@memoized
def test_function(a, b, c):
watcher['val'] += 1
return a * b * c
result1 = test_function(1, 2, 3)
result2 = test_function(1, 2, 3)
self.assertEquals(result1, result2)
self.assertEquals(watcher['val'], 1)
def test_memoized_on_methods(self):
class test_class:
def __init__(self, num):
self.num = num
self.watcher = 0
@memoized
def test_method(self, a, b, c):
self.watcher += 1
return a * b * c * self.num
instance = test_class(5)
result1 = instance.test_method(1, 2, 3)
result2 = instance.test_method(1, 2, 3)
self.assertEquals(result1, result2)
self.assertEquals(instance.watcher, 1)
instance.num = 10
self.assertEquals(result2, instance.test_method(1, 2, 3))
def test_memoized_on_methods_with_watches(self):
class test_class:
def __init__(self, x, y):
self.x = x
self.y = y
self.watcher = 0
@memoized(watch=('x', 'y'))
def test_method(self, a, b, c):
self.watcher += 1
return a * b * c * self.x * self.y
instance = test_class(3, 12)
result1 = instance.test_method(1, 2, 3)
result2 = instance.test_method(1, 2, 3)
self.assertEquals(result1, result2)
self.assertEquals(instance.watcher, 1)
result3 = instance.test_method(2, 3, 4)
self.assertEquals(instance.watcher, 2)
result4 = instance.test_method(2, 3, 4)
self.assertEquals(instance.watcher, 2)
self.assertEquals(result3, result4)
self.assertNotEqual(result3, result1)
instance.x = 1
result5 = instance.test_method(2, 3, 4)
self.assertEqual(instance.watcher, 3)
self.assertNotEqual(result5, result4)
result6 = instance.test_method(2, 3, 4)
self.assertEqual(instance.watcher, 3)
self.assertEqual(result6, result5)
instance.x = 10
instance.y = 10
result7 = instance.test_method(2, 3, 4)
self.assertEqual(instance.watcher, 4)
self.assertNotEqual(result7, result6)
instance.x = 3
instance.y = 12
result8 = instance.test_method(1, 2, 3)
self.assertEqual(instance.watcher, 4)
self.assertEqual(result1, result8)