mirror of
https://github.com/sunnypilot/sunnypilot.git
synced 2026-06-08 12:34:59 +08:00
webrtcd: turn CerealProxyRunner into more general AsyncTaskRunner (#38093)
* make cereal proxy runner more general * missing init()
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from abc import abstractmethod
|
||||
import time
|
||||
import argparse
|
||||
import asyncio
|
||||
@@ -24,8 +25,35 @@ from openpilot.system.webrtc.schema import generate_field
|
||||
from cereal import messaging, log
|
||||
|
||||
|
||||
class CerealOutgoingMessageProxy:
|
||||
class AsyncTaskRunner:
|
||||
def __init__(self):
|
||||
self.is_running = False
|
||||
self.task = None
|
||||
self.logger = logging.getLogger("webrtcd")
|
||||
|
||||
def start(self):
|
||||
assert self.task is None
|
||||
self.task = asyncio.create_task(self.run())
|
||||
|
||||
async def stop(self):
|
||||
if self.task is None:
|
||||
return
|
||||
task = self.task
|
||||
self.task = None
|
||||
if task.done():
|
||||
return
|
||||
task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
@abstractmethod
|
||||
async def run(self):
|
||||
pass
|
||||
|
||||
|
||||
class CerealOutgoingMessageProxy(AsyncTaskRunner):
|
||||
def __init__(self, sm: messaging.SubMaster):
|
||||
super().__init__()
|
||||
self.sm = sm
|
||||
self.channels: list[RTCDataChannel] = []
|
||||
|
||||
@@ -57,6 +85,19 @@ class CerealOutgoingMessageProxy:
|
||||
for channel in self.channels:
|
||||
channel.send(encoded_msg)
|
||||
|
||||
async def run(self):
|
||||
from aiortc.exceptions import InvalidStateError
|
||||
|
||||
while True:
|
||||
try:
|
||||
self.update()
|
||||
except InvalidStateError:
|
||||
self.logger.warning("Cereal outgoing proxy invalid state (connection closed)")
|
||||
break
|
||||
except Exception:
|
||||
self.logger.exception("Cereal outgoing proxy failure")
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
|
||||
class CerealIncomingMessageProxy:
|
||||
def __init__(self, pm: messaging.PubMaster):
|
||||
@@ -74,42 +115,6 @@ class CerealIncomingMessageProxy:
|
||||
self.pm.send(msg_type, msg)
|
||||
|
||||
|
||||
class CerealProxyRunner:
|
||||
def __init__(self, proxy: CerealOutgoingMessageProxy):
|
||||
self.proxy = proxy
|
||||
self.is_running = False
|
||||
self.task = None
|
||||
self.logger = logging.getLogger("webrtcd")
|
||||
|
||||
def start(self):
|
||||
assert self.task is None
|
||||
self.task = asyncio.create_task(self.run())
|
||||
|
||||
async def stop(self):
|
||||
if self.task is None:
|
||||
return
|
||||
task = self.task
|
||||
self.task = None
|
||||
if task.done():
|
||||
return
|
||||
task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
async def run(self):
|
||||
from aiortc.exceptions import InvalidStateError
|
||||
|
||||
while True:
|
||||
try:
|
||||
self.proxy.update()
|
||||
except InvalidStateError:
|
||||
self.logger.warning("Cereal outgoing proxy invalid state (connection closed)")
|
||||
break
|
||||
except Exception:
|
||||
self.logger.exception("Cereal outgoing proxy failure")
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
|
||||
class DynamicPubMaster(messaging.PubMaster):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -141,12 +146,10 @@ class StreamSession:
|
||||
self.incoming_bridge: CerealIncomingMessageProxy | None = None
|
||||
self.incoming_bridge_services = incoming_services
|
||||
self.outgoing_bridge: CerealOutgoingMessageProxy | None = None
|
||||
self.outgoing_bridge_runner: CerealProxyRunner | None = None
|
||||
if len(incoming_services) > 0:
|
||||
self.incoming_bridge = CerealIncomingMessageProxy(self.shared_pub_master)
|
||||
if len(outgoing_services) > 0:
|
||||
self.outgoing_bridge = CerealOutgoingMessageProxy(messaging.SubMaster(outgoing_services))
|
||||
self.outgoing_bridge_runner = CerealProxyRunner(self.outgoing_bridge)
|
||||
|
||||
self.run_task: asyncio.Task | None = None
|
||||
self._cleanup_lock = asyncio.Lock()
|
||||
@@ -209,10 +212,10 @@ class StreamSession:
|
||||
if self.incoming_bridge is not None:
|
||||
await self.shared_pub_master.add_services_if_needed(self.incoming_bridge_services)
|
||||
self.stream.set_message_handler(self.message_handler)
|
||||
if self.outgoing_bridge_runner is not None:
|
||||
if self.outgoing_bridge is not None:
|
||||
channel = self.stream.get_messaging_channel()
|
||||
self.outgoing_bridge_runner.proxy.add_channel(channel)
|
||||
self.outgoing_bridge_runner.start()
|
||||
self.outgoing_bridge.add_channel(channel)
|
||||
self.outgoing_bridge.start()
|
||||
self.logger.info("Stream session (%s) connected", self.identifier)
|
||||
|
||||
await self.stream.wait_for_disconnection()
|
||||
@@ -228,8 +231,8 @@ class StreamSession:
|
||||
if self._cleanup_done:
|
||||
return
|
||||
self._cleanup_done = True
|
||||
if self.outgoing_bridge_runner is not None:
|
||||
await self.outgoing_bridge_runner.stop()
|
||||
if self.outgoing_bridge is not None:
|
||||
await self.outgoing_bridge.stop()
|
||||
await self.stream.stop()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user