Merge tag 'qemu-macppc-20230206' of https://github.com/mcayland/qemu into staging
[qemu.git] / python / tests / protocol.py
blob56c4d441f9c4a994f4f440c350a84465a0e78b68
1 import asyncio
2 from contextlib import contextmanager
3 import os
4 import socket
5 from tempfile import TemporaryDirectory
7 import avocado
9 from qemu.qmp import ConnectError, Runstate
10 from qemu.qmp.protocol import AsyncProtocol, StateError
11 from qemu.qmp.util import asyncio_run, create_task
14 class NullProtocol(AsyncProtocol[None]):
15 """
16 NullProtocol is a test mockup of an AsyncProtocol implementation.
18 It adds a fake_session instance variable that enables a code path
19 that bypasses the actual connection logic, but still allows the
20 reader/writers to start.
22 Because the message type is defined as None, an asyncio.Event named
23 'trigger_input' is created that prohibits the reader from
24 incessantly being able to yield None; this event can be poked to
25 simulate an incoming message.
27 For testing symmetry with do_recv, an interface is added to "send" a
28 Null message.
30 For testing purposes, a "simulate_disconnection" method is also
31 added which allows us to trigger a bottom half disconnect without
32 injecting any real errors into the reader/writer loops; in essence
33 it performs exactly half of what disconnect() normally does.
34 """
35 def __init__(self, name=None):
36 self.fake_session = False
37 self.trigger_input: asyncio.Event
38 super().__init__(name)
40 async def _establish_session(self):
41 self.trigger_input = asyncio.Event()
42 await super()._establish_session()
44 async def _do_start_server(self, address, ssl=None):
45 if self.fake_session:
46 self._accepted = asyncio.Event()
47 self._set_state(Runstate.CONNECTING)
48 await asyncio.sleep(0)
49 else:
50 await super()._do_start_server(address, ssl)
52 async def _do_accept(self):
53 if self.fake_session:
54 self._accepted = None
55 else:
56 await super()._do_accept()
58 async def _do_connect(self, address, ssl=None):
59 if self.fake_session:
60 self._set_state(Runstate.CONNECTING)
61 await asyncio.sleep(0)
62 else:
63 await super()._do_connect(address, ssl)
65 async def _do_recv(self) -> None:
66 await self.trigger_input.wait()
67 self.trigger_input.clear()
69 def _do_send(self, msg: None) -> None:
70 pass
72 async def send_msg(self) -> None:
73 await self._outgoing.put(None)
75 async def simulate_disconnect(self) -> None:
76 """
77 Simulates a bottom-half disconnect.
79 This method schedules a disconnection but does not wait for it
80 to complete. This is used to put the loop into the DISCONNECTING
81 state without fully quiescing it back to IDLE. This is normally
82 something you cannot coax AsyncProtocol to do on purpose, but it
83 will be similar to what happens with an unhandled Exception in
84 the reader/writer.
86 Under normal circumstances, the library design requires you to
87 await on disconnect(), which awaits the disconnect task and
88 returns bottom half errors as a pre-condition to allowing the
89 loop to return back to IDLE.
90 """
91 self._schedule_disconnect()
94 class LineProtocol(AsyncProtocol[str]):
95 def __init__(self, name=None):
96 super().__init__(name)
97 self.rx_history = []
99 async def _do_recv(self) -> str:
100 raw = await self._readline()
101 msg = raw.decode()
102 self.rx_history.append(msg)
103 return msg
105 def _do_send(self, msg: str) -> None:
106 assert self._writer is not None
107 self._writer.write(msg.encode() + b'\n')
109 async def send_msg(self, msg: str) -> None:
110 await self._outgoing.put(msg)
113 def run_as_task(coro, allow_cancellation=False):
115 Run a given coroutine as a task.
117 Optionally, wrap it in a try..except block that allows this
118 coroutine to be canceled gracefully.
120 async def _runner():
121 try:
122 await coro
123 except asyncio.CancelledError:
124 if allow_cancellation:
125 return
126 raise
127 return create_task(_runner())
130 @contextmanager
131 def jammed_socket():
133 Opens up a random unused TCP port on localhost, then jams it.
135 socks = []
137 try:
138 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
139 sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
140 sock.bind(('127.0.0.1', 0))
141 sock.listen(1)
142 address = sock.getsockname()
144 socks.append(sock)
146 # I don't *fully* understand why, but it takes *two* un-accepted
147 # connections to start jamming the socket.
148 for _ in range(2):
149 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
150 sock.connect(address)
151 socks.append(sock)
153 yield address
155 finally:
156 for sock in socks:
157 sock.close()
160 class Smoke(avocado.Test):
162 def setUp(self):
163 self.proto = NullProtocol()
165 def test__repr__(self):
166 self.assertEqual(
167 repr(self.proto),
168 "<NullProtocol runstate=IDLE>"
171 def testRunstate(self):
172 self.assertEqual(
173 self.proto.runstate,
174 Runstate.IDLE
177 def testDefaultName(self):
178 self.assertEqual(
179 self.proto.name,
180 None
183 def testLogger(self):
184 self.assertEqual(
185 self.proto.logger.name,
186 'qemu.qmp.protocol'
189 def testName(self):
190 self.proto = NullProtocol('Steve')
192 self.assertEqual(
193 self.proto.name,
194 'Steve'
197 self.assertEqual(
198 self.proto.logger.name,
199 'qemu.qmp.protocol.Steve'
202 self.assertEqual(
203 repr(self.proto),
204 "<NullProtocol name='Steve' runstate=IDLE>"
208 class TestBase(avocado.Test):
210 def setUp(self):
211 self.proto = NullProtocol(type(self).__name__)
212 self.assertEqual(self.proto.runstate, Runstate.IDLE)
213 self.runstate_watcher = None
215 def tearDown(self):
216 self.assertEqual(self.proto.runstate, Runstate.IDLE)
218 async def _asyncSetUp(self):
219 pass
221 async def _asyncTearDown(self):
222 if self.runstate_watcher:
223 await self.runstate_watcher
225 @staticmethod
226 def async_test(async_test_method):
228 Decorator; adds SetUp and TearDown to async tests.
230 async def _wrapper(self, *args, **kwargs):
231 loop = asyncio.get_event_loop()
232 loop.set_debug(True)
234 await self._asyncSetUp()
235 await async_test_method(self, *args, **kwargs)
236 await self._asyncTearDown()
238 return _wrapper
240 # Definitions
242 # The states we expect a "bad" connect/accept attempt to transition through
243 BAD_CONNECTION_STATES = (
244 Runstate.CONNECTING,
245 Runstate.DISCONNECTING,
246 Runstate.IDLE,
249 # The states we expect a "good" session to transition through
250 GOOD_CONNECTION_STATES = (
251 Runstate.CONNECTING,
252 Runstate.RUNNING,
253 Runstate.DISCONNECTING,
254 Runstate.IDLE,
257 # Helpers
259 async def _watch_runstates(self, *states):
261 This launches a task alongside (most) tests below to confirm that
262 the sequence of runstate changes that occur is exactly as
263 anticipated.
265 async def _watcher():
266 for state in states:
267 new_state = await self.proto.runstate_changed()
268 self.assertEqual(
269 new_state,
270 state,
271 msg=f"Expected state '{state.name}'",
274 self.runstate_watcher = create_task(_watcher())
275 # Kick the loop and force the task to block on the event.
276 await asyncio.sleep(0)
279 class State(TestBase):
281 @TestBase.async_test
282 async def testSuperfluousDisconnect(self):
284 Test calling disconnect() while already disconnected.
286 await self._watch_runstates(
287 Runstate.DISCONNECTING,
288 Runstate.IDLE,
290 await self.proto.disconnect()
293 class Connect(TestBase):
295 Tests primarily related to calling Connect().
297 async def _bad_connection(self, family: str):
298 assert family in ('INET', 'UNIX')
300 if family == 'INET':
301 await self.proto.connect(('127.0.0.1', 0))
302 elif family == 'UNIX':
303 await self.proto.connect('/dev/null')
305 async def _hanging_connection(self):
306 with jammed_socket() as addr:
307 await self.proto.connect(addr)
309 async def _bad_connection_test(self, family: str):
310 await self._watch_runstates(*self.BAD_CONNECTION_STATES)
312 with self.assertRaises(ConnectError) as context:
313 await self._bad_connection(family)
315 self.assertIsInstance(context.exception.exc, OSError)
316 self.assertEqual(
317 context.exception.error_message,
318 "Failed to establish connection"
321 @TestBase.async_test
322 async def testBadINET(self):
324 Test an immediately rejected call to an IP target.
326 await self._bad_connection_test('INET')
328 @TestBase.async_test
329 async def testBadUNIX(self):
331 Test an immediately rejected call to a UNIX socket target.
333 await self._bad_connection_test('UNIX')
335 @TestBase.async_test
336 async def testCancellation(self):
338 Test what happens when a connection attempt is aborted.
340 # Note that accept() cannot be cancelled outright, as it isn't a task.
341 # However, we can wrap it in a task and cancel *that*.
342 await self._watch_runstates(*self.BAD_CONNECTION_STATES)
343 task = run_as_task(self._hanging_connection(), allow_cancellation=True)
345 state = await self.proto.runstate_changed()
346 self.assertEqual(state, Runstate.CONNECTING)
348 # This is insider baseball, but the connection attempt has
349 # yielded *just* before the actual connection attempt, so kick
350 # the loop to make sure it's truly wedged.
351 await asyncio.sleep(0)
353 task.cancel()
354 await task
356 @TestBase.async_test
357 async def testTimeout(self):
359 Test what happens when a connection attempt times out.
361 await self._watch_runstates(*self.BAD_CONNECTION_STATES)
362 task = run_as_task(self._hanging_connection())
364 # More insider baseball: to improve the speed of this test while
365 # guaranteeing that the connection even gets a chance to start,
366 # verify that the connection hangs *first*, then await the
367 # result of the task with a nearly-zero timeout.
369 state = await self.proto.runstate_changed()
370 self.assertEqual(state, Runstate.CONNECTING)
371 await asyncio.sleep(0)
373 with self.assertRaises(asyncio.TimeoutError):
374 await asyncio.wait_for(task, timeout=0)
376 @TestBase.async_test
377 async def testRequire(self):
379 Test what happens when a connection attempt is made while CONNECTING.
381 await self._watch_runstates(*self.BAD_CONNECTION_STATES)
382 task = run_as_task(self._hanging_connection(), allow_cancellation=True)
384 state = await self.proto.runstate_changed()
385 self.assertEqual(state, Runstate.CONNECTING)
387 with self.assertRaises(StateError) as context:
388 await self._bad_connection('UNIX')
390 self.assertEqual(
391 context.exception.error_message,
392 "NullProtocol is currently connecting."
394 self.assertEqual(context.exception.state, Runstate.CONNECTING)
395 self.assertEqual(context.exception.required, Runstate.IDLE)
397 task.cancel()
398 await task
400 @TestBase.async_test
401 async def testImplicitRunstateInit(self):
403 Test what happens if we do not wait on the runstate event until
404 AFTER a connection is made, i.e., connect()/accept() themselves
405 initialize the runstate event. All of the above tests force the
406 initialization by waiting on the runstate *first*.
408 task = run_as_task(self._hanging_connection(), allow_cancellation=True)
410 # Kick the loop to coerce the state change
411 await asyncio.sleep(0)
412 assert self.proto.runstate == Runstate.CONNECTING
414 # We already missed the transition to CONNECTING
415 await self._watch_runstates(Runstate.DISCONNECTING, Runstate.IDLE)
417 task.cancel()
418 await task
421 class Accept(Connect):
423 All of the same tests as Connect, but using the accept() interface.
425 async def _bad_connection(self, family: str):
426 assert family in ('INET', 'UNIX')
428 if family == 'INET':
429 await self.proto.start_server_and_accept(('example.com', 1))
430 elif family == 'UNIX':
431 await self.proto.start_server_and_accept('/dev/null')
433 async def _hanging_connection(self):
434 with TemporaryDirectory(suffix='.qmp') as tmpdir:
435 sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock")
436 await self.proto.start_server_and_accept(sock)
439 class FakeSession(TestBase):
441 def setUp(self):
442 super().setUp()
443 self.proto.fake_session = True
445 async def _asyncSetUp(self):
446 await super()._asyncSetUp()
447 await self._watch_runstates(*self.GOOD_CONNECTION_STATES)
449 async def _asyncTearDown(self):
450 await self.proto.disconnect()
451 await super()._asyncTearDown()
453 ####
455 @TestBase.async_test
456 async def testFakeConnect(self):
458 """Test the full state lifecycle (via connect) with a no-op session."""
459 await self.proto.connect('/not/a/real/path')
460 self.assertEqual(self.proto.runstate, Runstate.RUNNING)
462 @TestBase.async_test
463 async def testFakeAccept(self):
464 """Test the full state lifecycle (via accept) with a no-op session."""
465 await self.proto.start_server_and_accept('/not/a/real/path')
466 self.assertEqual(self.proto.runstate, Runstate.RUNNING)
468 @TestBase.async_test
469 async def testFakeRecv(self):
470 """Test receiving a fake/null message."""
471 await self.proto.start_server_and_accept('/not/a/real/path')
473 logname = self.proto.logger.name
474 with self.assertLogs(logname, level='DEBUG') as context:
475 self.proto.trigger_input.set()
476 self.proto.trigger_input.clear()
477 await asyncio.sleep(0) # Kick reader.
479 self.assertEqual(
480 context.output,
481 [f"DEBUG:{logname}:<-- None"],
484 @TestBase.async_test
485 async def testFakeSend(self):
486 """Test sending a fake/null message."""
487 await self.proto.start_server_and_accept('/not/a/real/path')
489 logname = self.proto.logger.name
490 with self.assertLogs(logname, level='DEBUG') as context:
491 # Cheat: Send a Null message to nobody.
492 await self.proto.send_msg()
493 # Kick writer; awaiting on a queue.put isn't sufficient to yield.
494 await asyncio.sleep(0)
496 self.assertEqual(
497 context.output,
498 [f"DEBUG:{logname}:--> None"],
501 async def _prod_session_api(
502 self,
503 current_state: Runstate,
504 error_message: str,
505 accept: bool = True
507 with self.assertRaises(StateError) as context:
508 if accept:
509 await self.proto.start_server_and_accept('/not/a/real/path')
510 else:
511 await self.proto.connect('/not/a/real/path')
513 self.assertEqual(context.exception.error_message, error_message)
514 self.assertEqual(context.exception.state, current_state)
515 self.assertEqual(context.exception.required, Runstate.IDLE)
517 @TestBase.async_test
518 async def testAcceptRequireRunning(self):
519 """Test that accept() cannot be called when Runstate=RUNNING"""
520 await self.proto.start_server_and_accept('/not/a/real/path')
522 await self._prod_session_api(
523 Runstate.RUNNING,
524 "NullProtocol is already connected and running.",
525 accept=True,
528 @TestBase.async_test
529 async def testConnectRequireRunning(self):
530 """Test that connect() cannot be called when Runstate=RUNNING"""
531 await self.proto.start_server_and_accept('/not/a/real/path')
533 await self._prod_session_api(
534 Runstate.RUNNING,
535 "NullProtocol is already connected and running.",
536 accept=False,
539 @TestBase.async_test
540 async def testAcceptRequireDisconnecting(self):
541 """Test that accept() cannot be called when Runstate=DISCONNECTING"""
542 await self.proto.start_server_and_accept('/not/a/real/path')
544 # Cheat: force a disconnect.
545 await self.proto.simulate_disconnect()
547 await self._prod_session_api(
548 Runstate.DISCONNECTING,
549 ("NullProtocol is disconnecting."
550 " Call disconnect() to return to IDLE state."),
551 accept=True,
554 @TestBase.async_test
555 async def testConnectRequireDisconnecting(self):
556 """Test that connect() cannot be called when Runstate=DISCONNECTING"""
557 await self.proto.start_server_and_accept('/not/a/real/path')
559 # Cheat: force a disconnect.
560 await self.proto.simulate_disconnect()
562 await self._prod_session_api(
563 Runstate.DISCONNECTING,
564 ("NullProtocol is disconnecting."
565 " Call disconnect() to return to IDLE state."),
566 accept=False,
570 class SimpleSession(TestBase):
572 def setUp(self):
573 super().setUp()
574 self.server = LineProtocol(type(self).__name__ + '-server')
576 async def _asyncSetUp(self):
577 await super()._asyncSetUp()
578 await self._watch_runstates(*self.GOOD_CONNECTION_STATES)
580 async def _asyncTearDown(self):
581 await self.proto.disconnect()
582 try:
583 await self.server.disconnect()
584 except EOFError:
585 pass
586 await super()._asyncTearDown()
588 @TestBase.async_test
589 async def testSmoke(self):
590 with TemporaryDirectory(suffix='.qmp') as tmpdir:
591 sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock")
592 server_task = create_task(self.server.start_server_and_accept(sock))
594 # give the server a chance to start listening [...]
595 await asyncio.sleep(0)
596 await self.proto.connect(sock)