Implemented "mark all as read".
[straw.git] / test / mock.py
blob7c20056c1afa6271aadb089b8d7e56db223d2a9f
2 # (c) Dave Kirby 2001 - 2005
3 # mock@thedeveloperscoach.com
5 # Original call interceptor and call assertion code by Phil Dawes (pdawes@users.sourceforge.net)
6 # Call interceptor code enhanced by Bruce Cropley (cropleyb@yahoo.com.au)
8 # This Python module and associated files are released under the FreeBSD
9 # license. Essentially, you can do what you like with it except pretend you wrote
10 # it yourself.
13 # Copyright (c) 2005, Dave Kirby
15 # All rights reserved.
17 # Redistribution and use in source and binary forms, with or without
18 # modification, are permitted provided that the following conditions are met:
20 # * Redistributions of source code must retain the above copyright
21 # notice, this list of conditions and the following disclaimer.
23 # * Redistributions in binary form must reproduce the above copyright
24 # notice, this list of conditions and the following disclaimer in the
25 # documentation and/or other materials provided with the distribution.
27 # * Neither the name of this library nor the names of its
28 # contributors may be used to endorse or promote products derived from
29 # this software without specific prior written permission.
31 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
32 # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
33 # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
34 # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
35 # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
36 # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
37 # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
38 # ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
39 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
40 # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
42 # mock@thedeveloperscoach.com
45 """
46 Mock object library for Python. Mock objects can be used when unit testing
47 to remove a dependency on another production class. They are typically used
48 when the dependency would either pull in lots of other classes, or
49 significantly slow down the execution of the test.
50 They are also used to create exceptional conditions that cannot otherwise
51 be easily triggered in the class under test.
52 """
54 __version__ = "0.1.0"
56 # Added in Python 2.1
57 import inspect
58 import re
60 class MockInterfaceError(Exception):
61 pass
63 class Mock:
64 """
65 The Mock class emulates any other class for testing purposes.
66 All method calls are stored for later examination.
67 """
69 def __init__(self, returnValues=None, realClass=None):
70 """
71 The Mock class constructor takes a dictionary of method names and
72 the values they return. Methods that are not in the returnValues
73 dictionary will return None.
74 You may also supply a class whose interface is being mocked.
75 All calls will be checked to see if they appear in the original
76 interface. Any calls to methods not appearing in the real class
77 will raise a MockInterfaceError. Any calls that would fail due to
78 non-matching parameter lists will also raise a MockInterfaceError.
79 Both of these help to prevent the Mock class getting out of sync
80 with the class it is Mocking.
81 """
82 self.mockCalledMethods = {}
83 self.mockAllCalledMethods = []
84 self.mockReturnValues = returnValues or {}
85 self.mockExpectations = {}
86 self.realClassMethods = None
87 if realClass:
88 self.realClassMethods = dict(inspect.getmembers(realClass, inspect.isroutine))
89 for retMethod in self.mockReturnValues.keys():
90 if not self.realClassMethods.has_key(retMethod):
91 raise MockInterfaceError("Return value supplied for method '%s' that was not in the original class" % retMethod)
92 self._setupSubclassMethodInterceptors()
94 def _setupSubclassMethodInterceptors(self):
95 methods = inspect.getmembers(self.__class__,inspect.isroutine)
96 baseMethods = dict(inspect.getmembers(Mock, inspect.ismethod))
97 for m in methods:
98 name = m[0]
99 # Don't record calls to methods of Mock base class.
100 if not name in baseMethods:
101 self.__dict__[name] = MockCallable(name, self, handcrafted=True)
103 def __getattr__(self, name):
104 return MockCallable(name, self)
106 def mockAddReturnValues(self, **methodReturnValues ):
107 self.mockReturnValues.update(methodReturnValues)
109 def mockSetExpectation(self, name, testFn, after=0, until=0):
110 self.mockExpectations.setdefault(name, []).append((testFn,after,until))
112 def _checkInterfaceCall(self, name, callParams, callKwParams):
114 Check that a call to a method of the given name to the original
115 class with the given parameters would not fail. If it would fail,
116 raise a MockInterfaceError.
117 Based on the Python 2.3.3 Reference Manual section 5.3.4: Calls.
119 if self.realClassMethods == None:
120 return
121 if not self.realClassMethods.has_key(name):
122 raise MockInterfaceError("Calling mock method '%s' that was not found in the original class" % name)
124 func = self.realClassMethods[name]
125 try:
126 args, varargs, varkw, defaults = inspect.getargspec(func)
127 except TypeError:
128 # func is not a Python function. It is probably a builtin,
129 # such as __repr__ or __coerce__. TODO: Checking?
130 # For now assume params are OK.
131 return
133 # callParams doesn't include self; args does include self.
134 numPosCallParams = 1 + len(callParams)
136 if numPosCallParams > len(args) and not varargs:
137 raise MockInterfaceError("Original %s() takes at most %s arguments (%s given)" %
138 (name, len(args), numPosCallParams))
140 # Get the number of positional arguments that appear in the call,
141 # also check for duplicate parameters and unknown parameters
142 numPosSeen = _getNumPosSeenAndCheck(numPosCallParams, callKwParams, args, varkw)
144 lenArgsNoDefaults = len(args) - len(defaults or [])
145 if numPosSeen < lenArgsNoDefaults:
146 raise MockInterfaceError("Original %s() takes at least %s arguments (%s given)" % (name, lenArgsNoDefaults, numPosSeen))
148 def mockGetAllCalls(self):
150 Return a list of MockCall objects,
151 representing all the methods in the order they were called.
153 return self.mockAllCalledMethods
154 getAllCalls = mockGetAllCalls # deprecated - kept for backward compatibility
156 def mockGetNamedCalls(self, methodName):
158 Return a list of MockCall objects,
159 representing all the calls to the named method in the order they were called.
161 return self.mockCalledMethods.get(methodName, [])
162 getNamedCalls = mockGetNamedCalls # deprecated - kept for backward compatibility
164 def mockCheckCall(self, index, name, *args, **kwargs):
165 '''test that the index-th call had the specified name and parameters'''
166 call = self.mockAllCalledMethods[index]
167 assert name == call.getName(), "%r != %r" % (name, call.getName())
168 call.checkArgs(*args, **kwargs)
171 def _getNumPosSeenAndCheck(numPosCallParams, callKwParams, args, varkw):
173 Positional arguments can appear as call parameters either named as
174 a named (keyword) parameter, or just as a value to be matched by
175 position. Count the positional arguments that are given by either
176 keyword or position, and check for duplicate specifications.
177 Also check for arguments specified by keyword that do not appear
178 in the method's parameter list.
180 posSeen = {}
181 for arg in args[:numPosCallParams]:
182 posSeen[arg] = True
183 for kwp in callKwParams:
184 if posSeen.has_key(kwp):
185 raise MockInterfaceError("%s appears as both a positional and named parameter." % kwp)
186 if kwp in args:
187 posSeen[kwp] = True
188 elif not varkw:
189 raise MockInterfaceError("Original method does not have a parameter '%s'" % kwp)
190 return len(posSeen)
192 class MockCall:
194 MockCall records the name and parameters of a call to an instance
195 of a Mock class. Instances of MockCall are created by the Mock class,
196 but can be inspected later as part of the test.
198 def __init__(self, name, params, kwparams ):
199 self.name = name
200 self.params = params
201 self.kwparams = kwparams
203 def checkArgs(self, *args, **kwargs):
204 assert args == self.params, "%r != %r" % (args, self.params)
205 assert kwargs == self.kwparams, "%r != %r" % (kwargs, self.kwparams)
207 def getParam( self, n ):
208 if isinstance(n, int):
209 return self.params[n]
210 elif isinstance(n, str):
211 return self.kwparams[n]
212 else:
213 raise IndexError, 'illegal index type for getParam'
215 def getNumParams(self):
216 return len(self.params)
218 def getNumKwParams(self):
219 return len(self.kwparams)
221 def getName(self):
222 return self.name
224 #pretty-print the method call
225 def __str__(self):
226 s = self.name + "("
227 sep = ''
228 for p in self.params:
229 s = s + sep + repr(p)
230 sep = ', '
231 items = self.kwparams.items()
232 items.sort()
233 for k,v in items:
234 s = s + sep + k + '=' + repr(v)
235 sep = ', '
236 s = s + ')'
237 return s
238 def __repr__(self):
239 return self.__str__()
241 class MockCallable:
243 Intercepts the call and records it, then delegates to either the mock's
244 dictionary of mock return values that was passed in to the constructor,
245 or a handcrafted method of a Mock subclass.
247 def __init__(self, name, mock, handcrafted=False):
248 self.name = name
249 self.mock = mock
250 self.handcrafted = handcrafted
252 def __call__(self, *params, **kwparams):
253 self.mock._checkInterfaceCall(self.name, params, kwparams)
254 thisCall = self.recordCall(params,kwparams)
255 self.checkExpectations(thisCall, params, kwparams)
256 return self.makeCall(params, kwparams)
258 def recordCall(self, params, kwparams):
260 Record the MockCall in an ordered list of all calls, and an ordered
261 list of calls for that method name.
263 thisCall = MockCall(self.name, params, kwparams)
264 calls = self.mock.mockCalledMethods.setdefault(self.name, [])
265 calls.append(thisCall)
266 self.mock.mockAllCalledMethods.append(thisCall)
267 return thisCall
269 def makeCall(self, params, kwparams):
270 if self.handcrafted:
271 allPosParams = (self.mock,) + params
272 func = _findFunc(self.mock.__class__, self.name)
273 if not func:
274 raise NotImplementedError
275 return func(*allPosParams, **kwparams)
276 else:
277 returnVal = self.mock.mockReturnValues.get(self.name)
278 if isinstance(returnVal, ReturnValuesBase):
279 returnVal = returnVal.next()
280 return returnVal
282 def checkExpectations(self, thisCall, params, kwparams):
283 if self.name in self.mock.mockExpectations:
284 callsMade = len(self.mock.mockCalledMethods[self.name])
285 for (expectation, after, until) in self.mock.mockExpectations[self.name]:
286 if callsMade > after and (until==0 or callsMade < until):
287 assert expectation(self.mock, thisCall, len(self.mock.mockAllCalledMethods)-1), 'Expectation failed: '+str(thisCall)
290 def _findFunc(cl, name):
291 """ Depth first search for a method with a given name. """
292 if cl.__dict__.has_key(name):
293 return cl.__dict__[name]
294 for base in cl.__bases__:
295 func = _findFunc(base, name)
296 if func:
297 return func
298 return None
302 class ReturnValuesBase:
303 def next(self):
304 try:
305 return self.iter.next()
306 except StopIteration:
307 raise AssertionError("No more return values")
308 def __iter__(self):
309 return self
311 class ReturnValues(ReturnValuesBase):
312 def __init__(self, *values):
313 self.iter = iter(values)
316 class ReturnIterator(ReturnValuesBase):
317 def __init__(self, iterator):
318 self.iter = iter(iterator)
321 def expectParams(*params, **keywords):
322 '''check that the callObj is called with specified params and keywords
324 def fn(mockObj, callObj, idx):
325 return callObj.params == params and callObj.kwparams == keywords
326 return fn
329 def expectAfter(*methods):
330 '''check that the function is only called after all the functions in 'methods'
332 def fn(mockObj, callObj, idx):
333 calledMethods = [method.getName() for method in mockObj.mockGetAllCalls()]
334 #skip last entry, since that is the current call
335 calledMethods = calledMethods[:-1]
336 for method in methods:
337 if method not in calledMethods:
338 return False
339 return True
340 return fn
342 def expectException(exception, *args, **kwargs):
343 ''' raise an exception when the method is called
345 def fn(mockObj, callObj, idx):
346 raise exception(*args, **kwargs)
347 return fn
350 def expectParam(paramIdx, cond):
351 '''check that the callObj is called with parameter specified by paramIdx (a position index or keyword)
352 fulfills the condition specified by cond.
353 cond is a function that takes a single argument, the value to test.
355 def fn(mockObj, callObj, idx):
356 param = callObj.getParam(paramIdx)
357 return cond(param)
358 return fn
360 def EQ(value):
361 def testFn(param):
362 return param == value
363 return testFn
365 def NE(value):
366 def testFn(param):
367 return param != value
368 return testFn
370 def GT(value):
371 def testFn(param):
372 return param > value
373 return testFn
375 def LT(value):
376 def testFn(param):
377 return param < value
378 return testFn
380 def GE(value):
381 def testFn(param):
382 return param >= value
383 return testFn
385 def LE(value):
386 def testFn(param):
387 return param <= value
388 return testFn
390 def AND(*condlist):
391 def testFn(param):
392 for cond in condlist:
393 if not cond(param):
394 return False
395 return True
396 return testFn
398 def OR(*condlist):
399 def testFn(param):
400 for cond in condlist:
401 if cond(param):
402 return True
403 return False
404 return testFn
406 def NOT(cond):
407 def testFn(param):
408 return not cond(param)
409 return testFn
411 def MATCHES(regex, *args, **kwargs):
412 compiled_regex = re.compile(regex, *args, **kwargs)
413 def testFn(param):
414 return compiled_regex.match(param) != None
415 return testFn
417 def SEQ(*sequence):
418 iterator = iter(sequence)
419 def testFn(param):
420 try:
421 cond = iterator.next()
422 except StopIteration:
423 raise AssertionError('SEQ exhausted')
424 return cond(param)
425 return testFn
427 def IS(instance):
428 def testFn(param):
429 return param is instance
430 return testFn
432 def ISINSTANCE(class_):
433 def testFn(param):
434 return isinstance(param, class_)
435 return testFn
437 def ISSUBCLASS(class_):
438 def testFn(param):
439 return issubclass(param, class_)
440 return testFn
442 def CONTAINS(val):
443 def testFn(param):
444 return val in param
445 return testFn
447 def IN(container):
448 def testFn(param):
449 return param in container
450 return testFn
452 def HASATTR(attr):
453 def testFn(param):
454 return hasattr(param, attr)
455 return testFn
457 def HASMETHOD(method):
458 def testFn(param):
459 return hasattr(param, method) and callable(getattr(param, method))
460 return testFn
462 CALLABLE = callable