diff --git a/frogpilot/tinygrad_modeld/tinygrad_modeld.py b/frogpilot/tinygrad_modeld/tinygrad_modeld.py index 77af121f0..dbd8047cc 100755 --- a/frogpilot/tinygrad_modeld/tinygrad_modeld.py +++ b/frogpilot/tinygrad_modeld/tinygrad_modeld.py @@ -191,6 +191,8 @@ class ModelState: # Add policy_generation attribute after loading policy_metadata self.policy_generation = model_version or "v8" self.is_v11 = (self.policy_generation == "v11") + self.is_v10 = (self.policy_generation == "v10") + self.is_v12 = (self.policy_generation == "v12") self.is_v9 = (self.policy_generation == "v9") self.mlsim = (self.policy_generation in ("v8", "v10", "v11", "v12")) @@ -329,14 +331,14 @@ class ModelState: self.full_prev_desired_curv[0,-1,:] = policy_outputs_dict['desired_curvature'][0, :] if self.prev_desired_curv_key is not None: - # v9 models expect zeros for prev_desired_curv(s); others use history - if self.is_v9: + # v9/v10/v11/v12 models expect zeros for prev_desired_curv(s); others use history + if self.is_v9 or self.is_v10 or self.is_v11 or self.is_v12: self.numpy_inputs[self.prev_desired_curv_key][:] = 0 * self.full_prev_desired_curv[0, self.temporal_idxs] else: self.numpy_inputs[self.prev_desired_curv_key][:] = self.full_prev_desired_curv[0, self.temporal_idxs] if self.off_policy_enabled and self.off_policy_prev_desired_curv_key is not None: - if self.is_v9: + if self.is_v9 or self.is_v12: self.off_policy_numpy_inputs[self.off_policy_prev_desired_curv_key][:] = 0 * self.full_prev_desired_curv[0, self.temporal_idxs] else: self.off_policy_numpy_inputs[self.off_policy_prev_desired_curv_key][:] = self.full_prev_desired_curv[0, self.temporal_idxs]