Update tinygrad_modeld.py

This commit is contained in:
firestar5683
2026-02-24 13:24:53 -06:00
parent 984fd6a6cb
commit e2a15a0985
+5 -3
View File
@@ -191,6 +191,8 @@ class ModelState:
# Add policy_generation attribute after loading policy_metadata
self.policy_generation = model_version or "v8"
self.is_v11 = (self.policy_generation == "v11")
self.is_v10 = (self.policy_generation == "v10")
self.is_v12 = (self.policy_generation == "v12")
self.is_v9 = (self.policy_generation == "v9")
self.mlsim = (self.policy_generation in ("v8", "v10", "v11", "v12"))
@@ -329,14 +331,14 @@ class ModelState:
self.full_prev_desired_curv[0,-1,:] = policy_outputs_dict['desired_curvature'][0, :]
if self.prev_desired_curv_key is not None:
# v9 models expect zeros for prev_desired_curv(s); others use history
if self.is_v9:
# v9/v10/v11/v12 models expect zeros for prev_desired_curv(s); others use history
if self.is_v9 or self.is_v10 or self.is_v11 or self.is_v12:
self.numpy_inputs[self.prev_desired_curv_key][:] = 0 * self.full_prev_desired_curv[0, self.temporal_idxs]
else:
self.numpy_inputs[self.prev_desired_curv_key][:] = self.full_prev_desired_curv[0, self.temporal_idxs]
if self.off_policy_enabled and self.off_policy_prev_desired_curv_key is not None:
if self.is_v9:
if self.is_v9 or self.is_v12:
self.off_policy_numpy_inputs[self.off_policy_prev_desired_curv_key][:] = 0 * self.full_prev_desired_curv[0, self.temporal_idxs]
else:
self.off_policy_numpy_inputs[self.off_policy_prev_desired_curv_key][:] = self.full_prev_desired_curv[0, self.temporal_idxs]