1 """Loading unittests."""
9 from functools
import cmp_to_key
as _CmpToKey
10 from fnmatch
import fnmatch
12 from . import case
, suite
16 # what about .pyc or .pyo (etc)
17 # we would need to avoid loading the same tests multiple times
18 # from '.py', '.pyc' *and* '.pyo'
19 VALID_MODULE_NAME
= re
.compile(r
'[_a-z]\w*\.py$', re
.IGNORECASE
)
22 def _make_failed_import_test(name
, suiteClass
):
23 message
= 'Failed to import test module: %s\n%s' % (name
, traceback
.format_exc())
24 return _make_failed_test('ModuleImportFailure', name
, ImportError(message
),
27 def _make_failed_load_tests(name
, exception
, suiteClass
):
28 return _make_failed_test('LoadTestsFailure', name
, exception
, suiteClass
)
30 def _make_failed_test(classname
, methodname
, exception
, suiteClass
):
31 def testFailure(self
):
33 attrs
= {methodname
: testFailure
}
34 TestClass
= type(classname
, (case
.TestCase
,), attrs
)
35 return suiteClass((TestClass(methodname
),))
38 class TestLoader(object):
40 This class is responsible for loading tests according to various criteria
41 and returning them wrapped in a TestSuite
43 testMethodPrefix
= 'test'
44 sortTestMethodsUsing
= cmp
45 suiteClass
= suite
.TestSuite
48 def loadTestsFromTestCase(self
, testCaseClass
):
49 """Return a suite of all tests cases contained in testCaseClass"""
50 if issubclass(testCaseClass
, suite
.TestSuite
):
51 raise TypeError("Test cases should not be derived from TestSuite." \
52 " Maybe you meant to derive from TestCase?")
53 testCaseNames
= self
.getTestCaseNames(testCaseClass
)
54 if not testCaseNames
and hasattr(testCaseClass
, 'runTest'):
55 testCaseNames
= ['runTest']
56 loaded_suite
= self
.suiteClass(map(testCaseClass
, testCaseNames
))
59 def loadTestsFromModule(self
, module
, use_load_tests
=True):
60 """Return a suite of all tests cases contained in the given module"""
62 for name
in dir(module
):
63 obj
= getattr(module
, name
)
64 if isinstance(obj
, type) and issubclass(obj
, case
.TestCase
):
65 tests
.append(self
.loadTestsFromTestCase(obj
))
67 load_tests
= getattr(module
, 'load_tests', None)
68 tests
= self
.suiteClass(tests
)
69 if use_load_tests
and load_tests
is not None:
71 return load_tests(self
, tests
, None)
73 return _make_failed_load_tests(module
.__name
__, e
,
77 def loadTestsFromName(self
, name
, module
=None):
78 """Return a suite of all tests cases given a string specifier.
80 The name may resolve either to a module, a test case class, a
81 test method within a test case class, or a callable object which
82 returns a TestCase or TestSuite instance.
84 The method optionally resolves the names relative to a given module.
86 parts
= name
.split('.')
91 module
= __import__('.'.join(parts_copy
))
100 parent
, obj
= obj
, getattr(obj
, part
)
102 if isinstance(obj
, types
.ModuleType
):
103 return self
.loadTestsFromModule(obj
)
104 elif isinstance(obj
, type) and issubclass(obj
, case
.TestCase
):
105 return self
.loadTestsFromTestCase(obj
)
106 elif (isinstance(obj
, types
.UnboundMethodType
) and
107 isinstance(parent
, type) and
108 issubclass(parent
, case
.TestCase
)):
109 return self
.suiteClass([parent(obj
.__name
__)])
110 elif isinstance(obj
, suite
.TestSuite
):
112 elif hasattr(obj
, '__call__'):
114 if isinstance(test
, suite
.TestSuite
):
116 elif isinstance(test
, case
.TestCase
):
117 return self
.suiteClass([test
])
119 raise TypeError("calling %s returned %s, not a test" %
122 raise TypeError("don't know how to make test from: %s" % obj
)
124 def loadTestsFromNames(self
, names
, module
=None):
125 """Return a suite of all tests cases found using the given sequence
126 of string specifiers. See 'loadTestsFromName()'.
128 suites
= [self
.loadTestsFromName(name
, module
) for name
in names
]
129 return self
.suiteClass(suites
)
131 def getTestCaseNames(self
, testCaseClass
):
132 """Return a sorted sequence of method names found within testCaseClass
134 def isTestMethod(attrname
, testCaseClass
=testCaseClass
,
135 prefix
=self
.testMethodPrefix
):
136 return attrname
.startswith(prefix
) and \
137 hasattr(getattr(testCaseClass
, attrname
), '__call__')
138 testFnNames
= filter(isTestMethod
, dir(testCaseClass
))
139 if self
.sortTestMethodsUsing
:
140 testFnNames
.sort(key
=_CmpToKey(self
.sortTestMethodsUsing
))
143 def discover(self
, start_dir
, pattern
='test*.py', top_level_dir
=None):
144 """Find and return all test modules from the specified start
145 directory, recursing into subdirectories to find them. Only test files
146 that match the pattern will be loaded. (Using shell style pattern
149 All test modules must be importable from the top level of the project.
150 If the start directory is not the top level directory then the top
151 level directory must be specified separately.
153 If a test package name (directory with '__init__.py') matches the
154 pattern then the package will be checked for a 'load_tests' function. If
155 this exists then it will be called with loader, tests, pattern.
157 If load_tests exists then discovery does *not* recurse into the package,
158 load_tests is responsible for loading all tests in the package.
160 The pattern is deliberately not stored as a loader attribute so that
161 packages can continue discovery themselves. top_level_dir is stored so
162 load_tests does not need to pass this argument in to loader.discover().
164 set_implicit_top
= False
165 if top_level_dir
is None and self
._top
_level
_dir
is not None:
166 # make top_level_dir optional if called from load_tests in a package
167 top_level_dir
= self
._top
_level
_dir
168 elif top_level_dir
is None:
169 set_implicit_top
= True
170 top_level_dir
= start_dir
172 top_level_dir
= os
.path
.abspath(top_level_dir
)
174 if not top_level_dir
in sys
.path
:
175 # all test modules must be importable from the top level directory
176 # should we *unconditionally* put the start directory in first
177 # in sys.path to minimise likelihood of conflicts between installed
178 # modules and development versions?
179 sys
.path
.insert(0, top_level_dir
)
180 self
._top
_level
_dir
= top_level_dir
182 is_not_importable
= False
183 if os
.path
.isdir(os
.path
.abspath(start_dir
)):
184 start_dir
= os
.path
.abspath(start_dir
)
185 if start_dir
!= top_level_dir
:
186 is_not_importable
= not os
.path
.isfile(os
.path
.join(start_dir
, '__init__.py'))
188 # support for discovery from dotted module names
190 __import__(start_dir
)
192 is_not_importable
= True
194 the_module
= sys
.modules
[start_dir
]
195 top_part
= start_dir
.split('.')[0]
196 start_dir
= os
.path
.abspath(os
.path
.dirname((the_module
.__file
__)))
198 self
._top
_level
_dir
= self
._get
_directory
_containing
_module
(top_part
)
199 sys
.path
.remove(top_level_dir
)
201 if is_not_importable
:
202 raise ImportError('Start directory is not importable: %r' % start_dir
)
204 tests
= list(self
._find
_tests
(start_dir
, pattern
))
205 return self
.suiteClass(tests
)
207 def _get_directory_containing_module(self
, module_name
):
208 module
= sys
.modules
[module_name
]
209 full_path
= os
.path
.abspath(module
.__file
__)
211 if os
.path
.basename(full_path
).lower().startswith('__init__.py'):
212 return os
.path
.dirname(os
.path
.dirname(full_path
))
214 # here we have been given a module rather than a package - so
215 # all we can do is search the *same* directory the module is in
216 # should an exception be raised instead
217 return os
.path
.dirname(full_path
)
219 def _get_name_from_path(self
, path
):
220 path
= os
.path
.splitext(os
.path
.normpath(path
))[0]
222 _relpath
= os
.path
.relpath(path
, self
._top
_level
_dir
)
223 assert not os
.path
.isabs(_relpath
), "Path must be within the project"
224 assert not _relpath
.startswith('..'), "Path must be within the project"
226 name
= _relpath
.replace(os
.path
.sep
, '.')
229 def _get_module_from_name(self
, name
):
231 return sys
.modules
[name
]
233 def _match_path(self
, path
, full_path
, pattern
):
234 # override this method to use alternative matching strategy
235 return fnmatch(path
, pattern
)
237 def _find_tests(self
, start_dir
, pattern
):
238 """Used by discovery. Yields test suites it loads."""
239 paths
= os
.listdir(start_dir
)
242 full_path
= os
.path
.join(start_dir
, path
)
243 if os
.path
.isfile(full_path
):
244 if not VALID_MODULE_NAME
.match(path
):
245 # valid Python identifiers only
247 if not self
._match
_path
(path
, full_path
, pattern
):
249 # if the test file matches, load it
250 name
= self
._get
_name
_from
_path
(full_path
)
252 module
= self
._get
_module
_from
_name
(name
)
254 yield _make_failed_import_test(name
, self
.suiteClass
)
256 mod_file
= os
.path
.abspath(getattr(module
, '__file__', full_path
))
257 realpath
= os
.path
.splitext(mod_file
)[0]
258 fullpath_noext
= os
.path
.splitext(full_path
)[0]
259 if realpath
.lower() != fullpath_noext
.lower():
260 module_dir
= os
.path
.dirname(realpath
)
261 mod_name
= os
.path
.splitext(os
.path
.basename(full_path
))[0]
262 expected_dir
= os
.path
.dirname(full_path
)
263 msg
= ("%r module incorrectly imported from %r. Expected %r. "
264 "Is this module globally installed?")
265 raise ImportError(msg
% (mod_name
, module_dir
, expected_dir
))
266 yield self
.loadTestsFromModule(module
)
267 elif os
.path
.isdir(full_path
):
268 if not os
.path
.isfile(os
.path
.join(full_path
, '__init__.py')):
273 if fnmatch(path
, pattern
):
274 # only check load_tests if the package directory itself matches the filter
275 name
= self
._get
_name
_from
_path
(full_path
)
276 package
= self
._get
_module
_from
_name
(name
)
277 load_tests
= getattr(package
, 'load_tests', None)
278 tests
= self
.loadTestsFromModule(package
, use_load_tests
=False)
280 if load_tests
is None:
281 if tests
is not None:
282 # tests loaded from package file
284 # recurse into the package
285 for test
in self
._find
_tests
(full_path
, pattern
):
289 yield load_tests(self
, tests
, pattern
)
291 yield _make_failed_load_tests(package
.__name
__, e
,
294 defaultTestLoader
= TestLoader()
297 def _makeLoader(prefix
, sortUsing
, suiteClass
=None):
298 loader
= TestLoader()
299 loader
.sortTestMethodsUsing
= sortUsing
300 loader
.testMethodPrefix
= prefix
302 loader
.suiteClass
= suiteClass
305 def getTestCaseNames(testCaseClass
, prefix
, sortUsing
=cmp):
306 return _makeLoader(prefix
, sortUsing
).getTestCaseNames(testCaseClass
)
308 def makeSuite(testCaseClass
, prefix
='test', sortUsing
=cmp,
309 suiteClass
=suite
.TestSuite
):
310 return _makeLoader(prefix
, sortUsing
, suiteClass
).loadTestsFromTestCase(testCaseClass
)
312 def findTestCases(module
, prefix
='test', sortUsing
=cmp,
313 suiteClass
=suite
.TestSuite
):
314 return _makeLoader(prefix
, sortUsing
, suiteClass
).loadTestsFromModule(module
)