Files
sunnypilot/system/webrtc/webrtcd.py
T
stef 4cd93f3eee webrtcd: turn CerealProxyRunner into more general AsyncTaskRunner (#38093)
* make cereal proxy runner more general

* missing init()
2026-05-27 13:49:46 -07:00

344 lines
11 KiB
Python
Executable File

#!/usr/bin/env python3
from abc import abstractmethod
import time
import argparse
import asyncio
import contextlib
import json
import uuid
import logging
from dataclasses import dataclass, field
from typing import Any, TYPE_CHECKING
# aiortc and its dependencies have lots of internal warnings :(
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning) # TODO: remove this when google-crc32c publish a python3.12 wheel
import capnp
from aiohttp import web
if TYPE_CHECKING:
from aiortc.rtcdatachannel import RTCDataChannel
from openpilot.system.webrtc.schema import generate_field
from cereal import messaging, log
class AsyncTaskRunner:
def __init__(self):
self.is_running = False
self.task = None
self.logger = logging.getLogger("webrtcd")
def start(self):
assert self.task is None
self.task = asyncio.create_task(self.run())
async def stop(self):
if self.task is None:
return
task = self.task
self.task = None
if task.done():
return
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await task
@abstractmethod
async def run(self):
pass
class CerealOutgoingMessageProxy(AsyncTaskRunner):
def __init__(self, sm: messaging.SubMaster):
super().__init__()
self.sm = sm
self.channels: list[RTCDataChannel] = []
def add_channel(self, channel: 'RTCDataChannel'):
self.channels.append(channel)
def to_json(self, msg_content: Any):
if isinstance(msg_content, capnp._DynamicStructReader):
msg_dict = msg_content.to_dict()
elif isinstance(msg_content, capnp._DynamicListReader):
msg_dict = [self.to_json(msg) for msg in msg_content]
elif isinstance(msg_content, bytes):
msg_dict = msg_content.decode()
else:
msg_dict = msg_content
return msg_dict
def update(self):
# this is blocking in async context...
self.sm.update(0)
for service, updated in self.sm.updated.items():
if not updated:
continue
msg_dict = self.to_json(self.sm[service])
mono_time, valid = self.sm.logMonoTime[service], self.sm.valid[service]
outgoing_msg = {"type": service, "logMonoTime": mono_time, "valid": valid, "data": msg_dict}
encoded_msg = json.dumps(outgoing_msg).encode()
for channel in self.channels:
channel.send(encoded_msg)
async def run(self):
from aiortc.exceptions import InvalidStateError
while True:
try:
self.update()
except InvalidStateError:
self.logger.warning("Cereal outgoing proxy invalid state (connection closed)")
break
except Exception:
self.logger.exception("Cereal outgoing proxy failure")
await asyncio.sleep(0.01)
class CerealIncomingMessageProxy:
def __init__(self, pm: messaging.PubMaster):
self.pm = pm
def send(self, message: bytes):
msg_json = json.loads(message)
msg_type, msg_data = msg_json["type"], msg_json["data"]
size = None
if not isinstance(msg_data, dict):
size = len(msg_data)
msg = messaging.new_message(msg_type, size=size)
setattr(msg, msg_type, msg_data)
self.pm.send(msg_type, msg)
class DynamicPubMaster(messaging.PubMaster):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lock = asyncio.Lock()
async def add_services_if_needed(self, services):
async with self.lock:
for service in services:
if service not in self.sock:
self.sock[service] = messaging.pub_sock(service)
class StreamSession:
shared_pub_master = DynamicPubMaster([])
def __init__(self, sdp: str, init_camera: str, incoming_services: list[str], outgoing_services: list[str], debug_mode: bool = False):
from aiortc.mediastreams import VideoStreamTrack
from openpilot.system.webrtc.device.video import LiveStreamVideoStreamTrack
from teleoprtc import WebRTCAnswerBuilder
builder = WebRTCAnswerBuilder(sdp)
self.video_track = LiveStreamVideoStreamTrack(init_camera) if not debug_mode else VideoStreamTrack()
builder.add_video_stream(init_camera, self.video_track)
self.stream = builder.stream()
self.identifier = str(uuid.uuid4())
self.incoming_bridge: CerealIncomingMessageProxy | None = None
self.incoming_bridge_services = incoming_services
self.outgoing_bridge: CerealOutgoingMessageProxy | None = None
if len(incoming_services) > 0:
self.incoming_bridge = CerealIncomingMessageProxy(self.shared_pub_master)
if len(outgoing_services) > 0:
self.outgoing_bridge = CerealOutgoingMessageProxy(messaging.SubMaster(outgoing_services))
self.run_task: asyncio.Task | None = None
self._cleanup_lock = asyncio.Lock()
self._cleanup_done = False
self.logger = logging.getLogger("webrtcd")
self.logger.info(
"New stream session (%s), init camera %s, incoming services %s, outgoing services %s",
self.identifier, init_camera, incoming_services, outgoing_services,
)
def start(self):
self.run_task = asyncio.create_task(self.run())
async def stop(self):
if self.run_task is not None and not self.run_task.done() and self.run_task is not asyncio.current_task():
self.run_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self.run_task
self.run_task = None
await self.post_run_cleanup()
async def get_answer(self):
return await self.stream.start()
def message_handler(self, message: bytes):
assert self.incoming_bridge is not None
try:
payload = json.loads(message) if isinstance(message, (bytes, str)) else None
if isinstance(payload, dict):
msg_type = payload.get("type")
if msg_type == "livestreamCameraSwitch":
self.video_track.switch_camera(payload["data"]["camera"])
return
if msg_type == "clockSync":
data = payload.get("data", {})
pong = json.dumps({"type": "clockSync", "data": {
"action": "pong", "browserSendTime": data.get("browserSendTime"), "deviceTime": time.time() * 1000, # noqa: TID251
}})
self.stream.get_messaging_channel().send(pong)
return
if msg_type == "enableTimingSei":
enabled = bool(payload.get("data", {}).get("enabled"))
if hasattr(self.video_track, 'timing_sei_enabled'):
self.video_track.timing_sei_enabled = enabled
return
if payload.get("type") not in self.incoming_bridge_services:
return
self.incoming_bridge.send(message)
except Exception:
self.logger.exception("Cereal incoming proxy failure")
async def run(self):
try:
await self.stream.wait_for_connection()
if self.stream.has_messaging_channel():
if self.incoming_bridge is not None:
await self.shared_pub_master.add_services_if_needed(self.incoming_bridge_services)
self.stream.set_message_handler(self.message_handler)
if self.outgoing_bridge is not None:
channel = self.stream.get_messaging_channel()
self.outgoing_bridge.add_channel(channel)
self.outgoing_bridge.start()
self.logger.info("Stream session (%s) connected", self.identifier)
await self.stream.wait_for_disconnection()
self.logger.info("Stream session (%s) ended", self.identifier)
except Exception:
self.logger.exception("Stream session failure")
finally:
await self.post_run_cleanup()
async def post_run_cleanup(self):
async with self._cleanup_lock:
if self._cleanup_done:
return
self._cleanup_done = True
if self.outgoing_bridge is not None:
await self.outgoing_bridge.stop()
await self.stream.stop()
@dataclass
class StreamRequestBody:
sdp: str
initCamera: str
bridge_services_in: list[str] = field(default_factory=list)
bridge_services_out: list[str] = field(default_factory=list)
async def get_stream(request: 'web.Request'):
stream_dict, debug_mode = request.app['streams'], request.app['debug']
raw_body = await request.json()
body = StreamRequestBody(**raw_body)
async with request.app['stream_lock']:
# Fully disconnect any other active stream before starting the replacement.
for sid, s in list(stream_dict.items()):
if s.run_task and not s.run_task.done():
try:
ch = s.stream.get_messaging_channel()
ch.send(json.dumps({"type": "connectionReplaced", "data": "Another device has connected, closing this session."}))
except Exception:
pass
await s.stop()
del stream_dict[sid]
session = StreamSession(body.sdp, body.initCamera, body.bridge_services_in, body.bridge_services_out, debug_mode)
try:
answer = await session.get_answer()
except ValueError as e:
await session.stop()
raise web.HTTPBadRequest(
text=json.dumps({"error": "invalid_sdp", "message": str(e)}),
content_type="application/json",
) from e
except Exception:
await session.stop()
raise
session.start()
stream_dict[session.identifier] = session
return web.json_response({"sdp": answer.sdp, "type": answer.type})
async def get_schema(request: 'web.Request'):
services = request.query["services"].split(",")
services = [s for s in services if s]
assert all(s in log.Event.schema.fields and not s.endswith("DEPRECATED") for s in services), "Invalid service name"
schema_dict = {s: generate_field(log.Event.schema.fields[s]) for s in services}
return web.json_response(schema_dict)
async def post_notify(request: 'web.Request'):
try:
payload = await request.json()
except Exception as e:
raise web.HTTPBadRequest(text="Invalid JSON") from e
for session in list(request.app.get('streams', {}).values()):
try:
ch = session.stream.get_messaging_channel()
ch.send(json.dumps(payload))
except Exception:
continue
return web.Response(status=200, text="OK")
async def on_shutdown(app: 'web.Application'):
for session in app['streams'].values():
await session.stop()
del app['streams']
def webrtcd_thread(host: str, port: int, debug: bool):
logging.basicConfig(level=logging.CRITICAL, handlers=[logging.StreamHandler()])
logging_level = logging.DEBUG if debug else logging.INFO
logging.getLogger("WebRTCStream").setLevel(logging_level)
logging.getLogger("webrtcd").setLevel(logging_level)
app = web.Application()
app['streams'] = dict()
app['stream_lock'] = asyncio.Lock()
app['debug'] = debug
app.on_shutdown.append(on_shutdown)
app.router.add_post("/stream", get_stream)
app.router.add_post("/notify", post_notify)
app.router.add_get("/schema", get_schema)
web.run_app(app, host=host, port=port)
def main():
parser = argparse.ArgumentParser(description="WebRTC daemon")
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to listen on")
parser.add_argument("--port", type=int, default=5001, help="Port to listen on")
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
args = parser.parse_args()
webrtcd_thread(args.host, args.port, args.debug)
if __name__=="__main__":
main()