diff --git a/frogpilot/assets/model_manager.py b/frogpilot/assets/model_manager.py index 63bf6a2f3..f760b1195 100644 --- a/frogpilot/assets/model_manager.py +++ b/frogpilot/assets/model_manager.py @@ -107,7 +107,7 @@ class ModelManager: self.downloading_model = False return - if model_version in ("v8", "v9", "v10", "v11"): + if model_version in ("v8", "v9", "v10", "v11", "v12"): # Download all PKL and metadata files for multi-file tinygrad models (v8 and v9) filenames = [ f"{model_to_download}_driving_policy_tinygrad.pkl", @@ -115,6 +115,11 @@ class ModelManager: f"{model_to_download}_driving_policy_metadata.pkl", f"{model_to_download}_driving_vision_metadata.pkl", ] + if model_version == "v12": + filenames += [ + f"{model_to_download}_driving_off_policy_tinygrad.pkl", + f"{model_to_download}_driving_off_policy_metadata.pkl", + ] for filename in filenames: model_path = MODELS_PATH / filename model_url = f"{repo_url}/Models/{filename}" @@ -208,13 +213,18 @@ class ModelManager: except Exception: model_version = None - if model_version in ("v8", "v9", "v10", "v11"): + if model_version in ("v8", "v9", "v10", "v11", "v12"): v8_v9_files = [ f"{model}_driving_policy_tinygrad.pkl", f"{model}_driving_vision_tinygrad.pkl", f"{model}_driving_policy_metadata.pkl", f"{model}_driving_vision_metadata.pkl", ] + if model_version == "v12": + v8_v9_files += [ + f"{model}_driving_off_policy_tinygrad.pkl", + f"{model}_driving_off_policy_metadata.pkl", + ] if all((MODELS_PATH / f).is_file() for f in v8_v9_files): downloaded_models.add(model) elif model_version == "v7": @@ -259,13 +269,18 @@ class ModelManager: except Exception: model_version = None - if model_version in ("v8", "v9", "v10", "v11"): + if model_version in ("v8", "v9", "v10", "v11", "v12"): v8_v9_files = [ f"{model}_driving_policy_tinygrad.pkl", f"{model}_driving_vision_tinygrad.pkl", f"{model}_driving_policy_metadata.pkl", f"{model}_driving_vision_metadata.pkl", ] + if model_version == "v12": + v8_v9_files += [ + f"{model}_driving_off_policy_tinygrad.pkl", + f"{model}_driving_off_policy_metadata.pkl", + ] for filename in v8_v9_files: path = MODELS_PATH / filename expected_size = model_sizes.get(filename.rsplit(".", 1)[0]) @@ -378,13 +393,18 @@ class ModelManager: handle_error(None, "Download cancelled...", "Download cancelled...", MODEL_DOWNLOAD_ALL_PARAM, DOWNLOAD_PROGRESS_PARAM, params_memory) return - if model_version in ("v8", "v9", "v10", "v11"): + if model_version in ("v8", "v9", "v10", "v11", "v12"): required_files = [ f"{model}_driving_policy_tinygrad.pkl", f"{model}_driving_vision_tinygrad.pkl", f"{model}_driving_policy_metadata.pkl", f"{model}_driving_vision_metadata.pkl", ] + if model_version == "v12": + required_files += [ + f"{model}_driving_off_policy_tinygrad.pkl", + f"{model}_driving_off_policy_metadata.pkl", + ] missing = [f for f in required_files if not (MODELS_PATH / f).is_file()] if missing: print(f"Tinygrad model {model} is missing files. Preparing to download...") diff --git a/frogpilot/common/frogpilot_variables.py b/frogpilot/common/frogpilot_variables.py index e0970491a..7374a2be2 100644 --- a/frogpilot/common/frogpilot_variables.py +++ b/frogpilot/common/frogpilot_variables.py @@ -96,6 +96,8 @@ EXCLUDED_KEYS = { } TINYGRAD_FILES = [ + ("driving_off_policy_metadata.pkl", "off-policy metadata"), + ("driving_off_policy_tinygrad.pkl", "off-policy model"), ("driving_policy_metadata.pkl", "policy metadata"), ("driving_policy_tinygrad.pkl", "policy model"), ("driving_vision_metadata.pkl", "vision metadata"), @@ -888,7 +890,7 @@ class FrogPilotVariables: toggle.model_version = DEFAULT_MODEL_VERSION toggle.classic_longitudinal = toggle.model_version in {"v1", "v2", "v3", "v4"} toggle.classic_model = toggle.model_version in {"v1", "v2", "v3", "v4"} - toggle.tinygrad_model = toggle.model_version in {"v8", "v9", "v10", "v11"} + toggle.tinygrad_model = toggle.model_version in {"v8", "v9", "v10", "v11", "v12"} toggle.tomb_raider = toggle.model == "space-lab" toggle.model_ui = params.get_bool("ModelUI") if tuning_level >= level["ModelUI"] else default.get_bool("ModelUI") diff --git a/frogpilot/tinygrad_modeld/SConscript b/frogpilot/tinygrad_modeld/SConscript index 42851bf91..0885c3c24 100644 --- a/frogpilot/tinygrad_modeld/SConscript +++ b/frogpilot/tinygrad_modeld/SConscript @@ -32,7 +32,10 @@ lenvCython.Program('models/commonmodel_pyx.so', 'models/commonmodel_pyx.pyx', LI tinygrad_files = ["#"+x for x in glob.glob(env.Dir("#tinygrad_repo").relpath + "/**", recursive=True, root_dir=env.Dir("#").abspath) if 'pycache' not in x] # Get model metadata -for model_name in ['driving_vision', 'driving_policy']: +model_metadata_names = ['driving_vision', 'driving_policy'] +if File("models/driving_off_policy.onnx").exists(): + model_metadata_names.append('driving_off_policy') +for model_name in model_metadata_names: fn = File(f"models/{model_name}").abspath script_files = [File(Dir("#frogpilot/tinygrad_modeld").File("get_model_metadata.py").abspath)] cmd = f'python3 {Dir("#frogpilot/tinygrad_modeld").abspath}/get_model_metadata.py {fn}.onnx' @@ -48,7 +51,10 @@ def tg_compile(flags, model_name): ) # Compile small models -for model_name in ['driving_vision', 'driving_policy', 'dmonitoring_model']: +model_compile_names = ['driving_vision', 'driving_policy', 'dmonitoring_model'] +if File("models/driving_off_policy.onnx").exists(): + model_compile_names.append('driving_off_policy') +for model_name in model_compile_names: flags = { 'larch64': 'DEV=QCOM', 'Darwin': 'DEV=CPU IMAGE=0', diff --git a/frogpilot/tinygrad_modeld/models/driving_policy.onnx b/frogpilot/tinygrad_modeld/models/driving_policy.onnx index e42bb8ea9..7de335528 100644 Binary files a/frogpilot/tinygrad_modeld/models/driving_policy.onnx and b/frogpilot/tinygrad_modeld/models/driving_policy.onnx differ diff --git a/frogpilot/tinygrad_modeld/models/driving_vision.onnx b/frogpilot/tinygrad_modeld/models/driving_vision.onnx index 902f1dd34..aff86857a 100644 Binary files a/frogpilot/tinygrad_modeld/models/driving_vision.onnx and b/frogpilot/tinygrad_modeld/models/driving_vision.onnx differ diff --git a/frogpilot/tinygrad_modeld/tinygrad_modeld.py b/frogpilot/tinygrad_modeld/tinygrad_modeld.py index 05eaa56b7..61e7ca57e 100755 --- a/frogpilot/tinygrad_modeld/tinygrad_modeld.py +++ b/frogpilot/tinygrad_modeld/tinygrad_modeld.py @@ -83,6 +83,34 @@ class ModelState: output: np.ndarray prev_desire: np.ndarray # for tracking the rising edge of the pulse + def _build_policy_inputs(self, input_shapes: dict[str, tuple[int, ...]]) -> tuple[dict[str, np.ndarray], str | None]: + numpy_inputs: dict[str, np.ndarray] = {} + + # Always-supported inputs (if model expects them) + desire_key_init = next((k for k in input_shapes if k.startswith('desire')), None) + if desire_key_init: + numpy_inputs[desire_key_init] = np.zeros((1, ModelConstants.INPUT_HISTORY_BUFFER_LEN, ModelConstants.DESIRE_LEN), dtype=np.float32) + if 'traffic_convention' in input_shapes: + numpy_inputs['traffic_convention'] = np.zeros((1, ModelConstants.TRAFFIC_CONVENTION_LEN), dtype=np.float32) + if 'features_buffer' in input_shapes: + numpy_inputs['features_buffer'] = np.zeros((1, ModelConstants.INPUT_HISTORY_BUFFER_LEN, ModelConstants.FEATURE_LEN), dtype=np.float32) + + # Optional inputs for non-v11 (and some v10/v9 variants) + # Lateral control params + if 'lateral_control_params' in input_shapes: + numpy_inputs['lateral_control_params'] = np.zeros((1, ModelConstants.LATERAL_CONTROL_PARAMS_LEN), dtype=np.float32) + + # Previous desired curvature: handle both singular and plural key names across model versions + prev_desired_curv_key = None + if 'prev_desired_curv' in input_shapes: + prev_desired_curv_key = 'prev_desired_curv' + numpy_inputs['prev_desired_curv'] = np.zeros((1, ModelConstants.INPUT_HISTORY_BUFFER_LEN, ModelConstants.PREV_DESIRED_CURV_LEN), dtype=np.float32) + elif 'prev_desired_curvs' in input_shapes: + prev_desired_curv_key = 'prev_desired_curvs' + numpy_inputs['prev_desired_curvs'] = np.zeros((1, ModelConstants.INPUT_HISTORY_BUFFER_LEN, ModelConstants.PREV_DESIRED_CURV_LEN), dtype=np.float32) + + return numpy_inputs, prev_desired_curv_key + def __init__(self, context: CLContext): # Dynamically build paths based on current model ID params = Params() @@ -101,13 +129,17 @@ class ModelState: models_dir = Path(__file__).parent / "models" VISION_PKL_PATH = models_dir / "driving_vision_tinygrad.pkl" POLICY_PKL_PATH = models_dir / "driving_policy_tinygrad.pkl" + OFF_POLICY_PKL_PATH = models_dir / "driving_off_policy_tinygrad.pkl" VISION_METADATA_PATH = models_dir / "driving_vision_metadata.pkl" POLICY_METADATA_PATH = models_dir / "driving_policy_metadata.pkl" + OFF_POLICY_METADATA_PATH = models_dir / "driving_off_policy_metadata.pkl" else: VISION_PKL_PATH = model_dir / f"{model_id}_driving_vision_tinygrad.pkl" POLICY_PKL_PATH = model_dir / f"{model_id}_driving_policy_tinygrad.pkl" + OFF_POLICY_PKL_PATH = model_dir / f"{model_id}_driving_off_policy_tinygrad.pkl" VISION_METADATA_PATH = model_dir / f"{model_id}_driving_vision_metadata.pkl" POLICY_METADATA_PATH = model_dir / f"{model_id}_driving_policy_metadata.pkl" + OFF_POLICY_METADATA_PATH = model_dir / f"{model_id}_driving_off_policy_metadata.pkl" # If ModelVersion is not set or not available, try to determine it from available model data if not model_version: @@ -159,7 +191,7 @@ class ModelState: self.policy_generation = model_version or "v8" self.is_v11 = (self.policy_generation == "v11") self.is_v9 = (self.policy_generation == "v9") - self.mlsim = (self.policy_generation in ("v8", "v10", "v11")) + self.mlsim = (self.policy_generation in ("v8", "v10", "v11", "v12")) self.frames = {name: DrivingModelFrame(context, ModelConstants.TEMPORAL_SKIP) for name in self.vision_input_names} self.prev_desire = np.zeros(ModelConstants.DESIRE_LEN, dtype=np.float32) @@ -170,33 +202,50 @@ class ModelState: # policy inputs (built dynamically to support all generations) - self.numpy_inputs = {} + self.numpy_inputs, self.prev_desired_curv_key = self._build_policy_inputs(self.policy_input_shapes) - # Always-supported inputs (if model expects them) - desire_key_init = next((k for k in self.policy_input_shapes if k.startswith('desire')), None) - if desire_key_init: - self.numpy_inputs[desire_key_init] = np.zeros((1, ModelConstants.INPUT_HISTORY_BUFFER_LEN, ModelConstants.DESIRE_LEN), dtype=np.float32) - if 'traffic_convention' in self.policy_input_shapes: - self.numpy_inputs['traffic_convention'] = np.zeros((1, ModelConstants.TRAFFIC_CONVENTION_LEN), dtype=np.float32) - if 'features_buffer' in self.policy_input_shapes: - self.numpy_inputs['features_buffer'] = np.zeros((1, ModelConstants.INPUT_HISTORY_BUFFER_LEN, ModelConstants.FEATURE_LEN), dtype=np.float32) + # Off-policy model (optional) + self.off_policy_enabled = False + self.off_policy_input_shapes: dict[str, tuple[int, ...]] = {} + self.off_policy_output_slices: dict[str, slice] = {} + self.off_policy_numpy_inputs: dict[str, np.ndarray] = {} + self.off_policy_prev_desired_curv_key: str | None = None + self.off_policy_desire_key: str | None = None + self.off_policy_inputs: dict[str, Tensor] | None = None + self.off_policy_output: np.ndarray | None = None - # Optional inputs for non-v11 (and some v10/v9 variants) - # Lateral control params - if 'lateral_control_params' in self.policy_input_shapes: - self.numpy_inputs['lateral_control_params'] = np.zeros((1, ModelConstants.LATERAL_CONTROL_PARAMS_LEN), dtype=np.float32) + off_policy_metadata = None + if self.policy_generation == "v12" or OFF_POLICY_METADATA_PATH.is_file() or OFF_POLICY_PKL_PATH.is_file(): + try: + with open(OFF_POLICY_METADATA_PATH, 'rb') as f: + off_policy_metadata = pickle.load(f) + except FileNotFoundError: + cloudlog.error(f"Missing metadata {OFF_POLICY_METADATA_PATH}, downloading...") + from openpilot.frogpilot.assets.model_manager import ModelManager + ModelManager().download_model(model_id) + try: + with open(OFF_POLICY_METADATA_PATH, 'rb') as f: + off_policy_metadata = pickle.load(f) + except FileNotFoundError: + cloudlog.warning(f"Off-policy metadata still missing: {OFF_POLICY_METADATA_PATH}") - # Previous desired curvature: handle both singular and plural key names across model versions - self.prev_desired_curv_key = None - if 'prev_desired_curv' in self.policy_input_shapes: - self.prev_desired_curv_key = 'prev_desired_curv' - self.numpy_inputs['prev_desired_curv'] = np.zeros((1, ModelConstants.INPUT_HISTORY_BUFFER_LEN, ModelConstants.PREV_DESIRED_CURV_LEN), dtype=np.float32) - elif 'prev_desired_curvs' in self.policy_input_shapes: - self.prev_desired_curv_key = 'prev_desired_curvs' - self.numpy_inputs['prev_desired_curvs'] = np.zeros((1, ModelConstants.INPUT_HISTORY_BUFFER_LEN, ModelConstants.PREV_DESIRED_CURV_LEN), dtype=np.float32) + if off_policy_metadata is not None: + self.off_policy_input_shapes = off_policy_metadata['input_shapes'] + self.off_policy_output_slices = off_policy_metadata['output_slices'] + off_policy_output_size = off_policy_metadata['output_shapes']['outputs'][1] + self.off_policy_numpy_inputs, self.off_policy_prev_desired_curv_key = self._build_policy_inputs(self.off_policy_input_shapes) + self.off_policy_desire_key = next((k for k in self.off_policy_numpy_inputs if k.startswith('desire')), None) + self.off_policy_inputs = {k: Tensor(v, device='NPY').realize() for k, v in self.off_policy_numpy_inputs.items()} + self.off_policy_output = np.zeros(off_policy_output_size, dtype=np.float32) + try: + with open(OFF_POLICY_PKL_PATH, "rb") as f: + self.off_policy_run = pickle.load(f) + self.off_policy_enabled = True + except FileNotFoundError: + cloudlog.warning(f"Missing off-policy model {OFF_POLICY_PKL_PATH}, skipping off-policy") - # Optional temporal buffer for previous desired curvature (allocate only if the policy expects it) - if getattr(self, 'prev_desired_curv_key', None) is not None: + # Optional temporal buffer for previous desired curvature (allocate only if any model expects it) + if self.prev_desired_curv_key is not None or self.off_policy_prev_desired_curv_key is not None: self.full_prev_desired_curv = np.zeros((1, ModelConstants.FULL_HISTORY_BUFFER_LEN, ModelConstants.PREV_DESIRED_CURV_LEN), dtype=np.float32) @@ -206,6 +255,7 @@ class ModelState: self.policy_inputs = {k: Tensor(v, device='NPY').realize() for k,v in self.numpy_inputs.items()} self.policy_output = np.zeros(policy_output_size, dtype=np.float32) self.parser = Parser() + self.off_policy_parser = Parser(ignore_missing=True) with open(VISION_PKL_PATH, "rb") as f: self.vision_run = pickle.load(f) @@ -231,10 +281,18 @@ class ModelState: self.full_desire[0,:-1] = self.full_desire[0,1:] self.full_desire[0,-1] = new_desire self.numpy_inputs[self.desire_key][:] = self.full_desire.reshape((1,ModelConstants.INPUT_HISTORY_BUFFER_LEN,ModelConstants.TEMPORAL_SKIP,-1)).max(axis=2) + if self.off_policy_enabled and self.off_policy_desire_key is not None: + self.off_policy_numpy_inputs[self.off_policy_desire_key][:] = self.numpy_inputs[self.desire_key] + + if 'traffic_convention' in self.numpy_inputs: + self.numpy_inputs['traffic_convention'][:] = inputs['traffic_convention'] + if self.off_policy_enabled and 'traffic_convention' in self.off_policy_numpy_inputs: + self.off_policy_numpy_inputs['traffic_convention'][:] = inputs['traffic_convention'] - self.numpy_inputs['traffic_convention'][:] = inputs['traffic_convention'] if 'lateral_control_params' in self.numpy_inputs: self.numpy_inputs['lateral_control_params'][:] = inputs['lateral_control_params'] + if self.off_policy_enabled and 'lateral_control_params' in self.off_policy_numpy_inputs: + self.off_policy_numpy_inputs['lateral_control_params'][:] = inputs['lateral_control_params'] if prepare_only: return None @@ -256,7 +314,10 @@ class ModelState: self.full_features_buffer[0,:-1] = self.full_features_buffer[0,1:] self.full_features_buffer[0,-1] = vision_outputs_dict['hidden_state'][0, :] - self.numpy_inputs['features_buffer'][:] = self.full_features_buffer[0, self.temporal_idxs] + if 'features_buffer' in self.numpy_inputs: + self.numpy_inputs['features_buffer'][:] = self.full_features_buffer[0, self.temporal_idxs] + if self.off_policy_enabled and 'features_buffer' in self.off_policy_numpy_inputs: + self.off_policy_numpy_inputs['features_buffer'][:] = self.full_features_buffer[0, self.temporal_idxs] self.policy_output = self.policy_run(**self.policy_inputs).contiguous().realize().uop.base.buffer.numpy() policy_outputs_dict = self.parser.parse_policy_outputs(self.slice_outputs(self.policy_output, self.policy_output_slices)) @@ -273,9 +334,24 @@ class ModelState: else: self.numpy_inputs[self.prev_desired_curv_key][:] = self.full_prev_desired_curv[0, self.temporal_idxs] + if self.off_policy_enabled and self.off_policy_prev_desired_curv_key is not None: + if self.is_v9: + self.off_policy_numpy_inputs[self.off_policy_prev_desired_curv_key][:] = 0 * self.full_prev_desired_curv[0, self.temporal_idxs] + else: + self.off_policy_numpy_inputs[self.off_policy_prev_desired_curv_key][:] = self.full_prev_desired_curv[0, self.temporal_idxs] + combined_outputs_dict = {**vision_outputs_dict, **policy_outputs_dict} + if self.off_policy_enabled: + self.off_policy_output = self.off_policy_run(**self.off_policy_inputs).contiguous().realize().uop.base.buffer.numpy() + off_policy_outputs_dict = self.off_policy_parser.parse_policy_outputs( + self.slice_outputs(self.off_policy_output, self.off_policy_output_slices) + ) + combined_outputs_dict.update(off_policy_outputs_dict) if SEND_RAW_PRED: - combined_outputs_dict['raw_pred'] = np.concatenate([self.vision_output.copy(), self.policy_output.copy()]) + raw_pred = [self.vision_output.copy(), self.policy_output.copy()] + if self.off_policy_enabled and self.off_policy_output is not None: + raw_pred.append(self.off_policy_output.copy()) + combined_outputs_dict['raw_pred'] = np.concatenate(raw_pred) return combined_outputs_dict diff --git a/frogpilot/ui/qt/offroad/model_settings.cc b/frogpilot/ui/qt/offroad/model_settings.cc index 0550202a7..67c0f9db9 100644 --- a/frogpilot/ui/qt/offroad/model_settings.cc +++ b/frogpilot/ui/qt/offroad/model_settings.cc @@ -503,6 +503,8 @@ bool FrogPilotModelPanel::isModelInstalled(const QString &key) const { bool has_policy_tg = false; bool has_vision_meta = false; bool has_vision_tg = false; + bool has_off_policy_meta = false; + bool has_off_policy_tg = false; bool foundAny = false; for (const QString &file : modelDir.entryList(QDir::Files)) { @@ -521,6 +523,10 @@ bool FrogPilotModelPanel::isModelInstalled(const QString &key) const { has_policy_meta = true; } else if (base.contains("_driving_policy_tinygrad")) { has_policy_tg = true; + } else if (base.contains("_driving_off_policy_metadata")) { + has_off_policy_meta = true; + } else if (base.contains("_driving_off_policy_tinygrad")) { + has_off_policy_tg = true; } else if (base.contains("_driving_vision_metadata")) { has_vision_meta = true; } else if (base.contains("_driving_vision_tinygrad")) { @@ -534,6 +540,9 @@ bool FrogPilotModelPanel::isModelInstalled(const QString &key) const { } if (has_policy_meta && has_policy_tg && has_vision_meta && has_vision_tg) { + if (has_off_policy_meta || has_off_policy_tg) { + return has_off_policy_meta && has_off_policy_tg; + } return true; }