#!/usr/bin/env python3 import json import re import urllib.request from pathlib import Path from openpilot.frogpilot.assets.download_functions import ( GITLAB_URL, download_file, get_repository_url, handle_error, handle_request_error, verify_download, ) from openpilot.frogpilot.common.frogpilot_utilities import delete_file from openpilot.frogpilot.common.frogpilot_variables import MODELS_PATH MANIFEST_CANDIDATES = ("v21",) TINYGRAD_VERSIONS = {"v8", "v9", "v10", "v11", "v12"} DEFAULT_MODEL_KEY = "sc" MODEL_KEY_CANONICAL_MAP = { "sc2": "sc", } CANCEL_DOWNLOAD_PARAM = "CancelModelDownload" DOWNLOAD_PROGRESS_PARAM = "ModelDownloadProgress" MODEL_DOWNLOAD_PARAM = "ModelToDownload" MODEL_DOWNLOAD_ALL_PARAM = "DownloadAllModels" UPDATE_TINYGRAD_PARAM = "UpdateTinygrad" def _clean_model_name(name: str) -> str: return re.sub(r"[πŸ—ΊοΈπŸ‘€πŸ“‘]", "", str(name or "")).strip() class ModelManager: def __init__(self, params, params_memory, boot_run=False): self.params = params self.params_memory = params_memory self.downloading_model = False self.available_models = [entry for entry in (self.params.get("AvailableModels") or "").split(",") if entry] self.model_versions = [entry for entry in (self.params.get("ModelVersions") or "").split(",") if entry] self.model_series = [entry for entry in (self.params.get("AvailableModelSeries") or "").split(",") if entry] self.available_model_names = [entry for entry in (self.params.get("AvailableModelNames") or "").split(",") if entry] self._ensure_model_params() if boot_run: self._sync_selected_model_version() @staticmethod def _canonical_model_key(model_key: str) -> str: key = (model_key or "").strip() return MODEL_KEY_CANONICAL_MAP.get(key, key) def _param_text(self, key: str) -> str: raw = self.params.get(key) if raw is None: return "" if isinstance(raw, bytes): return raw.decode("utf-8", errors="ignore").strip() return str(raw).strip() def _default_param_text(self, key: str) -> str: try: default_value = self.params.get_default_value(key) except Exception: return "" if default_value is None: return "" if isinstance(default_value, bytes): return default_value.decode("utf-8", errors="ignore").strip() return str(default_value).strip() def _set_model_param_keys(self, model_key: str | None = None, model_name: str | None = None, model_version: str | None = None): if model_key is not None and model_key != "": canonical_key = self._canonical_model_key(model_key) self.params.put("Model", canonical_key) self.params.put("DrivingModel", canonical_key) if model_name is not None and model_name != "": self.params.put("DrivingModelName", model_name) if model_version is not None and model_version != "": self.params.put("ModelVersion", model_version) self.params.put("DrivingModelVersion", model_version) def _ensure_model_params(self): selected_model = self._selected_model() current_version = self._param_text("ModelVersion") or self._param_text("DrivingModelVersion") if not current_version: current_version = self._default_param_text("ModelVersion") or self._default_param_text("DrivingModelVersion") or "v11" selected_name = self._param_text("DrivingModelName") if not selected_name and selected_model in self.available_models: selected_index = self.available_models.index(selected_model) if selected_index < len(self.available_model_names): selected_name = self.available_model_names[selected_index] self._set_model_param_keys(selected_model, selected_name, current_version) def _model_key_aliases(self, model_key: str) -> list[str]: canonical_key = self._canonical_model_key(model_key) aliases = [canonical_key] # Preserve legacy alias lookups (e.g. sc2) even when canonicalized to sc. for alias, canonical in MODEL_KEY_CANONICAL_MAP.items(): if canonical == canonical_key: aliases.append(alias) if model_key.endswith("_default"): aliases.append(model_key[:-8]) # v21 manifest uses legacy IDs with a trailing "2" (e.g. sc -> sc2). if model_key and not model_key.endswith("2"): aliases.append(f"{model_key}2") return [alias for alias in dict.fromkeys(aliases) if alias] def _model_version_map(self) -> dict[str, str]: return { model_key: self.model_versions[index] for index, model_key in enumerate(self.available_models) if index < len(self.model_versions) and model_key } def _selected_model(self) -> str: selected = self._param_text("Model") or self._param_text("DrivingModel") if selected: return self._canonical_model_key(selected) default_value = self._default_param_text("Model") or self._default_param_text("DrivingModel") if default_value: return self._canonical_model_key(default_value) return DEFAULT_MODEL_KEY def _required_files(self, model_key: str, model_version: str) -> list[str]: if model_version not in TINYGRAD_VERSIONS: return [] filenames = [ f"{model_key}_driving_policy_tinygrad.pkl", f"{model_key}_driving_vision_tinygrad.pkl", f"{model_key}_driving_policy_metadata.pkl", f"{model_key}_driving_vision_metadata.pkl", ] if model_version == "v12": filenames += [ f"{model_key}_driving_off_policy_tinygrad.pkl", f"{model_key}_driving_off_policy_metadata.pkl", ] return filenames def _is_model_downloaded(self, model_key: str, model_version: str) -> bool: required_files = self._required_files(model_key, model_version) if not required_files: return False return all((MODELS_PATH / filename).is_file() for filename in required_files) def _sync_selected_model_version(self): version_map = self._model_version_map() name_map = {model_key: model_name for model_key, model_name in zip(self.available_models, self.available_model_names)} selected = self._selected_model() version = version_map.get(selected) if version: self._set_model_param_keys(selected, name_map.get(selected), version) return for alias in self._model_key_aliases(selected): version = version_map.get(alias) if version: selected_name = name_map.get(selected) or name_map.get(alias) or self._param_text("DrivingModelName") self._set_model_param_keys(selected, selected_name, version) return fallback_version = self._param_text("ModelVersion") or self._param_text("DrivingModelVersion") if not fallback_version: fallback_version = self._default_param_text("ModelVersion") or self._default_param_text("DrivingModelVersion") or "v11" self._set_model_param_keys(selected, name_map.get(selected, ""), fallback_version) @staticmethod def _fetch_manifest(url: str) -> list[dict]: try: with urllib.request.urlopen(url, timeout=10) as response: payload = json.loads(response.read().decode("utf-8")) return payload.get("models", []) if isinstance(payload, dict) else [] except Exception as error: handle_request_error(error, None, None, None, None) return [] def _get_manifest(self, repo_url: str) -> tuple[str | None, list[dict]]: for manifest_version in MANIFEST_CANDIDATES: model_info = self._fetch_manifest(f"{repo_url}/Versions/model_names_{manifest_version}.json") if not model_info: continue # Desktop/dev build is tinygrad-only. filtered = [model for model in model_info if model.get("version") in TINYGRAD_VERSIONS] if not filtered: continue return manifest_version, filtered return None, [] def _remove_stale_model_files(self): valid_keys = set(self.available_models) for model_file in MODELS_PATH.glob("*_driving_*"): model_key = model_file.name.split("_driving_", 1)[0] if model_key not in valid_keys: delete_file(model_file, print_error=False) for temp_file in MODELS_PATH.glob("tmp*"): delete_file(temp_file, print_error=False) def _enforce_selected_model(self): if not self.available_models: return selected = self._selected_model() aliases = self._model_key_aliases(selected) if any(alias in self.available_models for alias in aliases): self._sync_selected_model_version() return try: default_model = self._default_param_text("Model") or self._default_param_text("DrivingModel") except Exception: default_model = DEFAULT_MODEL_KEY candidates = self._model_key_aliases(default_model) + self._model_key_aliases(DEFAULT_MODEL_KEY) + self.available_models replacement = next((entry for entry in candidates if entry in self.available_models), self.available_models[0]) replacement_index = self.available_models.index(replacement) replacement_name = self.available_model_names[replacement_index] if replacement_index < len(self.available_model_names) else replacement self._set_model_param_keys(replacement, replacement_name, None) self._sync_selected_model_version() def update_model_params(self, model_info: list[dict], manifest_version: str): del manifest_version self.available_models = [str(model.get("id") or "").strip() for model in model_info] self.available_model_names = [_clean_model_name(model.get("name")) for model in model_info] self.model_versions = [str(model.get("version") or "").strip() for model in model_info] self.model_series = [str(model.get("series") or "Custom Series").strip() for model in model_info] released_dates = [str(model.get("released") or "2023-01-01").strip() for model in model_info] community_favorites = [model_key for model_key, model in zip(self.available_models, model_info) if model.get("community_favorite", False)] self.params.put("AvailableModels", ",".join(self.available_models)) self.params.put("AvailableModelNames", ",".join(self.available_model_names)) self.params.put("AvailableModelSeries", ",".join(self.model_series)) self.params.put("ModelReleasedDates", ",".join(released_dates)) self.params.put("ModelVersions", ",".join(self.model_versions)) self.params.put("CommunityFavorites", ",".join(community_favorites)) self._sync_selected_model_version() try: version_map = {model_key: version for model_key, version in zip(self.available_models, self.model_versions)} versions_file = MODELS_PATH / ".model_versions.json" versions_file.parent.mkdir(parents=True, exist_ok=True) versions_file.write_text(json.dumps(version_map)) except Exception as error: print(f"Failed to write model versions cache: {error}") def check_models(self, boot_run: bool): del boot_run # Not currently needed, retained for call-site parity. self._remove_stale_model_files() self._enforce_selected_model() def update_models(self, boot_run=False): if self.downloading_model: return repo_url = get_repository_url() if repo_url is None: print("GitHub and GitLab are offline...") return manifest_version, model_info = self._get_manifest(repo_url) if not model_info: print("No compatible tinygrad manifest found.") return self.update_model_params(model_info, manifest_version or "unknown") self.check_models(boot_run) def download_model(self, model_to_download: str): self.downloading_model = True repo_url = get_repository_url() if not repo_url: handle_error(None, "GitHub and GitLab are offline...", "Repository unavailable", MODEL_DOWNLOAD_PARAM, DOWNLOAD_PROGRESS_PARAM, self.params_memory) self.downloading_model = False return version_map = self._model_version_map() model_version = version_map.get(model_to_download) required_files = self._required_files(model_to_download, model_version or "") if not required_files: handle_error(None, f"Unsupported model format for {model_to_download}", "Model download failed", MODEL_DOWNLOAD_PARAM, DOWNLOAD_PROGRESS_PARAM, self.params_memory) self.downloading_model = False return for filename in required_files: file_path = MODELS_PATH / filename file_url = f"{repo_url}/Models/{filename}" download_file(CANCEL_DOWNLOAD_PARAM, file_path, DOWNLOAD_PROGRESS_PARAM, file_url, MODEL_DOWNLOAD_PARAM, self.params_memory) if self.params_memory.get_bool(CANCEL_DOWNLOAD_PARAM): handle_error(None, "Download cancelled...", "Download cancelled...", MODEL_DOWNLOAD_PARAM, DOWNLOAD_PROGRESS_PARAM, self.params_memory) self.downloading_model = False return if verify_download(file_path, file_url): continue fallback_url = f"{GITLAB_URL}/Models/{filename}" download_file(CANCEL_DOWNLOAD_PARAM, file_path, DOWNLOAD_PROGRESS_PARAM, fallback_url, MODEL_DOWNLOAD_PARAM, self.params_memory) if self.params_memory.get_bool(CANCEL_DOWNLOAD_PARAM): handle_error(None, "Download cancelled...", "Download cancelled...", MODEL_DOWNLOAD_PARAM, DOWNLOAD_PROGRESS_PARAM, self.params_memory) self.downloading_model = False return if not verify_download(file_path, fallback_url): handle_error(file_path, "Verification failed...", f"Verification failed for {filename}", MODEL_DOWNLOAD_PARAM, DOWNLOAD_PROGRESS_PARAM, self.params_memory) self.downloading_model = False return self.params_memory.put(DOWNLOAD_PROGRESS_PARAM, "Downloaded!") self.params_memory.remove(MODEL_DOWNLOAD_PARAM) self.downloading_model = False def download_all_models(self): repo_url = get_repository_url() if not repo_url: handle_error(None, "GitHub and GitLab are offline...", "Repository unavailable", MODEL_DOWNLOAD_ALL_PARAM, DOWNLOAD_PROGRESS_PARAM, self.params_memory) return manifest_version, model_info = self._get_manifest(repo_url) if not model_info: handle_error(None, "Unable to fetch models...", "Model list unavailable", MODEL_DOWNLOAD_ALL_PARAM, DOWNLOAD_PROGRESS_PARAM, self.params_memory) return self.update_model_params(model_info, manifest_version or "unknown") for model_key, model_name in zip(self.available_models, self.available_model_names): if self.params_memory.get_bool(CANCEL_DOWNLOAD_PARAM): handle_error(None, "Download cancelled...", "Download cancelled...", MODEL_DOWNLOAD_ALL_PARAM, DOWNLOAD_PROGRESS_PARAM, self.params_memory) return model_version = self._model_version_map().get(model_key, "") if self._is_model_downloaded(model_key, model_version): continue self.params_memory.put(DOWNLOAD_PROGRESS_PARAM, f"Downloading \"{model_name}\"...") self.download_model(model_key) if self.params_memory.get_bool(CANCEL_DOWNLOAD_PARAM): return self.params_memory.put(DOWNLOAD_PROGRESS_PARAM, "All models downloaded!") self.params_memory.remove(MODEL_DOWNLOAD_ALL_PARAM) def update_tinygrad(self): # This branch ships tinygrad runtime in-tree. "Update" here refreshes local model files. self.params_memory.put(DOWNLOAD_PROGRESS_PARAM, "Updating...") for model_file in MODELS_PATH.glob("*_driving_*"): if model_file.is_file(): delete_file(model_file, print_error=False) model_versions_file = MODELS_PATH / ".model_versions.json" if model_versions_file.is_file(): delete_file(model_versions_file, print_error=False) self.params.put_bool("TinygradUpdateAvailable", False) self.params_memory.remove(UPDATE_TINYGRAD_PARAM) self.params_memory.remove(CANCEL_DOWNLOAD_PARAM) if self.params.get_bool("AutomaticallyDownloadModels"): self.params_memory.put_bool(MODEL_DOWNLOAD_ALL_PARAM, True) self.params_memory.put(DOWNLOAD_PROGRESS_PARAM, "Downloading...") else: self.params_memory.put(DOWNLOAD_PROGRESS_PARAM, "Updated!")