mirror of
https://github.com/sunnypilot/sunnypilot.git
synced 2026-06-27 00:52:05 +08:00
Update compile_modeld.py
This commit is contained in:
@@ -154,9 +154,9 @@ def create_jit_runner(vision_runner, policy_runners: list, nv12: NV12Frame, mode
|
||||
new_feat = vision_out_realized[:, features_slice].reshape(1, -1).unsqueeze(0)
|
||||
inputs['features_buffer'] = shift_and_sample(feat_q, new_feat, sample_skip_fn).realize()
|
||||
policy_outs = [next(iter(runner(inputs).values())).cast('float32') for runner in policy_runners]
|
||||
vision_out_kernel = vision_out_realized + 0
|
||||
policy_out_kernels = [p + 0 for p in policy_outs]
|
||||
return (vision_out_kernel, *policy_out_kernels) if len(policy_out_kernels) > 1 else (vision_out_kernel, policy_out_kernels[0])
|
||||
vision_out_final = vision_out_realized.detach()
|
||||
policy_out_final = [p.detach() for p in policy_outs]
|
||||
return (vision_out_final, *policy_out_final) if len(policy_out_final) > 1 else (vision_out_final, policy_out_final[0])
|
||||
inputs.update({road_key: img.realize(), wide_key: big_img.realize(), 'features_buffer': sample_skip_fn(feat_q).realize()})
|
||||
policy_out = next(iter(policy_runners[0](inputs).values())).cast('float32')
|
||||
new_feat = policy_out[:, features_slice].reshape(1, -1).unsqueeze(0)
|
||||
|
||||
Reference in New Issue
Block a user