Bug 1758345 [wpt PR 33082] - Bump mypy from 0.812 to 0.931 in /tools, a=testonly
[gecko.git] / testing / web-platform / tests / tools / webdriver / webdriver / bidi / transport.py
blobafe054528e8a109c98925b2d55799fc296aa8454
1 import asyncio
2 import json
3 import logging
4 import sys
5 from typing import Any, Callable, Coroutine, List, Optional, Mapping
7 import websockets
9 logger = logging.getLogger("webdriver.bidi")
12 def get_running_loop() -> asyncio.AbstractEventLoop:
13 if sys.version_info >= (3, 7):
14 return asyncio.get_running_loop()
15 else:
16 # Unlike the above, this will actually create an event loop
17 # if there isn't one; hopefully running tests in Python >= 3.7
18 # will allow us to catch any behaviour difference
19 # (Needs to be in else for mypy to believe this is reachable)
20 return asyncio.get_event_loop()
23 class Transport:
24 """Low level message handler for the WebSockets connection"""
25 def __init__(self, url: str,
26 msg_handler: Callable[[Mapping[str, Any]], Coroutine[Any, Any, None]],
27 loop: Optional[asyncio.AbstractEventLoop] = None):
28 self.url = url
29 self.connection: Optional[websockets.WebSocketClientProtocol] = None
30 self.msg_handler = msg_handler
31 self.send_buf: List[Mapping[str, Any]] = []
33 if loop is None:
34 loop = get_running_loop()
35 self.loop = loop
37 self.read_message_task: Optional[asyncio.Task[Any]] = None
39 async def start(self) -> None:
40 self.connection = await websockets.client.connect(self.url)
41 self.read_message_task = self.loop.create_task(self.read_messages())
43 for msg in self.send_buf:
44 await self._send(self.connection, msg)
46 async def send(self, data: Mapping[str, Any]) -> None:
47 if self.connection is not None:
48 await self._send(self.connection, data)
49 else:
50 self.send_buf.append(data)
52 @staticmethod
53 async def _send(
54 connection: websockets.WebSocketClientProtocol,
55 data: Mapping[str, Any]
56 ) -> None:
57 msg = json.dumps(data)
58 logger.debug("→ %s", msg)
59 await connection.send(msg)
61 async def handle(self, msg: str) -> None:
62 logger.debug("← %s", msg)
63 data = json.loads(msg)
64 await self.msg_handler(data)
66 async def end(self) -> None:
67 if self.connection:
68 await self.connection.close()
69 self.connection = None
71 async def read_messages(self) -> None:
72 assert self.connection is not None
73 async for msg in self.connection:
74 if not isinstance(msg, str):
75 raise ValueError("Got a binary message")
76 await self.handle(msg)