diff --git a/system/webrtc/webrtcd.py b/system/webrtc/webrtcd.py index 4fd922688e..9ff937b673 100755 --- a/system/webrtc/webrtcd.py +++ b/system/webrtc/webrtcd.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 +from abc import abstractmethod import time import argparse import asyncio @@ -24,8 +25,35 @@ from openpilot.system.webrtc.schema import generate_field from cereal import messaging, log -class CerealOutgoingMessageProxy: +class AsyncTaskRunner: + def __init__(self): + self.is_running = False + self.task = None + self.logger = logging.getLogger("webrtcd") + + def start(self): + assert self.task is None + self.task = asyncio.create_task(self.run()) + + async def stop(self): + if self.task is None: + return + task = self.task + self.task = None + if task.done(): + return + task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await task + + @abstractmethod + async def run(self): + pass + + +class CerealOutgoingMessageProxy(AsyncTaskRunner): def __init__(self, sm: messaging.SubMaster): + super().__init__() self.sm = sm self.channels: list[RTCDataChannel] = [] @@ -57,6 +85,19 @@ class CerealOutgoingMessageProxy: for channel in self.channels: channel.send(encoded_msg) + async def run(self): + from aiortc.exceptions import InvalidStateError + + while True: + try: + self.update() + except InvalidStateError: + self.logger.warning("Cereal outgoing proxy invalid state (connection closed)") + break + except Exception: + self.logger.exception("Cereal outgoing proxy failure") + await asyncio.sleep(0.01) + class CerealIncomingMessageProxy: def __init__(self, pm: messaging.PubMaster): @@ -74,42 +115,6 @@ class CerealIncomingMessageProxy: self.pm.send(msg_type, msg) -class CerealProxyRunner: - def __init__(self, proxy: CerealOutgoingMessageProxy): - self.proxy = proxy - self.is_running = False - self.task = None - self.logger = logging.getLogger("webrtcd") - - def start(self): - assert self.task is None - self.task = asyncio.create_task(self.run()) - - async def stop(self): - if self.task is None: - return - task = self.task - self.task = None - if task.done(): - return - task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await task - - async def run(self): - from aiortc.exceptions import InvalidStateError - - while True: - try: - self.proxy.update() - except InvalidStateError: - self.logger.warning("Cereal outgoing proxy invalid state (connection closed)") - break - except Exception: - self.logger.exception("Cereal outgoing proxy failure") - await asyncio.sleep(0.01) - - class DynamicPubMaster(messaging.PubMaster): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -141,12 +146,10 @@ class StreamSession: self.incoming_bridge: CerealIncomingMessageProxy | None = None self.incoming_bridge_services = incoming_services self.outgoing_bridge: CerealOutgoingMessageProxy | None = None - self.outgoing_bridge_runner: CerealProxyRunner | None = None if len(incoming_services) > 0: self.incoming_bridge = CerealIncomingMessageProxy(self.shared_pub_master) if len(outgoing_services) > 0: self.outgoing_bridge = CerealOutgoingMessageProxy(messaging.SubMaster(outgoing_services)) - self.outgoing_bridge_runner = CerealProxyRunner(self.outgoing_bridge) self.run_task: asyncio.Task | None = None self._cleanup_lock = asyncio.Lock() @@ -209,10 +212,10 @@ class StreamSession: if self.incoming_bridge is not None: await self.shared_pub_master.add_services_if_needed(self.incoming_bridge_services) self.stream.set_message_handler(self.message_handler) - if self.outgoing_bridge_runner is not None: + if self.outgoing_bridge is not None: channel = self.stream.get_messaging_channel() - self.outgoing_bridge_runner.proxy.add_channel(channel) - self.outgoing_bridge_runner.start() + self.outgoing_bridge.add_channel(channel) + self.outgoing_bridge.start() self.logger.info("Stream session (%s) connected", self.identifier) await self.stream.wait_for_disconnection() @@ -228,8 +231,8 @@ class StreamSession: if self._cleanup_done: return self._cleanup_done = True - if self.outgoing_bridge_runner is not None: - await self.outgoing_bridge_runner.stop() + if self.outgoing_bridge is not None: + await self.outgoing_bridge.stop() await self.stream.stop()