Compare commits

...

10 Commits

Author SHA1 Message Date
Jason Wen 4ba620cdd2 modeld_v2: add CUSTOM→CONST shape migration to pkl compat shim
Old tinygrad used Ops.CUSTOM for scalar shape constants promoted to
vec types. New tinygrad uses Ops.CONST for the same representation.
2026-06-06 03:42:53 -04:00
Jason Wen b5daec54e2 modeld_v2: full pkl backwards compat shim
Handles three tinygrad pkl format migrations during deserialization:
- SLICE UOps with realized buffers (QCOM compilation artifact)
- PARAM int arg → ParamArg dataclass
- PERMUTE(NOOP, CONST(vec)) → RESHAPE(NOOP, STACK(CONST...)) shape repr

Also makes action_t conditional in input queues for old models.
2026-06-06 03:39:53 -04:00
Jason Wen dded76d28e Merge remote-tracking branch 'sunnypilot/sunnypilot/master' into deep-model-wee 2026-06-06 03:15:37 -04:00
Jason Wen 41c7f23a74 send it 2026-06-06 03:07:38 -04:00
Jason Wen 69749f400f modeld_v2: add pkl compat shim for SLICE buffer and PARAM migration
Monkey-patches UOpMetaClass.__call__ during pickle.loads() to handle:
- SLICE UOps with realized buffers (both old and new tinygrad pkls)
- PARAM UOps with int arg (old pkl format, pre-ParamArg)

Zero runtime overhead — shim only active during deserialization.
2026-06-06 03:04:57 -04:00
Jason Wen c7f770d29a modeld_v2: adapt for deep model support
Parse action MDN output in split parser for deep model on_policy.

Only suppress off_policy plan when on_policy has its own plan output.
Deep models have on_policy producing action (no plan), so off_policy
plan must be kept.
2026-06-05 18:36:27 -04:00
Harald Schäfer e405157bdb Op model16 deep (#38073)
* modeld: RL driving model with 3-file split

Split the driving model into vision + off_policy + on_policy ONNX
files and wire up the RL policy:

- 3-file model split (vision / off_policy / on_policy), replacing the
  combined big_driving_policy/vision models
- compiler updates for the split models
- actually consume the policy action in modeld
- add desire state to the driving model
- model iterations (smoothness, off/on-policy weight updates)

* modeld: update driving model

* 1e72cf5a-785f-45ea-888f-28cdb14785de/100

* tinygrad hack

* fix parsing

* looser timing

* big

* Remove unnecessary modeld rebase changes

* Tighten modeld split cleanup

---------

Co-authored-by: Comma Device <device@comma.ai>
Co-authored-by: Armandpl <adpl33@gmail.com>
2026-06-05 18:34:10 -04:00
Harald Schäfer bdb6c15753 Refactor compile_modeld model setup (#38128) 2026-06-05 18:33:56 -04:00
Jason Wen 798d1836b2 tinygrad: bump to synced master + IMAGE hack for deep models
Bumps tinygrad submodule to sunnypilot/tinygrad master (synced with
upstream 556defa0f) plus comma's IMAGE hack (cherry-picked from
fd992d668) needed for deep model compilation with IMAGE=1.
2026-06-05 18:33:51 -04:00
Jason Wen 6dffdcca0b modeld_v2 prereqs deep models 2026-06-05 17:37:37 -04:00
21 changed files with 167 additions and 63 deletions
+2 -1
View File
@@ -87,13 +87,14 @@ frame_skip = ModelConstants.MODEL_RUN_FREQ // ModelConstants.MODEL_CONTEXT_FREQ
for usbgpu in [False, True] if USBGPU else [False]: for usbgpu in [False, True] if USBGPU else [False]:
target_pkl_path = File(modeld_pkl_path(usbgpu)).abspath target_pkl_path = File(modeld_pkl_path(usbgpu)).abspath
file_prefix, cmd_flags = ('big_', usbgpu_tg_flags) if usbgpu else ('', tg_flags) file_prefix, cmd_flags = ('big_', usbgpu_tg_flags) if usbgpu else ('', tg_flags)
driving_onnx_deps = [p for m in [f'{file_prefix}driving_vision', f'{file_prefix}driving_on_policy'] driving_onnx_deps = [p for m in [f'{file_prefix}driving_vision', f'{file_prefix}driving_on_policy', f'{file_prefix}driving_off_policy']
for p in get_existing_chunks(File(f"models/{m}.onnx").abspath)] for p in get_existing_chunks(File(f"models/{m}.onnx").abspath)]
camera_res_args = ' '.join(f'{cw}x{ch}' for cw, ch in CAMERA_CONFIGS) camera_res_args = ' '.join(f'{cw}x{ch}' for cw, ch in CAMERA_CONFIGS)
cmd = (f'{cmd_flags} {mac_brew_string} python3 {modeld_dir}/compile_modeld.py ' cmd = (f'{cmd_flags} {mac_brew_string} python3 {modeld_dir}/compile_modeld.py '
f'--model-size {model_w}x{model_h} ' f'--model-size {model_w}x{model_h} '
f'--camera-resolutions {camera_res_args} ' f'--camera-resolutions {camera_res_args} '
f'--vision-onnx {File(f"models/{file_prefix}driving_vision.onnx").abspath} ' f'--vision-onnx {File(f"models/{file_prefix}driving_vision.onnx").abspath} '
f'--off-policy-onnx {File(f"models/{file_prefix}driving_off_policy.onnx").abspath} '
f'--on-policy-onnx {File(f"models/{file_prefix}driving_on_policy.onnx").abspath} ' f'--on-policy-onnx {File(f"models/{file_prefix}driving_on_policy.onnx").abspath} '
f'--output {target_pkl_path} --frame-skip {frame_skip}') f'--output {target_pkl_path} --frame-skip {frame_skip}')
onnx_sizes_sum = sum(os.path.getsize(f) for f in driving_onnx_deps) onnx_sizes_sum = sum(os.path.getsize(f) for f in driving_onnx_deps)
+48 -31
View File
@@ -5,7 +5,7 @@ import os
import pickle import pickle
import time import time
from functools import partial from functools import partial
from collections import namedtuple, defaultdict from collections import namedtuple
import numpy as np import numpy as np
@@ -113,31 +113,43 @@ def make_frame_prepare(nv12: NV12Frame, model_w, model_h):
return frame_prepare_tinygrad return frame_prepare_tinygrad
def make_input_queues(vision_input_shapes, policy_input_shapes, frame_skip, device): def make_warp_input_queues(vision_input_shapes, frame_skip, device):
img = vision_input_shapes['img'] # (1, 12, 128, 256) img = vision_input_shapes['img'] # (1, 12, 128, 256)
n_frames = img[1] // 6 n_frames = img[1] // 6
img_buf_shape = (frame_skip * (n_frames - 1) + 1, 6, img[2], img[3]) img_buf_shape = (frame_skip * (n_frames - 1) + 1, 6, img[2], img[3])
npy = {
'tfm': np.zeros((3, 3), dtype=np.float32),
'big_tfm': np.zeros((3, 3), dtype=np.float32),
}
input_queues = {
'img_q': Tensor(np.zeros(img_buf_shape, dtype=np.uint8), device=device).contiguous().realize(),
'big_img_q': Tensor(np.zeros(img_buf_shape, dtype=np.uint8), device=device).contiguous().realize(),
**{k: Tensor(v, device='NPY').realize() for k, v in npy.items()},
}
return input_queues, npy
def make_input_queues(vision_input_shapes, policy_input_shapes, frame_skip, device):
input_queues, npy = make_warp_input_queues(vision_input_shapes, frame_skip, device)
fb = policy_input_shapes['features_buffer'] # (1, 25, 512) fb = policy_input_shapes['features_buffer'] # (1, 25, 512)
dp = policy_input_shapes['desire_pulse'] # (1, 25, 8) dp = policy_input_shapes['desire_pulse'] # (1, 25, 8)
tc = policy_input_shapes['traffic_convention'] # (1, 2) tc = policy_input_shapes['traffic_convention'] # (1, 2)
#TODO action_t is hardcoded to match tc for future compatibility #TODO action_t is hardcoded to match tc for future compatibility
at = tc at = tc
npy = { policy_npy = {
'desire': np.zeros(dp[2], dtype=np.float32), 'desire': np.zeros(dp[2], dtype=np.float32),
'traffic_convention': np.zeros(tc, dtype=np.float32), 'traffic_convention': np.zeros(tc, dtype=np.float32),
'tfm': np.zeros((3, 3), dtype=np.float32),
'big_tfm': np.zeros((3, 3), dtype=np.float32),
'action_t': np.zeros(at, dtype=np.float32), 'action_t': np.zeros(at, dtype=np.float32),
} }
input_queues = { npy.update(policy_npy)
'img_q': Tensor(np.zeros(img_buf_shape, dtype=np.uint8), device=device).contiguous().realize(), input_queues.update({
'big_img_q': Tensor(np.zeros(img_buf_shape, dtype=np.uint8), device=device).contiguous().realize(),
'feat_q': Tensor(np.zeros((frame_skip * (fb[1] - 1) + 1, fb[0], fb[2]), dtype=np.float32), device=device).contiguous().realize(), 'feat_q': Tensor(np.zeros((frame_skip * (fb[1] - 1) + 1, fb[0], fb[2]), dtype=np.float32), device=device).contiguous().realize(),
'desire_q': Tensor(np.zeros((frame_skip * dp[1], dp[0], dp[2]), dtype=np.float32), device=device).contiguous().realize(), 'desire_q': Tensor(np.zeros((frame_skip * dp[1], dp[0], dp[2]), dtype=np.float32), device=device).contiguous().realize(),
**{k: Tensor(v, device='NPY').realize() for k, v in npy.items()}, **{k: Tensor(v, device='NPY').realize() for k, v in policy_npy.items()},
} })
return input_queues, npy return input_queues, npy
@@ -171,9 +183,10 @@ def make_warp(nv12, model_w, model_h, frame_skip):
return warp_enqueue return warp_enqueue
def make_run_policy(vision_runner, on_policy_runner, vision_features_slice, frame_skip): def make_run_policy(model_runners, model_metadata, frame_skip):
sample_desire_fn = partial(sample_desire, frame_skip=frame_skip) sample_desire_fn = partial(sample_desire, frame_skip=frame_skip)
sample_skip_fn = partial(sample_skip, frame_skip=frame_skip) sample_skip_fn = partial(sample_skip, frame_skip=frame_skip)
vision_features_slice = model_metadata['vision']['output_slices']['hidden_state']
def run_policy(img, big_img, feat_q, desire_q, desire, traffic_convention, action_t): def run_policy(img, big_img, feat_q, desire_q, desire, traffic_convention, action_t):
desire = desire.to(Device.DEFAULT) desire = desire.to(Device.DEFAULT)
@@ -181,7 +194,7 @@ def make_run_policy(vision_runner, on_policy_runner, vision_features_slice, fram
action_t = action_t.to(Device.DEFAULT) action_t = action_t.to(Device.DEFAULT)
Tensor.realize(desire, traffic_convention, action_t) Tensor.realize(desire, traffic_convention, action_t)
desire_buf = shift_and_sample(desire_q, desire.reshape(1, 1, -1), sample_desire_fn) desire_buf = shift_and_sample(desire_q, desire.reshape(1, 1, -1), sample_desire_fn)
vision_out = next(iter(vision_runner({'img': img, 'big_img': big_img}).values())).cast('float32') vision_out = next(iter(model_runners['vision']({'img': img, 'big_img': big_img}).values())).cast('float32')
new_feat = vision_out[:, vision_features_slice].reshape(1, -1).unsqueeze(0) new_feat = vision_out[:, vision_features_slice].reshape(1, -1).unsqueeze(0)
feat_buf = shift_and_sample(feat_q, new_feat, sample_skip_fn) feat_buf = shift_and_sample(feat_q, new_feat, sample_skip_fn)
@@ -192,20 +205,16 @@ def make_run_policy(vision_runner, on_policy_runner, vision_features_slice, fram
'traffic_convention': traffic_convention, 'traffic_convention': traffic_convention,
'action_t': action_t, 'action_t': action_t,
} }
on_policy_out = next(iter(on_policy_runner(inputs).values())).cast('float32') on_policy_out = next(iter(model_runners['on_policy'](inputs).values())).cast('float32')
#off_policy_out = next(iter(off_policy_runner(inputs).values())).cast('float32') off_policy_out = next(iter(model_runners['off_policy'](inputs).values())).cast('float32')
return vision_out, on_policy_out return vision_out, on_policy_out, off_policy_out
return run_policy return run_policy
def compile_jit(jit, make_random_inputs, input_keys, frame_skip, vision_metadata, policy_metadata): def compile_jit(jit, make_random_inputs, input_keys, make_queues):
vision_input_shapes = vision_metadata['input_shapes']
policy_input_shapes = policy_metadata['input_shapes']
SEED = 42 SEED = 42
def random_inputs_run(fn, seed, test_val=None, test_buffers=None, expect_match=True): def random_inputs_run(fn, seed, test_val=None, test_buffers=None, expect_match=True):
input_queues, npy = make_input_queues(vision_input_shapes, policy_input_shapes, frame_skip, Device.DEFAULT) input_queues, npy = make_queues(Device.DEFAULT)
np.random.seed(seed) np.random.seed(seed)
Tensor.manual_seed(seed) Tensor.manual_seed(seed)
@@ -269,30 +278,38 @@ if __name__ == "__main__":
p.add_argument('--camera-resolutions', type=_parse_size, nargs='+', required=True, p.add_argument('--camera-resolutions', type=_parse_size, nargs='+', required=True,
help='camera resolutions WxH (one or more)') help='camera resolutions WxH (one or more)')
p.add_argument('--vision-onnx', required=True) p.add_argument('--vision-onnx', required=True)
p.add_argument('--off-policy-onnx', required=True)
p.add_argument('--on-policy-onnx', required=True) p.add_argument('--on-policy-onnx', required=True)
p.add_argument('--output', required=True) p.add_argument('--output', required=True)
p.add_argument('--frame-skip', type=int, required=True) p.add_argument('--frame-skip', type=int, required=True)
args = p.parse_args() args = p.parse_args()
out = defaultdict(dict) model_paths = {
vision_path, on_policy_path = read_file_chunked_to_shm(args.vision_onnx), read_file_chunked_to_shm(args.on_policy_onnx) 'vision': read_file_chunked_to_shm(args.vision_onnx),
'off_policy': read_file_chunked_to_shm(args.off_policy_onnx),
'on_policy': read_file_chunked_to_shm(args.on_policy_onnx),
}
model_w, model_h = args.model_size model_w, model_h = args.model_size
vision_runner = OnnxRunner(vision_path) model_runners = {name: OnnxRunner(path) for name, path in model_paths.items()}
on_policy_runner = OnnxRunner(on_policy_path) out = {'metadata': {name: make_metadata_dict(path) for name, path in model_paths.items()}}
vision_metadata, on_policy_metadata = make_metadata_dict(vision_path), make_metadata_dict(on_policy_path)
run_policy_jit = TinyJit(make_run_policy(vision_runner, on_policy_runner, vision_metadata['output_slices']['hidden_state'], args.frame_skip), prune=True) assert out['metadata']['off_policy']['input_shapes'] == out['metadata']['on_policy']['input_shapes']
out['metadata']['vision'], out['metadata']['on_policy'] = vision_metadata, on_policy_metadata
make_random_model_inputs = partial(make_random_images, keys=['img', 'big_img'], shape=vision_metadata['input_shapes']['img']) run_policy_jit = TinyJit(make_run_policy(model_runners, out['metadata'], args.frame_skip), prune=True)
out['run_policy'] = compile_jit(run_policy_jit, make_random_model_inputs, POLICY_INPUTS, args.frame_skip, vision_metadata, on_policy_metadata)
make_policy_queues = partial(make_input_queues, out['metadata']['vision']['input_shapes'],
out['metadata']['on_policy']['input_shapes'], args.frame_skip)
make_random_model_inputs = partial(make_random_images, keys=['img', 'big_img'], shape=out['metadata']['vision']['input_shapes']['img'])
out['run_policy'] = compile_jit(run_policy_jit, make_random_model_inputs, POLICY_INPUTS,
make_policy_queues)
for cam_w, cam_h in args.camera_resolutions: for cam_w, cam_h in args.camera_resolutions:
nv12 = NV12Frame(cam_w, cam_h, *get_nv12_info(cam_w, cam_h)) nv12 = NV12Frame(cam_w, cam_h, *get_nv12_info(cam_w, cam_h))
make_random_warp_inputs = partial(make_random_images, keys=['frame', 'big_frame'], shape=nv12.size, device=WARP_DEV) make_random_warp_inputs = partial(make_random_images, keys=['frame', 'big_frame'], shape=nv12.size, device=WARP_DEV)
warp_enqueue = TinyJit(make_warp(nv12, model_w, model_h, args.frame_skip), prune=True) warp_enqueue = TinyJit(make_warp(nv12, model_w, model_h, args.frame_skip), prune=True)
out[(cam_w,cam_h)] = compile_jit(warp_enqueue, make_random_warp_inputs, WARP_INPUTS, args.frame_skip, vision_metadata, on_policy_metadata) make_warp_queues = partial(make_warp_input_queues, out['metadata']['vision']['input_shapes'], args.frame_skip)
out[(cam_w,cam_h)] = compile_jit(warp_enqueue, make_random_warp_inputs, WARP_INPUTS, make_warp_queues)
with open(args.output, "wb") as f: with open(args.output, "wb") as f:
pickle.dump(out, f) pickle.dump(out, f)
+9 -4
View File
@@ -89,6 +89,9 @@ class ModelState(ModelStateBase):
self.vision_input_names = list(self.vision_input_shapes.keys()) self.vision_input_names = list(self.vision_input_shapes.keys())
self.vision_output_slices = vision_metadata['output_slices'] self.vision_output_slices = vision_metadata['output_slices']
off_policy_metadata = jits['metadata']['off_policy']
self.off_policy_output_slices = off_policy_metadata['output_slices']
policy_metadata = jits['metadata']['on_policy'] policy_metadata = jits['metadata']['on_policy']
self.policy_input_shapes = policy_metadata['input_shapes'] self.policy_input_shapes = policy_metadata['input_shapes']
self.policy_output_slices = policy_metadata['output_slices'] self.policy_output_slices = policy_metadata['output_slices']
@@ -133,18 +136,20 @@ class ModelState(ModelStateBase):
if prepare_only: if prepare_only:
return None return None
vision_output, on_policy_output = self.run_policy( vision_output, on_policy_output, off_policy_output = self.run_policy(
**{k: self.input_queues[k] for k in POLICY_INPUTS}, img=img, big_img=big_img **{k: self.input_queues[k] for k in POLICY_INPUTS if k in self.input_queues}, img=img, big_img=big_img
) )
vision_output = vision_output.numpy().flatten() vision_output = vision_output.numpy().flatten()
off_policy_output = off_policy_output.numpy().flatten()
on_policy_output = on_policy_output.numpy().flatten() on_policy_output = on_policy_output.numpy().flatten()
vision_outputs_dict = self.parser.parse_vision_outputs(self.slice_outputs(vision_output, self.vision_output_slices)) vision_outputs_dict = self.parser.parse_vision_outputs(self.slice_outputs(vision_output, self.vision_output_slices))
off_policy_outputs_dict = self.parser.parse_off_policy_outputs(self.slice_outputs(off_policy_output, self.off_policy_output_slices))
policy_outputs_dict = self.parser.parse_policy_outputs(self.slice_outputs(on_policy_output, self.policy_output_slices)) policy_outputs_dict = self.parser.parse_policy_outputs(self.slice_outputs(on_policy_output, self.policy_output_slices))
combined_outputs_dict = {**vision_outputs_dict, **policy_outputs_dict} combined_outputs_dict = {**vision_outputs_dict, **off_policy_outputs_dict, **policy_outputs_dict}
if SEND_RAW_PRED: if SEND_RAW_PRED:
combined_outputs_dict['raw_pred'] = np.concatenate([vision_output.copy(), on_policy_output.copy()]) combined_outputs_dict['raw_pred'] = np.concatenate([vision_output.copy(), on_policy_output.copy(), off_policy_output.copy()])
return combined_outputs_dict return combined_outputs_dict
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8a26866121d1d3a1152bfce024ed7584b8569507d120d4bc8917320093dcd31a
size 41191256
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:565e53c38dcd64c50dd3fe4d5ee1530213aeefd66c3f6b67ea6a72a32612a6bf oid sha256:94b07ef7a0f65d5c41ac696b4ae7bdc59e2d4c5f504460e2b0d720620892c2e8
size 14061419 size 33679037
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:1f0cab5033fe9e3bc5e174a2e790fa277f7d9fc44c65822d734064d2f899a9a0 oid sha256:eda005282417ffa825092ece5c16b5584142044cdbcf15b6d0246136ac6db601
size 296203378 size 120584466
@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6173be8a69b1d9633a09969c80b2a8bd990bfe7d3e76e192a0e537f6fd72222b
size 41192485
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:78477124cbf3ffe30fa951ebada8410b43c4242c6054584d656f1d329b067e15 oid sha256:6b66ef783af3fa86190e85a6b4f729cd1443b20be41134aa258f9c376825a45c
size 14060847 size 33680163
+2 -2
View File
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:ee29ee5bce84d1ce23e9ff381280de9b4e4d96d2934cd751740354884e112c66 oid sha256:bbd0761201b3b161587d097f173c66bf82cd02966e5f0d1edd888c970d6f6d87
size 46877473 size 21735970
+10 -4
View File
@@ -96,11 +96,17 @@ class Parser:
self.parse_mdn('pose', outs, in_N=0, out_N=0, out_shape=(ModelConstants.POSE_WIDTH,)) self.parse_mdn('pose', outs, in_N=0, out_N=0, out_shape=(ModelConstants.POSE_WIDTH,))
self.parse_mdn('wide_from_device_euler', outs, in_N=0, out_N=0, out_shape=(ModelConstants.WIDE_FROM_DEVICE_WIDTH,)) self.parse_mdn('wide_from_device_euler', outs, in_N=0, out_N=0, out_shape=(ModelConstants.WIDE_FROM_DEVICE_WIDTH,))
self.parse_mdn('road_transform', outs, in_N=0, out_N=0, out_shape=(ModelConstants.POSE_WIDTH,)) self.parse_mdn('road_transform', outs, in_N=0, out_N=0, out_shape=(ModelConstants.POSE_WIDTH,))
self.parse_categorical_crossentropy('desire_pred', outs, out_shape=(ModelConstants.DESIRE_PRED_LEN,ModelConstants.DESIRE_PRED_WIDTH))
self.parse_binary_crossentropy('meta', outs)
self.parse_mdn('lane_lines', outs, in_N=0, out_N=0, out_shape=(ModelConstants.NUM_LANE_LINES,ModelConstants.IDX_N,ModelConstants.LANE_LINES_WIDTH)) self.parse_mdn('lane_lines', outs, in_N=0, out_N=0, out_shape=(ModelConstants.NUM_LANE_LINES,ModelConstants.IDX_N,ModelConstants.LANE_LINES_WIDTH))
self.parse_mdn('road_edges', outs, in_N=0, out_N=0, out_shape=(ModelConstants.NUM_ROAD_EDGES,ModelConstants.IDX_N,ModelConstants.LANE_LINES_WIDTH)) self.parse_mdn('road_edges', outs, in_N=0, out_N=0, out_shape=(ModelConstants.NUM_ROAD_EDGES,ModelConstants.IDX_N,ModelConstants.LANE_LINES_WIDTH))
self.parse_binary_crossentropy('lane_lines_prob', outs) self.parse_binary_crossentropy('lane_lines_prob', outs)
self.parse_categorical_crossentropy('desire_pred', outs, out_shape=(ModelConstants.DESIRE_PRED_LEN,ModelConstants.DESIRE_PRED_WIDTH)) return outs
self.parse_binary_crossentropy('meta', outs)
def parse_off_policy_outputs(self, outs: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
plan_mhp = self.is_mhp(outs, 'plan', ModelConstants.IDX_N * ModelConstants.PLAN_WIDTH)
plan_in_N, plan_out_N = (ModelConstants.PLAN_MHP_N, ModelConstants.PLAN_MHP_SELECTION) if plan_mhp else (0, 0)
self.parse_mdn('plan', outs, in_N=plan_in_N, out_N=plan_out_N, out_shape=(ModelConstants.IDX_N, ModelConstants.PLAN_WIDTH))
self.parse_binary_crossentropy('lead_prob', outs) self.parse_binary_crossentropy('lead_prob', outs)
lead_mhp = self.is_mhp(outs, 'lead', ModelConstants.LEAD_MHP_SELECTION * ModelConstants.LEAD_TRAJ_LEN * ModelConstants.LEAD_WIDTH) lead_mhp = self.is_mhp(outs, 'lead', ModelConstants.LEAD_MHP_SELECTION * ModelConstants.LEAD_TRAJ_LEN * ModelConstants.LEAD_WIDTH)
lead_in_N, lead_out_N = (ModelConstants.LEAD_MHP_N, ModelConstants.LEAD_MHP_SELECTION) if lead_mhp else (0, 0) lead_in_N, lead_out_N = (ModelConstants.LEAD_MHP_N, ModelConstants.LEAD_MHP_SELECTION) if lead_mhp else (0, 0)
@@ -110,11 +116,11 @@ class Parser:
return outs return outs
def parse_policy_outputs(self, outs: dict[str, np.ndarray]) -> dict[str, np.ndarray]: def parse_policy_outputs(self, outs: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
self.parse_mdn('plan', outs, in_N=0, out_N=0, out_shape=(ModelConstants.IDX_N, ModelConstants.PLAN_WIDTH)) self.parse_mdn('action', outs, in_N=0, out_N=0, out_shape=(ModelConstants.ACTION_WIDTH,))
self.parse_categorical_crossentropy('desire_state', outs, out_shape=(ModelConstants.DESIRE_PRED_WIDTH,))
return outs return outs
def parse_outputs(self, outs: dict[str, np.ndarray]) -> dict[str, np.ndarray]: def parse_outputs(self, outs: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
outs = self.parse_vision_outputs(outs) outs = self.parse_vision_outputs(outs)
outs = self.parse_off_policy_outputs(outs)
outs = self.parse_policy_outputs(outs) outs = self.parse_policy_outputs(outs)
return outs return outs
@@ -34,7 +34,7 @@ GITHUB = GithubUtils(API_TOKEN, DATA_TOKEN)
EXEC_TIMINGS = [ EXEC_TIMINGS = [
# model, instant max, average max # model, instant max, average max
("modelV2", 0.05, 0.028), ("modelV2", 0.05, 0.032),
("driverStateV2", 0.05, 0.018), ("driverStateV2", 0.05, 0.018),
] ]
+3 -1
View File
@@ -59,8 +59,10 @@ def make_split_input_queues(vision_input_shapes, policy_input_shapes, frame_skip
'tfm': np.zeros((3, 3), dtype=np.float32), 'tfm': np.zeros((3, 3), dtype=np.float32),
'big_tfm': np.zeros((3, 3), dtype=np.float32), 'big_tfm': np.zeros((3, 3), dtype=np.float32),
} }
if 'action_t' in policy_input_shapes:
npy['action_t'] = np.zeros(tc, dtype=np.float32)
handled = {'features_buffer', desire_key, 'traffic_convention'} handled = {'features_buffer', desire_key, 'traffic_convention', 'action_t'}
for key, shape in policy_input_shapes.items(): for key, shape in policy_input_shapes.items():
if key in handled: if key in handled:
continue continue
+1
View File
@@ -35,6 +35,7 @@ class ModelConstants:
LANE_LINES_WIDTH = 2 LANE_LINES_WIDTH = 2
ROAD_EDGES_WIDTH = 2 ROAD_EDGES_WIDTH = 2
PLAN_WIDTH = 15 PLAN_WIDTH = 15
ACTION_WIDTH = 2
DESIRE_PRED_WIDTH = 8 DESIRE_PRED_WIDTH = 8
LAT_PLANNER_SOLUTION_WIDTH = 4 LAT_PLANNER_SOLUTION_WIDTH = 4
DESIRED_CURV_WIDTH = 1 DESIRED_CURV_WIDTH = 1
+54 -5
View File
@@ -44,6 +44,34 @@ from openpilot.sunnypilot.models.helpers import get_active_bundle
PROCESS_NAME = "selfdrive.modeld.modeld_tinygrad" PROCESS_NAME = "selfdrive.modeld.modeld_tinygrad"
def _load_pkl_compat(data: bytes):
from tinygrad.uop.ops import UOpMetaClass, UOp, Ops, buffers, ParamArg
from tinygrad.dtype import dtypes
_orig_call = UOpMetaClass.__call__
def _compat_call(cls, op, dtype=dtypes.void, src=(), arg=None, tag=None, metadata=None, _buffer=None):
if _buffer is not None and op is Ops.SLICE:
created = _orig_call(cls, op, dtype, src, arg, tag, metadata)
buffers[created] = _buffer
return created
if op is Ops.PARAM and isinstance(arg, int):
arg = ParamArg(arg)
if op is Ops.PERMUTE and len(src) >= 2 and src[0].op is Ops.NOOP:
op = Ops.RESHAPE
shape_uop = src[1]
if shape_uop.op is Ops.CONST and hasattr(shape_uop.dtype, 'count') and shape_uop.dtype.count > 1 and isinstance(shape_uop.arg, tuple):
src = (src[0], UOp(Ops.STACK, shape_uop.dtype, tuple(UOp(Ops.CONST, dtypes.weakint, (), arg=v) for v in shape_uop.arg)))
if op is Ops.CUSTOM and hasattr(dtype, 'count') and dtype.count > 1 and not isinstance(arg, str):
op = Ops.CONST
return _orig_call(cls, op, dtype, src, arg, tag, metadata, _buffer if op is Ops.BUFFER else None)
UOpMetaClass.__call__ = _compat_call
try:
return pickle.loads(data)
finally:
UOpMetaClass.__call__ = _orig_call
def _pkl_exists(path): def _pkl_exists(path):
from openpilot.common.file_chunker import get_manifest_path from openpilot.common.file_chunker import get_manifest_path
return os.path.exists(path) or os.path.exists(get_manifest_path(path)) return os.path.exists(path) or os.path.exists(get_manifest_path(path))
@@ -111,7 +139,7 @@ class ModelState(ModelStateBase):
from openpilot.common.file_chunker import read_file_chunked from openpilot.common.file_chunker import read_file_chunked
cloudlog.warning(f"loading combined pkl: {pkl_path}") cloudlog.warning(f"loading combined pkl: {pkl_path}")
jits = pickle.loads(read_file_chunked(pkl_path)) jits = _load_pkl_compat(read_file_chunked(pkl_path))
self.DEV = Device.DEFAULT self.DEV = Device.DEFAULT
@@ -129,7 +157,7 @@ class ModelState(ModelStateBase):
else: else:
vision_metadata = metadata['vision'] vision_metadata = metadata['vision']
policy_keys = [k for k in metadata if k != 'vision'] policy_keys = [k for k in metadata if k != 'vision']
if policy_keys == ['policy']: if len(policy_keys) == 1 and policy_keys[0] in ('policy', 'on_policy'):
self._combined_model_type = 'split' self._combined_model_type = 'split'
else: else:
self._combined_model_type = 'multi_policy' self._combined_model_type = 'multi_policy'
@@ -138,6 +166,8 @@ class ModelState(ModelStateBase):
self._policy_slices_list = [metadata[k]['output_slices'] for k in policy_keys] self._policy_slices_list = [metadata[k]['output_slices'] for k in policy_keys]
self.policy_output_slices = self._policy_slices_list[0] self.policy_output_slices = self._policy_slices_list[0]
self._has_on_policy = any('on' in k.lower() for k in policy_keys) self._has_on_policy = any('on' in k.lower() for k in policy_keys)
on_policy_key = next((k for k in policy_keys if 'on' in k.lower()), None)
self._on_policy_has_plan = on_policy_key is not None and 'plan' in metadata[on_policy_key]['output_slices']
first_policy_metadata = metadata[policy_keys[0]] first_policy_metadata = metadata[policy_keys[0]]
vision_input_shapes = vision_metadata['input_shapes'] vision_input_shapes = vision_metadata['input_shapes']
policy_input_shapes = first_policy_metadata['input_shapes'] policy_input_shapes = first_policy_metadata['input_shapes']
@@ -201,7 +231,7 @@ class ModelState(ModelStateBase):
inputs[desire_key][0] = 0 inputs[desire_key][0] = 0
self.npy[desire_key][:] = np.where(inputs[desire_key] - self.prev_desire > .99, inputs[desire_key], 0) self.npy[desire_key][:] = np.where(inputs[desire_key] - self.prev_desire > .99, inputs[desire_key], 0)
self.prev_desire[:] = inputs[desire_key] self.prev_desire[:] = inputs[desire_key]
for key in ('traffic_convention', 'lateral_control_params'): for key in ('traffic_convention', 'lateral_control_params', 'action_t'):
if key in self.npy and key in inputs: if key in self.npy and key in inputs:
self.npy[key][:] = inputs[key] self.npy[key][:] = inputs[key]
@@ -229,7 +259,7 @@ class ModelState(ModelStateBase):
policy_output = raw_outputs[i + 1].numpy().flatten() policy_output = raw_outputs[i + 1].numpy().flatten()
policy_sliced = {k: policy_output[np.newaxis, v] for k, v in policy_slices.items()} policy_sliced = {k: policy_output[np.newaxis, v] for k, v in policy_slices.items()}
parsed = self.parser.parse_policy_outputs(policy_sliced) parsed = self.parser.parse_policy_outputs(policy_sliced)
if 'off' in self._policy_keys[i] and self._has_on_policy: if 'off' in self._policy_keys[i] and self._on_policy_has_plan:
parsed.pop('plan', None) parsed.pop('plan', None)
outputs.update(parsed) outputs.update(parsed)
@@ -245,6 +275,20 @@ class ModelState(ModelStateBase):
def get_action_from_model(self, model_output: dict[str, np.ndarray], prev_action: log.ModelDataV2.Action, def get_action_from_model(self, model_output: dict[str, np.ndarray], prev_action: log.ModelDataV2.Action,
lat_action_t: float, long_action_t: float, v_ego: float) -> log.ModelDataV2.Action: lat_action_t: float, long_action_t: float, v_ego: float) -> log.ModelDataV2.Action:
if 'action' in model_output:
desired_accel = model_output['action'][0, 1]
desired_curvature = model_output['action'][0, 0] / (max(1.0, v_ego))**2
should_stop = (v_ego < 0.3 and desired_accel < 0.1)
desired_accel = smooth_value(desired_accel, prev_action.desiredAcceleration, self.LONG_SMOOTH_SECONDS)
if self.generation is not None and self.generation >= 10:
if v_ego > self.MIN_LAT_CONTROL_SPEED:
desired_curvature = smooth_value(desired_curvature, prev_action.desiredCurvature, self.LAT_SMOOTH_SECONDS)
else:
desired_curvature = prev_action.desiredCurvature
return log.ModelDataV2.Action(desiredCurvature=float(desired_curvature),
desiredAcceleration=float(desired_accel),
shouldStop=bool(should_stop))
plan = model_output['plan'][0] plan = model_output['plan'][0]
desired_accel, should_stop = get_accel_from_plan(plan[:, Plan.VELOCITY][:, 0], plan[:, Plan.ACCELERATION][:, 0], self.constants.T_IDXS, desired_accel, should_stop = get_accel_from_plan(plan[:, Plan.VELOCITY][:, 0], plan[:, Plan.ACCELERATION][:, 0], self.constants.T_IDXS,
action_t=long_action_t) action_t=long_action_t)
@@ -404,9 +448,14 @@ def main(demo=False):
bufs = {name: buf_extra if 'big' in name else buf_main for name in model.vision_input_names} bufs = {name: buf_extra if 'big' in name else buf_main for name in model.vision_input_names}
transforms = {name: model_transform_extra if 'big' in name else model_transform_main for name in model.vision_input_names} transforms = {name: model_transform_extra if 'big' in name else model_transform_main for name in model.vision_input_names}
inputs:dict[str, np.ndarray] = { frame_delay = DT_MDL
action_delay = DT_MDL / 2
lat_action_t = lat_delay + frame_delay + action_delay
long_action_t = long_delay + frame_delay + action_delay
inputs: dict[str, np.ndarray] = {
model.desire_key: vec_desire, model.desire_key: vec_desire,
'traffic_convention': traffic_convention, 'traffic_convention': traffic_convention,
'action_t': np.array([lat_action_t, long_action_t], dtype=np.float32),
} }
if 'lateral_control_params' in model.npy: if 'lateral_control_params' in model.npy:
@@ -146,6 +146,8 @@ class Parser:
def parse_policy_outputs(self, outs: dict[str, np.ndarray]) -> dict[str, np.ndarray]: def parse_policy_outputs(self, outs: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
self.parse_dynamic_outputs(outs) self.parse_dynamic_outputs(outs)
self.split_outputs(outs) self.split_outputs(outs)
if 'action' in outs:
self.parse_mdn('action', outs, in_N=0, out_N=0, out_shape=(SplitModelConstants.ACTION_WIDTH,))
return outs return outs
def parse_outputs(self, outs: dict[str, np.ndarray]) -> dict[str, np.ndarray]: def parse_outputs(self, outs: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
@@ -80,11 +80,10 @@ class TestStockEquivalence:
stock_queues, stock_npy = make_input_queues(SPLIT_VISION_INPUT_SHAPES, SPLIT_POLICY_INPUT_SHAPES, frame_skip, stock_queues, stock_npy = make_input_queues(SPLIT_VISION_INPUT_SHAPES, SPLIT_POLICY_INPUT_SHAPES, frame_skip,
device='NPY') device='NPY')
# TODO-SP: remove action_t skip once SP adds prerequisite for deep models (action_t input queue) optional_keys = {'action_t'} if 'action_t' not in SPLIT_POLICY_INPUT_SHAPES else set()
skip_keys = {'action_t'} assert set(state.input_queues.keys()) == set(stock_queues.keys()) - optional_keys, \
assert set(state.input_queues.keys()) == set(stock_queues.keys()) - skip_keys, \
f"Queue keys differ: v2={set(state.input_queues.keys())}, stock={set(stock_queues.keys())}" f"Queue keys differ: v2={set(state.input_queues.keys())}, stock={set(stock_queues.keys())}"
assert set(state.npy.keys()) == set(stock_npy.keys()) - skip_keys, \ assert set(state.npy.keys()) == set(stock_npy.keys()) - optional_keys, \
f"Npy keys differ: v2={set(state.npy.keys())}, stock={set(stock_npy.keys())}" f"Npy keys differ: v2={set(state.npy.keys())}, stock={set(stock_npy.keys())}"
def test_split_queue_keys_work_with_desire_key(self, model_state_factory): def test_split_queue_keys_work_with_desire_key(self, model_state_factory):
@@ -68,3 +68,18 @@ def test_recovery_power_scaling():
# For the below, yes, I know this isn't the same slicing as fillmodlmsg. This is to show that the values are only scaled on curv # For the below, yes, I know this isn't the same slicing as fillmodlmsg. This is to show that the values are only scaled on curv
expected_curv_plan_vel = plan[0, :, Plan.VELOCITY][:, 0] + control * planplus[0, :, Plan.VELOCITY][:, 0] expected_curv_plan_vel = plan[0, :, Plan.VELOCITY][:, 0] + control * planplus[0, :, Plan.VELOCITY][:, 0]
np.testing.assert_allclose(recorded_curv_plans[0][:, Plan.VELOCITY][:, 0], expected_curv_plan_vel, rtol=1e-5, atol=1e-6) np.testing.assert_allclose(recorded_curv_plans[0][:, Plan.VELOCITY][:, 0], expected_curv_plan_vel, rtol=1e-5, atol=1e-6)
def test_action_direct_output():
state = MockStruct(
LONG_SMOOTH_SECONDS=0.3,
LAT_SMOOTH_SECONDS=0.1,
MIN_LAT_CONTROL_SPEED=0.3,
generation=12,
)
prev_action = log.ModelDataV2.Action()
model_output = {'action': np.array([[0.01, -0.5]])}
result = ModelState.get_action_from_model(state, model_output, prev_action, 0.1, 0.1, 10.0)
assert result.desiredAcceleration != 0.0
assert result.desiredCurvature != 0.0
assert isinstance(result.shouldStop, bool)
+2 -2
View File
@@ -143,8 +143,8 @@ class Warp:
self._nv12_cache[key] = get_nv12_info(cam_w, cam_h)[3] self._nv12_cache[key] = get_nv12_info(cam_w, cam_h)[3]
yuv_size = self._nv12_cache[key] yuv_size = self._nv12_cache[key]
road_ptr = bufs[road].data.ctypes.data road_ptr = np.frombuffer(bufs[road].data, dtype=np.uint8).ctypes.data
wide_ptr = bufs[wide].data.ctypes.data wide_ptr = np.frombuffer(bufs[wide].data, dtype=np.uint8).ctypes.data
if road_ptr not in self._blob_cache: if road_ptr not in self._blob_cache:
self._blob_cache[road_ptr] = Tensor.from_blob(road_ptr, (yuv_size,), dtype='uint8') self._blob_cache[road_ptr] = Tensor.from_blob(road_ptr, (yuv_size,), dtype='uint8')
if wide_ptr not in self._blob_cache: if wide_ptr not in self._blob_cache:
@@ -43,6 +43,7 @@ class SplitModelConstants:
LANE_LINES_WIDTH = 2 LANE_LINES_WIDTH = 2
ROAD_EDGES_WIDTH = 2 ROAD_EDGES_WIDTH = 2
PLAN_WIDTH = 15 PLAN_WIDTH = 15
ACTION_WIDTH = 2
DESIRE_PRED_WIDTH = 8 DESIRE_PRED_WIDTH = 8
LAT_PLANNER_SOLUTION_WIDTH = 4 LAT_PLANNER_SOLUTION_WIDTH = 4
DESIRED_CURV_WIDTH = 1 DESIRED_CURV_WIDTH = 1
@@ -32,7 +32,7 @@ class Proc:
PROCS = [ PROCS = [
Proc(['camerad'], 1.65, atol=0.4, msgs=['roadCameraState', 'wideRoadCameraState', 'driverCameraState']), Proc(['camerad'], 1.65, atol=0.4, msgs=['roadCameraState', 'wideRoadCameraState', 'driverCameraState']),
Proc(['modeld'], 1.5, atol=0.2, msgs=['modelV2']), Proc(['modeld'], 1.8, atol=0.2, msgs=['modelV2']),
Proc(['dmonitoringmodeld'], 0.65, atol=0.35, msgs=['driverStateV2']), Proc(['dmonitoringmodeld'], 0.65, atol=0.35, msgs=['driverStateV2']),
Proc(['encoderd'], 0.23, msgs=[]), Proc(['encoderd'], 0.23, msgs=[]),
] ]