diff --git a/messaging/__init__.py b/messaging/__init__.py index 5e36bd8..cb92b7a 100644 --- a/messaging/__init__.py +++ b/messaging/__init__.py @@ -1,5 +1,5 @@ # must be build with scons -from .messaging_pyx import Context, Poller, SubSocket, PubSocket # pylint: disable=no-name-in-module, import-error +from .messaging_pyx import Context, Poller, SubSocket, PubSocket, SubMaster # pylint: disable=no-name-in-module, import-error from .messaging_pyx import MultiplePublishersError, MessagingError # pylint: disable=no-name-in-module, import-error import capnp @@ -122,89 +122,6 @@ def recv_one_retry(sock): if dat is not None: return log.Event.from_bytes(dat) -class SubMaster(): - def __init__(self, services, ignore_alive=None, addr="127.0.0.1"): - self.poller = Poller() - self.frame = -1 - self.updated = {s: False for s in services} - self.rcv_time = {s: 0. for s in services} - self.rcv_frame = {s: 0 for s in services} - self.alive = {s: False for s in services} - self.sock = {} - self.freq = {} - self.data = {} - self.logMonoTime = {} - self.valid = {} - - if ignore_alive is not None: - self.ignore_alive = ignore_alive - else: - self.ignore_alive = [] - - for s in services: - if addr is not None: - self.sock[s] = sub_sock(s, poller=self.poller, addr=addr, conflate=True) - self.freq[s] = service_list[s].frequency - - try: - data = new_message(s) - except capnp.lib.capnp.KjException: # pylint: disable=c-extension-no-member - # lists - data = new_message(s, 0) - - self.data[s] = getattr(data, s) - self.logMonoTime[s] = 0 - self.valid[s] = data.valid - - def __getitem__(self, s): - return self.data[s] - - def update(self, timeout=1000): - msgs = [] - for sock in self.poller.poll(timeout): - msgs.append(recv_one_or_none(sock)) - self.update_msgs(sec_since_boot(), msgs) - - def update_msgs(self, cur_time, msgs): - # TODO: add optional input that specify the service to wait for - self.frame += 1 - self.updated = dict.fromkeys(self.updated, False) - for msg in msgs: - if msg is None: - continue - - s = msg.which() - self.updated[s] = True - self.rcv_time[s] = cur_time - self.rcv_frame[s] = self.frame - self.data[s] = getattr(msg, s) - self.logMonoTime[s] = msg.logMonoTime - self.valid[s] = msg.valid - - for s in self.data: - # arbitrary small number to avoid float comparison. If freq is 0, we can skip the check - if self.freq[s] > 1e-5: - # alive if delay is within 10x the expected frequency - self.alive[s] = (cur_time - self.rcv_time[s]) < (10. / self.freq[s]) - else: - self.alive[s] = True - - def all_alive(self, service_list=None): - if service_list is None: # check all - service_list = self.alive.keys() - return all(self.alive[s] for s in service_list if s not in self.ignore_alive) - - def all_valid(self, service_list=None): - if service_list is None: # check all - service_list = self.valid.keys() - return all(self.valid[s] for s in service_list) - - def all_alive_and_valid(self, service_list=None): - if service_list is None: # check all - service_list = self.alive.keys() - return self.all_alive(service_list=service_list) and self.all_valid(service_list=service_list) - - class PubMaster(): def __init__(self, services): self.sock = {} diff --git a/messaging/messaging_pyx.pyx b/messaging/messaging_pyx.pyx index 1bb4073..34efcc1 100644 --- a/messaging/messaging_pyx.pyx +++ b/messaging/messaging_pyx.pyx @@ -13,6 +13,16 @@ from messaging cimport PubSocket as cppPubSocket from messaging cimport Poller as cppPoller from messaging cimport Message as cppMessage +import capnp +from cereal import log +from cereal.services import service_list + +try: + from common.realtime import sec_since_boot +except ImportError: + import time + sec_since_boot = time.time + print("Warning, using python time.time() instead of faster sec_since_boot") class MessagingError(Exception): pass @@ -59,12 +69,12 @@ cdef class Poller: cdef int t = timeout with nogil: - result = self.poller.poll(t) + result = self.poller.poll(t) for s in result: - socket = SubSocket() - socket.setPtr(s) - sockets.append(socket) + socket = SubSocket() + socket.setPtr(s) + sockets.append(socket) return sockets @@ -149,3 +159,110 @@ cdef class PubSocket: raise MultiplePublishersError else: raise MessagingError + + +context = Context() + +cdef class SubMaster: + cdef: + Poller poller + + cdef readonly: + int frame + dict updated + dict rcv_time + dict rcv_frame + dict alive + dict sock + dict freq + dict raw_data + dict data + dict logMonoTime + dict valid + list ignore_alive + + def __init__(self, services, ignore_alive=None, addr="127.0.0.1"): + self.poller = Poller() + self.frame = -1 + self.updated = {s: False for s in services} + self.rcv_time = {s: 0. for s in services} + self.rcv_frame = {s: 0 for s in services} + self.alive = {s: False for s in services} + self.sock = {} + self.freq = {} + self.raw_data= {} + self.data = {} + self.logMonoTime = {} + self.valid = {} + + if ignore_alive is not None: + self.ignore_alive = ignore_alive + else: + self.ignore_alive = [] + + for s in services: + if addr is not None: + self.sock[s] = SubSocket() + self.sock[s].connect(context, s, addr.encode('utf8'), True) + self.poller.registerSocket(self.sock[s]) + self.freq[s] = service_list[s].frequency + + data = log.Event.new_message() + try: + data.init(s) + except capnp.lib.capnp.KjException: + data.init(s, 0) # lists + self.data[s] = getattr(data, s) + self.logMonoTime[s] = 0 + self.valid[s] = True + + def __getitem__(self, s): + return self.data[s] + + def update(self, timeout=1000): + msgs = [] + for sock in self.poller.poll(timeout): + msg = sock.receive(non_blocking=True) + if msg is not None: + msg = log.Event.from_bytes(msg) + msgs.append(msg) + self.update_msgs(sec_since_boot(), msgs) + + def update_msgs(self, cur_time, msgs): + self.frame += 1 + self.updated = dict.fromkeys(self.updated, False) + for msg in msgs: + if msg is None: + continue + + s = msg.which() + self.updated[s] = True + self.rcv_time[s] = cur_time + self.rcv_frame[s] = self.frame + self.data[s] = getattr(msg, s) + self.logMonoTime[s] = msg.logMonoTime + self.valid[s] = msg.valid + + for s in self.data: + # arbitrary small number to avoid float comparison. If freq is 0, we can skip the check + if self.freq[s] > 1e-5: + # alive if delay is within 10x the expected frequency + self.alive[s] = (cur_time - self.rcv_time[s]) < (10. / self.freq[s]) + else: + self.alive[s] = True + + def all_alive(self, service_list=None): + if service_list is None: # check all + service_list = self.alive.keys() + return all(self.alive[s] for s in service_list if s not in self.ignore_alive) + + def all_valid(self, service_list=None): + if service_list is None: # check all + service_list = self.valid.keys() + return all(self.valid[s] for s in service_list) + + def all_alive_and_valid(self, service_list=None): + if service_list is None: # check all + service_list = self.alive.keys() + return self.all_alive(service_list=service_list) and self.all_valid(service_list=service_list) +