mirror of
https://github.com/firestar5683/StarPilot.git
synced 2026-06-28 01:52:06 +08:00
LogReader: no redownloading on multiple iterations (#31141)
* no redownload * sort old-commit-hash: 88dcaa51c4d1fcc338d44f55134593760334ae23
This commit is contained in:
+13
-8
@@ -12,7 +12,7 @@ import sys
|
||||
import urllib.parse
|
||||
import warnings
|
||||
|
||||
from typing import Iterable, Iterator, List, Type
|
||||
from typing import Dict, Iterable, Iterator, List, Type
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from cereal import log as capnp_log
|
||||
@@ -232,20 +232,25 @@ class LogReader:
|
||||
self.sort_by_time = sort_by_time
|
||||
self.only_union_types = only_union_types
|
||||
|
||||
self.__lrs: Dict[int, _LogFileReader] = {}
|
||||
self.reset()
|
||||
|
||||
def __iter__(self):
|
||||
for identifier in self.logreader_identifiers:
|
||||
yield from _LogFileReader(identifier)
|
||||
def _get_lr(self, i):
|
||||
if i not in self.__lrs:
|
||||
self.__lrs[i] = _LogFileReader(self.logreader_identifiers[i])
|
||||
return self.__lrs[i]
|
||||
|
||||
def _run_on_segment(self, func, identifier):
|
||||
lr = _LogFileReader(identifier)
|
||||
return func(lr)
|
||||
def __iter__(self):
|
||||
for i in range(len(self.logreader_identifiers)):
|
||||
yield from self._get_lr(i)
|
||||
|
||||
def _run_on_segment(self, func, i):
|
||||
return func(self._get_lr(i))
|
||||
|
||||
def run_across_segments(self, num_processes, func):
|
||||
with multiprocessing.Pool(num_processes) as pool:
|
||||
ret = []
|
||||
for p in pool.map(partial(self._run_on_segment, func), self.logreader_identifiers):
|
||||
for p in pool.map(partial(self._run_on_segment, func), range(len(self.logreader_identifiers))):
|
||||
ret.extend(p)
|
||||
return ret
|
||||
|
||||
|
||||
@@ -3,8 +3,11 @@ import tempfile
|
||||
import numpy as np
|
||||
import unittest
|
||||
import pytest
|
||||
from parameterized import parameterized
|
||||
import requests
|
||||
|
||||
from parameterized import parameterized
|
||||
from unittest import mock
|
||||
|
||||
from openpilot.tools.lib.logreader import LogReader, parse_indirect, parse_slice, ReadMode
|
||||
from openpilot.tools.lib.route import SegmentRange
|
||||
|
||||
@@ -104,11 +107,15 @@ class TestLogReader(unittest.TestCase):
|
||||
self.assertEqual(qlog_len*2, qlog_len_2)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_multiple_iterations(self):
|
||||
@mock.patch("openpilot.tools.lib.logreader._LogFileReader")
|
||||
def test_multiple_iterations(self, init_mock):
|
||||
lr = LogReader(f"{TEST_ROUTE}/0/q")
|
||||
qlog_len1 = len(list(lr))
|
||||
qlog_len2 = len(list(lr))
|
||||
|
||||
# ensure we don't create multiple instances of _LogFileReader, which means downloading the files twice
|
||||
self.assertEqual(init_mock.call_count, 1)
|
||||
|
||||
self.assertEqual(qlog_len1, qlog_len2)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user