mirror of
https://github.com/dragonpilot/dragonpilot.git
synced 2026-06-20 13:32:04 +08:00
segmentrangereader: support direct parsing (#30973)
* use correct source * revert * cleanup imports * clean * direct parsing * rename * move up * fixes * fix that * better error message
This commit is contained in:
+1
-20
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
import re
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
from urllib.parse import urlparse
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from typing import Optional
|
||||
@@ -231,27 +231,8 @@ class SegmentName:
|
||||
def __str__(self) -> str: return self._canonical_name
|
||||
|
||||
|
||||
def parse_useradmin(segment_range):
|
||||
if "useradmin.comma.ai" in segment_range:
|
||||
query = parse_qs(urlparse(segment_range).query)
|
||||
return query["onebox"][0]
|
||||
return segment_range
|
||||
|
||||
def parse_cabana(segment_range):
|
||||
if "cabana.comma.ai" in segment_range:
|
||||
query = parse_qs(urlparse(segment_range).query)
|
||||
return query["route"][0]
|
||||
return segment_range
|
||||
|
||||
def parse_cd(segment_range):
|
||||
return segment_range.replace("cd:/", "")
|
||||
|
||||
class SegmentRange:
|
||||
def __init__(self, segment_range: str):
|
||||
segment_range = parse_useradmin(segment_range)
|
||||
segment_range = parse_cabana(segment_range)
|
||||
segment_range = parse_cd(segment_range)
|
||||
|
||||
self.m = re.fullmatch(RE.SEGMENT_RANGE, segment_range)
|
||||
assert self.m, f"Segment range is not valid {segment_range}"
|
||||
|
||||
|
||||
+62
-5
@@ -1,6 +1,9 @@
|
||||
import enum
|
||||
import re
|
||||
import numpy as np
|
||||
import pathlib
|
||||
import re
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from openpilot.selfdrive.test.openpilotci import get_url
|
||||
from openpilot.tools.lib.helpers import RE
|
||||
from openpilot.tools.lib.logreader import LogReader
|
||||
@@ -37,6 +40,10 @@ def comma_api_source(sr: SegmentRange, mode=ReadMode.RLOG, sort_by_time=False):
|
||||
|
||||
log_paths = route.log_paths() if mode == ReadMode.RLOG else route.qlog_paths()
|
||||
|
||||
invalid_segs = [seg for seg in segs if log_paths[seg] is None]
|
||||
|
||||
assert not len(invalid_segs), f"Some of the requested segments are not available: {invalid_segs}"
|
||||
|
||||
for seg in segs:
|
||||
yield LogReader(log_paths[seg], sort_by_time=sort_by_time)
|
||||
|
||||
@@ -52,6 +59,9 @@ def openpilotci_source(sr: SegmentRange, mode=ReadMode.RLOG, sort_by_time=False)
|
||||
for seg in segs:
|
||||
yield LogReader(get_url(sr.route_name, seg, 'rlog' if mode == ReadMode.RLOG else 'qlog'), sort_by_time=sort_by_time)
|
||||
|
||||
def direct_source(file_or_url, sort_by_time):
|
||||
yield LogReader(file_or_url, sort_by_time=sort_by_time)
|
||||
|
||||
def auto_source(*args, **kwargs):
|
||||
# Automatically determine viable source
|
||||
|
||||
@@ -69,14 +79,61 @@ def auto_source(*args, **kwargs):
|
||||
|
||||
return comma_api_source(*args, **kwargs)
|
||||
|
||||
def parse_useradmin(identifier):
|
||||
if "useradmin.comma.ai" in identifier:
|
||||
query = parse_qs(urlparse(identifier).query)
|
||||
return query["onebox"][0]
|
||||
return None
|
||||
|
||||
def parse_cabana(identifier):
|
||||
if "cabana.comma.ai" in identifier:
|
||||
query = parse_qs(urlparse(identifier).query)
|
||||
return query["route"][0]
|
||||
return None
|
||||
|
||||
def parse_cd(identifier):
|
||||
if "cd:/" in identifier:
|
||||
return identifier.replace("cd:/", "")
|
||||
return None
|
||||
|
||||
def parse_direct(identifier):
|
||||
if "https://" in identifier or "http://" in identifier or pathlib.Path(identifier).exists():
|
||||
return identifier
|
||||
return None
|
||||
|
||||
def parse_indirect(identifier):
|
||||
parsed = parse_useradmin(identifier) or parse_cabana(identifier)
|
||||
|
||||
if parsed is not None:
|
||||
return parsed, comma_api_source, True
|
||||
|
||||
parsed = parse_cd(identifier)
|
||||
if parsed is not None:
|
||||
return parsed, internal_source, True
|
||||
|
||||
return identifier, None, False
|
||||
|
||||
class SegmentRangeReader:
|
||||
def __init__(self, segment_range: str, default_mode=ReadMode.RLOG, default_source=auto_source, sort_by_time=False):
|
||||
sr = SegmentRange(segment_range)
|
||||
def _logreaders_from_identifier(self, identifier):
|
||||
parsed, source, is_indirect = parse_indirect(identifier)
|
||||
|
||||
mode = default_mode if sr.selector is None else ReadMode(sr.selector)
|
||||
if not is_indirect:
|
||||
direct_parsed = parse_direct(identifier)
|
||||
if direct_parsed is not None:
|
||||
return direct_source(identifier, sort_by_time=self.sort_by_time)
|
||||
|
||||
self.lrs = default_source(sr, mode, sort_by_time)
|
||||
sr = SegmentRange(parsed)
|
||||
mode = self.default_mode if sr.selector is None else ReadMode(sr.selector)
|
||||
source = self.default_source if source is None else source
|
||||
|
||||
return source(sr, mode, sort_by_time=self.sort_by_time)
|
||||
|
||||
def __init__(self, identifier: str, default_mode=ReadMode.RLOG, default_source=auto_source, sort_by_time=False):
|
||||
self.default_mode = default_mode
|
||||
self.default_source = default_source
|
||||
self.sort_by_time = sort_by_time
|
||||
|
||||
self.lrs = self._logreaders_from_identifier(identifier)
|
||||
|
||||
def __iter__(self):
|
||||
for lr in self.lrs:
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
import shutil
|
||||
import tempfile
|
||||
import numpy as np
|
||||
import unittest
|
||||
from parameterized import parameterized
|
||||
import requests
|
||||
|
||||
from openpilot.tools.lib.route import SegmentRange
|
||||
from openpilot.tools.lib.srreader import ReadMode, SegmentRangeReader, parse_slice
|
||||
from openpilot.tools.lib.srreader import ReadMode, SegmentRangeReader, parse_slice, parse_indirect
|
||||
|
||||
NUM_SEGS = 17 # number of segments in the test route
|
||||
ALL_SEGS = list(np.arange(NUM_SEGS))
|
||||
TEST_ROUTE = "344c5c15b34f2d8a/2024-01-03--09-37-12"
|
||||
QLOG_FILE = "https://commadataci.blob.core.windows.net/openpilotci/0375fdf7b1ce594d/2019-06-13--08-32-25/3/qlog.bz2"
|
||||
|
||||
class TestSegmentRangeReader(unittest.TestCase):
|
||||
@parameterized.expand([
|
||||
@@ -36,11 +40,23 @@ class TestSegmentRangeReader(unittest.TestCase):
|
||||
(f"https://cabana.comma.ai/?route={TEST_ROUTE}", ALL_SEGS),
|
||||
(f"cd:/{TEST_ROUTE}", ALL_SEGS),
|
||||
])
|
||||
def test_parse_slice(self, segment_range, expected):
|
||||
sr = SegmentRange(segment_range)
|
||||
def test_indirect_parsing(self, identifier, expected):
|
||||
parsed, _, _ = parse_indirect(identifier)
|
||||
sr = SegmentRange(parsed)
|
||||
segs = parse_slice(sr)
|
||||
self.assertListEqual(list(segs), expected)
|
||||
|
||||
def test_direct_parsing(self):
|
||||
qlog = tempfile.NamedTemporaryFile(mode='wb', delete=False)
|
||||
|
||||
with requests.get(QLOG_FILE, stream=True) as r:
|
||||
with qlog as f:
|
||||
shutil.copyfileobj(r.raw, f)
|
||||
|
||||
for f in [QLOG_FILE, qlog.name]:
|
||||
l = len(list(SegmentRangeReader(f)))
|
||||
self.assertGreater(l, 100)
|
||||
|
||||
@parameterized.expand([
|
||||
(f"{TEST_ROUTE}///",),
|
||||
(f"{TEST_ROUTE}---",),
|
||||
|
||||
Reference in New Issue
Block a user