Update compile_modeld.py

This commit is contained in:
discountchubbs
2026-06-07 09:47:40 -07:00
parent be20848487
commit 049dfd2eaa
+3 -3
View File
@@ -154,9 +154,9 @@ def create_jit_runner(vision_runner, policy_runners: list, nv12: NV12Frame, mode
new_feat = vision_out_realized[:, features_slice].reshape(1, -1).unsqueeze(0)
inputs['features_buffer'] = shift_and_sample(feat_q, new_feat, sample_skip_fn).realize()
policy_outs = [next(iter(runner(inputs).values())).cast('float32') for runner in policy_runners]
vision_out_kernel = vision_out_realized + 0
policy_out_kernels = [p + 0 for p in policy_outs]
return (vision_out_kernel, *policy_out_kernels) if len(policy_out_kernels) > 1 else (vision_out_kernel, policy_out_kernels[0])
vision_out_final = vision_out_realized.detach()
policy_out_final = [p.detach() for p in policy_outs]
return (vision_out_final, *policy_out_final) if len(policy_out_final) > 1 else (vision_out_final, policy_out_final[0])
inputs.update({road_key: img.realize(), wide_key: big_img.realize(), 'features_buffer': sample_skip_fn(feat_q).realize()})
policy_out = next(iter(policy_runners[0](inputs).values())).cast('float32')
new_feat = policy_out[:, features_slice].reshape(1, -1).unsqueeze(0)