move sections
[python/dscho.git] / Lib / test / test_contextlib.py
blobf28c95eadbc1968b50c2fa1ca79d8547b583b2c5
1 """Unit tests for contextlib.py, and other context managers."""
3 import sys
4 import tempfile
5 import unittest
6 from contextlib import * # Tests __all__
7 from test import test_support
8 try:
9 import threading
10 except ImportError:
11 threading = None
14 class ContextManagerTestCase(unittest.TestCase):
16 def test_contextmanager_plain(self):
17 state = []
18 @contextmanager
19 def woohoo():
20 state.append(1)
21 yield 42
22 state.append(999)
23 with woohoo() as x:
24 self.assertEqual(state, [1])
25 self.assertEqual(x, 42)
26 state.append(x)
27 self.assertEqual(state, [1, 42, 999])
29 def test_contextmanager_finally(self):
30 state = []
31 @contextmanager
32 def woohoo():
33 state.append(1)
34 try:
35 yield 42
36 finally:
37 state.append(999)
38 with self.assertRaises(ZeroDivisionError):
39 with woohoo() as x:
40 self.assertEqual(state, [1])
41 self.assertEqual(x, 42)
42 state.append(x)
43 raise ZeroDivisionError()
44 self.assertEqual(state, [1, 42, 999])
46 def test_contextmanager_no_reraise(self):
47 @contextmanager
48 def whee():
49 yield
50 ctx = whee()
51 ctx.__enter__()
52 # Calling __exit__ should not result in an exception
53 self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None))
55 def test_contextmanager_trap_yield_after_throw(self):
56 @contextmanager
57 def whoo():
58 try:
59 yield
60 except:
61 yield
62 ctx = whoo()
63 ctx.__enter__()
64 self.assertRaises(
65 RuntimeError, ctx.__exit__, TypeError, TypeError("foo"), None
68 def test_contextmanager_except(self):
69 state = []
70 @contextmanager
71 def woohoo():
72 state.append(1)
73 try:
74 yield 42
75 except ZeroDivisionError, e:
76 state.append(e.args[0])
77 self.assertEqual(state, [1, 42, 999])
78 with woohoo() as x:
79 self.assertEqual(state, [1])
80 self.assertEqual(x, 42)
81 state.append(x)
82 raise ZeroDivisionError(999)
83 self.assertEqual(state, [1, 42, 999])
85 def _create_contextmanager_attribs(self):
86 def attribs(**kw):
87 def decorate(func):
88 for k,v in kw.items():
89 setattr(func,k,v)
90 return func
91 return decorate
92 @contextmanager
93 @attribs(foo='bar')
94 def baz(spam):
95 """Whee!"""
96 return baz
98 def test_contextmanager_attribs(self):
99 baz = self._create_contextmanager_attribs()
100 self.assertEqual(baz.__name__,'baz')
101 self.assertEqual(baz.foo, 'bar')
103 @unittest.skipIf(sys.flags.optimize >= 2,
104 "Docstrings are omitted with -O2 and above")
105 def test_contextmanager_doc_attrib(self):
106 baz = self._create_contextmanager_attribs()
107 self.assertEqual(baz.__doc__, "Whee!")
109 class NestedTestCase(unittest.TestCase):
111 # XXX This needs more work
113 def test_nested(self):
114 @contextmanager
115 def a():
116 yield 1
117 @contextmanager
118 def b():
119 yield 2
120 @contextmanager
121 def c():
122 yield 3
123 with nested(a(), b(), c()) as (x, y, z):
124 self.assertEqual(x, 1)
125 self.assertEqual(y, 2)
126 self.assertEqual(z, 3)
128 def test_nested_cleanup(self):
129 state = []
130 @contextmanager
131 def a():
132 state.append(1)
133 try:
134 yield 2
135 finally:
136 state.append(3)
137 @contextmanager
138 def b():
139 state.append(4)
140 try:
141 yield 5
142 finally:
143 state.append(6)
144 with self.assertRaises(ZeroDivisionError):
145 with nested(a(), b()) as (x, y):
146 state.append(x)
147 state.append(y)
148 1 // 0
149 self.assertEqual(state, [1, 4, 2, 5, 6, 3])
151 def test_nested_right_exception(self):
152 @contextmanager
153 def a():
154 yield 1
155 class b(object):
156 def __enter__(self):
157 return 2
158 def __exit__(self, *exc_info):
159 try:
160 raise Exception()
161 except:
162 pass
163 with self.assertRaises(ZeroDivisionError):
164 with nested(a(), b()) as (x, y):
165 1 // 0
166 self.assertEqual((x, y), (1, 2))
168 def test_nested_b_swallows(self):
169 @contextmanager
170 def a():
171 yield
172 @contextmanager
173 def b():
174 try:
175 yield
176 except:
177 # Swallow the exception
178 pass
179 try:
180 with nested(a(), b()):
181 1 // 0
182 except ZeroDivisionError:
183 self.fail("Didn't swallow ZeroDivisionError")
185 def test_nested_break(self):
186 @contextmanager
187 def a():
188 yield
189 state = 0
190 while True:
191 state += 1
192 with nested(a(), a()):
193 break
194 state += 10
195 self.assertEqual(state, 1)
197 def test_nested_continue(self):
198 @contextmanager
199 def a():
200 yield
201 state = 0
202 while state < 3:
203 state += 1
204 with nested(a(), a()):
205 continue
206 state += 10
207 self.assertEqual(state, 3)
209 def test_nested_return(self):
210 @contextmanager
211 def a():
212 try:
213 yield
214 except:
215 pass
216 def foo():
217 with nested(a(), a()):
218 return 1
219 return 10
220 self.assertEqual(foo(), 1)
222 class ClosingTestCase(unittest.TestCase):
224 # XXX This needs more work
226 def test_closing(self):
227 state = []
228 class C:
229 def close(self):
230 state.append(1)
231 x = C()
232 self.assertEqual(state, [])
233 with closing(x) as y:
234 self.assertEqual(x, y)
235 self.assertEqual(state, [1])
237 def test_closing_error(self):
238 state = []
239 class C:
240 def close(self):
241 state.append(1)
242 x = C()
243 self.assertEqual(state, [])
244 with self.assertRaises(ZeroDivisionError):
245 with closing(x) as y:
246 self.assertEqual(x, y)
247 1 // 0
248 self.assertEqual(state, [1])
250 class FileContextTestCase(unittest.TestCase):
252 def testWithOpen(self):
253 tfn = tempfile.mktemp()
254 try:
255 f = None
256 with open(tfn, "w") as f:
257 self.assertFalse(f.closed)
258 f.write("Booh\n")
259 self.assertTrue(f.closed)
260 f = None
261 with self.assertRaises(ZeroDivisionError):
262 with open(tfn, "r") as f:
263 self.assertFalse(f.closed)
264 self.assertEqual(f.read(), "Booh\n")
265 1 // 0
266 self.assertTrue(f.closed)
267 finally:
268 test_support.unlink(tfn)
270 @unittest.skipUnless(threading, 'Threading required for this test.')
271 class LockContextTestCase(unittest.TestCase):
273 def boilerPlate(self, lock, locked):
274 self.assertFalse(locked())
275 with lock:
276 self.assertTrue(locked())
277 self.assertFalse(locked())
278 with self.assertRaises(ZeroDivisionError):
279 with lock:
280 self.assertTrue(locked())
281 1 // 0
282 self.assertFalse(locked())
284 def testWithLock(self):
285 lock = threading.Lock()
286 self.boilerPlate(lock, lock.locked)
288 def testWithRLock(self):
289 lock = threading.RLock()
290 self.boilerPlate(lock, lock._is_owned)
292 def testWithCondition(self):
293 lock = threading.Condition()
294 def locked():
295 return lock._is_owned()
296 self.boilerPlate(lock, locked)
298 def testWithSemaphore(self):
299 lock = threading.Semaphore()
300 def locked():
301 if lock.acquire(False):
302 lock.release()
303 return False
304 else:
305 return True
306 self.boilerPlate(lock, locked)
308 def testWithBoundedSemaphore(self):
309 lock = threading.BoundedSemaphore()
310 def locked():
311 if lock.acquire(False):
312 lock.release()
313 return False
314 else:
315 return True
316 self.boilerPlate(lock, locked)
318 # This is needed to make the test actually run under regrtest.py!
319 def test_main():
320 with test_support.check_warnings(("With-statements now directly support "
321 "multiple context managers",
322 DeprecationWarning)):
323 test_support.run_unittest(__name__)
325 if __name__ == "__main__":
326 test_main()