cache: utilities to print and clear it (#656)
[sympy.git] / sympy / core / cache.py
blob25f612fe5d0f015b9f2cbf3749065bdf2ceed3da
1 """ Caching facility for SymPy """
3 def mycopy(obj, level=0):
4 if isinstance(obj, (list, tuple)):
5 return obj.__class__(map(mycopy, obj))
6 elif isinstance(obj, dict):
7 d = obj.__class__()
8 for k,v in obj.items():
9 d[mycopy(k)] = mycopy(v)
10 return d
11 return obj
14 # TODO: refactor CACHE & friends into class?
16 # global cache registry:
17 CACHE = [] # [] of
18 # (item, {} or tuple of {})
20 def print_cache():
21 """print cache content"""
23 for item, cache in CACHE:
24 item = str(item)
25 head = '='*len(item)
27 print head
28 print item
29 print head
32 if not isinstance(cache, tuple):
33 cache = (cache,)
34 shown = False
35 else:
36 shown = True
38 for i, kv in enumerate(cache):
39 if shown:
40 print '\n*** %i ***\n' % i
42 for k, v in kv.iteritems():
43 print ' %s :\t%s' % (k, v)
46 def clear_cache():
47 """clear cache content"""
48 for item, cache in CACHE:
49 if not isinstance(cache, tuple):
50 cache = (cache,)
52 for kv in cache:
53 kv.clear()
57 def cache_it_fast(func):
58 func._cache_it_cache = func_cache_it_cache = {}
59 CACHE.append((func, func_cache_it_cache))
61 def wrapper(*args, **kw_args):
62 if kw_args:
63 keys = kw_args.keys()
64 keys.sort()
65 items = [(k+'=',kw_args[k]) for k in keys]
66 k = args + tuple(items)
67 else:
68 k = args
69 cache_flag = False
70 try:
71 r = func_cache_it_cache[k]
72 except KeyError:
73 r = func(*args, **kw_args)
74 cache_flag = True
75 if cache_flag:
76 func_cache_it_cache[k] = r
77 return mycopy(r)
78 return wrapper
80 def cache_it_immutable(func):
81 func._cache_it_cache = func_cache_it_cache = {}
82 CACHE.append((func, func_cache_it_cache))
84 def wrapper(*args, **kw_args):
85 if kw_args:
86 keys = kw_args.keys()
87 keys.sort()
88 items = [(k+'=',kw_args[k]) for k in keys]
89 k = args + tuple(items)
90 else:
91 k = args
92 try:
93 return func_cache_it_cache[k]
94 except KeyError:
95 pass
96 func_cache_it_cache[k] = r = func(*args, **kw_args)
97 return r
98 return wrapper
100 def cache_it_debug(func):
101 func._cache_it_cache = func_cache_it_cache = {}
102 func._cache_it_cache_repr = func_cache_it_cache_repr = {}
103 CACHE.append((func, (func_cache_it_cache, func_cache_it_cache_repr)))
105 def wrapper(*args, **kw_args):
106 if kw_args:
107 keys = kw_args.keys()
108 keys.sort()
109 items = [(k+'=',kw_args[k]) for k in keys]
110 k = args + tuple(items)
111 else:
112 k = args
113 cache_flag = False
114 try:
115 r = func_cache_it_cache[k]
116 except KeyError:
117 r = func(*args, **kw_args)
118 cache_flag = True
119 if cache_flag:
120 func_cache_it_cache[k] = r
121 # XXX just use repr(r) here
122 f = Basic.set_repr_level(0)
123 func_cache_it_cache_repr[k] = repr(r)
124 Basic.set_repr_level(f)
125 else:
126 s = func_cache_it_cache_repr[k]
127 # XXX just use repr(r) here
128 f = Basic.set_repr_level(0)
129 new_s = repr(r)
130 Basic.set_repr_level(f)
131 # check that cache values have not changed
132 assert new_s==s,`func,s,r, args[0].__class__`
133 return mycopy(r)
134 return wrapper
136 cache_it = cache_it_fast
137 #cache_it = cache_it_debug # twice slower
139 def cache_it_nondummy(func):
140 func._cache_it_cache = func_cache_it_cache = {}
141 CACHE.append((func, func_cache_it_cache))
143 def wrapper(*args, **kw_args):
144 if kw_args:
145 try:
146 dummy = kw_args['dummy']
147 except KeyError:
148 dummy = None
149 if dummy:
150 return func(*args, **kw_args)
151 keys = kw_args.keys()
152 keys.sort()
153 items = [(k+'=',kw_args[k]) for k in keys]
154 k = args + tuple(items)
155 else:
156 k = args
157 try:
158 return func_cache_it_cache[k]
159 except KeyError:
160 pass
161 func_cache_it_cache[k] = r = func(*args, **kw_args)
162 return r
163 return wrapper
166 class MemoizerArg:
167 """ See Memoizer.
170 def __init__(self, allowed_types, converter = None, name = None):
171 self._allowed_types = allowed_types
172 self.converter = converter
173 self.name = name
175 def fix_allowed_types(self, have_been_here={}):
176 i = id(self)
177 if have_been_here.get(i): return
178 allowed_types = self._allowed_types
179 if isinstance(allowed_types, str):
180 self.allowed_types = getattr(Basic, allowed_types)
181 elif isinstance(allowed_types, (tuple, list)):
182 new_allowed_types = []
183 for t in allowed_types:
184 if isinstance(t, str):
185 t = getattr(Basic, t)
186 new_allowed_types.append(t)
187 self.allowed_types = tuple(new_allowed_types)
188 else:
189 self.allowed_types = allowed_types
190 have_been_here[i] = True
191 return
193 def process(self, obj, func, index = None):
194 if isinstance(obj, self.allowed_types):
195 if self.converter is not None:
196 obj = self.converter(obj)
197 return obj
198 func_src = '%s:%s:function %s' % (func.func_code.co_filename, func.func_code.co_firstlineno, func.func_name)
199 if index is None:
200 raise ValueError('%s return value must be of type %r but got %r' % (func_src, self.allowed_types, obj))
201 if isinstance(index, (int,long)):
202 raise ValueError('%s %s-th argument must be of type %r but got %r' % (func_src, index, self.allowed_types, obj))
203 if isinstance(index, str):
204 raise ValueError('%s %r keyword argument must be of type %r but got %r' % (func_src, index, self.allowed_types, obj))
205 raise NotImplementedError(`index,type(index)`)
207 class Memoizer:
208 """ Memoizer function decorator generator.
210 Features:
211 - checks that function arguments have allowed types
212 - optionally apply converters to arguments
213 - cache the results of function calls
214 - optionally apply converter to function values
216 Usage:
218 @Memoizer(<allowed types for argument 0>,
219 MemoizerArg(<allowed types for argument 1>),
220 MemoizerArg(<allowed types for argument 2>, <convert argument before function call>),
221 MemoizerArg(<allowed types for argument 3>, <convert argument before function call>, name=<kw argument name>),
223 return_value_converter = <None or converter function, usually makes a copy>
225 def function(<arguments>, <kw_argumnets>):
228 Details:
229 - if allowed type is string object then there Basic must have attribute
230 with the string name that is used as the allowed type --- this is needed
231 for applying Memoizer decorator to Basic methods when Basic definition
232 is not defined.
234 Restrictions:
235 - arguments must be immutable
236 - when function values are mutable then one must use return_value_converter to
237 deep copy the returned values
239 Ref: http://en.wikipedia.org/wiki/Memoization
242 def __init__(self, *arg_templates, **kw_arg_templates):
243 new_arg_templates = []
244 for t in arg_templates:
245 if not isinstance(t, MemoizerArg):
246 t = MemoizerArg(t)
247 new_arg_templates.append(t)
248 self.arg_templates = tuple(new_arg_templates)
249 return_value_converter = kw_arg_templates.pop('return_value_converter', None)
250 self.kw_arg_templates = kw_arg_templates.copy()
251 for template in self.arg_templates:
252 if template.name is not None:
253 self.kw_arg_templates[template.name] = template
254 if return_value_converter is None:
255 self.return_value_converter = lambda obj: obj
256 else:
257 self.return_value_converter = return_value_converter
259 def fix_allowed_types(self, have_been_here={}):
260 i = id(self)
261 if have_been_here.get(i): return
262 for t in self.arg_templates:
263 t.fix_allowed_types()
264 for k,t in self.kw_arg_templates.items():
265 t.fix_allowed_types()
266 have_been_here[i] = True
268 def __call__(self, func):
269 cache = {}
270 value_cache = {}
271 CACHE.append((func, (cache, value_cache)))
273 def wrapper(*args, **kw_args):
274 kw_items = tuple(kw_args.items())
275 try:
276 return self.return_value_converter(cache[args,kw_items])
277 except KeyError:
278 pass
279 self.fix_allowed_types()
280 new_args = tuple([template.process(a,func,i) for (a, template, i) in zip(args, self.arg_templates, range(len(args)))])
281 assert len(args)==len(new_args)
282 new_kw_args = {}
283 for k, v in kw_items:
284 template = self.kw_arg_templates[k]
285 v = template.process(v, func, k)
286 new_kw_args[k] = v
287 new_kw_items = tuple(new_kw_args.items())
288 try:
289 return self.return_value_converter(cache[new_args, new_kw_items])
290 except KeyError:
291 r = func(*new_args, **new_kw_args)
292 try:
293 try:
294 r = value_cache[r]
295 except KeyError:
296 value_cache[r] = r
297 except TypeError:
298 pass
299 cache[new_args, new_kw_items] = cache[args, kw_items] = r
300 return self.return_value_converter(r)
301 return wrapper