This commit is contained in:
discountchubbs
2026-06-07 03:56:09 -07:00
parent dd35c27981
commit 74692d0b5f
2 changed files with 11 additions and 12 deletions

View File

@@ -12,11 +12,11 @@ on:
required: false required: false
type: string type: string
recompiled_dir: recompiled_dir:
description: 'Existing recompiled directory number (e.g. 3 for recompiled3)' description: 'Existing recompiled directory number (e.g. 1 for recompiled1)'
required: true required: true
type: string type: string
json_version: json_version:
description: 'driving_models version number to update (e.g. 5 for driving_models_v5.json)' description: 'driving_models version number to update (e.g. 18 for driving_models_v18.json)'
required: true required: true
type: string type: string
artifact_suffix: artifact_suffix:
@@ -63,12 +63,11 @@ on:
default: 'None' default: 'None'
options: options:
- None - None
- Simple Plan Models - Master Models
- Space Lab Models - Release Models
- TR Models - 2025 World Models
- DTR Models - 2026 World Models
- Custom Merge Models - Custom Merge Models
- FOF series models
- Other - Other
custom_model_folder: custom_model_folder:
description: 'Custom model folder name (if "Other" selected)' description: 'Custom model folder name (if "Other" selected)'

View File

@@ -149,13 +149,13 @@ def create_jit_runner(vision_runner, policy_runners: list, nv12: NV12Frame, mode
inputs['traffic_convention'] = traffic_conv_dev inputs['traffic_convention'] = traffic_conv_dev
if vision_runner: if vision_runner:
vision_out = next(iter(vision_runner({road_key: img, wide_key: big_img}).values())).cast('float32') vision_out = next(iter(vision_runner({road_key: img, wide_key: big_img}).values()))
vision_out = vision_out.realize() vision_out = vision_out.cast('float32').realize().numpy()
new_feat = vision_out[:, features_slice].reshape(1, -1).unsqueeze(0) vision_out_tensor = Tensor(vision_out, device=Device.DEFAULT)
new_feat = vision_out_tensor[:, features_slice].reshape(1, -1).unsqueeze(0)
inputs['features_buffer'] = shift_and_sample(feat_q, new_feat, sample_skip_fn).realize() inputs['features_buffer'] = shift_and_sample(feat_q, new_feat, sample_skip_fn).realize()
policy_outs = [next(iter(runner(inputs).values())).cast('float32') for runner in policy_runners] policy_outs = [next(iter(runner(inputs).values())).cast('float32') for runner in policy_runners]
return (vision_out, *policy_outs) if len(policy_outs) > 1 else (vision_out, policy_outs[0]) return (vision_out_tensor, *policy_outs) if len(policy_outs) > 1 else (vision_out_tensor, policy_outs[0])
inputs.update({road_key: img, wide_key: big_img, 'features_buffer': sample_skip_fn(feat_q)}) inputs.update({road_key: img, wide_key: big_img, 'features_buffer': sample_skip_fn(feat_q)})
policy_out = next(iter(policy_runners[0](inputs).values())).cast('float32') policy_out = next(iter(policy_runners[0](inputs).values())).cast('float32')
new_feat = policy_out[:, features_slice].reshape(1, -1).unsqueeze(0) new_feat = policy_out[:, features_slice].reshape(1, -1).unsqueeze(0)