Issue #1515: Enable use of deepcopy() with instance methods. Patch by Robert Collins.
[python.git] / Lib / unittest / loader.py
blob31c343b49b1e5f9d4badda68755e584b6ea02a17
1 """Loading unittests."""
3 import os
4 import re
5 import sys
6 import traceback
7 import types
9 from fnmatch import fnmatch
11 from . import case, suite
14 def _CmpToKey(mycmp):
15 'Convert a cmp= function into a key= function'
16 class K(object):
17 def __init__(self, obj):
18 self.obj = obj
19 def __lt__(self, other):
20 return mycmp(self.obj, other.obj) == -1
21 return K
24 # what about .pyc or .pyo (etc)
25 # we would need to avoid loading the same tests multiple times
26 # from '.py', '.pyc' *and* '.pyo'
27 VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE)
30 def _make_failed_import_test(name, suiteClass):
31 message = 'Failed to import test module: %s' % name
32 if hasattr(traceback, 'format_exc'):
33 # Python 2.3 compatibility
34 # format_exc returns two frames of discover.py as well
35 message += '\n%s' % traceback.format_exc()
37 def testImportFailure(self):
38 raise ImportError(message)
39 attrs = {name: testImportFailure}
40 ModuleImportFailure = type('ModuleImportFailure', (case.TestCase,), attrs)
41 return suiteClass((ModuleImportFailure(name),))
44 class TestLoader(object):
45 """
46 This class is responsible for loading tests according to various criteria
47 and returning them wrapped in a TestSuite
48 """
49 testMethodPrefix = 'test'
50 sortTestMethodsUsing = cmp
51 suiteClass = suite.TestSuite
52 _top_level_dir = None
54 def loadTestsFromTestCase(self, testCaseClass):
55 """Return a suite of all tests cases contained in testCaseClass"""
56 if issubclass(testCaseClass, suite.TestSuite):
57 raise TypeError("Test cases should not be derived from TestSuite." \
58 " Maybe you meant to derive from TestCase?")
59 testCaseNames = self.getTestCaseNames(testCaseClass)
60 if not testCaseNames and hasattr(testCaseClass, 'runTest'):
61 testCaseNames = ['runTest']
62 loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames))
63 return loaded_suite
65 def loadTestsFromModule(self, module, use_load_tests=True):
66 """Return a suite of all tests cases contained in the given module"""
67 tests = []
68 for name in dir(module):
69 obj = getattr(module, name)
70 if isinstance(obj, type) and issubclass(obj, case.TestCase):
71 tests.append(self.loadTestsFromTestCase(obj))
73 load_tests = getattr(module, 'load_tests', None)
74 if use_load_tests and load_tests is not None:
75 return load_tests(self, tests, None)
76 return self.suiteClass(tests)
78 def loadTestsFromName(self, name, module=None):
79 """Return a suite of all tests cases given a string specifier.
81 The name may resolve either to a module, a test case class, a
82 test method within a test case class, or a callable object which
83 returns a TestCase or TestSuite instance.
85 The method optionally resolves the names relative to a given module.
86 """
87 parts = name.split('.')
88 if module is None:
89 parts_copy = parts[:]
90 while parts_copy:
91 try:
92 module = __import__('.'.join(parts_copy))
93 break
94 except ImportError:
95 del parts_copy[-1]
96 if not parts_copy:
97 raise
98 parts = parts[1:]
99 obj = module
100 for part in parts:
101 parent, obj = obj, getattr(obj, part)
103 if isinstance(obj, types.ModuleType):
104 return self.loadTestsFromModule(obj)
105 elif isinstance(obj, type) and issubclass(obj, case.TestCase):
106 return self.loadTestsFromTestCase(obj)
107 elif (isinstance(obj, types.UnboundMethodType) and
108 isinstance(parent, type) and
109 issubclass(parent, case.TestCase)):
110 return self.suiteClass([parent(obj.__name__)])
111 elif isinstance(obj, suite.TestSuite):
112 return obj
113 elif hasattr(obj, '__call__'):
114 test = obj()
115 if isinstance(test, suite.TestSuite):
116 return test
117 elif isinstance(test, case.TestCase):
118 return self.suiteClass([test])
119 else:
120 raise TypeError("calling %s returned %s, not a test" %
121 (obj, test))
122 else:
123 raise TypeError("don't know how to make test from: %s" % obj)
125 def loadTestsFromNames(self, names, module=None):
126 """Return a suite of all tests cases found using the given sequence
127 of string specifiers. See 'loadTestsFromName()'.
129 suites = [self.loadTestsFromName(name, module) for name in names]
130 return self.suiteClass(suites)
132 def getTestCaseNames(self, testCaseClass):
133 """Return a sorted sequence of method names found within testCaseClass
135 def isTestMethod(attrname, testCaseClass=testCaseClass,
136 prefix=self.testMethodPrefix):
137 return attrname.startswith(prefix) and \
138 hasattr(getattr(testCaseClass, attrname), '__call__')
139 testFnNames = filter(isTestMethod, dir(testCaseClass))
140 if self.sortTestMethodsUsing:
141 testFnNames.sort(key=_CmpToKey(self.sortTestMethodsUsing))
142 return testFnNames
144 def discover(self, start_dir, pattern='test*.py', top_level_dir=None):
145 """Find and return all test modules from the specified start
146 directory, recursing into subdirectories to find them. Only test files
147 that match the pattern will be loaded. (Using shell style pattern
148 matching.)
150 All test modules must be importable from the top level of the project.
151 If the start directory is not the top level directory then the top
152 level directory must be specified separately.
154 If a test package name (directory with '__init__.py') matches the
155 pattern then the package will be checked for a 'load_tests' function. If
156 this exists then it will be called with loader, tests, pattern.
158 If load_tests exists then discovery does *not* recurse into the package,
159 load_tests is responsible for loading all tests in the package.
161 The pattern is deliberately not stored as a loader attribute so that
162 packages can continue discovery themselves. top_level_dir is stored so
163 load_tests does not need to pass this argument in to loader.discover().
165 if top_level_dir is None and self._top_level_dir is not None:
166 # make top_level_dir optional if called from load_tests in a package
167 top_level_dir = self._top_level_dir
168 elif top_level_dir is None:
169 top_level_dir = start_dir
171 top_level_dir = os.path.abspath(os.path.normpath(top_level_dir))
172 start_dir = os.path.abspath(os.path.normpath(start_dir))
174 if not top_level_dir in sys.path:
175 # all test modules must be importable from the top level directory
176 sys.path.append(top_level_dir)
177 self._top_level_dir = top_level_dir
179 if start_dir != top_level_dir and not os.path.isfile(os.path.join(start_dir, '__init__.py')):
180 # what about __init__.pyc or pyo (etc)
181 raise ImportError('Start directory is not importable: %r' % start_dir)
183 tests = list(self._find_tests(start_dir, pattern))
184 return self.suiteClass(tests)
186 def _get_name_from_path(self, path):
187 path = os.path.splitext(os.path.normpath(path))[0]
189 _relpath = os.path.relpath(path, self._top_level_dir)
190 assert not os.path.isabs(_relpath), "Path must be within the project"
191 assert not _relpath.startswith('..'), "Path must be within the project"
193 name = _relpath.replace(os.path.sep, '.')
194 return name
196 def _get_module_from_name(self, name):
197 __import__(name)
198 return sys.modules[name]
200 def _find_tests(self, start_dir, pattern):
201 """Used by discovery. Yields test suites it loads."""
202 paths = os.listdir(start_dir)
204 for path in paths:
205 full_path = os.path.join(start_dir, path)
206 if os.path.isfile(full_path):
207 if not VALID_MODULE_NAME.match(path):
208 # valid Python identifiers only
209 continue
211 if fnmatch(path, pattern):
212 # if the test file matches, load it
213 name = self._get_name_from_path(full_path)
214 try:
215 module = self._get_module_from_name(name)
216 except:
217 yield _make_failed_import_test(name, self.suiteClass)
218 else:
219 yield self.loadTestsFromModule(module)
220 elif os.path.isdir(full_path):
221 if not os.path.isfile(os.path.join(full_path, '__init__.py')):
222 continue
224 load_tests = None
225 tests = None
226 if fnmatch(path, pattern):
227 # only check load_tests if the package directory itself matches the filter
228 name = self._get_name_from_path(full_path)
229 package = self._get_module_from_name(name)
230 load_tests = getattr(package, 'load_tests', None)
231 tests = self.loadTestsFromModule(package, use_load_tests=False)
233 if load_tests is None:
234 if tests is not None:
235 # tests loaded from package file
236 yield tests
237 # recurse into the package
238 for test in self._find_tests(full_path, pattern):
239 yield test
240 else:
241 yield load_tests(self, tests, pattern)
243 defaultTestLoader = TestLoader()
246 def _makeLoader(prefix, sortUsing, suiteClass=None):
247 loader = TestLoader()
248 loader.sortTestMethodsUsing = sortUsing
249 loader.testMethodPrefix = prefix
250 if suiteClass:
251 loader.suiteClass = suiteClass
252 return loader
254 def getTestCaseNames(testCaseClass, prefix, sortUsing=cmp):
255 return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass)
257 def makeSuite(testCaseClass, prefix='test', sortUsing=cmp,
258 suiteClass=suite.TestSuite):
259 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(testCaseClass)
261 def findTestCases(module, prefix='test', sortUsing=cmp,
262 suiteClass=suite.TestSuite):
263 return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(module)