2 # (c) Dave Kirby 2001 - 2005
3 # mock@thedeveloperscoach.com
5 # Original call interceptor and call assertion code by Phil Dawes (pdawes@users.sourceforge.net)
6 # Call interceptor code enhanced by Bruce Cropley (cropleyb@yahoo.com.au)
8 # This Python module and associated files are released under the FreeBSD
9 # license. Essentially, you can do what you like with it except pretend you wrote
13 # Copyright (c) 2005, Dave Kirby
15 # All rights reserved.
17 # Redistribution and use in source and binary forms, with or without
18 # modification, are permitted provided that the following conditions are met:
20 # * Redistributions of source code must retain the above copyright
21 # notice, this list of conditions and the following disclaimer.
23 # * Redistributions in binary form must reproduce the above copyright
24 # notice, this list of conditions and the following disclaimer in the
25 # documentation and/or other materials provided with the distribution.
27 # * Neither the name of this library nor the names of its
28 # contributors may be used to endorse or promote products derived from
29 # this software without specific prior written permission.
31 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
32 # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
33 # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
34 # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
35 # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
36 # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
37 # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
38 # ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
39 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
40 # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
42 # mock@thedeveloperscoach.com
46 Mock object library for Python. Mock objects can be used when unit testing
47 to remove a dependency on another production class. They are typically used
48 when the dependency would either pull in lots of other classes, or
49 significantly slow down the execution of the test.
50 They are also used to create exceptional conditions that cannot otherwise
51 be easily triggered in the class under test.
60 class MockInterfaceError(Exception):
65 The Mock class emulates any other class for testing purposes.
66 All method calls are stored for later examination.
69 def __init__(self
, returnValues
=None, realClass
=None):
71 The Mock class constructor takes a dictionary of method names and
72 the values they return. Methods that are not in the returnValues
73 dictionary will return None.
74 You may also supply a class whose interface is being mocked.
75 All calls will be checked to see if they appear in the original
76 interface. Any calls to methods not appearing in the real class
77 will raise a MockInterfaceError. Any calls that would fail due to
78 non-matching parameter lists will also raise a MockInterfaceError.
79 Both of these help to prevent the Mock class getting out of sync
80 with the class it is Mocking.
82 self
.mockCalledMethods
= {}
83 self
.mockAllCalledMethods
= []
84 self
.mockReturnValues
= returnValues
or {}
85 self
.mockExpectations
= {}
86 self
.realClassMethods
= None
88 self
.realClassMethods
= dict(inspect
.getmembers(realClass
, inspect
.isroutine
))
89 for retMethod
in self
.mockReturnValues
.keys():
90 if not self
.realClassMethods
.has_key(retMethod
):
91 raise MockInterfaceError("Return value supplied for method '%s' that was not in the original class" % retMethod
)
92 self
._setupSubclassMethodInterceptors
()
94 def _setupSubclassMethodInterceptors(self
):
95 methods
= inspect
.getmembers(self
.__class
__,inspect
.isroutine
)
96 baseMethods
= dict(inspect
.getmembers(Mock
, inspect
.ismethod
))
99 # Don't record calls to methods of Mock base class.
100 if not name
in baseMethods
:
101 self
.__dict
__[name
] = MockCallable(name
, self
, handcrafted
=True)
103 def __getattr__(self
, name
):
104 return MockCallable(name
, self
)
106 def mockAddReturnValues(self
, **methodReturnValues
):
107 self
.mockReturnValues
.update(methodReturnValues
)
109 def mockSetExpectation(self
, name
, testFn
, after
=0, until
=0):
110 self
.mockExpectations
.setdefault(name
, []).append((testFn
,after
,until
))
112 def _checkInterfaceCall(self
, name
, callParams
, callKwParams
):
114 Check that a call to a method of the given name to the original
115 class with the given parameters would not fail. If it would fail,
116 raise a MockInterfaceError.
117 Based on the Python 2.3.3 Reference Manual section 5.3.4: Calls.
119 if self
.realClassMethods
== None:
121 if not self
.realClassMethods
.has_key(name
):
122 raise MockInterfaceError("Calling mock method '%s' that was not found in the original class" % name
)
124 func
= self
.realClassMethods
[name
]
126 args
, varargs
, varkw
, defaults
= inspect
.getargspec(func
)
128 # func is not a Python function. It is probably a builtin,
129 # such as __repr__ or __coerce__. TODO: Checking?
130 # For now assume params are OK.
133 # callParams doesn't include self; args does include self.
134 numPosCallParams
= 1 + len(callParams
)
136 if numPosCallParams
> len(args
) and not varargs
:
137 raise MockInterfaceError("Original %s() takes at most %s arguments (%s given)" %
138 (name
, len(args
), numPosCallParams
))
140 # Get the number of positional arguments that appear in the call,
141 # also check for duplicate parameters and unknown parameters
142 numPosSeen
= _getNumPosSeenAndCheck(numPosCallParams
, callKwParams
, args
, varkw
)
144 lenArgsNoDefaults
= len(args
) - len(defaults
or [])
145 if numPosSeen
< lenArgsNoDefaults
:
146 raise MockInterfaceError("Original %s() takes at least %s arguments (%s given)" % (name
, lenArgsNoDefaults
, numPosSeen
))
148 def mockGetAllCalls(self
):
150 Return a list of MockCall objects,
151 representing all the methods in the order they were called.
153 return self
.mockAllCalledMethods
154 getAllCalls
= mockGetAllCalls
# deprecated - kept for backward compatibility
156 def mockGetNamedCalls(self
, methodName
):
158 Return a list of MockCall objects,
159 representing all the calls to the named method in the order they were called.
161 return self
.mockCalledMethods
.get(methodName
, [])
162 getNamedCalls
= mockGetNamedCalls
# deprecated - kept for backward compatibility
164 def mockCheckCall(self
, index
, name
, *args
, **kwargs
):
165 '''test that the index-th call had the specified name and parameters'''
166 call
= self
.mockAllCalledMethods
[index
]
167 assert name
== call
.getName(), "%r != %r" % (name
, call
.getName())
168 call
.checkArgs(*args
, **kwargs
)
171 def _getNumPosSeenAndCheck(numPosCallParams
, callKwParams
, args
, varkw
):
173 Positional arguments can appear as call parameters either named as
174 a named (keyword) parameter, or just as a value to be matched by
175 position. Count the positional arguments that are given by either
176 keyword or position, and check for duplicate specifications.
177 Also check for arguments specified by keyword that do not appear
178 in the method's parameter list.
181 for arg
in args
[:numPosCallParams
]:
183 for kwp
in callKwParams
:
184 if posSeen
.has_key(kwp
):
185 raise MockInterfaceError("%s appears as both a positional and named parameter." % kwp
)
189 raise MockInterfaceError("Original method does not have a parameter '%s'" % kwp
)
194 MockCall records the name and parameters of a call to an instance
195 of a Mock class. Instances of MockCall are created by the Mock class,
196 but can be inspected later as part of the test.
198 def __init__(self
, name
, params
, kwparams
):
201 self
.kwparams
= kwparams
203 def checkArgs(self
, *args
, **kwargs
):
204 assert args
== self
.params
, "%r != %r" % (args
, self
.params
)
205 assert kwargs
== self
.kwparams
, "%r != %r" % (kwargs
, self
.kwparams
)
207 def getParam( self
, n
):
208 if isinstance(n
, int):
209 return self
.params
[n
]
210 elif isinstance(n
, str):
211 return self
.kwparams
[n
]
213 raise IndexError, 'illegal index type for getParam'
215 def getNumParams(self
):
216 return len(self
.params
)
218 def getNumKwParams(self
):
219 return len(self
.kwparams
)
224 #pretty-print the method call
228 for p
in self
.params
:
229 s
= s
+ sep
+ repr(p
)
231 items
= self
.kwparams
.items()
234 s
= s
+ sep
+ k
+ '=' + repr(v
)
239 return self
.__str
__()
243 Intercepts the call and records it, then delegates to either the mock's
244 dictionary of mock return values that was passed in to the constructor,
245 or a handcrafted method of a Mock subclass.
247 def __init__(self
, name
, mock
, handcrafted
=False):
250 self
.handcrafted
= handcrafted
252 def __call__(self
, *params
, **kwparams
):
253 self
.mock
._checkInterfaceCall
(self
.name
, params
, kwparams
)
254 thisCall
= self
.recordCall(params
,kwparams
)
255 self
.checkExpectations(thisCall
, params
, kwparams
)
256 return self
.makeCall(params
, kwparams
)
258 def recordCall(self
, params
, kwparams
):
260 Record the MockCall in an ordered list of all calls, and an ordered
261 list of calls for that method name.
263 thisCall
= MockCall(self
.name
, params
, kwparams
)
264 calls
= self
.mock
.mockCalledMethods
.setdefault(self
.name
, [])
265 calls
.append(thisCall
)
266 self
.mock
.mockAllCalledMethods
.append(thisCall
)
269 def makeCall(self
, params
, kwparams
):
271 allPosParams
= (self
.mock
,) + params
272 func
= _findFunc(self
.mock
.__class
__, self
.name
)
274 raise NotImplementedError
275 return func(*allPosParams
, **kwparams
)
277 returnVal
= self
.mock
.mockReturnValues
.get(self
.name
)
278 if isinstance(returnVal
, ReturnValuesBase
):
279 returnVal
= returnVal
.next()
282 def checkExpectations(self
, thisCall
, params
, kwparams
):
283 if self
.name
in self
.mock
.mockExpectations
:
284 callsMade
= len(self
.mock
.mockCalledMethods
[self
.name
])
285 for (expectation
, after
, until
) in self
.mock
.mockExpectations
[self
.name
]:
286 if callsMade
> after
and (until
==0 or callsMade
< until
):
287 assert expectation(self
.mock
, thisCall
, len(self
.mock
.mockAllCalledMethods
)-1), 'Expectation failed: '+str(thisCall
)
290 def _findFunc(cl
, name
):
291 """ Depth first search for a method with a given name. """
292 if cl
.__dict
__.has_key(name
):
293 return cl
.__dict
__[name
]
294 for base
in cl
.__bases
__:
295 func
= _findFunc(base
, name
)
302 class ReturnValuesBase
:
305 return self
.iter.next()
306 except StopIteration:
307 raise AssertionError("No more return values")
311 class ReturnValues(ReturnValuesBase
):
312 def __init__(self
, *values
):
313 self
.iter = iter(values
)
316 class ReturnIterator(ReturnValuesBase
):
317 def __init__(self
, iterator
):
318 self
.iter = iter(iterator
)
321 def expectParams(*params
, **keywords
):
322 '''check that the callObj is called with specified params and keywords
324 def fn(mockObj
, callObj
, idx
):
325 return callObj
.params
== params
and callObj
.kwparams
== keywords
329 def expectAfter(*methods
):
330 '''check that the function is only called after all the functions in 'methods'
332 def fn(mockObj
, callObj
, idx
):
333 calledMethods
= [method
.getName() for method
in mockObj
.mockGetAllCalls()]
334 #skip last entry, since that is the current call
335 calledMethods
= calledMethods
[:-1]
336 for method
in methods
:
337 if method
not in calledMethods
:
342 def expectException(exception
, *args
, **kwargs
):
343 ''' raise an exception when the method is called
345 def fn(mockObj
, callObj
, idx
):
346 raise exception(*args
, **kwargs
)
350 def expectParam(paramIdx
, cond
):
351 '''check that the callObj is called with parameter specified by paramIdx (a position index or keyword)
352 fulfills the condition specified by cond.
353 cond is a function that takes a single argument, the value to test.
355 def fn(mockObj
, callObj
, idx
):
356 param
= callObj
.getParam(paramIdx
)
362 return param
== value
367 return param
!= value
382 return param
>= value
387 return param
<= value
392 for cond
in condlist
:
400 for cond
in condlist
:
408 return not cond(param
)
411 def MATCHES(regex
, *args
, **kwargs
):
412 compiled_regex
= re
.compile(regex
, *args
, **kwargs
)
414 return compiled_regex
.match(param
) != None
418 iterator
= iter(sequence
)
421 cond
= iterator
.next()
422 except StopIteration:
423 raise AssertionError('SEQ exhausted')
429 return param
is instance
432 def ISINSTANCE(class_
):
434 return isinstance(param
, class_
)
437 def ISSUBCLASS(class_
):
439 return issubclass(param
, class_
)
449 return param
in container
454 return hasattr(param
, attr
)
457 def HASMETHOD(method
):
459 return hasattr(param
, method
) and callable(getattr(param
, method
))