Compare commits

...

3 Commits

Author SHA1 Message Date
discountchubbs
be20848487 Update compile_modeld.py 2026-06-07 04:17:14 -07:00
discountchubbs
cdd232b606 Merge branch 'deep-rl' of github.com:sunnypilot/sunnypilot into deep-rl 2026-06-07 04:06:16 -07:00
discountchubbs
b21c70b1ba Update compile_modeld.py 2026-06-07 04:05:54 -07:00

View File

@@ -144,19 +144,20 @@ def create_jit_runner(vision_runner, policy_runners: list, nv12: NV12Frame, mode
return img, big_img
desire_buf = shift_and_sample(desire_q, desire_dev.reshape(1, 1, -1), sample_desire_fn)
inputs = {desire_key: desire_buf.realize(), **extra_tensors}
inputs = {desire_key: desire_buf.realize(), **{key: value.realize() for key, value in extra_tensors.items()}}
if traffic_conv_dev is not None:
inputs['traffic_convention'] = traffic_conv_dev
inputs['traffic_convention'] = traffic_conv_dev.realize()
if vision_runner:
vision_out = next(iter(vision_runner({road_key: img, wide_key: big_img}).values()))
vision_out = vision_out.cast('float32').realize().numpy()
vision_out_tensor = Tensor(vision_out, device=Device.DEFAULT)
new_feat = vision_out_tensor[:, features_slice].reshape(1, -1).unsqueeze(0)
vision_out_realized = vision_out.cast('float32').realize()
new_feat = vision_out_realized[:, features_slice].reshape(1, -1).unsqueeze(0)
inputs['features_buffer'] = shift_and_sample(feat_q, new_feat, sample_skip_fn).realize()
policy_outs = [next(iter(runner(inputs).values())).cast('float32') for runner in policy_runners]
return (vision_out_tensor, *policy_outs) if len(policy_outs) > 1 else (vision_out_tensor, policy_outs[0])
inputs.update({road_key: img, wide_key: big_img, 'features_buffer': sample_skip_fn(feat_q)})
vision_out_kernel = vision_out_realized + 0
policy_out_kernels = [p + 0 for p in policy_outs]
return (vision_out_kernel, *policy_out_kernels) if len(policy_out_kernels) > 1 else (vision_out_kernel, policy_out_kernels[0])
inputs.update({road_key: img.realize(), wide_key: big_img.realize(), 'features_buffer': sample_skip_fn(feat_q).realize()})
policy_out = next(iter(policy_runners[0](inputs).values())).cast('float32')
new_feat = policy_out[:, features_slice].reshape(1, -1).unsqueeze(0)
shift_and_sample(feat_q, new_feat, sample_skip_fn).realize()