some late comments :)
[gostyle.git] / utils / db_cache.py
blob1fdf6afd912558e71f5fba8fd26e21663484a95d
1 from sqlalchemy import Table, Column, Integer, ForeignKey, Text, Date, PickleType, Float
2 from sqlalchemy.ext.declarative import declarative_base
3 from sqlalchemy import create_engine
4 from sqlalchemy.orm import sessionmaker
5 import time
6 import logging
7 import inspect
8 import types
9 import functools
11 import utils
13 logger = logging.getLogger(__name__)
15 """
16 A simple caching scheme for pure functions, supporting pure functions as
17 args as well.
19 Changes of code do NOT make the cache invalid - so you should delete the
20 cache database yourself if you change any pure functions.
21 """
23 ## hodnoty def kwargs to ovsem meni, jen kdyz se fce predava parametrem (funkci kterou taky cachujeme),
24 ## nikoliv kdyz je volana primo
25 ## pze kdyz se predava parametrem, tak ta vnejsi fce nevi jaky ma def param
27 # By default (without running init_cache ) a dict (=> cache not persistent across runs & processes)
28 cache_object = {}
30 Base = declarative_base()
31 class CacheLine(Base):
32 """
33 Maps key -> value, saving time of creation, which is used as a criterion for time expiration.
34 """
35 __tablename__ = 'cacheline'
36 id = Column(Integer, primary_key=True)
37 time = Column(Float)
38 key = Column(Text, index=True)
39 value = Column(PickleType)
41 def __str__(self):
42 return "(%s, %s) -> %s" % (self.key, self.time, self.value)
44 def __repr__(self):
45 return "CacheLine(%s)" % (str(self))
47 class DBCacheObject:
48 """ The cache uses the same interface as dict."""
49 def __init__(self, db_session, expire):
50 self.session = db_session
51 self.expire = expire
53 def delete_expired(self):
54 expired_before = time.time() - self.expire
55 self.session.query(CacheLine).filter(CacheLine.time < expired_before).delete()
56 self.session.commit()
58 def __getitem__(self, key):
59 # with correct key
60 q = self.session.query(CacheLine).filter(CacheLine.key == key)
62 # if expiration rate set
63 if self.expire:
64 expired_before = time.time() - self.expire
65 # not expired
66 q = q.filter(CacheLine.time > expired_before)
68 # order by time
69 by_time = q.order_by(CacheLine.time).all()
71 # the last one
72 if len(by_time):
73 return by_time[-1].value
75 raise KeyError
77 def __setitem__(self, key, value):
78 l = CacheLine(time=time.time(), key=key, value=value)
79 self.session.add(l)
80 self.session.commit()
83 def delete_expired():
84 global cache_object
85 if not isinstance(cache_object, DBCacheObject):
86 logging.warn("Cannot remove expired elemets from cache - not a DBCacheObject")
87 return
89 logging.info("Deleting expired cache rows...")
90 cache_object.delete_expired()
92 def _print_all():
93 global cache_object
95 if isinstance(cache_object, DBCacheObject):
96 it = cache_object.session.query(CacheLine).all()
97 else:
98 it = cache_object.iteritems()
100 print "CACHE:"
101 for a in it:
102 print "\t", a
106 # Pure function
109 class PureFunction(object):
110 """PureFunction is a class that has nice function repr like
111 <pure_function __main__.f> instead of the default repr
112 <function f at 0x11e5320>.
114 By using it, the user declares, that calls to the same function with
115 same arguments will always (in time, accross different processes, ..)
116 have the same results and can be thus cached.
118 def __init__(self, f):
119 self.f = f
120 assert isinstance(f, types.FunctionType)
121 functools.update_wrapper(self, f)
123 def getargspec(self):
124 return inspect.getargspec(self.f)
126 def get_default_kwargs(self):
127 args, varargs, varkw, defaults = self.getargspec()
128 if defaults:
129 return dict(zip(args[-len(defaults):], defaults))
131 def __call__(self, *args, **kwargs):
132 logger.debug("calling %s"%repr(self))
133 return self.f(*args, **kwargs)
135 def __repr__(self):
136 return '<pure_function %s>'%(utils.repr_origin(self.f))
137 #return '<pure_function %s def_kwargs=%s>'%(utils.repr_origin(self.f), repr( self.get_default_kwargs()))
139 # to be used as a deco
140 declare_pure_function = PureFunction
146 def init_cache(filename='CACHE.db', expires=0, sqlalchemy_echo=False):
148 Initialize cache, sets up the global cache_object.
150 filename -- specifies the sqlite dbfile to store the results to
151 expires -- specifies expiration in seconds. If you set this to 0,
152 cached data are valid forever
153 echo -- whether to output sqlalchemy logs
155 if filename == None:
156 # By default, the cache object is a dict
157 if expires:
158 logger.warn('Dictionary cache object does not support time expiration of cached values!')
159 else:
160 engine = create_engine('sqlite:///%s'%filename, echo=sqlalchemy_echo)
161 Base.metadata.create_all(engine)
162 Session = sessionmaker(bind=engine)
163 session = Session()
165 global cache_object
166 cache_object = DBCacheObject(session, expires)
168 def close_cache():
169 global cache_object
170 cache_object.session.close()
172 def make_key(f, f_args, f_kwargs):
173 if isinstance(f, PureFunction):
174 spect = f.getargspec()
175 elif isinstance(f, types.FunctionType):
176 spect = inspect.getargspec(f)
177 else:
178 raise TypeError("Unable to obtain arg specification for function : '%s'"%(repr(f)))
180 args, varargs, varkw, defaults = spect
181 default_kwargs = {}
182 if defaults:
183 default_kwargs = dict(zip(args[-len(defaults):], defaults))
184 for (key, val) in f_kwargs.iteritems():
185 assert key in default_kwargs
187 f_kwargs_joined = default_kwargs
188 f_kwargs_joined.update(f_kwargs)
190 #rep = "%s(args=%s, kwargs=%s)"%(utils.function_nice_repr(f), repr(f_args), repr(f_kwargs_joined))
192 rep = "%s(%s)"%(repr(f),
193 ', '.join(map(repr, f_args)
194 + [ '%s=%s'%(key, repr( val)) for key, val in f_kwargs_joined.iteritems() ]))
196 ## XXX "normal temporary" objects
197 if 'at 0x' in rep:
198 logger.warn("Object(s) specified in '%s' do not have a proper repr."%(rep))
200 return rep
203 # The deco
206 def cache_result(fun):
207 """Compute the key, look if the result of a computation is in the
208 cache. If so, return it, otw run the function, cache the result and
209 return it."""
210 def g_factory(f):
211 def g(*args, **kwargs):
212 global cache_object
213 key = make_key(f, args, kwargs)
214 try:
215 cached = cache_object[key]
216 logger.info("Returning CACHED for: '%s'"%(key))
217 return cached
218 except KeyError:
219 ret = f(*args, **kwargs)
220 cache_object[key] = ret
221 logger.info("CACHING for: '%s'"%(key))
222 return ret
223 return g
225 # if we got PureFunction, the returned function should also be pure
226 # please see the PureFunction.__doc__
227 if isinstance(fun, PureFunction):
228 g = g_factory(fun)
229 functools.update_wrapper(g, fun.f)
230 return PureFunction(g)
232 return functools.wraps(fun)(g_factory(fun))
234 if __name__ == "__main__":
235 logging.basicConfig()
236 l = logging.getLogger(__name__)
237 l.setLevel(logging.INFO)
239 init_cache(filename=':memory:', expires=0.1)
241 @cache_result
242 @declare_pure_function
243 def add(a, b):
244 return a + b
246 @cache_result
247 @declare_pure_function
248 def call_10(f):
249 return f(10)
251 @cache_result
252 @declare_pure_function
253 def multmap(l):
254 return ( reduce( (lambda x, y: x*y) , l), time.time() )
256 def test1():
257 multmap([1,2,3])
258 multmap([1,2,3])
259 print "sleep 0.1"
260 time.sleep(0.1)
261 multmap([1,2,3])
263 def test2():
264 """Stateless (pure) class and a pure function as arguments"""
266 class Adder:
267 """ The Adder must be `stateless` in a sense that results
268 of __call__ will always produce the same results for the
269 same args. Moreover the Adder must have __repr__ which has
270 all the information to uniquely define the Adder instance -
271 once again, so that the statement about __call__ holds.
273 The user is responsible for the statelessness!
274 (as with @declare_pure_function)
276 def __init__(self, offset):
277 self.offset = offset
278 def __call__(self, a, b=10):
279 return a + self.offset
280 def __repr__(self):
281 return "Adder(offset=%s)"%self.offset
283 a = Adder(2)
285 @cache_result
286 @declare_pure_function
287 def my_map(f, l):
288 return map(f, l)
290 my_map(a, range(10))
291 my_map(a, range(10))
293 @declare_pure_function
294 def multiplicator(x, mult=2):
295 return x * mult
297 my_map(multiplicator, range(10))
299 from utils import partial, partial_right
301 my_map(partial_right(multiplicator, 2), range(10))
302 my_map(partial(a, 2), range(10))
303 my_map(partial_right(multiplicator, 2), range(10))
304 my_map(partial(a, 2), range(10))
306 def test3():
307 """Test warning for nonpure functions as arguments"""
308 @cache_result
309 def h(x):
310 return 2 * x
312 h(10)
313 print
314 call_10(h)
316 def test4():
317 """Test timeout"""
318 multmap([1,2,3])
319 multmap([1,2,3, 4])
320 multmap([1,2,3])
321 _print_all()
322 time.sleep(0.5)
323 multmap([1,2,3])
324 _print_all()
325 delete_expired()
326 _print_all()
329 test1()
330 #test2()