1 """Loading unittests."""
9 from fnmatch
import fnmatch
11 from . import case
, suite
15 'Convert a cmp= function into a key= function'
17 def __init__(self
, obj
):
19 def __lt__(self
, other
):
20 return mycmp(self
.obj
, other
.obj
) == -1
24 # what about .pyc or .pyo (etc)
25 # we would need to avoid loading the same tests multiple times
26 # from '.py', '.pyc' *and* '.pyo'
27 VALID_MODULE_NAME
= re
.compile(r
'[_a-z]\w*\.py$', re
.IGNORECASE
)
30 def _make_failed_import_test(name
, suiteClass
):
31 message
= 'Failed to import test module: %s' % name
32 if hasattr(traceback
, 'format_exc'):
33 # Python 2.3 compatibility
34 # format_exc returns two frames of discover.py as well
35 message
+= '\n%s' % traceback
.format_exc()
37 def testImportFailure(self
):
38 raise ImportError(message
)
39 attrs
= {name
: testImportFailure
}
40 ModuleImportFailure
= type('ModuleImportFailure', (case
.TestCase
,), attrs
)
41 return suiteClass((ModuleImportFailure(name
),))
44 class TestLoader(object):
46 This class is responsible for loading tests according to various criteria
47 and returning them wrapped in a TestSuite
49 testMethodPrefix
= 'test'
50 sortTestMethodsUsing
= cmp
51 suiteClass
= suite
.TestSuite
54 def loadTestsFromTestCase(self
, testCaseClass
):
55 """Return a suite of all tests cases contained in testCaseClass"""
56 if issubclass(testCaseClass
, suite
.TestSuite
):
57 raise TypeError("Test cases should not be derived from TestSuite." \
58 " Maybe you meant to derive from TestCase?")
59 testCaseNames
= self
.getTestCaseNames(testCaseClass
)
60 if not testCaseNames
and hasattr(testCaseClass
, 'runTest'):
61 testCaseNames
= ['runTest']
62 loaded_suite
= self
.suiteClass(map(testCaseClass
, testCaseNames
))
65 def loadTestsFromModule(self
, module
, use_load_tests
=True):
66 """Return a suite of all tests cases contained in the given module"""
68 for name
in dir(module
):
69 obj
= getattr(module
, name
)
70 if isinstance(obj
, type) and issubclass(obj
, case
.TestCase
):
71 tests
.append(self
.loadTestsFromTestCase(obj
))
73 load_tests
= getattr(module
, 'load_tests', None)
74 if use_load_tests
and load_tests
is not None:
75 return load_tests(self
, tests
, None)
76 return self
.suiteClass(tests
)
78 def loadTestsFromName(self
, name
, module
=None):
79 """Return a suite of all tests cases given a string specifier.
81 The name may resolve either to a module, a test case class, a
82 test method within a test case class, or a callable object which
83 returns a TestCase or TestSuite instance.
85 The method optionally resolves the names relative to a given module.
87 parts
= name
.split('.')
92 module
= __import__('.'.join(parts_copy
))
101 parent
, obj
= obj
, getattr(obj
, part
)
103 if isinstance(obj
, types
.ModuleType
):
104 return self
.loadTestsFromModule(obj
)
105 elif isinstance(obj
, type) and issubclass(obj
, case
.TestCase
):
106 return self
.loadTestsFromTestCase(obj
)
107 elif (isinstance(obj
, types
.UnboundMethodType
) and
108 isinstance(parent
, type) and
109 issubclass(parent
, case
.TestCase
)):
110 return self
.suiteClass([parent(obj
.__name
__)])
111 elif isinstance(obj
, suite
.TestSuite
):
113 elif hasattr(obj
, '__call__'):
115 if isinstance(test
, suite
.TestSuite
):
117 elif isinstance(test
, case
.TestCase
):
118 return self
.suiteClass([test
])
120 raise TypeError("calling %s returned %s, not a test" %
123 raise TypeError("don't know how to make test from: %s" % obj
)
125 def loadTestsFromNames(self
, names
, module
=None):
126 """Return a suite of all tests cases found using the given sequence
127 of string specifiers. See 'loadTestsFromName()'.
129 suites
= [self
.loadTestsFromName(name
, module
) for name
in names
]
130 return self
.suiteClass(suites
)
132 def getTestCaseNames(self
, testCaseClass
):
133 """Return a sorted sequence of method names found within testCaseClass
135 def isTestMethod(attrname
, testCaseClass
=testCaseClass
,
136 prefix
=self
.testMethodPrefix
):
137 return attrname
.startswith(prefix
) and \
138 hasattr(getattr(testCaseClass
, attrname
), '__call__')
139 testFnNames
= filter(isTestMethod
, dir(testCaseClass
))
140 if self
.sortTestMethodsUsing
:
141 testFnNames
.sort(key
=_CmpToKey(self
.sortTestMethodsUsing
))
144 def discover(self
, start_dir
, pattern
='test*.py', top_level_dir
=None):
145 """Find and return all test modules from the specified start
146 directory, recursing into subdirectories to find them. Only test files
147 that match the pattern will be loaded. (Using shell style pattern
150 All test modules must be importable from the top level of the project.
151 If the start directory is not the top level directory then the top
152 level directory must be specified separately.
154 If a test package name (directory with '__init__.py') matches the
155 pattern then the package will be checked for a 'load_tests' function. If
156 this exists then it will be called with loader, tests, pattern.
158 If load_tests exists then discovery does *not* recurse into the package,
159 load_tests is responsible for loading all tests in the package.
161 The pattern is deliberately not stored as a loader attribute so that
162 packages can continue discovery themselves. top_level_dir is stored so
163 load_tests does not need to pass this argument in to loader.discover().
165 if top_level_dir
is None and self
._top
_level
_dir
is not None:
166 # make top_level_dir optional if called from load_tests in a package
167 top_level_dir
= self
._top
_level
_dir
168 elif top_level_dir
is None:
169 top_level_dir
= start_dir
171 top_level_dir
= os
.path
.abspath(os
.path
.normpath(top_level_dir
))
172 start_dir
= os
.path
.abspath(os
.path
.normpath(start_dir
))
174 if not top_level_dir
in sys
.path
:
175 # all test modules must be importable from the top level directory
176 sys
.path
.append(top_level_dir
)
177 self
._top
_level
_dir
= top_level_dir
179 if start_dir
!= top_level_dir
and not os
.path
.isfile(os
.path
.join(start_dir
, '__init__.py')):
180 # what about __init__.pyc or pyo (etc)
181 raise ImportError('Start directory is not importable: %r' % start_dir
)
183 tests
= list(self
._find
_tests
(start_dir
, pattern
))
184 return self
.suiteClass(tests
)
186 def _get_name_from_path(self
, path
):
187 path
= os
.path
.splitext(os
.path
.normpath(path
))[0]
189 _relpath
= os
.path
.relpath(path
, self
._top
_level
_dir
)
190 assert not os
.path
.isabs(_relpath
), "Path must be within the project"
191 assert not _relpath
.startswith('..'), "Path must be within the project"
193 name
= _relpath
.replace(os
.path
.sep
, '.')
196 def _get_module_from_name(self
, name
):
198 return sys
.modules
[name
]
200 def _find_tests(self
, start_dir
, pattern
):
201 """Used by discovery. Yields test suites it loads."""
202 paths
= os
.listdir(start_dir
)
205 full_path
= os
.path
.join(start_dir
, path
)
206 if os
.path
.isfile(full_path
):
207 if not VALID_MODULE_NAME
.match(path
):
208 # valid Python identifiers only
211 if fnmatch(path
, pattern
):
212 # if the test file matches, load it
213 name
= self
._get
_name
_from
_path
(full_path
)
215 module
= self
._get
_module
_from
_name
(name
)
217 yield _make_failed_import_test(name
, self
.suiteClass
)
219 yield self
.loadTestsFromModule(module
)
220 elif os
.path
.isdir(full_path
):
221 if not os
.path
.isfile(os
.path
.join(full_path
, '__init__.py')):
226 if fnmatch(path
, pattern
):
227 # only check load_tests if the package directory itself matches the filter
228 name
= self
._get
_name
_from
_path
(full_path
)
229 package
= self
._get
_module
_from
_name
(name
)
230 load_tests
= getattr(package
, 'load_tests', None)
231 tests
= self
.loadTestsFromModule(package
, use_load_tests
=False)
233 if load_tests
is None:
234 if tests
is not None:
235 # tests loaded from package file
237 # recurse into the package
238 for test
in self
._find
_tests
(full_path
, pattern
):
241 yield load_tests(self
, tests
, pattern
)
243 defaultTestLoader
= TestLoader()
246 def _makeLoader(prefix
, sortUsing
, suiteClass
=None):
247 loader
= TestLoader()
248 loader
.sortTestMethodsUsing
= sortUsing
249 loader
.testMethodPrefix
= prefix
251 loader
.suiteClass
= suiteClass
254 def getTestCaseNames(testCaseClass
, prefix
, sortUsing
=cmp):
255 return _makeLoader(prefix
, sortUsing
).getTestCaseNames(testCaseClass
)
257 def makeSuite(testCaseClass
, prefix
='test', sortUsing
=cmp,
258 suiteClass
=suite
.TestSuite
):
259 return _makeLoader(prefix
, sortUsing
, suiteClass
).loadTestsFromTestCase(testCaseClass
)
261 def findTestCases(module
, prefix
='test', sortUsing
=cmp,
262 suiteClass
=suite
.TestSuite
):
263 return _makeLoader(prefix
, sortUsing
, suiteClass
).loadTestsFromModule(module
)