mirror of
https://github.com/firestar5683/StarPilot.git
synced 2026-07-02 12:02:09 +08:00
Update tinygrad_modeld.py
This commit is contained in:
@@ -191,6 +191,8 @@ class ModelState:
|
||||
# Add policy_generation attribute after loading policy_metadata
|
||||
self.policy_generation = model_version or "v8"
|
||||
self.is_v11 = (self.policy_generation == "v11")
|
||||
self.is_v10 = (self.policy_generation == "v10")
|
||||
self.is_v12 = (self.policy_generation == "v12")
|
||||
self.is_v9 = (self.policy_generation == "v9")
|
||||
self.mlsim = (self.policy_generation in ("v8", "v10", "v11", "v12"))
|
||||
|
||||
@@ -329,14 +331,14 @@ class ModelState:
|
||||
self.full_prev_desired_curv[0,-1,:] = policy_outputs_dict['desired_curvature'][0, :]
|
||||
|
||||
if self.prev_desired_curv_key is not None:
|
||||
# v9 models expect zeros for prev_desired_curv(s); others use history
|
||||
if self.is_v9:
|
||||
# v9/v10/v11/v12 models expect zeros for prev_desired_curv(s); others use history
|
||||
if self.is_v9 or self.is_v10 or self.is_v11 or self.is_v12:
|
||||
self.numpy_inputs[self.prev_desired_curv_key][:] = 0 * self.full_prev_desired_curv[0, self.temporal_idxs]
|
||||
else:
|
||||
self.numpy_inputs[self.prev_desired_curv_key][:] = self.full_prev_desired_curv[0, self.temporal_idxs]
|
||||
|
||||
if self.off_policy_enabled and self.off_policy_prev_desired_curv_key is not None:
|
||||
if self.is_v9:
|
||||
if self.is_v9 or self.is_v12:
|
||||
self.off_policy_numpy_inputs[self.off_policy_prev_desired_curv_key][:] = 0 * self.full_prev_desired_curv[0, self.temporal_idxs]
|
||||
else:
|
||||
self.off_policy_numpy_inputs[self.off_policy_prev_desired_curv_key][:] = self.full_prev_desired_curv[0, self.temporal_idxs]
|
||||
|
||||
Reference in New Issue
Block a user