2 from contextlib
import contextmanager
5 from tempfile
import TemporaryDirectory
9 from qemu
.qmp
import ConnectError
, Runstate
10 from qemu
.qmp
.protocol
import AsyncProtocol
, StateError
11 from qemu
.qmp
.util
import asyncio_run
, create_task
14 class NullProtocol(AsyncProtocol
[None]):
16 NullProtocol is a test mockup of an AsyncProtocol implementation.
18 It adds a fake_session instance variable that enables a code path
19 that bypasses the actual connection logic, but still allows the
20 reader/writers to start.
22 Because the message type is defined as None, an asyncio.Event named
23 'trigger_input' is created that prohibits the reader from
24 incessantly being able to yield None; this event can be poked to
25 simulate an incoming message.
27 For testing symmetry with do_recv, an interface is added to "send" a
30 For testing purposes, a "simulate_disconnection" method is also
31 added which allows us to trigger a bottom half disconnect without
32 injecting any real errors into the reader/writer loops; in essence
33 it performs exactly half of what disconnect() normally does.
35 def __init__(self
, name
=None):
36 self
.fake_session
= False
37 self
.trigger_input
: asyncio
.Event
38 super().__init
__(name
)
40 async def _establish_session(self
):
41 self
.trigger_input
= asyncio
.Event()
42 await super()._establish
_session
()
44 async def _do_start_server(self
, address
, ssl
=None):
46 self
._accepted
= asyncio
.Event()
47 self
._set
_state
(Runstate
.CONNECTING
)
48 await asyncio
.sleep(0)
50 await super()._do
_start
_server
(address
, ssl
)
52 async def _do_accept(self
):
56 await super()._do
_accept
()
58 async def _do_connect(self
, address
, ssl
=None):
60 self
._set
_state
(Runstate
.CONNECTING
)
61 await asyncio
.sleep(0)
63 await super()._do
_connect
(address
, ssl
)
65 async def _do_recv(self
) -> None:
66 await self
.trigger_input
.wait()
67 self
.trigger_input
.clear()
69 def _do_send(self
, msg
: None) -> None:
72 async def send_msg(self
) -> None:
73 await self
._outgoing
.put(None)
75 async def simulate_disconnect(self
) -> None:
77 Simulates a bottom-half disconnect.
79 This method schedules a disconnection but does not wait for it
80 to complete. This is used to put the loop into the DISCONNECTING
81 state without fully quiescing it back to IDLE. This is normally
82 something you cannot coax AsyncProtocol to do on purpose, but it
83 will be similar to what happens with an unhandled Exception in
86 Under normal circumstances, the library design requires you to
87 await on disconnect(), which awaits the disconnect task and
88 returns bottom half errors as a pre-condition to allowing the
89 loop to return back to IDLE.
91 self
._schedule
_disconnect
()
94 class LineProtocol(AsyncProtocol
[str]):
95 def __init__(self
, name
=None):
96 super().__init
__(name
)
99 async def _do_recv(self
) -> str:
100 raw
= await self
._readline
()
102 self
.rx_history
.append(msg
)
105 def _do_send(self
, msg
: str) -> None:
106 assert self
._writer
is not None
107 self
._writer
.write(msg
.encode() + b
'\n')
109 async def send_msg(self
, msg
: str) -> None:
110 await self
._outgoing
.put(msg
)
113 def run_as_task(coro
, allow_cancellation
=False):
115 Run a given coroutine as a task.
117 Optionally, wrap it in a try..except block that allows this
118 coroutine to be canceled gracefully.
123 except asyncio
.CancelledError
:
124 if allow_cancellation
:
127 return create_task(_runner())
133 Opens up a random unused TCP port on localhost, then jams it.
138 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
139 sock
.setsockopt(socket
.SOL_SOCKET
, socket
.SO_REUSEADDR
, 1)
140 sock
.bind(('127.0.0.1', 0))
142 address
= sock
.getsockname()
146 # I don't *fully* understand why, but it takes *two* un-accepted
147 # connections to start jamming the socket.
149 sock
= socket
.socket(socket
.AF_INET
, socket
.SOCK_STREAM
)
150 sock
.connect(address
)
160 class Smoke(avocado
.Test
):
163 self
.proto
= NullProtocol()
165 def test__repr__(self
):
168 "<NullProtocol runstate=IDLE>"
171 def testRunstate(self
):
177 def testDefaultName(self
):
183 def testLogger(self
):
185 self
.proto
.logger
.name
,
190 self
.proto
= NullProtocol('Steve')
198 self
.proto
.logger
.name
,
199 'qemu.qmp.protocol.Steve'
204 "<NullProtocol name='Steve' runstate=IDLE>"
208 class TestBase(avocado
.Test
):
211 self
.proto
= NullProtocol(type(self
).__name
__)
212 self
.assertEqual(self
.proto
.runstate
, Runstate
.IDLE
)
213 self
.runstate_watcher
= None
216 self
.assertEqual(self
.proto
.runstate
, Runstate
.IDLE
)
218 async def _asyncSetUp(self
):
221 async def _asyncTearDown(self
):
222 if self
.runstate_watcher
:
223 await self
.runstate_watcher
226 def async_test(async_test_method
):
228 Decorator; adds SetUp and TearDown to async tests.
230 async def _wrapper(self
, *args
, **kwargs
):
231 loop
= asyncio
.get_event_loop()
234 await self
._asyncSetUp
()
235 await async_test_method(self
, *args
, **kwargs
)
236 await self
._asyncTearDown
()
242 # The states we expect a "bad" connect/accept attempt to transition through
243 BAD_CONNECTION_STATES
= (
245 Runstate
.DISCONNECTING
,
249 # The states we expect a "good" session to transition through
250 GOOD_CONNECTION_STATES
= (
253 Runstate
.DISCONNECTING
,
259 async def _watch_runstates(self
, *states
):
261 This launches a task alongside (most) tests below to confirm that
262 the sequence of runstate changes that occur is exactly as
265 async def _watcher():
267 new_state
= await self
.proto
.runstate_changed()
271 msg
=f
"Expected state '{state.name}'",
274 self
.runstate_watcher
= create_task(_watcher())
275 # Kick the loop and force the task to block on the event.
276 await asyncio
.sleep(0)
279 class State(TestBase
):
282 async def testSuperfluousDisconnect(self
):
284 Test calling disconnect() while already disconnected.
286 await self
._watch
_runstates
(
287 Runstate
.DISCONNECTING
,
290 await self
.proto
.disconnect()
293 class Connect(TestBase
):
295 Tests primarily related to calling Connect().
297 async def _bad_connection(self
, family
: str):
298 assert family
in ('INET', 'UNIX')
301 await self
.proto
.connect(('127.0.0.1', 0))
302 elif family
== 'UNIX':
303 await self
.proto
.connect('/dev/null')
305 async def _hanging_connection(self
):
306 with
jammed_socket() as addr
:
307 await self
.proto
.connect(addr
)
309 async def _bad_connection_test(self
, family
: str):
310 await self
._watch
_runstates
(*self
.BAD_CONNECTION_STATES
)
312 with self
.assertRaises(ConnectError
) as context
:
313 await self
._bad
_connection
(family
)
315 self
.assertIsInstance(context
.exception
.exc
, OSError)
317 context
.exception
.error_message
,
318 "Failed to establish connection"
322 async def testBadINET(self
):
324 Test an immediately rejected call to an IP target.
326 await self
._bad
_connection
_test
('INET')
329 async def testBadUNIX(self
):
331 Test an immediately rejected call to a UNIX socket target.
333 await self
._bad
_connection
_test
('UNIX')
336 async def testCancellation(self
):
338 Test what happens when a connection attempt is aborted.
340 # Note that accept() cannot be cancelled outright, as it isn't a task.
341 # However, we can wrap it in a task and cancel *that*.
342 await self
._watch
_runstates
(*self
.BAD_CONNECTION_STATES
)
343 task
= run_as_task(self
._hanging
_connection
(), allow_cancellation
=True)
345 state
= await self
.proto
.runstate_changed()
346 self
.assertEqual(state
, Runstate
.CONNECTING
)
348 # This is insider baseball, but the connection attempt has
349 # yielded *just* before the actual connection attempt, so kick
350 # the loop to make sure it's truly wedged.
351 await asyncio
.sleep(0)
357 async def testTimeout(self
):
359 Test what happens when a connection attempt times out.
361 await self
._watch
_runstates
(*self
.BAD_CONNECTION_STATES
)
362 task
= run_as_task(self
._hanging
_connection
())
364 # More insider baseball: to improve the speed of this test while
365 # guaranteeing that the connection even gets a chance to start,
366 # verify that the connection hangs *first*, then await the
367 # result of the task with a nearly-zero timeout.
369 state
= await self
.proto
.runstate_changed()
370 self
.assertEqual(state
, Runstate
.CONNECTING
)
371 await asyncio
.sleep(0)
373 with self
.assertRaises(asyncio
.TimeoutError
):
374 await asyncio
.wait_for(task
, timeout
=0)
377 async def testRequire(self
):
379 Test what happens when a connection attempt is made while CONNECTING.
381 await self
._watch
_runstates
(*self
.BAD_CONNECTION_STATES
)
382 task
= run_as_task(self
._hanging
_connection
(), allow_cancellation
=True)
384 state
= await self
.proto
.runstate_changed()
385 self
.assertEqual(state
, Runstate
.CONNECTING
)
387 with self
.assertRaises(StateError
) as context
:
388 await self
._bad
_connection
('UNIX')
391 context
.exception
.error_message
,
392 "NullProtocol is currently connecting."
394 self
.assertEqual(context
.exception
.state
, Runstate
.CONNECTING
)
395 self
.assertEqual(context
.exception
.required
, Runstate
.IDLE
)
401 async def testImplicitRunstateInit(self
):
403 Test what happens if we do not wait on the runstate event until
404 AFTER a connection is made, i.e., connect()/accept() themselves
405 initialize the runstate event. All of the above tests force the
406 initialization by waiting on the runstate *first*.
408 task
= run_as_task(self
._hanging
_connection
(), allow_cancellation
=True)
410 # Kick the loop to coerce the state change
411 await asyncio
.sleep(0)
412 assert self
.proto
.runstate
== Runstate
.CONNECTING
414 # We already missed the transition to CONNECTING
415 await self
._watch
_runstates
(Runstate
.DISCONNECTING
, Runstate
.IDLE
)
421 class Accept(Connect
):
423 All of the same tests as Connect, but using the accept() interface.
425 async def _bad_connection(self
, family
: str):
426 assert family
in ('INET', 'UNIX')
429 await self
.proto
.start_server_and_accept(('example.com', 1))
430 elif family
== 'UNIX':
431 await self
.proto
.start_server_and_accept('/dev/null')
433 async def _hanging_connection(self
):
434 with
TemporaryDirectory(suffix
='.qmp') as tmpdir
:
435 sock
= os
.path
.join(tmpdir
, type(self
.proto
).__name
__ + ".sock")
436 await self
.proto
.start_server_and_accept(sock
)
439 class FakeSession(TestBase
):
443 self
.proto
.fake_session
= True
445 async def _asyncSetUp(self
):
446 await super()._asyncSetUp
()
447 await self
._watch
_runstates
(*self
.GOOD_CONNECTION_STATES
)
449 async def _asyncTearDown(self
):
450 await self
.proto
.disconnect()
451 await super()._asyncTearDown
()
456 async def testFakeConnect(self
):
458 """Test the full state lifecycle (via connect) with a no-op session."""
459 await self
.proto
.connect('/not/a/real/path')
460 self
.assertEqual(self
.proto
.runstate
, Runstate
.RUNNING
)
463 async def testFakeAccept(self
):
464 """Test the full state lifecycle (via accept) with a no-op session."""
465 await self
.proto
.start_server_and_accept('/not/a/real/path')
466 self
.assertEqual(self
.proto
.runstate
, Runstate
.RUNNING
)
469 async def testFakeRecv(self
):
470 """Test receiving a fake/null message."""
471 await self
.proto
.start_server_and_accept('/not/a/real/path')
473 logname
= self
.proto
.logger
.name
474 with self
.assertLogs(logname
, level
='DEBUG') as context
:
475 self
.proto
.trigger_input
.set()
476 self
.proto
.trigger_input
.clear()
477 await asyncio
.sleep(0) # Kick reader.
481 [f
"DEBUG:{logname}:<-- None"],
485 async def testFakeSend(self
):
486 """Test sending a fake/null message."""
487 await self
.proto
.start_server_and_accept('/not/a/real/path')
489 logname
= self
.proto
.logger
.name
490 with self
.assertLogs(logname
, level
='DEBUG') as context
:
491 # Cheat: Send a Null message to nobody.
492 await self
.proto
.send_msg()
493 # Kick writer; awaiting on a queue.put isn't sufficient to yield.
494 await asyncio
.sleep(0)
498 [f
"DEBUG:{logname}:--> None"],
501 async def _prod_session_api(
503 current_state
: Runstate
,
507 with self
.assertRaises(StateError
) as context
:
509 await self
.proto
.start_server_and_accept('/not/a/real/path')
511 await self
.proto
.connect('/not/a/real/path')
513 self
.assertEqual(context
.exception
.error_message
, error_message
)
514 self
.assertEqual(context
.exception
.state
, current_state
)
515 self
.assertEqual(context
.exception
.required
, Runstate
.IDLE
)
518 async def testAcceptRequireRunning(self
):
519 """Test that accept() cannot be called when Runstate=RUNNING"""
520 await self
.proto
.start_server_and_accept('/not/a/real/path')
522 await self
._prod
_session
_api
(
524 "NullProtocol is already connected and running.",
529 async def testConnectRequireRunning(self
):
530 """Test that connect() cannot be called when Runstate=RUNNING"""
531 await self
.proto
.start_server_and_accept('/not/a/real/path')
533 await self
._prod
_session
_api
(
535 "NullProtocol is already connected and running.",
540 async def testAcceptRequireDisconnecting(self
):
541 """Test that accept() cannot be called when Runstate=DISCONNECTING"""
542 await self
.proto
.start_server_and_accept('/not/a/real/path')
544 # Cheat: force a disconnect.
545 await self
.proto
.simulate_disconnect()
547 await self
._prod
_session
_api
(
548 Runstate
.DISCONNECTING
,
549 ("NullProtocol is disconnecting."
550 " Call disconnect() to return to IDLE state."),
555 async def testConnectRequireDisconnecting(self
):
556 """Test that connect() cannot be called when Runstate=DISCONNECTING"""
557 await self
.proto
.start_server_and_accept('/not/a/real/path')
559 # Cheat: force a disconnect.
560 await self
.proto
.simulate_disconnect()
562 await self
._prod
_session
_api
(
563 Runstate
.DISCONNECTING
,
564 ("NullProtocol is disconnecting."
565 " Call disconnect() to return to IDLE state."),
570 class SimpleSession(TestBase
):
574 self
.server
= LineProtocol(type(self
).__name
__ + '-server')
576 async def _asyncSetUp(self
):
577 await super()._asyncSetUp
()
578 await self
._watch
_runstates
(*self
.GOOD_CONNECTION_STATES
)
580 async def _asyncTearDown(self
):
581 await self
.proto
.disconnect()
583 await self
.server
.disconnect()
586 await super()._asyncTearDown
()
589 async def testSmoke(self
):
590 with
TemporaryDirectory(suffix
='.qmp') as tmpdir
:
591 sock
= os
.path
.join(tmpdir
, type(self
.proto
).__name
__ + ".sock")
592 server_task
= create_task(self
.server
.start_server_and_accept(sock
))
594 # give the server a chance to start listening [...]
595 await asyncio
.sleep(0)
596 await self
.proto
.connect(sock
)