2 from contextlib
import contextmanager
5 from tempfile
import TemporaryDirectory
9 from qemu
.aqmp
import ConnectError
, Runstate
10 from qemu
.aqmp
.protocol
import AsyncProtocol
, StateError
11 from qemu
.aqmp
.util
import asyncio_run
, create_task
14 class NullProtocol(AsyncProtocol
[None]):
16 NullProtocol is a test mockup of an AsyncProtocol implementation.
18 It adds a fake_session instance variable that enables a code path
19 that bypasses the actual connection logic, but still allows the
20 reader/writers to start.
22 Because the message type is defined as None, an asyncio.Event named
23 'trigger_input' is created that prohibits the reader from
24 incessantly being able to yield None; this event can be poked to
25 simulate an incoming message.
27 For testing symmetry with do_recv, an interface is added to "send" a
30 For testing purposes, a "simulate_disconnection" method is also
31 added which allows us to trigger a bottom half disconnect without
32 injecting any real errors into the reader/writer loops; in essence
33 it performs exactly half of what disconnect() normally does.
35 def __init__(self
, name
=None):
36 self
.fake_session
= False
37 self
.trigger_input
: asyncio
.Event
38 super().__init
__(name
)
40 async def _establish_session(self
):
41 self
.trigger_input
= asyncio
.Event()
42 await super()._establish
_session
()
44 async def _do_accept(self
, address
, ssl
=None):
45 if not self
.fake_session
:
46 await super()._do
_accept
(address
, ssl
)
48 async def _do_connect(self
, address
, ssl
=None):
49 if not self
.fake_session
:
50 await super()._do
_connect
(address
, ssl
)
52 async def _do_recv(self
) -> None:
53 await self
.trigger_input
.wait()
54 self
.trigger_input
.clear()
56 def _do_send(self
, msg
: None) -> None:
59 async def send_msg(self
) -> None:
60 await self
._outgoing
.put(None)
62 async def simulate_disconnect(self
) -> None:
64 Simulates a bottom-half disconnect.
66 This method schedules a disconnection but does not wait for it
67 to complete. This is used to put the loop into the DISCONNECTING
68 state without fully quiescing it back to IDLE. This is normally
69 something you cannot coax AsyncProtocol to do on purpose, but it
70 will be similar to what happens with an unhandled Exception in
73 Under normal circumstances, the library design requires you to
74 await on disconnect(), which awaits the disconnect task and
75 returns bottom half errors as a pre-condition to allowing the
76 loop to return back to IDLE.
78 self
._schedule
_disconnect
()
81 class LineProtocol(AsyncProtocol
[str]):
82 def __init__(self
, name
=None):
83 super().__init
__(name
)
86 async def _do_recv(self
) -> str:
87 raw
= await self
._readline
()
89 self
.rx_history
.append(msg
)
92 def _do_send(self
, msg
: str) -> None:
93 assert self
._writer
is not None
94 self
._writer
.write(msg
.encode() + b
'\n')
96 async def send_msg(self
, msg
: str) -> None:
97 await self
._outgoing
.put(msg
)
100 def run_as_task(coro
, allow_cancellation
=False):
102 Run a given coroutine as a task.
104 Optionally, wrap it in a try..except block that allows this
105 coroutine to be canceled gracefully.
110 except asyncio
.CancelledError
:
111 if allow_cancellation
:
114 return create_task(_runner())
120 Opens up a random unused TCP port on localhost, then jams it.
125 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
126 sock
.setsockopt(socket
.SOL_SOCKET
, socket
.SO_REUSEADDR
, 1)
127 sock
.bind(('127.0.0.1', 0))
129 address
= sock
.getsockname()
133 # I don't *fully* understand why, but it takes *two* un-accepted
134 # connections to start jamming the socket.
136 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
137 sock
.connect(address
)
147 class Smoke(avocado
.Test
):
150 self
.proto
= NullProtocol()
152 def test__repr__(self
):
155 "<NullProtocol runstate=IDLE>"
158 def testRunstate(self
):
164 def testDefaultName(self
):
170 def testLogger(self
):
172 self
.proto
.logger
.name
,
177 self
.proto
= NullProtocol('Steve')
185 self
.proto
.logger
.name
,
186 'qemu.aqmp.protocol.Steve'
191 "<NullProtocol name='Steve' runstate=IDLE>"
195 class TestBase(avocado
.Test
):
198 self
.proto
= NullProtocol(type(self
).__name
__)
199 self
.assertEqual(self
.proto
.runstate
, Runstate
.IDLE
)
200 self
.runstate_watcher
= None
203 self
.assertEqual(self
.proto
.runstate
, Runstate
.IDLE
)
205 async def _asyncSetUp(self
):
208 async def _asyncTearDown(self
):
209 if self
.runstate_watcher
:
210 await self
.runstate_watcher
213 def async_test(async_test_method
):
215 Decorator; adds SetUp and TearDown to async tests.
217 async def _wrapper(self
, *args
, **kwargs
):
218 loop
= asyncio
.get_event_loop()
221 await self
._asyncSetUp
()
222 await async_test_method(self
, *args
, **kwargs
)
223 await self
._asyncTearDown
()
229 # The states we expect a "bad" connect/accept attempt to transition through
230 BAD_CONNECTION_STATES
= (
232 Runstate
.DISCONNECTING
,
236 # The states we expect a "good" session to transition through
237 GOOD_CONNECTION_STATES
= (
240 Runstate
.DISCONNECTING
,
246 async def _watch_runstates(self
, *states
):
248 This launches a task alongside (most) tests below to confirm that
249 the sequence of runstate changes that occur is exactly as
252 async def _watcher():
254 new_state
= await self
.proto
.runstate_changed()
258 msg
=f
"Expected state '{state.name}'",
261 self
.runstate_watcher
= create_task(_watcher())
262 # Kick the loop and force the task to block on the event.
263 await asyncio
.sleep(0)
266 class State(TestBase
):
269 async def testSuperfluousDisconnect(self
):
271 Test calling disconnect() while already disconnected.
273 await self
._watch
_runstates
(
274 Runstate
.DISCONNECTING
,
277 await self
.proto
.disconnect()
280 class Connect(TestBase
):
282 Tests primarily related to calling Connect().
284 async def _bad_connection(self
, family
: str):
285 assert family
in ('INET', 'UNIX')
288 await self
.proto
.connect(('127.0.0.1', 0))
289 elif family
== 'UNIX':
290 await self
.proto
.connect('/dev/null')
292 async def _hanging_connection(self
):
293 with
jammed_socket() as addr
:
294 await self
.proto
.connect(addr
)
296 async def _bad_connection_test(self
, family
: str):
297 await self
._watch
_runstates
(*self
.BAD_CONNECTION_STATES
)
299 with self
.assertRaises(ConnectError
) as context
:
300 await self
._bad
_connection
(family
)
302 self
.assertIsInstance(context
.exception
.exc
, OSError)
304 context
.exception
.error_message
,
305 "Failed to establish connection"
309 async def testBadINET(self
):
311 Test an immediately rejected call to an IP target.
313 await self
._bad
_connection
_test
('INET')
316 async def testBadUNIX(self
):
318 Test an immediately rejected call to a UNIX socket target.
320 await self
._bad
_connection
_test
('UNIX')
323 async def testCancellation(self
):
325 Test what happens when a connection attempt is aborted.
327 # Note that accept() cannot be cancelled outright, as it isn't a task.
328 # However, we can wrap it in a task and cancel *that*.
329 await self
._watch
_runstates
(*self
.BAD_CONNECTION_STATES
)
330 task
= run_as_task(self
._hanging
_connection
(), allow_cancellation
=True)
332 state
= await self
.proto
.runstate_changed()
333 self
.assertEqual(state
, Runstate
.CONNECTING
)
335 # This is insider baseball, but the connection attempt has
336 # yielded *just* before the actual connection attempt, so kick
337 # the loop to make sure it's truly wedged.
338 await asyncio
.sleep(0)
344 async def testTimeout(self
):
346 Test what happens when a connection attempt times out.
348 await self
._watch
_runstates
(*self
.BAD_CONNECTION_STATES
)
349 task
= run_as_task(self
._hanging
_connection
())
351 # More insider baseball: to improve the speed of this test while
352 # guaranteeing that the connection even gets a chance to start,
353 # verify that the connection hangs *first*, then await the
354 # result of the task with a nearly-zero timeout.
356 state
= await self
.proto
.runstate_changed()
357 self
.assertEqual(state
, Runstate
.CONNECTING
)
358 await asyncio
.sleep(0)
360 with self
.assertRaises(asyncio
.TimeoutError
):
361 await asyncio
.wait_for(task
, timeout
=0)
364 async def testRequire(self
):
366 Test what happens when a connection attempt is made while CONNECTING.
368 await self
._watch
_runstates
(*self
.BAD_CONNECTION_STATES
)
369 task
= run_as_task(self
._hanging
_connection
(), allow_cancellation
=True)
371 state
= await self
.proto
.runstate_changed()
372 self
.assertEqual(state
, Runstate
.CONNECTING
)
374 with self
.assertRaises(StateError
) as context
:
375 await self
._bad
_connection
('UNIX')
378 context
.exception
.error_message
,
379 "NullProtocol is currently connecting."
381 self
.assertEqual(context
.exception
.state
, Runstate
.CONNECTING
)
382 self
.assertEqual(context
.exception
.required
, Runstate
.IDLE
)
388 async def testImplicitRunstateInit(self
):
390 Test what happens if we do not wait on the runstate event until
391 AFTER a connection is made, i.e., connect()/accept() themselves
392 initialize the runstate event. All of the above tests force the
393 initialization by waiting on the runstate *first*.
395 task
= run_as_task(self
._hanging
_connection
(), allow_cancellation
=True)
397 # Kick the loop to coerce the state change
398 await asyncio
.sleep(0)
399 assert self
.proto
.runstate
== Runstate
.CONNECTING
401 # We already missed the transition to CONNECTING
402 await self
._watch
_runstates
(Runstate
.DISCONNECTING
, Runstate
.IDLE
)
408 class Accept(Connect
):
410 All of the same tests as Connect, but using the accept() interface.
412 async def _bad_connection(self
, family
: str):
413 assert family
in ('INET', 'UNIX')
416 await self
.proto
.accept(('example.com', 1))
417 elif family
== 'UNIX':
418 await self
.proto
.accept('/dev/null')
420 async def _hanging_connection(self
):
421 with
TemporaryDirectory(suffix
='.aqmp') as tmpdir
:
422 sock
= os
.path
.join(tmpdir
, type(self
.proto
).__name
__ + ".sock")
423 await self
.proto
.accept(sock
)
426 class FakeSession(TestBase
):
430 self
.proto
.fake_session
= True
432 async def _asyncSetUp(self
):
433 await super()._asyncSetUp
()
434 await self
._watch
_runstates
(*self
.GOOD_CONNECTION_STATES
)
436 async def _asyncTearDown(self
):
437 await self
.proto
.disconnect()
438 await super()._asyncTearDown
()
443 async def testFakeConnect(self
):
445 """Test the full state lifecycle (via connect) with a no-op session."""
446 await self
.proto
.connect('/not/a/real/path')
447 self
.assertEqual(self
.proto
.runstate
, Runstate
.RUNNING
)
450 async def testFakeAccept(self
):
451 """Test the full state lifecycle (via accept) with a no-op session."""
452 await self
.proto
.accept('/not/a/real/path')
453 self
.assertEqual(self
.proto
.runstate
, Runstate
.RUNNING
)
456 async def testFakeRecv(self
):
457 """Test receiving a fake/null message."""
458 await self
.proto
.accept('/not/a/real/path')
460 logname
= self
.proto
.logger
.name
461 with self
.assertLogs(logname
, level
='DEBUG') as context
:
462 self
.proto
.trigger_input
.set()
463 self
.proto
.trigger_input
.clear()
464 await asyncio
.sleep(0) # Kick reader.
468 [f
"DEBUG:{logname}:<-- None"],
472 async def testFakeSend(self
):
473 """Test sending a fake/null message."""
474 await self
.proto
.accept('/not/a/real/path')
476 logname
= self
.proto
.logger
.name
477 with self
.assertLogs(logname
, level
='DEBUG') as context
:
478 # Cheat: Send a Null message to nobody.
479 await self
.proto
.send_msg()
480 # Kick writer; awaiting on a queue.put isn't sufficient to yield.
481 await asyncio
.sleep(0)
485 [f
"DEBUG:{logname}:--> None"],
488 async def _prod_session_api(
490 current_state
: Runstate
,
494 with self
.assertRaises(StateError
) as context
:
496 await self
.proto
.accept('/not/a/real/path')
498 await self
.proto
.connect('/not/a/real/path')
500 self
.assertEqual(context
.exception
.error_message
, error_message
)
501 self
.assertEqual(context
.exception
.state
, current_state
)
502 self
.assertEqual(context
.exception
.required
, Runstate
.IDLE
)
505 async def testAcceptRequireRunning(self
):
506 """Test that accept() cannot be called when Runstate=RUNNING"""
507 await self
.proto
.accept('/not/a/real/path')
509 await self
._prod
_session
_api
(
511 "NullProtocol is already connected and running.",
516 async def testConnectRequireRunning(self
):
517 """Test that connect() cannot be called when Runstate=RUNNING"""
518 await self
.proto
.accept('/not/a/real/path')
520 await self
._prod
_session
_api
(
522 "NullProtocol is already connected and running.",
527 async def testAcceptRequireDisconnecting(self
):
528 """Test that accept() cannot be called when Runstate=DISCONNECTING"""
529 await self
.proto
.accept('/not/a/real/path')
531 # Cheat: force a disconnect.
532 await self
.proto
.simulate_disconnect()
534 await self
._prod
_session
_api
(
535 Runstate
.DISCONNECTING
,
536 ("NullProtocol is disconnecting."
537 " Call disconnect() to return to IDLE state."),
542 async def testConnectRequireDisconnecting(self
):
543 """Test that connect() cannot be called when Runstate=DISCONNECTING"""
544 await self
.proto
.accept('/not/a/real/path')
546 # Cheat: force a disconnect.
547 await self
.proto
.simulate_disconnect()
549 await self
._prod
_session
_api
(
550 Runstate
.DISCONNECTING
,
551 ("NullProtocol is disconnecting."
552 " Call disconnect() to return to IDLE state."),
557 class SimpleSession(TestBase
):
561 self
.server
= LineProtocol(type(self
).__name
__ + '-server')
563 async def _asyncSetUp(self
):
564 await super()._asyncSetUp
()
565 await self
._watch
_runstates
(*self
.GOOD_CONNECTION_STATES
)
567 async def _asyncTearDown(self
):
568 await self
.proto
.disconnect()
570 await self
.server
.disconnect()
573 await super()._asyncTearDown
()
576 async def testSmoke(self
):
577 with
TemporaryDirectory(suffix
='.aqmp') as tmpdir
:
578 sock
= os
.path
.join(tmpdir
, type(self
.proto
).__name
__ + ".sock")
579 server_task
= create_task(self
.server
.accept(sock
))
581 # give the server a chance to start listening [...]
582 await asyncio
.sleep(0)
583 await self
.proto
.connect(sock
)