webrtcd: turn CerealProxyRunner into more general AsyncTaskRunner (#38093)

* make cereal proxy runner more general

* missing init()
This commit is contained in:
stef
2026-05-27 13:49:46 -07:00
committed by GitHub
parent 4585e93066
commit 4cd93f3eee

View File

@@ -1,5 +1,6 @@
#!/usr/bin/env python3
from abc import abstractmethod
import time
import argparse
import asyncio
@@ -24,8 +25,35 @@ from openpilot.system.webrtc.schema import generate_field
from cereal import messaging, log
class CerealOutgoingMessageProxy:
class AsyncTaskRunner:
def __init__(self):
self.is_running = False
self.task = None
self.logger = logging.getLogger("webrtcd")
def start(self):
assert self.task is None
self.task = asyncio.create_task(self.run())
async def stop(self):
if self.task is None:
return
task = self.task
self.task = None
if task.done():
return
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await task
@abstractmethod
async def run(self):
pass
class CerealOutgoingMessageProxy(AsyncTaskRunner):
def __init__(self, sm: messaging.SubMaster):
super().__init__()
self.sm = sm
self.channels: list[RTCDataChannel] = []
@@ -57,6 +85,19 @@ class CerealOutgoingMessageProxy:
for channel in self.channels:
channel.send(encoded_msg)
async def run(self):
from aiortc.exceptions import InvalidStateError
while True:
try:
self.update()
except InvalidStateError:
self.logger.warning("Cereal outgoing proxy invalid state (connection closed)")
break
except Exception:
self.logger.exception("Cereal outgoing proxy failure")
await asyncio.sleep(0.01)
class CerealIncomingMessageProxy:
def __init__(self, pm: messaging.PubMaster):
@@ -74,42 +115,6 @@ class CerealIncomingMessageProxy:
self.pm.send(msg_type, msg)
class CerealProxyRunner:
def __init__(self, proxy: CerealOutgoingMessageProxy):
self.proxy = proxy
self.is_running = False
self.task = None
self.logger = logging.getLogger("webrtcd")
def start(self):
assert self.task is None
self.task = asyncio.create_task(self.run())
async def stop(self):
if self.task is None:
return
task = self.task
self.task = None
if task.done():
return
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await task
async def run(self):
from aiortc.exceptions import InvalidStateError
while True:
try:
self.proxy.update()
except InvalidStateError:
self.logger.warning("Cereal outgoing proxy invalid state (connection closed)")
break
except Exception:
self.logger.exception("Cereal outgoing proxy failure")
await asyncio.sleep(0.01)
class DynamicPubMaster(messaging.PubMaster):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -141,12 +146,10 @@ class StreamSession:
self.incoming_bridge: CerealIncomingMessageProxy | None = None
self.incoming_bridge_services = incoming_services
self.outgoing_bridge: CerealOutgoingMessageProxy | None = None
self.outgoing_bridge_runner: CerealProxyRunner | None = None
if len(incoming_services) > 0:
self.incoming_bridge = CerealIncomingMessageProxy(self.shared_pub_master)
if len(outgoing_services) > 0:
self.outgoing_bridge = CerealOutgoingMessageProxy(messaging.SubMaster(outgoing_services))
self.outgoing_bridge_runner = CerealProxyRunner(self.outgoing_bridge)
self.run_task: asyncio.Task | None = None
self._cleanup_lock = asyncio.Lock()
@@ -209,10 +212,10 @@ class StreamSession:
if self.incoming_bridge is not None:
await self.shared_pub_master.add_services_if_needed(self.incoming_bridge_services)
self.stream.set_message_handler(self.message_handler)
if self.outgoing_bridge_runner is not None:
if self.outgoing_bridge is not None:
channel = self.stream.get_messaging_channel()
self.outgoing_bridge_runner.proxy.add_channel(channel)
self.outgoing_bridge_runner.start()
self.outgoing_bridge.add_channel(channel)
self.outgoing_bridge.start()
self.logger.info("Stream session (%s) connected", self.identifier)
await self.stream.wait_for_disconnection()
@@ -228,8 +231,8 @@ class StreamSession:
if self._cleanup_done:
return
self._cleanup_done = True
if self.outgoing_bridge_runner is not None:
await self.outgoing_bridge_runner.stop()
if self.outgoing_bridge is not None:
await self.outgoing_bridge.stop()
await self.stream.stop()