[Feature] enhanced memoized on get_sqla_engine and other functions (#3530)

* added watch to memoized

* added unit tests for memoized

* code style changes
This commit is contained in:
Jeff Niu 2017-12-17 10:35:00 -08:00 committed by Maxime Beauchemin
parent 500e6256c0
commit af7cdeba4d
3 changed files with 108 additions and 11 deletions

View File

@ -629,6 +629,8 @@ class Database(Model, AuditMixinNullable, ImportMixin):
effective_username = g.user.username effective_username = g.user.username
return effective_username return effective_username
@utils.memoized(
watch=('impersonate_user', 'sqlalchemy_uri_decrypted', 'extra'))
def get_sqla_engine(self, schema=None, nullpool=False, user_name=None): def get_sqla_engine(self, schema=None, nullpool=False, user_name=None):
extra = self.get_extra() extra = self.get_extra()
url = make_url(self.sqlalchemy_uri_decrypted) url = make_url(self.sqlalchemy_uri_decrypted)
@ -662,10 +664,10 @@ class Database(Model, AuditMixinNullable, ImportMixin):
return create_engine(url, **params) return create_engine(url, **params)
def get_reserved_words(self): def get_reserved_words(self):
return self.get_sqla_engine().dialect.preparer.reserved_words return self.get_dialect().preparer.reserved_words
def get_quoter(self): def get_quoter(self):
return self.get_sqla_engine().dialect.identifier_preparer.quote return self.get_dialect().identifier_preparer.quote
def get_df(self, sql, schema): def get_df(self, sql, schema):
sql = sql.strip().strip(';') sql = sql.strip().strip(';')
@ -813,6 +815,7 @@ class Database(Model, AuditMixinNullable, ImportMixin):
return engine.has_table( return engine.has_table(
table.table_name, table.schema or None) table.table_name, table.schema or None)
@utils.memoized
def get_dialect(self): def get_dialect(self):
sqla_url = url.make_url(self.sqlalchemy_uri_decrypted) sqla_url = url.make_url(self.sqlalchemy_uri_decrypted)
return sqla_url.get_dialect()() return sqla_url.get_dialect()()

View File

@ -91,38 +91,58 @@ def flasher(msg, severity=None):
logging.info(msg) logging.info(msg)
class memoized(object): # noqa class _memoized(object): # noqa
"""Decorator that caches a function's return value each time it is called """Decorator that caches a function's return value each time it is called
If called later with the same arguments, the cached value is returned, and If called later with the same arguments, the cached value is returned, and
not re-evaluated. not re-evaluated.
Define ``watch`` as a tuple of attribute names if this Decorator
should account for instance variable changes.
""" """
def __init__(self, func): def __init__(self, func, watch=()):
self.func = func self.func = func
self.cache = {} self.cache = {}
self.is_method = False
self.watch = watch
def __call__(self, *args): def __call__(self, *args, **kwargs):
key = [args, frozenset(kwargs.items())]
if self.is_method:
key.append(tuple([getattr(args[0], v, None) for v in self.watch]))
key = tuple(key)
if key in self.cache:
return self.cache[key]
try: try:
return self.cache[args] value = self.func(*args, **kwargs)
except KeyError: self.cache[key] = value
value = self.func(*args)
self.cache[args] = value
return value return value
except TypeError: except TypeError:
# uncachable -- for instance, passing a list as an argument. # uncachable -- for instance, passing a list as an argument.
# Better to not cache than to blow up entirely. # Better to not cache than to blow up entirely.
return self.func(*args) return self.func(*args, **kwargs)
def __repr__(self): def __repr__(self):
"""Return the function's docstring.""" """Return the function's docstring."""
return self.func.__doc__ return self.func.__doc__
def __get__(self, obj, objtype): def __get__(self, obj, objtype):
if not self.is_method:
self.is_method = True
"""Support instance methods.""" """Support instance methods."""
return functools.partial(self.__call__, obj) return functools.partial(self.__call__, obj)
def memoized(func=None, watch=None):
if func:
return _memoized(func)
else:
def wrapper(f):
return _memoized(f, watch)
return wrapper
def js_string_to_python(item): def js_string_to_python(item):
return None if item in ('null', 'undefined') else item return None if item in ('null', 'undefined') else item

View File

@ -8,7 +8,7 @@ import numpy
from superset.utils import ( from superset.utils import (
base_json_conv, datetime_f, json_int_dttm_ser, json_iso_dttm_ser, base_json_conv, datetime_f, json_int_dttm_ser, json_iso_dttm_ser,
JSONEncodedDict, merge_extra_filters, parse_human_timedelta, JSONEncodedDict, memoized, merge_extra_filters, parse_human_timedelta,
SupersetException, validate_json, zlib_compress, zlib_decompress_to_string, SupersetException, validate_json, zlib_compress, zlib_decompress_to_string,
) )
@ -219,3 +219,77 @@ class UtilsTestCase(unittest.TestCase):
invalid = '{"a": 5, "b": [1, 5, ["g", "h]]}' invalid = '{"a": 5, "b": [1, 5, ["g", "h]]}'
with self.assertRaises(SupersetException): with self.assertRaises(SupersetException):
validate_json(invalid) validate_json(invalid)
def test_memoized_on_functions(self):
watcher = {'val': 0}
@memoized
def test_function(a, b, c):
watcher['val'] += 1
return a * b * c
result1 = test_function(1, 2, 3)
result2 = test_function(1, 2, 3)
self.assertEquals(result1, result2)
self.assertEquals(watcher['val'], 1)
def test_memoized_on_methods(self):
class test_class:
def __init__(self, num):
self.num = num
self.watcher = 0
@memoized
def test_method(self, a, b, c):
self.watcher += 1
return a * b * c * self.num
instance = test_class(5)
result1 = instance.test_method(1, 2, 3)
result2 = instance.test_method(1, 2, 3)
self.assertEquals(result1, result2)
self.assertEquals(instance.watcher, 1)
instance.num = 10
self.assertEquals(result2, instance.test_method(1, 2, 3))
def test_memoized_on_methods_with_watches(self):
class test_class:
def __init__(self, x, y):
self.x = x
self.y = y
self.watcher = 0
@memoized(watch=('x', 'y'))
def test_method(self, a, b, c):
self.watcher += 1
return a * b * c * self.x * self.y
instance = test_class(3, 12)
result1 = instance.test_method(1, 2, 3)
result2 = instance.test_method(1, 2, 3)
self.assertEquals(result1, result2)
self.assertEquals(instance.watcher, 1)
result3 = instance.test_method(2, 3, 4)
self.assertEquals(instance.watcher, 2)
result4 = instance.test_method(2, 3, 4)
self.assertEquals(instance.watcher, 2)
self.assertEquals(result3, result4)
self.assertNotEqual(result3, result1)
instance.x = 1
result5 = instance.test_method(2, 3, 4)
self.assertEqual(instance.watcher, 3)
self.assertNotEqual(result5, result4)
result6 = instance.test_method(2, 3, 4)
self.assertEqual(instance.watcher, 3)
self.assertEqual(result6, result5)
instance.x = 10
instance.y = 10
result7 = instance.test_method(2, 3, 4)
self.assertEqual(instance.watcher, 4)
self.assertNotEqual(result7, result6)
instance.x = 3
instance.y = 12
result8 = instance.test_method(1, 2, 3)
self.assertEqual(instance.watcher, 4)
self.assertEqual(result1, result8)