s4:dsdb/drepl: update the source_dsa_obj/invocation_id in repsFrom
[Samba/gebeck_regimport.git] / lib / testtools / testtools / tests / test_spinner.py
blob3d677bd75450cb62876378ad306fe53ec0f688ca
1 # Copyright (c) 2010 testtools developers. See LICENSE for details.
3 """Tests for the evil Twisted reactor-spinning we do."""
5 import os
6 import signal
8 from testtools import (
9 skipIf,
10 TestCase,
12 from testtools.helpers import try_import
13 from testtools.matchers import (
14 Equals,
15 Is,
16 MatchesException,
17 Raises,
20 _spinner = try_import('testtools._spinner')
22 defer = try_import('twisted.internet.defer')
23 Failure = try_import('twisted.python.failure.Failure')
26 class NeedsTwistedTestCase(TestCase):
28 def setUp(self):
29 super(NeedsTwistedTestCase, self).setUp()
30 if defer is None or Failure is None:
31 self.skipTest("Need Twisted to run")
34 class TestNotReentrant(NeedsTwistedTestCase):
36 def test_not_reentrant(self):
37 # A function decorated as not being re-entrant will raise a
38 # _spinner.ReentryError if it is called while it is running.
39 calls = []
40 @_spinner.not_reentrant
41 def log_something():
42 calls.append(None)
43 if len(calls) < 5:
44 log_something()
45 self.assertThat(
46 log_something, Raises(MatchesException(_spinner.ReentryError)))
47 self.assertEqual(1, len(calls))
49 def test_deeper_stack(self):
50 calls = []
51 @_spinner.not_reentrant
52 def g():
53 calls.append(None)
54 if len(calls) < 5:
55 f()
56 @_spinner.not_reentrant
57 def f():
58 calls.append(None)
59 if len(calls) < 5:
60 g()
61 self.assertThat(f, Raises(MatchesException(_spinner.ReentryError)))
62 self.assertEqual(2, len(calls))
65 class TestExtractResult(NeedsTwistedTestCase):
67 def test_not_fired(self):
68 # _spinner.extract_result raises _spinner.DeferredNotFired if it's
69 # given a Deferred that has not fired.
70 self.assertThat(lambda:_spinner.extract_result(defer.Deferred()),
71 Raises(MatchesException(_spinner.DeferredNotFired)))
73 def test_success(self):
74 # _spinner.extract_result returns the value of the Deferred if it has
75 # fired successfully.
76 marker = object()
77 d = defer.succeed(marker)
78 self.assertThat(_spinner.extract_result(d), Equals(marker))
80 def test_failure(self):
81 # _spinner.extract_result raises the failure's exception if it's given
82 # a Deferred that is failing.
83 try:
84 1/0
85 except ZeroDivisionError:
86 f = Failure()
87 d = defer.fail(f)
88 self.assertThat(lambda:_spinner.extract_result(d),
89 Raises(MatchesException(ZeroDivisionError)))
92 class TestTrapUnhandledErrors(NeedsTwistedTestCase):
94 def test_no_deferreds(self):
95 marker = object()
96 result, errors = _spinner.trap_unhandled_errors(lambda: marker)
97 self.assertEqual([], errors)
98 self.assertIs(marker, result)
100 def test_unhandled_error(self):
101 failures = []
102 def make_deferred_but_dont_handle():
103 try:
105 except ZeroDivisionError:
106 f = Failure()
107 failures.append(f)
108 defer.fail(f)
109 result, errors = _spinner.trap_unhandled_errors(
110 make_deferred_but_dont_handle)
111 self.assertIs(None, result)
112 self.assertEqual(failures, [error.failResult for error in errors])
115 class TestRunInReactor(NeedsTwistedTestCase):
117 def make_reactor(self):
118 from twisted.internet import reactor
119 return reactor
121 def make_spinner(self, reactor=None):
122 if reactor is None:
123 reactor = self.make_reactor()
124 return _spinner.Spinner(reactor)
126 def make_timeout(self):
127 return 0.01
129 def test_function_called(self):
130 # run_in_reactor actually calls the function given to it.
131 calls = []
132 marker = object()
133 self.make_spinner().run(self.make_timeout(), calls.append, marker)
134 self.assertThat(calls, Equals([marker]))
136 def test_return_value_returned(self):
137 # run_in_reactor returns the value returned by the function given to
138 # it.
139 marker = object()
140 result = self.make_spinner().run(self.make_timeout(), lambda: marker)
141 self.assertThat(result, Is(marker))
143 def test_exception_reraised(self):
144 # If the given function raises an error, run_in_reactor re-raises that
145 # error.
146 self.assertThat(
147 lambda:self.make_spinner().run(self.make_timeout(), lambda: 1/0),
148 Raises(MatchesException(ZeroDivisionError)))
150 def test_keyword_arguments(self):
151 # run_in_reactor passes keyword arguments on.
152 calls = []
153 function = lambda *a, **kw: calls.extend([a, kw])
154 self.make_spinner().run(self.make_timeout(), function, foo=42)
155 self.assertThat(calls, Equals([(), {'foo': 42}]))
157 def test_not_reentrant(self):
158 # run_in_reactor raises an error if it is called inside another call
159 # to run_in_reactor.
160 spinner = self.make_spinner()
161 self.assertThat(lambda: spinner.run(
162 self.make_timeout(), spinner.run, self.make_timeout(),
163 lambda: None), Raises(MatchesException(_spinner.ReentryError)))
165 def test_deferred_value_returned(self):
166 # If the given function returns a Deferred, run_in_reactor returns the
167 # value in the Deferred at the end of the callback chain.
168 marker = object()
169 result = self.make_spinner().run(
170 self.make_timeout(), lambda: defer.succeed(marker))
171 self.assertThat(result, Is(marker))
173 def test_preserve_signal_handler(self):
174 signals = ['SIGINT', 'SIGTERM', 'SIGCHLD']
175 signals = filter(
176 None, (getattr(signal, name, None) for name in signals))
177 for sig in signals:
178 self.addCleanup(signal.signal, sig, signal.getsignal(sig))
179 new_hdlrs = list(lambda *a: None for _ in signals)
180 for sig, hdlr in zip(signals, new_hdlrs):
181 signal.signal(sig, hdlr)
182 spinner = self.make_spinner()
183 spinner.run(self.make_timeout(), lambda: None)
184 self.assertEqual(new_hdlrs, map(signal.getsignal, signals))
186 def test_timeout(self):
187 # If the function takes too long to run, we raise a
188 # _spinner.TimeoutError.
189 timeout = self.make_timeout()
190 self.assertThat(
191 lambda:self.make_spinner().run(timeout, lambda: defer.Deferred()),
192 Raises(MatchesException(_spinner.TimeoutError)))
194 def test_no_junk_by_default(self):
195 # If the reactor hasn't spun yet, then there cannot be any junk.
196 spinner = self.make_spinner()
197 self.assertThat(spinner.get_junk(), Equals([]))
199 def test_clean_do_nothing(self):
200 # If there's nothing going on in the reactor, then clean does nothing
201 # and returns an empty list.
202 spinner = self.make_spinner()
203 result = spinner._clean()
204 self.assertThat(result, Equals([]))
206 def test_clean_delayed_call(self):
207 # If there's a delayed call in the reactor, then clean cancels it and
208 # returns an empty list.
209 reactor = self.make_reactor()
210 spinner = self.make_spinner(reactor)
211 call = reactor.callLater(10, lambda: None)
212 results = spinner._clean()
213 self.assertThat(results, Equals([call]))
214 self.assertThat(call.active(), Equals(False))
216 def test_clean_delayed_call_cancelled(self):
217 # If there's a delayed call that's just been cancelled, then it's no
218 # longer there.
219 reactor = self.make_reactor()
220 spinner = self.make_spinner(reactor)
221 call = reactor.callLater(10, lambda: None)
222 call.cancel()
223 results = spinner._clean()
224 self.assertThat(results, Equals([]))
226 def test_clean_selectables(self):
227 # If there's still a selectable (e.g. a listening socket), then
228 # clean() removes it from the reactor's registry.
230 # Note that the socket is left open. This emulates a bug in trial.
231 from twisted.internet.protocol import ServerFactory
232 reactor = self.make_reactor()
233 spinner = self.make_spinner(reactor)
234 port = reactor.listenTCP(0, ServerFactory())
235 spinner.run(self.make_timeout(), lambda: None)
236 results = spinner.get_junk()
237 self.assertThat(results, Equals([port]))
239 def test_clean_running_threads(self):
240 import threading
241 import time
242 current_threads = list(threading.enumerate())
243 reactor = self.make_reactor()
244 timeout = self.make_timeout()
245 spinner = self.make_spinner(reactor)
246 spinner.run(timeout, reactor.callInThread, time.sleep, timeout / 2.0)
247 # Python before 2.5 has a race condition with thread handling where
248 # join() does not remove threads from enumerate before returning - the
249 # thread being joined does the removal. This was fixed in Python 2.5
250 # but we still support 2.4, so we have to workaround the issue.
251 # http://bugs.python.org/issue1703448.
252 self.assertThat(
253 [thread for thread in threading.enumerate() if thread.isAlive()],
254 Equals(current_threads))
256 def test_leftover_junk_available(self):
257 # If 'run' is given a function that leaves the reactor dirty in some
258 # way, 'run' will clean up the reactor and then store information
259 # about the junk. This information can be got using get_junk.
260 from twisted.internet.protocol import ServerFactory
261 reactor = self.make_reactor()
262 spinner = self.make_spinner(reactor)
263 port = spinner.run(
264 self.make_timeout(), reactor.listenTCP, 0, ServerFactory())
265 self.assertThat(spinner.get_junk(), Equals([port]))
267 def test_will_not_run_with_previous_junk(self):
268 # If 'run' is called and there's still junk in the spinner's junk
269 # list, then the spinner will refuse to run.
270 from twisted.internet.protocol import ServerFactory
271 reactor = self.make_reactor()
272 spinner = self.make_spinner(reactor)
273 timeout = self.make_timeout()
274 spinner.run(timeout, reactor.listenTCP, 0, ServerFactory())
275 self.assertThat(lambda: spinner.run(timeout, lambda: None),
276 Raises(MatchesException(_spinner.StaleJunkError)))
278 def test_clear_junk_clears_previous_junk(self):
279 # If 'run' is called and there's still junk in the spinner's junk
280 # list, then the spinner will refuse to run.
281 from twisted.internet.protocol import ServerFactory
282 reactor = self.make_reactor()
283 spinner = self.make_spinner(reactor)
284 timeout = self.make_timeout()
285 port = spinner.run(timeout, reactor.listenTCP, 0, ServerFactory())
286 junk = spinner.clear_junk()
287 self.assertThat(junk, Equals([port]))
288 self.assertThat(spinner.get_junk(), Equals([]))
290 @skipIf(os.name != "posix", "Sending SIGINT with os.kill is posix only")
291 def test_sigint_raises_no_result_error(self):
292 # If we get a SIGINT during a run, we raise _spinner.NoResultError.
293 SIGINT = getattr(signal, 'SIGINT', None)
294 if not SIGINT:
295 self.skipTest("SIGINT not available")
296 reactor = self.make_reactor()
297 spinner = self.make_spinner(reactor)
298 timeout = self.make_timeout()
299 reactor.callLater(timeout, os.kill, os.getpid(), SIGINT)
300 self.assertThat(lambda:spinner.run(timeout * 5, defer.Deferred),
301 Raises(MatchesException(_spinner.NoResultError)))
302 self.assertEqual([], spinner._clean())
304 @skipIf(os.name != "posix", "Sending SIGINT with os.kill is posix only")
305 def test_sigint_raises_no_result_error_second_time(self):
306 # If we get a SIGINT during a run, we raise _spinner.NoResultError.
307 # This test is exactly the same as test_sigint_raises_no_result_error,
308 # and exists to make sure we haven't futzed with state.
309 self.test_sigint_raises_no_result_error()
311 @skipIf(os.name != "posix", "Sending SIGINT with os.kill is posix only")
312 def test_fast_sigint_raises_no_result_error(self):
313 # If we get a SIGINT during a run, we raise _spinner.NoResultError.
314 SIGINT = getattr(signal, 'SIGINT', None)
315 if not SIGINT:
316 self.skipTest("SIGINT not available")
317 reactor = self.make_reactor()
318 spinner = self.make_spinner(reactor)
319 timeout = self.make_timeout()
320 reactor.callWhenRunning(os.kill, os.getpid(), SIGINT)
321 self.assertThat(lambda:spinner.run(timeout * 5, defer.Deferred),
322 Raises(MatchesException(_spinner.NoResultError)))
323 self.assertEqual([], spinner._clean())
325 @skipIf(os.name != "posix", "Sending SIGINT with os.kill is posix only")
326 def test_fast_sigint_raises_no_result_error_second_time(self):
327 self.test_fast_sigint_raises_no_result_error()
330 def test_suite():
331 from unittest import TestLoader
332 return TestLoader().loadTestsFromName(__name__)