This commit is contained in:
firestar5683
2026-02-05 14:15:36 -06:00
parent 51c54cc353
commit dc191e833e
7 changed files with 147 additions and 34 deletions
+24 -4
View File
@@ -107,7 +107,7 @@ class ModelManager:
self.downloading_model = False
return
if model_version in ("v8", "v9", "v10", "v11"):
if model_version in ("v8", "v9", "v10", "v11", "v12"):
# Download all PKL and metadata files for multi-file tinygrad models (v8 and v9)
filenames = [
f"{model_to_download}_driving_policy_tinygrad.pkl",
@@ -115,6 +115,11 @@ class ModelManager:
f"{model_to_download}_driving_policy_metadata.pkl",
f"{model_to_download}_driving_vision_metadata.pkl",
]
if model_version == "v12":
filenames += [
f"{model_to_download}_driving_off_policy_tinygrad.pkl",
f"{model_to_download}_driving_off_policy_metadata.pkl",
]
for filename in filenames:
model_path = MODELS_PATH / filename
model_url = f"{repo_url}/Models/{filename}"
@@ -208,13 +213,18 @@ class ModelManager:
except Exception:
model_version = None
if model_version in ("v8", "v9", "v10", "v11"):
if model_version in ("v8", "v9", "v10", "v11", "v12"):
v8_v9_files = [
f"{model}_driving_policy_tinygrad.pkl",
f"{model}_driving_vision_tinygrad.pkl",
f"{model}_driving_policy_metadata.pkl",
f"{model}_driving_vision_metadata.pkl",
]
if model_version == "v12":
v8_v9_files += [
f"{model}_driving_off_policy_tinygrad.pkl",
f"{model}_driving_off_policy_metadata.pkl",
]
if all((MODELS_PATH / f).is_file() for f in v8_v9_files):
downloaded_models.add(model)
elif model_version == "v7":
@@ -259,13 +269,18 @@ class ModelManager:
except Exception:
model_version = None
if model_version in ("v8", "v9", "v10", "v11"):
if model_version in ("v8", "v9", "v10", "v11", "v12"):
v8_v9_files = [
f"{model}_driving_policy_tinygrad.pkl",
f"{model}_driving_vision_tinygrad.pkl",
f"{model}_driving_policy_metadata.pkl",
f"{model}_driving_vision_metadata.pkl",
]
if model_version == "v12":
v8_v9_files += [
f"{model}_driving_off_policy_tinygrad.pkl",
f"{model}_driving_off_policy_metadata.pkl",
]
for filename in v8_v9_files:
path = MODELS_PATH / filename
expected_size = model_sizes.get(filename.rsplit(".", 1)[0])
@@ -378,13 +393,18 @@ class ModelManager:
handle_error(None, "Download cancelled...", "Download cancelled...", MODEL_DOWNLOAD_ALL_PARAM, DOWNLOAD_PROGRESS_PARAM, params_memory)
return
if model_version in ("v8", "v9", "v10", "v11"):
if model_version in ("v8", "v9", "v10", "v11", "v12"):
required_files = [
f"{model}_driving_policy_tinygrad.pkl",
f"{model}_driving_vision_tinygrad.pkl",
f"{model}_driving_policy_metadata.pkl",
f"{model}_driving_vision_metadata.pkl",
]
if model_version == "v12":
required_files += [
f"{model}_driving_off_policy_tinygrad.pkl",
f"{model}_driving_off_policy_metadata.pkl",
]
missing = [f for f in required_files if not (MODELS_PATH / f).is_file()]
if missing:
print(f"Tinygrad model {model} is missing files. Preparing to download...")
+3 -1
View File
@@ -96,6 +96,8 @@ EXCLUDED_KEYS = {
}
TINYGRAD_FILES = [
("driving_off_policy_metadata.pkl", "off-policy metadata"),
("driving_off_policy_tinygrad.pkl", "off-policy model"),
("driving_policy_metadata.pkl", "policy metadata"),
("driving_policy_tinygrad.pkl", "policy model"),
("driving_vision_metadata.pkl", "vision metadata"),
@@ -888,7 +890,7 @@ class FrogPilotVariables:
toggle.model_version = DEFAULT_MODEL_VERSION
toggle.classic_longitudinal = toggle.model_version in {"v1", "v2", "v3", "v4"}
toggle.classic_model = toggle.model_version in {"v1", "v2", "v3", "v4"}
toggle.tinygrad_model = toggle.model_version in {"v8", "v9", "v10", "v11"}
toggle.tinygrad_model = toggle.model_version in {"v8", "v9", "v10", "v11", "v12"}
toggle.tomb_raider = toggle.model == "space-lab"
toggle.model_ui = params.get_bool("ModelUI") if tuning_level >= level["ModelUI"] else default.get_bool("ModelUI")
+8 -2
View File
@@ -32,7 +32,10 @@ lenvCython.Program('models/commonmodel_pyx.so', 'models/commonmodel_pyx.pyx', LI
tinygrad_files = ["#"+x for x in glob.glob(env.Dir("#tinygrad_repo").relpath + "/**", recursive=True, root_dir=env.Dir("#").abspath) if 'pycache' not in x]
# Get model metadata
for model_name in ['driving_vision', 'driving_policy']:
model_metadata_names = ['driving_vision', 'driving_policy']
if File("models/driving_off_policy.onnx").exists():
model_metadata_names.append('driving_off_policy')
for model_name in model_metadata_names:
fn = File(f"models/{model_name}").abspath
script_files = [File(Dir("#frogpilot/tinygrad_modeld").File("get_model_metadata.py").abspath)]
cmd = f'python3 {Dir("#frogpilot/tinygrad_modeld").abspath}/get_model_metadata.py {fn}.onnx'
@@ -48,7 +51,10 @@ def tg_compile(flags, model_name):
)
# Compile small models
for model_name in ['driving_vision', 'driving_policy', 'dmonitoring_model']:
model_compile_names = ['driving_vision', 'driving_policy', 'dmonitoring_model']
if File("models/driving_off_policy.onnx").exists():
model_compile_names.append('driving_off_policy')
for model_name in model_compile_names:
flags = {
'larch64': 'DEV=QCOM',
'Darwin': 'DEV=CPU IMAGE=0',
Binary file not shown.
Binary file not shown.
+103 -27
View File
@@ -83,6 +83,34 @@ class ModelState:
output: np.ndarray
prev_desire: np.ndarray # for tracking the rising edge of the pulse
def _build_policy_inputs(self, input_shapes: dict[str, tuple[int, ...]]) -> tuple[dict[str, np.ndarray], str | None]:
numpy_inputs: dict[str, np.ndarray] = {}
# Always-supported inputs (if model expects them)
desire_key_init = next((k for k in input_shapes if k.startswith('desire')), None)
if desire_key_init:
numpy_inputs[desire_key_init] = np.zeros((1, ModelConstants.INPUT_HISTORY_BUFFER_LEN, ModelConstants.DESIRE_LEN), dtype=np.float32)
if 'traffic_convention' in input_shapes:
numpy_inputs['traffic_convention'] = np.zeros((1, ModelConstants.TRAFFIC_CONVENTION_LEN), dtype=np.float32)
if 'features_buffer' in input_shapes:
numpy_inputs['features_buffer'] = np.zeros((1, ModelConstants.INPUT_HISTORY_BUFFER_LEN, ModelConstants.FEATURE_LEN), dtype=np.float32)
# Optional inputs for non-v11 (and some v10/v9 variants)
# Lateral control params
if 'lateral_control_params' in input_shapes:
numpy_inputs['lateral_control_params'] = np.zeros((1, ModelConstants.LATERAL_CONTROL_PARAMS_LEN), dtype=np.float32)
# Previous desired curvature: handle both singular and plural key names across model versions
prev_desired_curv_key = None
if 'prev_desired_curv' in input_shapes:
prev_desired_curv_key = 'prev_desired_curv'
numpy_inputs['prev_desired_curv'] = np.zeros((1, ModelConstants.INPUT_HISTORY_BUFFER_LEN, ModelConstants.PREV_DESIRED_CURV_LEN), dtype=np.float32)
elif 'prev_desired_curvs' in input_shapes:
prev_desired_curv_key = 'prev_desired_curvs'
numpy_inputs['prev_desired_curvs'] = np.zeros((1, ModelConstants.INPUT_HISTORY_BUFFER_LEN, ModelConstants.PREV_DESIRED_CURV_LEN), dtype=np.float32)
return numpy_inputs, prev_desired_curv_key
def __init__(self, context: CLContext):
# Dynamically build paths based on current model ID
params = Params()
@@ -101,13 +129,17 @@ class ModelState:
models_dir = Path(__file__).parent / "models"
VISION_PKL_PATH = models_dir / "driving_vision_tinygrad.pkl"
POLICY_PKL_PATH = models_dir / "driving_policy_tinygrad.pkl"
OFF_POLICY_PKL_PATH = models_dir / "driving_off_policy_tinygrad.pkl"
VISION_METADATA_PATH = models_dir / "driving_vision_metadata.pkl"
POLICY_METADATA_PATH = models_dir / "driving_policy_metadata.pkl"
OFF_POLICY_METADATA_PATH = models_dir / "driving_off_policy_metadata.pkl"
else:
VISION_PKL_PATH = model_dir / f"{model_id}_driving_vision_tinygrad.pkl"
POLICY_PKL_PATH = model_dir / f"{model_id}_driving_policy_tinygrad.pkl"
OFF_POLICY_PKL_PATH = model_dir / f"{model_id}_driving_off_policy_tinygrad.pkl"
VISION_METADATA_PATH = model_dir / f"{model_id}_driving_vision_metadata.pkl"
POLICY_METADATA_PATH = model_dir / f"{model_id}_driving_policy_metadata.pkl"
OFF_POLICY_METADATA_PATH = model_dir / f"{model_id}_driving_off_policy_metadata.pkl"
# If ModelVersion is not set or not available, try to determine it from available model data
if not model_version:
@@ -159,7 +191,7 @@ class ModelState:
self.policy_generation = model_version or "v8"
self.is_v11 = (self.policy_generation == "v11")
self.is_v9 = (self.policy_generation == "v9")
self.mlsim = (self.policy_generation in ("v8", "v10", "v11"))
self.mlsim = (self.policy_generation in ("v8", "v10", "v11", "v12"))
self.frames = {name: DrivingModelFrame(context, ModelConstants.TEMPORAL_SKIP) for name in self.vision_input_names}
self.prev_desire = np.zeros(ModelConstants.DESIRE_LEN, dtype=np.float32)
@@ -170,33 +202,50 @@ class ModelState:
# policy inputs (built dynamically to support all generations)
self.numpy_inputs = {}
self.numpy_inputs, self.prev_desired_curv_key = self._build_policy_inputs(self.policy_input_shapes)
# Always-supported inputs (if model expects them)
desire_key_init = next((k for k in self.policy_input_shapes if k.startswith('desire')), None)
if desire_key_init:
self.numpy_inputs[desire_key_init] = np.zeros((1, ModelConstants.INPUT_HISTORY_BUFFER_LEN, ModelConstants.DESIRE_LEN), dtype=np.float32)
if 'traffic_convention' in self.policy_input_shapes:
self.numpy_inputs['traffic_convention'] = np.zeros((1, ModelConstants.TRAFFIC_CONVENTION_LEN), dtype=np.float32)
if 'features_buffer' in self.policy_input_shapes:
self.numpy_inputs['features_buffer'] = np.zeros((1, ModelConstants.INPUT_HISTORY_BUFFER_LEN, ModelConstants.FEATURE_LEN), dtype=np.float32)
# Off-policy model (optional)
self.off_policy_enabled = False
self.off_policy_input_shapes: dict[str, tuple[int, ...]] = {}
self.off_policy_output_slices: dict[str, slice] = {}
self.off_policy_numpy_inputs: dict[str, np.ndarray] = {}
self.off_policy_prev_desired_curv_key: str | None = None
self.off_policy_desire_key: str | None = None
self.off_policy_inputs: dict[str, Tensor] | None = None
self.off_policy_output: np.ndarray | None = None
# Optional inputs for non-v11 (and some v10/v9 variants)
# Lateral control params
if 'lateral_control_params' in self.policy_input_shapes:
self.numpy_inputs['lateral_control_params'] = np.zeros((1, ModelConstants.LATERAL_CONTROL_PARAMS_LEN), dtype=np.float32)
off_policy_metadata = None
if self.policy_generation == "v12" or OFF_POLICY_METADATA_PATH.is_file() or OFF_POLICY_PKL_PATH.is_file():
try:
with open(OFF_POLICY_METADATA_PATH, 'rb') as f:
off_policy_metadata = pickle.load(f)
except FileNotFoundError:
cloudlog.error(f"Missing metadata {OFF_POLICY_METADATA_PATH}, downloading...")
from openpilot.frogpilot.assets.model_manager import ModelManager
ModelManager().download_model(model_id)
try:
with open(OFF_POLICY_METADATA_PATH, 'rb') as f:
off_policy_metadata = pickle.load(f)
except FileNotFoundError:
cloudlog.warning(f"Off-policy metadata still missing: {OFF_POLICY_METADATA_PATH}")
# Previous desired curvature: handle both singular and plural key names across model versions
self.prev_desired_curv_key = None
if 'prev_desired_curv' in self.policy_input_shapes:
self.prev_desired_curv_key = 'prev_desired_curv'
self.numpy_inputs['prev_desired_curv'] = np.zeros((1, ModelConstants.INPUT_HISTORY_BUFFER_LEN, ModelConstants.PREV_DESIRED_CURV_LEN), dtype=np.float32)
elif 'prev_desired_curvs' in self.policy_input_shapes:
self.prev_desired_curv_key = 'prev_desired_curvs'
self.numpy_inputs['prev_desired_curvs'] = np.zeros((1, ModelConstants.INPUT_HISTORY_BUFFER_LEN, ModelConstants.PREV_DESIRED_CURV_LEN), dtype=np.float32)
if off_policy_metadata is not None:
self.off_policy_input_shapes = off_policy_metadata['input_shapes']
self.off_policy_output_slices = off_policy_metadata['output_slices']
off_policy_output_size = off_policy_metadata['output_shapes']['outputs'][1]
self.off_policy_numpy_inputs, self.off_policy_prev_desired_curv_key = self._build_policy_inputs(self.off_policy_input_shapes)
self.off_policy_desire_key = next((k for k in self.off_policy_numpy_inputs if k.startswith('desire')), None)
self.off_policy_inputs = {k: Tensor(v, device='NPY').realize() for k, v in self.off_policy_numpy_inputs.items()}
self.off_policy_output = np.zeros(off_policy_output_size, dtype=np.float32)
try:
with open(OFF_POLICY_PKL_PATH, "rb") as f:
self.off_policy_run = pickle.load(f)
self.off_policy_enabled = True
except FileNotFoundError:
cloudlog.warning(f"Missing off-policy model {OFF_POLICY_PKL_PATH}, skipping off-policy")
# Optional temporal buffer for previous desired curvature (allocate only if the policy expects it)
if getattr(self, 'prev_desired_curv_key', None) is not None:
# Optional temporal buffer for previous desired curvature (allocate only if any model expects it)
if self.prev_desired_curv_key is not None or self.off_policy_prev_desired_curv_key is not None:
self.full_prev_desired_curv = np.zeros((1, ModelConstants.FULL_HISTORY_BUFFER_LEN, ModelConstants.PREV_DESIRED_CURV_LEN), dtype=np.float32)
@@ -206,6 +255,7 @@ class ModelState:
self.policy_inputs = {k: Tensor(v, device='NPY').realize() for k,v in self.numpy_inputs.items()}
self.policy_output = np.zeros(policy_output_size, dtype=np.float32)
self.parser = Parser()
self.off_policy_parser = Parser(ignore_missing=True)
with open(VISION_PKL_PATH, "rb") as f:
self.vision_run = pickle.load(f)
@@ -231,10 +281,18 @@ class ModelState:
self.full_desire[0,:-1] = self.full_desire[0,1:]
self.full_desire[0,-1] = new_desire
self.numpy_inputs[self.desire_key][:] = self.full_desire.reshape((1,ModelConstants.INPUT_HISTORY_BUFFER_LEN,ModelConstants.TEMPORAL_SKIP,-1)).max(axis=2)
if self.off_policy_enabled and self.off_policy_desire_key is not None:
self.off_policy_numpy_inputs[self.off_policy_desire_key][:] = self.numpy_inputs[self.desire_key]
if 'traffic_convention' in self.numpy_inputs:
self.numpy_inputs['traffic_convention'][:] = inputs['traffic_convention']
if self.off_policy_enabled and 'traffic_convention' in self.off_policy_numpy_inputs:
self.off_policy_numpy_inputs['traffic_convention'][:] = inputs['traffic_convention']
self.numpy_inputs['traffic_convention'][:] = inputs['traffic_convention']
if 'lateral_control_params' in self.numpy_inputs:
self.numpy_inputs['lateral_control_params'][:] = inputs['lateral_control_params']
if self.off_policy_enabled and 'lateral_control_params' in self.off_policy_numpy_inputs:
self.off_policy_numpy_inputs['lateral_control_params'][:] = inputs['lateral_control_params']
if prepare_only:
return None
@@ -256,7 +314,10 @@ class ModelState:
self.full_features_buffer[0,:-1] = self.full_features_buffer[0,1:]
self.full_features_buffer[0,-1] = vision_outputs_dict['hidden_state'][0, :]
self.numpy_inputs['features_buffer'][:] = self.full_features_buffer[0, self.temporal_idxs]
if 'features_buffer' in self.numpy_inputs:
self.numpy_inputs['features_buffer'][:] = self.full_features_buffer[0, self.temporal_idxs]
if self.off_policy_enabled and 'features_buffer' in self.off_policy_numpy_inputs:
self.off_policy_numpy_inputs['features_buffer'][:] = self.full_features_buffer[0, self.temporal_idxs]
self.policy_output = self.policy_run(**self.policy_inputs).contiguous().realize().uop.base.buffer.numpy()
policy_outputs_dict = self.parser.parse_policy_outputs(self.slice_outputs(self.policy_output, self.policy_output_slices))
@@ -273,9 +334,24 @@ class ModelState:
else:
self.numpy_inputs[self.prev_desired_curv_key][:] = self.full_prev_desired_curv[0, self.temporal_idxs]
if self.off_policy_enabled and self.off_policy_prev_desired_curv_key is not None:
if self.is_v9:
self.off_policy_numpy_inputs[self.off_policy_prev_desired_curv_key][:] = 0 * self.full_prev_desired_curv[0, self.temporal_idxs]
else:
self.off_policy_numpy_inputs[self.off_policy_prev_desired_curv_key][:] = self.full_prev_desired_curv[0, self.temporal_idxs]
combined_outputs_dict = {**vision_outputs_dict, **policy_outputs_dict}
if self.off_policy_enabled:
self.off_policy_output = self.off_policy_run(**self.off_policy_inputs).contiguous().realize().uop.base.buffer.numpy()
off_policy_outputs_dict = self.off_policy_parser.parse_policy_outputs(
self.slice_outputs(self.off_policy_output, self.off_policy_output_slices)
)
combined_outputs_dict.update(off_policy_outputs_dict)
if SEND_RAW_PRED:
combined_outputs_dict['raw_pred'] = np.concatenate([self.vision_output.copy(), self.policy_output.copy()])
raw_pred = [self.vision_output.copy(), self.policy_output.copy()]
if self.off_policy_enabled and self.off_policy_output is not None:
raw_pred.append(self.off_policy_output.copy())
combined_outputs_dict['raw_pred'] = np.concatenate(raw_pred)
return combined_outputs_dict
@@ -503,6 +503,8 @@ bool FrogPilotModelPanel::isModelInstalled(const QString &key) const {
bool has_policy_tg = false;
bool has_vision_meta = false;
bool has_vision_tg = false;
bool has_off_policy_meta = false;
bool has_off_policy_tg = false;
bool foundAny = false;
for (const QString &file : modelDir.entryList(QDir::Files)) {
@@ -521,6 +523,10 @@ bool FrogPilotModelPanel::isModelInstalled(const QString &key) const {
has_policy_meta = true;
} else if (base.contains("_driving_policy_tinygrad")) {
has_policy_tg = true;
} else if (base.contains("_driving_off_policy_metadata")) {
has_off_policy_meta = true;
} else if (base.contains("_driving_off_policy_tinygrad")) {
has_off_policy_tg = true;
} else if (base.contains("_driving_vision_metadata")) {
has_vision_meta = true;
} else if (base.contains("_driving_vision_tinygrad")) {
@@ -534,6 +540,9 @@ bool FrogPilotModelPanel::isModelInstalled(const QString &key) const {
}
if (has_policy_meta && has_policy_tg && has_vision_meta && has_vision_tg) {
if (has_off_policy_meta || has_off_policy_tg) {
return has_off_policy_meta && has_off_policy_tg;
}
return true;
}