mirror of
https://github.com/sunnypilot/sunnypilot.git
synced 2026-06-09 01:25:11 +08:00
Compare commits
12 Commits
henc
...
refactor-m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6682be7bf9 | ||
|
|
830ac999e6 | ||
|
|
bbdccb019f | ||
|
|
90240eb1bc | ||
|
|
cf42e21617 | ||
|
|
cea6763391 | ||
|
|
5a22bd9b16 | ||
|
|
357c3afe34 | ||
|
|
61e28c3c89 | ||
|
|
4a124fb615 | ||
|
|
be4028d8a9 | ||
|
|
428ed5b9b6 |
@@ -28,3 +28,30 @@ for pathdef, fn in {'TRANSFORM': 'transforms/transform.cl', 'LOADYUV': 'transfor
|
||||
cython_libs = envCython["LIBS"] + libs
|
||||
commonmodel_lib = lenv.Library('commonmodel', common_src)
|
||||
lenvCython.Program('models/commonmodel_pyx.so', 'models/commonmodel_pyx.pyx', LIBS=[commonmodel_lib, *cython_libs], FRAMEWORKS=frameworks)
|
||||
tinygrad_files = ["#"+x for x in glob.glob(env.Dir("#tinygrad_repo").relpath + "/**", recursive=True, root_dir=env.Dir("#").abspath) if 'pycache' not in x]
|
||||
|
||||
# Get model metadata
|
||||
for model_name in ['supercombo', 'driving_vision', 'driving_policy']:
|
||||
fn = File(f"models/{model_name}").abspath
|
||||
if File(f"models/{model_name}.onnx").exists():
|
||||
script_files = [File(Dir("#sunnypilot/modeld_v2").File("get_model_metadata.py").abspath)]
|
||||
cmd = f'python3 {Dir("#sunnypilot/modeld_v2").abspath}/get_model_metadata.py {fn}.onnx'
|
||||
lenv.Command(fn + "_metadata.pkl", [fn + ".onnx"] + tinygrad_files + script_files, cmd)
|
||||
|
||||
def tg_compile(flags, model_name):
|
||||
pythonpath_string = 'PYTHONPATH="${PYTHONPATH}:' + env.Dir("#tinygrad_repo").abspath + '"'
|
||||
fn = File(f"models/{model_name}").abspath
|
||||
return lenv.Command(
|
||||
fn + "_tinygrad.pkl",
|
||||
[fn + ".onnx"] + tinygrad_files,
|
||||
f'{pythonpath_string} {flags} python3 {Dir("#tinygrad_repo").abspath}/examples/openpilot/compile3.py {fn}.onnx {fn}_tinygrad.pkl'
|
||||
)
|
||||
|
||||
# Compile small models
|
||||
for model_name in ['supercombo', 'driving_vision', 'driving_policy']:
|
||||
if File(f"models/{model_name}.onnx").exists():
|
||||
flags = {
|
||||
'larch64': 'DEV=QCOM',
|
||||
'Darwin': 'DEV=CPU IMAGE=0',
|
||||
}.get(arch, 'DEV=LLVM IMAGE=0')
|
||||
tg_compile(flags, model_name)
|
||||
|
||||
@@ -82,3 +82,32 @@ class Meta:
|
||||
BRAKE_PRESS = slice(32, 55, 4)
|
||||
LEFT_BLINKER = slice(33, 55, 4)
|
||||
RIGHT_BLINKER = slice(34, 55, 4)
|
||||
|
||||
class MetaTombRaider:
|
||||
ENGAGED = slice(0, 1)
|
||||
# next 2, 4, 6, 8, 10 seconds
|
||||
GAS_DISENGAGE = slice(1, 41, 8)
|
||||
BRAKE_DISENGAGE = slice(2, 41, 8)
|
||||
STEER_OVERRIDE = slice(3, 41, 8)
|
||||
HARD_BRAKE_3 = slice(4, 41, 8)
|
||||
HARD_BRAKE_4 = slice(5, 41, 8)
|
||||
HARD_BRAKE_5 = slice(6, 41, 8)
|
||||
GAS_PRESS = slice(7, 41, 8)
|
||||
BRAKE_PRESS = slice(8, 41, 8)
|
||||
# next 0, 2, 4, 6, 8, 10 seconds
|
||||
LEFT_BLINKER = slice(41, 53, 2)
|
||||
RIGHT_BLINKER = slice(42, 53, 2)
|
||||
|
||||
class MetaSimPose:
|
||||
ENGAGED = slice(0, 1)
|
||||
# next 2, 4, 6, 8, 10 seconds
|
||||
GAS_DISENGAGE = slice(1, 36, 7)
|
||||
BRAKE_DISENGAGE = slice(2, 36, 7)
|
||||
STEER_OVERRIDE = slice(3, 36, 7)
|
||||
HARD_BRAKE_3 = slice(4, 36, 7)
|
||||
HARD_BRAKE_4 = slice(5, 36, 7)
|
||||
HARD_BRAKE_5 = slice(6, 36, 7)
|
||||
GAS_PRESS = slice(7, 36, 7)
|
||||
# next 0, 2, 4, 6, 8, 10 seconds
|
||||
LEFT_BLINKER = slice(36, 48, 2)
|
||||
RIGHT_BLINKER = slice(37, 48, 2)
|
||||
|
||||
@@ -4,6 +4,7 @@ import numpy as np
|
||||
from cereal import log
|
||||
from openpilot.sunnypilot.modeld_v2.constants import ModelConstants, Plan
|
||||
from openpilot.selfdrive.controls.lib.drive_helpers import get_curvature_from_plan
|
||||
from openpilot.sunnypilot.selfdrive.controls.lib.drive_helpers import CONTROL_N, get_lag_adjusted_curvature, MIN_SPEED
|
||||
|
||||
SEND_RAW_PRED = os.getenv('SEND_RAW_PRED')
|
||||
|
||||
@@ -12,8 +13,16 @@ ConfidenceClass = log.ModelDataV2.ConfidenceClass
|
||||
|
||||
def get_curvature_from_output(output, vego, lat_action_t, mlsim):
|
||||
if not mlsim:
|
||||
if desired_curv := output.get('desired_curvature'): # If the model outputs the desired curvature, use that directly
|
||||
return float(desired_curv[0, 0])
|
||||
if 'lat_planner_solution' in output:
|
||||
x, y, yaw, yawRate = [output['lat_planner_solution'][0, :, i].tolist() for i in range(4)]
|
||||
x_sol = np.column_stack([x, y, yaw, yawRate])
|
||||
v_ego = max(MIN_SPEED, vego)
|
||||
psis = x_sol[0:CONTROL_N, 2].tolist()
|
||||
curvatures = (x_sol[0:CONTROL_N, 3] / v_ego).tolist()
|
||||
desired_curvature = get_lag_adjusted_curvature(lat_action_t, v_ego, psis, curvatures)
|
||||
else:
|
||||
desired_curvature = float(output.get('desired_curvature')[0, 0])
|
||||
return desired_curvature
|
||||
|
||||
plan_output = output['plan'][0]
|
||||
return float(get_curvature_from_plan(plan_output[:, Plan.T_FROM_CURRENT_EULER][:, 2], plan_output[:, Plan.ORIENTATION_RATE][:, 2],
|
||||
@@ -118,14 +127,31 @@ def fill_model_msg(base_msg: capnp._DynamicStructBuilder, extended_msg: capnp._D
|
||||
# action (includes lateral planning now)
|
||||
modelV2.action = action
|
||||
|
||||
# times at X_IDXS of edges and lines aren't used
|
||||
LINE_T_IDXS: list[float] = []
|
||||
# times at X_IDXS according to model plan
|
||||
PLAN_T_IDXS = [np.nan] * ModelConstants.IDX_N
|
||||
PLAN_T_IDXS[0] = 0.0
|
||||
plan_x = net_output_data['plan'][0, :, Plan.POSITION][:, 0].tolist()
|
||||
for xidx in range(1, ModelConstants.IDX_N):
|
||||
tidx = 0
|
||||
# increment tidx until we find an element that's further away than the current xidx
|
||||
while tidx < ModelConstants.IDX_N - 1 and plan_x[tidx + 1] < ModelConstants.X_IDXS[xidx]:
|
||||
tidx += 1
|
||||
if tidx == ModelConstants.IDX_N - 1:
|
||||
# if the Plan doesn't extend far enough, set plan_t to the max value (10s), then break
|
||||
PLAN_T_IDXS[xidx] = ModelConstants.T_IDXS[ModelConstants.IDX_N - 1]
|
||||
break
|
||||
# interpolate to find `t` for the current xidx
|
||||
current_x_val = plan_x[tidx]
|
||||
next_x_val = plan_x[tidx + 1]
|
||||
p = (ModelConstants.X_IDXS[xidx] - current_x_val) / (next_x_val - current_x_val) if abs(
|
||||
next_x_val - current_x_val) > 1e-9 else float('nan')
|
||||
PLAN_T_IDXS[xidx] = p * ModelConstants.T_IDXS[tidx + 1] + (1 - p) * ModelConstants.T_IDXS[tidx]
|
||||
|
||||
# lane lines
|
||||
modelV2.init('laneLines', 4)
|
||||
for i in range(4):
|
||||
lane_line = modelV2.laneLines[i]
|
||||
fill_xyzt(lane_line, LINE_T_IDXS, np.array(ModelConstants.X_IDXS), net_output_data['lane_lines'][0,i,:,0], net_output_data['lane_lines'][0,i,:,1])
|
||||
fill_xyzt(lane_line, PLAN_T_IDXS, np.array(ModelConstants.X_IDXS), net_output_data['lane_lines'][0,i,:,0], net_output_data['lane_lines'][0,i,:,1])
|
||||
modelV2.laneLineStds = net_output_data['lane_lines_stds'][0,:,0,0].tolist()
|
||||
modelV2.laneLineProbs = net_output_data['lane_lines_prob'][0,1::2].tolist()
|
||||
|
||||
@@ -135,7 +161,7 @@ def fill_model_msg(base_msg: capnp._DynamicStructBuilder, extended_msg: capnp._D
|
||||
modelV2.init('roadEdges', 2)
|
||||
for i in range(2):
|
||||
road_edge = modelV2.roadEdges[i]
|
||||
fill_xyzt(road_edge, LINE_T_IDXS, np.array(ModelConstants.X_IDXS), net_output_data['road_edges'][0,i,:,0], net_output_data['road_edges'][0,i,:,1])
|
||||
fill_xyzt(road_edge, PLAN_T_IDXS, np.array(ModelConstants.X_IDXS), net_output_data['road_edges'][0,i,:,0], net_output_data['road_edges'][0,i,:,1])
|
||||
modelV2.roadEdgeStds = net_output_data['road_edges_stds'][0,:,0,0].tolist()
|
||||
|
||||
# leads
|
||||
|
||||
310
sunnypilot/modeld_v2/model_metadata_lookup.py
Normal file
310
sunnypilot/modeld_v2/model_metadata_lookup.py
Normal file
@@ -0,0 +1,310 @@
|
||||
"""
|
||||
Copyright (c) 2021-, Haibin Wen, sunnypilot, and a number of other contributors.
|
||||
|
||||
This file is part of sunnypilot and is licensed under the MIT License.
|
||||
See the LICENSE.md file in the root directory for more details.
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
|
||||
def parse_metadata_file(metadata_path: str, type_key: str = None) -> dict:
|
||||
if not os.path.exists(metadata_path):
|
||||
print(f"Error: File not found: {metadata_path}")
|
||||
sys.exit(1)
|
||||
with open(metadata_path, 'rb') as f:
|
||||
metadata = pickle.load(f)
|
||||
result = {
|
||||
"metadata_path": metadata_path,
|
||||
"model_checkpoint": metadata.get("model_checkpoint"),
|
||||
**{k: metadata[k] for k in ("input_shapes", "output_shapes", "output_slices") if k in metadata}
|
||||
}
|
||||
if type_key:
|
||||
result[type_key] = True
|
||||
return result
|
||||
|
||||
def add_to_model_metadata(new_data: dict, key: str) -> None:
|
||||
file_path = os.path.abspath(__file__)
|
||||
with open(file_path) as f:
|
||||
lines = f.readlines()
|
||||
# Find insertion point for new entry
|
||||
for i in range(len(lines)-1, -1, -1):
|
||||
if lines[i].strip() == '}' and any('MODEL_METADATA' in l for l in lines[:i]):
|
||||
insert_idx = i
|
||||
break
|
||||
else:
|
||||
print("Could not find where to insert new entry in MODEL_METADATA.")
|
||||
sys.exit(1)
|
||||
|
||||
def format_val(val: Any, indent: int = 2) -> str:
|
||||
if isinstance(val, dict):
|
||||
items = [f'"{k}": {format_val(v, indent+2)}' for k, v in val.items()]
|
||||
return '{\n' + ',\n'.join(' '*(indent+2) + item for item in items) + '\n' + ' '*indent + '}'
|
||||
if isinstance(val, str):
|
||||
return f'"{val}"'
|
||||
if isinstance(val, slice):
|
||||
if val.step is None:
|
||||
return f'slice({val.start}, {val.stop})'
|
||||
return f'slice({val.start}, {val.stop}, {val.step})'
|
||||
if isinstance(val, (tuple, list)):
|
||||
open_s, close_s = ('(', ')') if isinstance(val, tuple) else ('[', ']')
|
||||
return open_s + ', '.join(format_val(x, 0) for x in val) + close_s
|
||||
return repr(val)
|
||||
|
||||
entry_str = f' "{key}": {format_val(new_data, 2)},\n'
|
||||
lines.insert(insert_idx, entry_str)
|
||||
|
||||
with open(file_path, 'w') as f:
|
||||
f.writelines(lines)
|
||||
print(f"Added entry to MODEL_METADATA with key: {key}")
|
||||
|
||||
|
||||
MODEL_METADATA = {
|
||||
"supercombo_wd40": {
|
||||
"metadata_path": "/Users/james/Downloads/model-WD40 (April 09, 2024)-558/supercombo_wd40_metadata.pkl",
|
||||
"model_checkpoint": None,
|
||||
"non20hz": True,
|
||||
"input_shapes": {
|
||||
"input_imgs": (1, 12, 128, 256),
|
||||
"big_input_imgs": (1, 12, 128, 256),
|
||||
"desire": (1, 100, 8),
|
||||
"traffic_convention": (1, 2),
|
||||
"lateral_control_params": (1, 2),
|
||||
"prev_desired_curv": (1, 100, 1),
|
||||
"features_buffer": (1, 99, 512),
|
||||
},
|
||||
"output_shapes": {"outputs": (1, 6504)},
|
||||
"output_slices": {
|
||||
"plan": slice(0, 4955),
|
||||
"lane_lines": slice(4955, 5483),
|
||||
"lane_lines_prob": slice(5483, 5491),
|
||||
"road_edges": slice(5491, 5755),
|
||||
"lead": slice(5755, 5857),
|
||||
"lead_prob": slice(5857, 5860),
|
||||
"desire_state": slice(5860, 5868),
|
||||
"meta": slice(5868, 5916),
|
||||
"desire_pred": slice(5916, 5948),
|
||||
"pose": slice(5948, 5960),
|
||||
"wide_from_device_euler": slice(5960, 5966),
|
||||
"sim_pose": slice(5966, 5978),
|
||||
"road_transform": slice(5978, 5990),
|
||||
"desired_curvature": slice(5990, 5992),
|
||||
"hidden_state": slice(5992, None),
|
||||
},
|
||||
},
|
||||
"supercombo_farmville": {
|
||||
"metadata_path": "/Users/james/Downloads/model-Farmville (November 07, 2023)-557/supercombo_farmville_metadata.pkl",
|
||||
"model_checkpoint": None,
|
||||
"non20hz": True,
|
||||
"input_shapes": {
|
||||
"input_imgs": (1, 12, 128, 256),
|
||||
"big_input_imgs": (1, 12, 128, 256),
|
||||
"desire": (1, 100, 8),
|
||||
"traffic_convention": (1, 2),
|
||||
"lat_planner_state": (1, 4),
|
||||
"nav_features": (1, 256),
|
||||
"nav_instructions": (1, 150),
|
||||
"features_buffer": (1, 99, 512),
|
||||
},
|
||||
"output_shapes": {"outputs": (1, 6768)},
|
||||
"output_slices": {
|
||||
"plan": slice(0, 4955),
|
||||
"lane_lines": slice(4955, 5483),
|
||||
"lane_lines_prob": slice(5483, 5491),
|
||||
"road_edges": slice(5491, 5755),
|
||||
"lead": slice(5755, 5857),
|
||||
"lead_prob": slice(5857, 5860),
|
||||
"desire_state": slice(5860, 5868),
|
||||
"meta": slice(5868, 5916),
|
||||
"desire_pred": slice(5916, 5948),
|
||||
"pose": slice(5948, 5960),
|
||||
"wide_from_device_euler": slice(5960, 5966),
|
||||
"sim_pose": slice(5966, 5978),
|
||||
"road_transform": slice(5978, 5990),
|
||||
"lat_planner_solution": slice(5990, 6254),
|
||||
"hidden_state": slice(6254, -2),
|
||||
"pad": slice(-2, None),
|
||||
},
|
||||
},
|
||||
"driving_policy_steam_powered": {
|
||||
"metadata_path": "selfdrive/modeld/models/driving_policy_metadata.pkl",
|
||||
"model_checkpoint": "a8f96b93-bde2-4e28-a732-4df21ebba968/400",
|
||||
"split": True,
|
||||
"input_shapes": {
|
||||
"desire": (1, 25, 8),
|
||||
"traffic_convention": (1, 2),
|
||||
"features_buffer": (1, 25, 512),
|
||||
},
|
||||
"output_shapes": {"outputs": (1, 1000)},
|
||||
"output_slices": {
|
||||
"plan": slice(0, 990),
|
||||
"desire_state": slice(990, 998),
|
||||
"pad": slice(-2, None),
|
||||
},
|
||||
},
|
||||
"supercombo_op": {
|
||||
"metadata_path": "/Users/james/Downloads/model-Optimus Prime (September 21, 2023)-559/supercombo_op_metadata.pkl",
|
||||
"model_checkpoint": None,
|
||||
"non20hz": True,
|
||||
"input_shapes": {
|
||||
"input_imgs": (1, 12, 128, 256),
|
||||
"big_input_imgs": (1, 12, 128, 256),
|
||||
"desire": (1, 100, 8),
|
||||
"traffic_convention": (1, 2),
|
||||
"nav_features": (1, 256),
|
||||
"nav_instructions": (1, 150),
|
||||
"features_buffer": (1, 99, 512),
|
||||
},
|
||||
"output_shapes": {"outputs": (1, 6504)},
|
||||
"output_slices": {
|
||||
"plan": slice(0, 4955),
|
||||
"lane_lines": slice(4955, 5483),
|
||||
"lane_lines_prob": slice(5483, 5491),
|
||||
"road_edges": slice(5491, 5755),
|
||||
"lead": slice(5755, 5857),
|
||||
"lead_prob": slice(5857, 5860),
|
||||
"desire_state": slice(5860, 5868),
|
||||
"meta": slice(5868, 5916),
|
||||
"desire_pred": slice(5916, 5948),
|
||||
"pose": slice(5948, 5960),
|
||||
"wide_from_device_euler": slice(5960, 5966),
|
||||
"sim_pose": slice(5966, 5978),
|
||||
"road_transform": slice(5978, 5990),
|
||||
"hidden_state": slice(5990, -2),
|
||||
"pad": slice(-2, None),
|
||||
},
|
||||
},
|
||||
"supercombo_nd": {
|
||||
"metadata_path": "/Users/james/Downloads/model-Notre Dame (July 01, 2024)-568/supercombo_nd_metadata.pkl",
|
||||
"model_checkpoint": None,
|
||||
"non20hz": True,
|
||||
"input_shapes": {
|
||||
"input_imgs": (1, 12, 128, 256),
|
||||
"big_input_imgs": (1, 12, 128, 256),
|
||||
"desire": (1, 100, 8),
|
||||
"traffic_convention": (1, 2),
|
||||
"lateral_control_params": (1, 2),
|
||||
"prev_desired_curv": (1, 100, 1),
|
||||
"features_buffer": (1, 99, 512),
|
||||
},
|
||||
"output_shapes": {"outputs": (1, 6512)},
|
||||
"output_slices": {
|
||||
"plan": slice(0, 4955),
|
||||
"lane_lines": slice(4955, 5483),
|
||||
"lane_lines_prob": slice(5483, 5491),
|
||||
"road_edges": slice(5491, 5755),
|
||||
"lead": slice(5755, 5857),
|
||||
"lead_prob": slice(5857, 5860),
|
||||
"desire_state": slice(5860, 5868),
|
||||
"meta": slice(5868, 5921),
|
||||
"desire_pred": slice(5921, 5953),
|
||||
"pose": slice(5953, 5965),
|
||||
"wide_from_device_euler": slice(5965, 5971),
|
||||
"sim_pose": slice(5971, 5983),
|
||||
"road_transform": slice(5983, 5995),
|
||||
"desired_curvature": slice(5995, 5997),
|
||||
"hidden_state": slice(5997, -3),
|
||||
"pad": slice(-3, None),
|
||||
},
|
||||
},
|
||||
"supercombo_npr": {
|
||||
"metadata_path": "/Users/james/Downloads/supercombo_npr_metadata.pkl",
|
||||
"model_checkpoint": None,
|
||||
"input_shapes": {
|
||||
"input_imgs": (1, 12, 128, 256),
|
||||
"big_input_imgs": (1, 12, 128, 256),
|
||||
"desire": (1, 25, 8),
|
||||
"traffic_convention": (1, 2),
|
||||
"features_buffer": (1, 24, 512)
|
||||
},
|
||||
"output_shapes": {
|
||||
"outputs": (1, 6500)
|
||||
},
|
||||
"output_slices": {
|
||||
"plan": slice(0, 4955),
|
||||
"lane_lines": slice(4955, 5483),
|
||||
"lane_lines_prob": slice(5483, 5491),
|
||||
"road_edges": slice(5491, 5755),
|
||||
"lead": slice(5755, 5857),
|
||||
"lead_prob": slice(5857, 5860),
|
||||
"desire_state": slice(5860, 5868),
|
||||
"meta": slice(5868, 5923),
|
||||
"desire_pred": slice(5923, 5955),
|
||||
"pose": slice(5955, 5967),
|
||||
"wide_from_device_euler": slice(5967, 5973),
|
||||
"road_transform": slice(5973, 5985),
|
||||
"hidden_state": slice(5985, -3),
|
||||
"pad": slice(-3, None)
|
||||
},
|
||||
"20hz": True
|
||||
},
|
||||
"driving_policy_renamed_desire": {
|
||||
"metadata_path": "/Users/james/Downloads/model-ugh (August 27, 2025)-575/driving_policy_ugh_metadata.pkl",
|
||||
"model_checkpoint": "a8f96b93-bde2-4e28-a732-4df21ebba968/400",
|
||||
"split": True,
|
||||
"input_shapes": {
|
||||
"desire_pulse": (1, 25, 8),
|
||||
"traffic_convention": (1, 2),
|
||||
"features_buffer": (1, 25, 512)
|
||||
},
|
||||
"output_shapes": {
|
||||
"outputs": (1, 1000)
|
||||
},
|
||||
"output_slices": {
|
||||
"plan": slice(0, 990),
|
||||
"desire_state": slice(990, 998),
|
||||
"pad": slice(-2, None)
|
||||
}
|
||||
},
|
||||
"supercombo_nts": { # released in January of this year, so its not 20hz, but it is modern logic..
|
||||
"metadata_path": "/Users/james/Downloads/supercombo_nts_metadata.pkl",
|
||||
"model_checkpoint": None,
|
||||
"non20hz": True,
|
||||
"input_shapes": {
|
||||
"input_imgs": (1, 12, 128, 256),
|
||||
"big_input_imgs": (1, 12, 128, 256),
|
||||
"desire": (1, 100, 8),
|
||||
"traffic_convention": (1, 2),
|
||||
"lateral_control_params": (1, 2),
|
||||
"prev_desired_curv": (1, 100, 1),
|
||||
"features_buffer": (1, 99, 512)
|
||||
},
|
||||
"output_shapes": {
|
||||
"outputs": (1, 6512)
|
||||
},
|
||||
"output_slices": {
|
||||
"plan": slice(0, 4955),
|
||||
"lane_lines": slice(4955, 5483),
|
||||
"lane_lines_prob": slice(5483, 5491),
|
||||
"road_edges": slice(5491, 5755),
|
||||
"lead": slice(5755, 5857),
|
||||
"lead_prob": slice(5857, 5860),
|
||||
"desire_state": slice(5860, 5868),
|
||||
"meta": slice(5868, 5923),
|
||||
"desire_pred": slice(5923, 5955),
|
||||
"pose": slice(5955, 5967),
|
||||
"wide_from_device_euler": slice(5967, 5973),
|
||||
"sim_pose": slice(5973, 5985),
|
||||
"road_transform": slice(5985, 5997),
|
||||
"desired_curvature": slice(5997, 5999),
|
||||
"hidden_state": slice(5999, -1),
|
||||
"pad": slice(-1, None)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Add model metadata from .pkl file to lookup dictionary")
|
||||
parser.add_argument("metadata_file", help="Path to *_metadata.pkl file to parse")
|
||||
parser.add_argument("--type", dest="type_key", default=None,
|
||||
help="Type key to set (e.g., non20hz, split, 20hz)")
|
||||
args = parser.parse_args()
|
||||
basename = os.path.basename(args.metadata_file)
|
||||
dict_key = basename.replace("_metadata.pkl", "")
|
||||
|
||||
metadata = parse_metadata_file(args.metadata_file, args.type_key)
|
||||
add_to_model_metadata(metadata, dict_key)
|
||||
@@ -27,7 +27,7 @@ from openpilot.sunnypilot.modeld.modeld_base import ModelStateBase
|
||||
from openpilot.sunnypilot.models.helpers import get_active_bundle
|
||||
from openpilot.sunnypilot.models.runners.helpers import get_model_runner
|
||||
|
||||
PROCESS_NAME = "selfdrive.modeld.modeld"
|
||||
PROCESS_NAME = "selfdrive.modeld.modeld_tinygrad"
|
||||
|
||||
|
||||
class FrameMeta:
|
||||
@@ -40,11 +40,76 @@ class FrameMeta:
|
||||
self.frame_id, self.timestamp_sof, self.timestamp_eof = vipc.frame_id, vipc.timestamp_sof, vipc.timestamp_eof
|
||||
|
||||
|
||||
class InputQueues:
|
||||
def __init__(self, input_shapes: dict, input_dtypes: dict):
|
||||
self.input_shapes = input_shapes
|
||||
self.input_dtypes = input_dtypes
|
||||
self.buffers: dict[str, np.ndarray | None] = {}
|
||||
self.indices: dict[str, np.ndarray | None] = {}
|
||||
for key, shape in input_shapes.items():
|
||||
self._setup_buffer_for_key(key, shape, input_dtypes[key])
|
||||
|
||||
def _setup_buffer_for_key(self, key, shape, dtype):
|
||||
# Temporal input: shape is [batch, history, features]
|
||||
if len(shape) == 3 and shape[1] > 1:
|
||||
buffer_history_len = max(100, shape[1] * 4 if shape[1] < 100 else shape[1])
|
||||
self.buffers[key] = np.zeros((1, buffer_history_len, shape[2]), dtype=dtype)
|
||||
features_buffer_shape = self.input_shapes.get('features_buffer')
|
||||
if shape[1] in (24, 25) and features_buffer_shape and features_buffer_shape[1] == 24:
|
||||
step = int(-buffer_history_len / shape[1])
|
||||
self.indices[key] = np.arange(step, step * (shape[1] + 1), step)[::-1]
|
||||
elif shape[1] == 25:
|
||||
skip = buffer_history_len // shape[1]
|
||||
self.indices[key] = np.arange(buffer_history_len)[-1 - (skip * (shape[1] - 1))::skip]
|
||||
elif shape[1] == buffer_history_len:
|
||||
self.indices[key] = np.arange(buffer_history_len)
|
||||
else:
|
||||
self.indices[key] = None
|
||||
|
||||
def update_dtypes_and_shapes(self, input_dtypes: dict, input_shapes: dict) -> None:
|
||||
self.input_dtypes.update(input_dtypes)
|
||||
self.input_shapes.update(input_shapes)
|
||||
for key in input_dtypes:
|
||||
if key in self.buffers and self.buffers[key] is not None:
|
||||
shape = input_shapes[key]
|
||||
self._setup_buffer_for_key(key, shape, input_dtypes[key])
|
||||
|
||||
def enqueue(self, inputs: dict[str, np.ndarray]) -> None:
|
||||
for key, new_val in inputs.items():
|
||||
if key not in self.buffers or self.buffers[key] is None:
|
||||
continue
|
||||
if new_val.dtype != self.input_dtypes[key]:
|
||||
raise ValueError(f'Input {key} has wrong dtype {new_val.dtype}, expected {self.input_dtypes[key]}')
|
||||
buf = self.buffers[key]
|
||||
if buf is not None:
|
||||
if buf.shape[1] == new_val.shape[0]:
|
||||
buf[0, -new_val.shape[0]:] = new_val
|
||||
buf[0, :-new_val.shape[0]] = buf[0, new_val.shape[0]:]
|
||||
else:
|
||||
buf[0, :-1] = buf[0, 1:]
|
||||
buf[0, -1] = new_val
|
||||
|
||||
def get(self, *names) -> dict[str, np.ndarray]:
|
||||
result: dict[str, np.ndarray] = {}
|
||||
for key in names:
|
||||
buf = self.buffers.get(key, None)
|
||||
if buf is not None:
|
||||
out_shape = self.input_shapes.get(key)
|
||||
# Roll buffer and assign based on desire.shape[1] value
|
||||
if out_shape is not None and key.startswith('desire') and buf.shape[1] > out_shape[1]:
|
||||
skip = buf.shape[1] // out_shape[1]
|
||||
result[key] = buf.reshape((out_shape[0], out_shape[1], skip, -1)).max(axis=2)
|
||||
elif self.indices[key] is not None and buf.shape[1] > 1:
|
||||
result[key] = buf[0, self.indices[key]]
|
||||
elif out_shape is not None and buf.shape[1] >= out_shape[1]:
|
||||
result[key] = buf[0, -out_shape[1]:]
|
||||
return result
|
||||
|
||||
|
||||
class ModelState(ModelStateBase):
|
||||
frames: dict[str, DrivingModelFrame]
|
||||
inputs: dict[str, np.ndarray]
|
||||
prev_desire: np.ndarray # for tracking the rising edge of the pulse
|
||||
temporal_idxs: slice | np.ndarray
|
||||
|
||||
def __init__(self, context: CLContext):
|
||||
ModelStateBase.__init__(self)
|
||||
@@ -56,63 +121,47 @@ class ModelState(ModelStateBase):
|
||||
raise
|
||||
|
||||
model_bundle = get_active_bundle()
|
||||
self.generation = model_bundle.generation if model_bundle is not None else None
|
||||
overrides = {override.key: override.value for override in model_bundle.overrides}
|
||||
self.generation = model_bundle.generation if model_bundle else None
|
||||
overrides = {override.key: override.value for override in model_bundle.overrides} if model_bundle else {}
|
||||
|
||||
self.LAT_SMOOTH_SECONDS = float(overrides.get('lat', ".0"))
|
||||
self.LONG_SMOOTH_SECONDS = float(overrides.get('long', ".0"))
|
||||
self.MIN_LAT_CONTROL_SPEED = 0.3
|
||||
|
||||
buffer_length = 5 if self.model_runner.is_20hz else 2
|
||||
buffer_length = 4 if self.model_runner.is_20hz else 2
|
||||
self.frames = {name: DrivingModelFrame(context, buffer_length) for name in self.model_runner.vision_input_names}
|
||||
self.prev_desire = np.zeros(self.constants.DESIRE_LEN, dtype=np.float32)
|
||||
|
||||
# img buffers are managed in openCL transform code
|
||||
self.numpy_inputs = {}
|
||||
self.temporal_buffers = {}
|
||||
self.temporal_idxs_map = {}
|
||||
input_dtypes = dict.fromkeys(self.model_runner.input_shapes, np.float32)
|
||||
self.numpy_inputs = {k: np.zeros(shape, dtype=input_dtypes[k]) for k, shape in self.model_runner.input_shapes.items() if k not in self.frames}
|
||||
|
||||
for key, shape in self.model_runner.input_shapes.items():
|
||||
if key not in self.frames: # Managed by opencl
|
||||
self.numpy_inputs[key] = np.zeros(shape, dtype=np.float32)
|
||||
# Temporal input: shape is [batch, history, features]
|
||||
if len(shape) == 3 and shape[1] > 1:
|
||||
buffer_history_len = max(100, (shape[1] * 4 if shape[1] < 100 else shape[1])) # Allow for higher history buffers in the future
|
||||
feature_len = shape[2]
|
||||
self.temporal_buffers[key] = np.zeros((1, buffer_history_len, feature_len), dtype=np.float32)
|
||||
features_buffer_shape = self.model_runner.input_shapes.get('features_buffer')
|
||||
if shape[1] in (24, 25) and features_buffer_shape is not None and features_buffer_shape[1] == 24: # 20Hz
|
||||
step = int(-buffer_history_len / shape[1])
|
||||
self.temporal_idxs_map[key] = np.arange(step, step * (shape[1] + 1), step)[::-1]
|
||||
elif shape[1] == 25: # Split
|
||||
skip = buffer_history_len // shape[1]
|
||||
self.temporal_idxs_map[key] = np.arange(buffer_history_len)[-1 - (skip * (shape[1] - 1))::skip]
|
||||
elif shape[1] == buffer_history_len: # non20hz
|
||||
self.temporal_idxs_map[key] = np.arange(buffer_history_len)
|
||||
temporal_inputs = {k: v for k, v in self.model_runner.input_shapes.items() if len(v) == 3 and v[1] > 1}
|
||||
self.input_queues = InputQueues(temporal_inputs, dict.fromkeys(temporal_inputs, np.float32))
|
||||
self.prev_desire = np.zeros(self.numpy_inputs[self.desire_key].shape[2], dtype=np.float32)
|
||||
|
||||
@property
|
||||
def mlsim(self) -> bool:
|
||||
return bool(self.generation is not None and self.generation >= 11)
|
||||
|
||||
@property
|
||||
def desire_key(self) -> str:
|
||||
return next(key for key in self.numpy_inputs if key.startswith('desire'))
|
||||
|
||||
def run(self, bufs: dict[str, VisionBuf], transforms: dict[str, np.ndarray],
|
||||
inputs: dict[str, np.ndarray], prepare_only: bool) -> dict[str, np.ndarray] | None:
|
||||
# Model decides when action is completed, so desire input is just a pulse triggered on rising edge
|
||||
inputs['desire'][0] = 0
|
||||
new_desire = np.where(inputs['desire'] - self.prev_desire > .99, inputs['desire'], 0)
|
||||
self.prev_desire[:] = inputs['desire']
|
||||
self.temporal_buffers['desire'][0,:-1] = self.temporal_buffers['desire'][0,1:]
|
||||
self.temporal_buffers['desire'][0,-1] = new_desire
|
||||
inputs[self.desire_key][0] = 0
|
||||
new_desire = np.where(inputs[self.desire_key] - self.prev_desire > .99, inputs[self.desire_key], 0)
|
||||
self.prev_desire[:] = inputs[self.desire_key]
|
||||
|
||||
# Roll buffer and assign based on desire.shape[1] value
|
||||
if self.temporal_buffers['desire'].shape[1] > self.numpy_inputs['desire'].shape[1]:
|
||||
skip = self.temporal_buffers['desire'].shape[1] // self.numpy_inputs['desire'].shape[1]
|
||||
self.numpy_inputs['desire'][:] = (
|
||||
self.temporal_buffers['desire'][0].reshape(self.numpy_inputs['desire'].shape[0], self.numpy_inputs['desire'].shape[1], skip, -1).max(axis=2))
|
||||
else:
|
||||
self.numpy_inputs['desire'][:] = self.temporal_buffers['desire'][0, self.temporal_idxs_map['desire']]
|
||||
batch_inputs = {key: (new_desire if key == self.desire_key else inputs[key])
|
||||
for key in self.input_queues.buffers
|
||||
if not (key == 'features_buffer' and 'hidden_state' in self.numpy_inputs) and (key == self.desire_key or key in inputs)}
|
||||
self.input_queues.enqueue(batch_inputs)
|
||||
|
||||
for key in self.numpy_inputs:
|
||||
if key in inputs and key not in ['desire']:
|
||||
if key in self.input_queues.buffers:
|
||||
self.numpy_inputs[key][:] = self.input_queues.get(key)[key]
|
||||
elif key in inputs:
|
||||
self.numpy_inputs[key][:] = inputs[key]
|
||||
|
||||
imgs_cl = {name: self.frames[name].prepare(bufs[name], transforms[name].flatten()) for name in self.model_runner.vision_input_names}
|
||||
@@ -126,27 +175,27 @@ class ModelState(ModelStateBase):
|
||||
# Run model inference
|
||||
outputs = self.model_runner.run_model()
|
||||
|
||||
# Update features_buffer
|
||||
self.temporal_buffers['features_buffer'][0, :-1] = self.temporal_buffers['features_buffer'][0, 1:]
|
||||
self.temporal_buffers['features_buffer'][0, -1] = outputs['hidden_state'][0, :]
|
||||
self.numpy_inputs['features_buffer'][:] = self.temporal_buffers['features_buffer'][0, self.temporal_idxs_map['features_buffer']]
|
||||
if "lat_planner_solution" in outputs and "lat_planner_state" in self.numpy_inputs:
|
||||
idx_n = outputs['lat_planner_solution'].shape[1]
|
||||
t_idxs = [10.0 * ((i / (idx_n - 1))**2) for i in range(idx_n)]
|
||||
self.numpy_inputs['lat_planner_state'][2] = np.interp(DT_MDL, t_idxs, outputs['lat_planner_solution'][0, :, 2])
|
||||
self.numpy_inputs['lat_planner_state'][3] = np.interp(DT_MDL, t_idxs, outputs['lat_planner_solution'][0, :, 3])
|
||||
|
||||
# Enqueue features buffer
|
||||
self.input_queues.enqueue({'features_buffer': outputs['hidden_state'][0, :]})
|
||||
self.numpy_inputs['features_buffer'][:] = self.input_queues.get('features_buffer')['features_buffer']
|
||||
|
||||
if "desired_curvature" in outputs and "prev_desired_curv" in self.numpy_inputs:
|
||||
self.process_desired_curvature(outputs, 'prev_desired_curv')
|
||||
|
||||
if "desired_curvature" in outputs:
|
||||
input_name_prev = None
|
||||
if "prev_desired_curvs" in self.numpy_inputs.keys():
|
||||
input_name_prev = 'prev_desired_curvs'
|
||||
elif "prev_desired_curv" in self.numpy_inputs.keys():
|
||||
input_name_prev = 'prev_desired_curv'
|
||||
if input_name_prev and input_name_prev in self.temporal_buffers:
|
||||
self.process_desired_curvature(outputs, input_name_prev)
|
||||
return outputs
|
||||
|
||||
def process_desired_curvature(self, outputs, input_name_prev):
|
||||
self.temporal_buffers[input_name_prev][0,:-1] = self.temporal_buffers[input_name_prev][0,1:]
|
||||
self.temporal_buffers[input_name_prev][0,-1,:] = outputs['desired_curvature'][0, :]
|
||||
self.numpy_inputs[input_name_prev][:] = self.temporal_buffers[input_name_prev][0, self.temporal_idxs_map[input_name_prev]]
|
||||
def process_desired_curvature(self, outputs, input_name):
|
||||
self.input_queues.enqueue({input_name: outputs['desired_curvature'][0, :]})
|
||||
self.numpy_inputs[input_name][:] = self.input_queues.get(input_name)[input_name]
|
||||
if self.mlsim:
|
||||
self.numpy_inputs[input_name_prev][:] = 0*self.temporal_buffers[input_name_prev][0, self.temporal_idxs_map[input_name_prev]]
|
||||
self.numpy_inputs[input_name][:] = 0 * self.input_queues.get(input_name)[input_name]
|
||||
|
||||
|
||||
def get_action_from_model(self, model_output: dict[str, np.ndarray], prev_action: log.ModelDataV2.Action,
|
||||
lat_action_t: float, long_action_t: float, v_ego: float) -> log.ModelDataV2.Action:
|
||||
@@ -207,19 +256,13 @@ def main(demo=False):
|
||||
|
||||
publish_state = PublishState()
|
||||
params = Params()
|
||||
|
||||
# setup filter to track dropped frames
|
||||
frame_dropped_filter = FirstOrderFilter(0., 10., 1. / model.constants.MODEL_FREQ)
|
||||
frame_id = 0
|
||||
last_vipc_frame_id = 0
|
||||
run_count = 0
|
||||
frame_id = last_vipc_frame_id = run_count = 0
|
||||
|
||||
model_transform_main = np.zeros((3, 3), dtype=np.float32)
|
||||
model_transform_extra = np.zeros((3, 3), dtype=np.float32)
|
||||
model_transform_main = model_transform_extra = np.zeros((3, 3), dtype=np.float32)
|
||||
live_calib_seen = False
|
||||
buf_main, buf_extra = None, None
|
||||
meta_main = FrameMeta()
|
||||
meta_extra = FrameMeta()
|
||||
buf_main = buf_extra = None
|
||||
meta_main = meta_extra = FrameMeta()
|
||||
|
||||
|
||||
if demo:
|
||||
@@ -306,12 +349,19 @@ def main(demo=False):
|
||||
bufs = {name: buf_extra if 'big' in name else buf_main for name in model.model_runner.vision_input_names}
|
||||
transforms = {name: model_transform_extra if 'big' in name else model_transform_main for name in model.model_runner.vision_input_names}
|
||||
inputs:dict[str, np.ndarray] = {
|
||||
'desire': vec_desire,
|
||||
model.desire_key: vec_desire,
|
||||
'traffic_convention': traffic_convention,
|
||||
}
|
||||
|
||||
if "lateral_control_params" in model.numpy_inputs.keys():
|
||||
inputs['lateral_control_params'] = np.array([v_ego, lat_delay], dtype=np.float32)
|
||||
conditional_inputs = {
|
||||
"lateral_control_params": lambda v_ego=v_ego, lat_delay=lat_delay: np.array([v_ego, lat_delay], dtype=np.float32),
|
||||
"driving_style": lambda: np.array([1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], dtype=np.float32),
|
||||
"nav_features": lambda: np.zeros(model.model_runner.input_shapes.get('nav_features')[1], dtype=np.float32),
|
||||
"nav_instructions": lambda: np.zeros(model.model_runner.input_shapes.get('nav_instructions')[1], dtype=np.float32),
|
||||
}
|
||||
for key, value in conditional_inputs.items():
|
||||
if key in model.numpy_inputs:
|
||||
inputs[key] = value()
|
||||
|
||||
mt1 = time.perf_counter()
|
||||
model_output = model.run(bufs, transforms, inputs, prepare_only)
|
||||
|
||||
@@ -5,16 +5,26 @@ from typing import Any
|
||||
import openpilot.sunnypilot.models.helpers as helpers
|
||||
import openpilot.sunnypilot.models.runners.helpers as runner_helpers
|
||||
import openpilot.sunnypilot.modeld_v2.modeld as modeld_module
|
||||
from openpilot.sunnypilot.modeld_v2.model_metadata_lookup import MODEL_METADATA
|
||||
|
||||
ModelState = modeld_module.ModelState
|
||||
SHAPE_MODE_PARAMS = []
|
||||
for _, meta in MODEL_METADATA.items():
|
||||
mode = ''
|
||||
if isinstance(meta, dict):
|
||||
if meta.get('split'):
|
||||
mode = 'split'
|
||||
elif meta.get('non20hz'):
|
||||
mode = 'non20hz'
|
||||
elif meta.get('20hz'):
|
||||
mode = '20hz'
|
||||
|
||||
|
||||
# These are the shapes extracted/loaded from the model onnx
|
||||
SHAPE_MODE_PARAMS = [
|
||||
({'desire': (1, 25, 8), 'features_buffer': (1, 25, 512), 'prev_desired_curv': (1, 25, 1)}, 'split'),
|
||||
({'desire': (1, 25, 8), 'features_buffer': (1, 24, 512), 'prev_desired_curv': (1, 25, 1)}, '20hz'),
|
||||
({'desire': (1, 100, 8), 'features_buffer': (1, 99, 512), 'prev_desired_curv': (1, 100, 1)}, 'non20hz'),
|
||||
]
|
||||
input_shapes = {}
|
||||
for k, v in meta.get('input_shapes', {}).items():
|
||||
if k not in ["input_imgs", "big_input_imgs"]:
|
||||
input_shapes[k] = v
|
||||
if input_shapes:
|
||||
SHAPE_MODE_PARAMS.append((input_shapes, mode))
|
||||
|
||||
|
||||
# This creates a dummy runner, override, and bundle instance for the tests to run, without actually trying to load a physical model.
|
||||
@@ -106,8 +116,10 @@ def test_buffer_shapes_and_indices(shapes, mode, apply_patches):
|
||||
state = ModelState(None)
|
||||
constants = DummyModelRunner(shapes).constants
|
||||
for key in shapes:
|
||||
buf = state.temporal_buffers.get(key, None)
|
||||
idxs = state.temporal_idxs_map.get(key, None)
|
||||
buf = state.input_queues.buffers.get(key, None)
|
||||
idxs = state.input_queues.indices.get(key, None)
|
||||
if buf is None:
|
||||
continue # not all shapes are 3D, and the non-3D ones are not buffered
|
||||
# Buffer shape logic
|
||||
if mode == 'split':
|
||||
expected_shape = (1, constants.FULL_HISTORY_BUFFER_LEN, shapes[key][2])
|
||||
@@ -130,10 +142,10 @@ def test_buffer_shapes_and_indices(shapes, mode, apply_patches):
|
||||
assert idxs is None or idxs.size == 0, f"{key}: buffer idxs should be None or empty"
|
||||
|
||||
|
||||
def legacy_buffer_update(buf, new_val, mode, key, constants, idxs):
|
||||
def legacy_buffer_update(buf, new_val, mode, key, constants, idxs, input_shape, prev_desire=None):
|
||||
# This is what we compare the new dynamic logic to, to ensure it does the same thing
|
||||
if mode == 'split':
|
||||
if key == 'desire':
|
||||
if key == 'desire' or key.startswith('desire'):
|
||||
buf[0,:-1] = buf[0,1:]
|
||||
buf[0,-1] = new_val
|
||||
return buf.reshape((1, constants.INPUT_HISTORY_BUFFER_LEN, constants.TEMPORAL_SKIP, -1)).max(axis=2)
|
||||
@@ -173,15 +185,23 @@ def legacy_buffer_update(buf, new_val, mode, key, constants, idxs):
|
||||
return legacy_buf[idxs]
|
||||
elif mode == 'non20hz':
|
||||
if key == 'desire':
|
||||
length = new_val.shape[0]
|
||||
buf[0,:-1,:length] = buf[0,1:,:length]
|
||||
buf[0,-1,:length] = new_val[:length]
|
||||
desire_len = constants.DESIRE_LEN
|
||||
if prev_desire is None:
|
||||
prev_desire = np.zeros(desire_len, dtype=np.float32)
|
||||
# Set first element to zero
|
||||
new_val = new_val.copy()
|
||||
new_val[0] = 0
|
||||
# Shift buffer by desire len
|
||||
buf[0][:-desire_len] = buf[0][desire_len:]
|
||||
# Only insert new desire if rising edge
|
||||
buf[0][-desire_len:] = np.where(new_val - prev_desire > 0.99, new_val, 0)
|
||||
prev_desire[:] = new_val
|
||||
return buf[0]
|
||||
elif key == 'features_buffer':
|
||||
feature_len = new_val.shape[0]
|
||||
buf[0,:-1,:feature_len] = buf[0,1:,:feature_len]
|
||||
buf[0,-1,:feature_len] = new_val[:feature_len]
|
||||
return buf[0]
|
||||
feature_len = constants.FEATURE_LEN
|
||||
buf[0, :-feature_len] = buf[0, feature_len:]
|
||||
buf[0, -feature_len:] = new_val
|
||||
return buf[0, -input_shape[1]:]
|
||||
elif key == 'prev_desired_curv':
|
||||
length = new_val.shape[0]
|
||||
buf[0,:-length,0] = buf[0,length:,0]
|
||||
@@ -191,32 +211,18 @@ def legacy_buffer_update(buf, new_val, mode, key, constants, idxs):
|
||||
|
||||
|
||||
def dynamic_buffer_update(state, key, new_val, mode):
|
||||
if key == 'desire':
|
||||
state.temporal_buffers['desire'][0,:-1] = state.temporal_buffers['desire'][0,1:]
|
||||
state.temporal_buffers['desire'][0,-1] = new_val
|
||||
if state.temporal_buffers['desire'].shape[1] > state.numpy_inputs['desire'].shape[1]:
|
||||
skip = state.temporal_buffers['desire'].shape[1] // state.numpy_inputs['desire'].shape[1]
|
||||
return state.temporal_buffers['desire'][0].reshape(
|
||||
state.numpy_inputs['desire'].shape[0], state.numpy_inputs['desire'].shape[1], skip, -1
|
||||
).max(axis=2)
|
||||
else:
|
||||
return state.temporal_buffers['desire'][0, state.temporal_idxs_map['desire']]
|
||||
|
||||
inputs = {'desire': np.zeros((1, state.constants.DESIRE_LEN), dtype=np.float32)}
|
||||
for k, tb in state.temporal_buffers.items():
|
||||
if k in state.temporal_idxs_map:
|
||||
continue
|
||||
buf_len = tb.shape[1]
|
||||
if k in state.numpy_inputs:
|
||||
out_len = state.numpy_inputs[k].shape[1]
|
||||
if out_len <= buf_len:
|
||||
state.temporal_idxs_map[k] = np.arange(buf_len)[-out_len:]
|
||||
else:
|
||||
state.temporal_idxs_map[k] = np.arange(buf_len)
|
||||
else:
|
||||
state.temporal_idxs_map[k] = np.arange(buf_len)
|
||||
if key == 'desire' or key.startswith('desire'):
|
||||
inputs = {k: np.zeros(v[2], dtype=np.float32) if len(v) == 3 else np.zeros(v[1], dtype=np.float32)
|
||||
for k, v in state.model_runner.input_shapes.items() if k != key}
|
||||
inputs[key] = new_val.copy()
|
||||
# ModelState.run expects desire as a pulse, so we zero the first element.
|
||||
inputs[key][0] = 0
|
||||
state.run({}, {}, inputs, prepare_only=False)
|
||||
return state.numpy_inputs[key]
|
||||
|
||||
if key == 'features_buffer':
|
||||
inputs = {k: np.zeros(v[2], dtype=np.float32) if len(v) == 3 else np.zeros(v[1], dtype=np.float32)
|
||||
for k, v in state.model_runner.input_shapes.items() if k != 'features_buffer'}
|
||||
def run_model_stub():
|
||||
return {
|
||||
'hidden_state': np.asarray(new_val, dtype=np.float32).reshape(1, -1),
|
||||
@@ -226,6 +232,8 @@ def dynamic_buffer_update(state, key, new_val, mode):
|
||||
return state.numpy_inputs['features_buffer'][0]
|
||||
|
||||
if key == 'prev_desired_curv':
|
||||
inputs = {k: np.zeros(v[2], dtype=np.float32) if len(v) == 3 else np.zeros(v[1], dtype=np.float32)
|
||||
for k, v in state.model_runner.input_shapes.items() if k != 'prev_desired_curv'}
|
||||
def run_model_stub():
|
||||
return {
|
||||
'hidden_state': np.zeros((1, state.constants.FEATURE_LEN), dtype=np.float32),
|
||||
@@ -241,16 +249,27 @@ def dynamic_buffer_update(state, key, new_val, mode):
|
||||
@pytest.mark.parametrize("key", ["desire", "features_buffer", "prev_desired_curv"])
|
||||
def test_buffer_update_equivalence(shapes, mode, key, apply_patches):
|
||||
state = ModelState(None)
|
||||
if key == "desire":
|
||||
desire_keys = [k for k in shapes.keys() if k.startswith('desire')]
|
||||
if desire_keys:
|
||||
actual_key = desire_keys[0] # Use the first (and likely only) desire key
|
||||
else:
|
||||
actual_key = key
|
||||
|
||||
if actual_key not in state.numpy_inputs:
|
||||
pytest.skip()
|
||||
|
||||
constants = DummyModelRunner(shapes).constants
|
||||
buf = state.temporal_buffers.get(key, None)
|
||||
idxs = state.temporal_idxs_map.get(key, None)
|
||||
input_shape = shapes[key]
|
||||
buf = state.input_queues.buffers.get(actual_key, None)
|
||||
idxs = state.input_queues.indices.get(actual_key, None)
|
||||
input_shape = shapes[actual_key]
|
||||
prev_desire = np.zeros(constants.DESIRE_LEN, dtype=np.float32) if key == 'desire' else None
|
||||
|
||||
for step in range(20): # multiple steps to ensure history is built up
|
||||
new_val = np.full((input_shape[2],), step, dtype=np.float32)
|
||||
expected = legacy_buffer_update(buf, new_val, mode, key, constants, idxs)
|
||||
actual = dynamic_buffer_update(state, key, new_val, mode)
|
||||
# Model returns the reduced numpy_inputs history, compare the last n entries so the test is checking the same slices.
|
||||
expected = legacy_buffer_update(buf, new_val, mode, actual_key, constants, idxs, input_shape, prev_desire)
|
||||
actual = dynamic_buffer_update(state, actual_key, new_val, mode)
|
||||
if expected is not None and actual is not None and expected.shape != actual.shape:
|
||||
if expected.ndim == 2 and actual.ndim == 2 and expected.shape[1] == actual.shape[1]:
|
||||
expected = expected[-actual.shape[0]:]
|
||||
assert np.allclose(actual, expected), f"{mode} {key}: dynamic buffer update does not match legacy logic"
|
||||
assert np.allclose(actual, expected), f"{mode} {actual_key}: dynamic buffer update does not match legacy logic"
|
||||
|
||||
@@ -11,16 +11,15 @@ import pickle
|
||||
CUSTOM_MODEL_PATH = Paths.model_root()
|
||||
|
||||
|
||||
# Set QCOM environment variable for TICI devices, potentially enabling hardware acceleration
|
||||
# Set device environment variable for hardware acceleration
|
||||
USBGPU = "USBGPU" in os.environ
|
||||
if USBGPU:
|
||||
os.environ['AMD'] = '1'
|
||||
os.environ['DEV'] = 'AMD'
|
||||
os.environ['AMD_IFACE'] = 'USB'
|
||||
elif TICI:
|
||||
os.environ['QCOM'] = '1'
|
||||
os.environ['DEV'] = 'QCOM'
|
||||
else:
|
||||
os.environ['LLVM'] = '1'
|
||||
os.environ['JIT'] = '2' # TODO: This may cause issues
|
||||
os.environ['DEV'] = 'LLVM'
|
||||
|
||||
|
||||
class ModelData:
|
||||
|
||||
Reference in New Issue
Block a user