5 from typing
import Any
, Callable
, Coroutine
, List
, Optional
, Mapping
9 logger
= logging
.getLogger("webdriver.bidi")
12 def get_running_loop() -> asyncio
.AbstractEventLoop
:
13 if sys
.version_info
>= (3, 7):
14 return asyncio
.get_running_loop()
16 # Unlike the above, this will actually create an event loop
17 # if there isn't one; hopefully running tests in Python >= 3.7
18 # will allow us to catch any behaviour difference
19 # (Needs to be in else for mypy to believe this is reachable)
20 return asyncio
.get_event_loop()
24 """Low level message handler for the WebSockets connection"""
25 def __init__(self
, url
: str,
26 msg_handler
: Callable
[[Mapping
[str, Any
]], Coroutine
[Any
, Any
, None]],
27 loop
: Optional
[asyncio
.AbstractEventLoop
] = None):
29 self
.connection
: Optional
[websockets
.WebSocketClientProtocol
] = None
30 self
.msg_handler
= msg_handler
31 self
.send_buf
: List
[Mapping
[str, Any
]] = []
34 loop
= get_running_loop()
37 self
.read_message_task
: Optional
[asyncio
.Task
[Any
]] = None
39 async def start(self
) -> None:
40 self
.connection
= await websockets
.client
.connect(self
.url
)
41 self
.read_message_task
= self
.loop
.create_task(self
.read_messages())
43 for msg
in self
.send_buf
:
44 await self
._send
(self
.connection
, msg
)
46 async def send(self
, data
: Mapping
[str, Any
]) -> None:
47 if self
.connection
is not None:
48 await self
._send
(self
.connection
, data
)
50 self
.send_buf
.append(data
)
54 connection
: websockets
.WebSocketClientProtocol
,
55 data
: Mapping
[str, Any
]
57 msg
= json
.dumps(data
)
58 logger
.debug("→ %s", msg
)
59 await connection
.send(msg
)
61 async def handle(self
, msg
: str) -> None:
62 logger
.debug("← %s", msg
)
63 data
= json
.loads(msg
)
64 await self
.msg_handler(data
)
66 async def end(self
) -> None:
68 await self
.connection
.close()
69 self
.connection
= None
71 async def read_messages(self
) -> None:
72 assert self
.connection
is not None
73 async for msg
in self
.connection
:
74 if not isinstance(msg
, str):
75 raise ValueError("Got a binary message")
76 await self
.handle(msg
)