Files
StarPilot/starpilot/assets/model_manager.py
T
firestar5683 e9935624f7 smoosh smoosh
2026-05-26 21:59:46 -05:00

608 lines
23 KiB
Python

#!/usr/bin/env python3
import json
import random
import re
import urllib.request
from pathlib import Path
from openpilot.starpilot.assets.download_functions import (
GITLAB_URL,
download_file,
get_repository_url,
handle_error,
handle_request_error,
verify_download,
)
from openpilot.starpilot.common.model_versions import (
is_tinygrad_model_version,
uses_combined_driving_artifacts,
uses_split_off_policy_artifacts,
)
from openpilot.starpilot.common.starpilot_utilities import delete_file
from openpilot.starpilot.common.starpilot_variables import MODELS_PATH
MANIFEST_CANDIDATES = ("v21",)
TINYGRAD_VERSIONS = {f"v{i}" for i in range(8, 33)}
DEFAULT_MODEL_KEY = "sc2"
ARTIFACT_URLS_CACHE = ".model_artifact_urls.json"
MODEL_KEY_CANONICAL_MAP = {
"sc": DEFAULT_MODEL_KEY,
}
CANCEL_DOWNLOAD_PARAM = "CancelModelDownload"
DOWNLOAD_PROGRESS_PARAM = "ModelDownloadProgress"
MODEL_DOWNLOAD_PARAM = "ModelToDownload"
MODEL_DOWNLOAD_ALL_PARAM = "DownloadAllModels"
UPDATE_TINYGRAD_PARAM = "UpdateTinygrad"
def _clean_model_name(name: str) -> str:
return re.sub(r"[🗺️👀📡]", "", str(name or "")).strip()
def canonical_model_key(model_key: str) -> str:
key = (model_key or "").strip()
return MODEL_KEY_CANONICAL_MAP.get(key, key)
def is_builtin_model_key(model_key: str) -> bool:
return canonical_model_key(model_key) == DEFAULT_MODEL_KEY
def model_key_aliases(model_key: str) -> list[str]:
canonical_key = canonical_model_key(model_key)
aliases = [canonical_key]
for alias, canonical in MODEL_KEY_CANONICAL_MAP.items():
if canonical == canonical_key:
aliases.append(alias)
if model_key.endswith("_default"):
aliases.append(model_key[:-8])
if model_key and not model_key.endswith("2"):
aliases.append(f"{model_key}2")
return [alias for alias in dict.fromkeys(alias for alias in aliases if alias)]
class ModelManager:
def __init__(self, params, params_memory, boot_run=False):
self.params = params
self.params_memory = params_memory
self.downloading_model = False
self.available_models: list[str] = []
self.model_versions: list[str] = []
self.model_series: list[str] = []
self.available_model_names: list[str] = []
self._load_catalog_from_params()
self._ensure_model_params()
if boot_run:
self._sync_selected_model_version()
@staticmethod
def _canonical_model_key(model_key: str) -> str:
return canonical_model_key(model_key)
def _param_text(self, key: str) -> str:
raw = self.params.get(key)
if raw is None:
return ""
if isinstance(raw, bytes):
return raw.decode("utf-8", errors="ignore").strip()
return str(raw).strip()
def _param_bool(self, key: str) -> bool:
try:
return bool(self.params.get_bool(key))
except Exception:
return self._param_text(key).lower() in {"1", "true", "yes", "on"}
def _default_param_text(self, key: str) -> str:
try:
default_value = self.params.get_default_value(key)
except Exception:
return ""
if default_value is None:
return ""
if isinstance(default_value, bytes):
return default_value.decode("utf-8", errors="ignore").strip()
return str(default_value).strip()
def _resolve_mirrored_param(self, primary_key: str, secondary_key: str) -> str:
primary_val = self._param_text(primary_key)
secondary_val = self._param_text(secondary_key)
if primary_val == secondary_val:
return secondary_val or primary_val
primary_default = self._default_param_text(primary_key)
secondary_default = self._default_param_text(secondary_key)
primary_non_default = bool(primary_val) and primary_val != primary_default
secondary_non_default = bool(secondary_val) and secondary_val != secondary_default
if secondary_non_default:
return secondary_val
if primary_non_default:
return primary_val
return secondary_val or primary_val
def _load_catalog_from_params(self):
self.available_models = [entry for entry in self._param_text("AvailableModels").split(",") if entry]
self.model_versions = [entry for entry in self._param_text("ModelVersions").split(",") if entry]
self.model_series = [entry for entry in self._param_text("AvailableModelSeries").split(",") if entry]
self.available_model_names = [entry for entry in self._param_text("AvailableModelNames").split(",") if entry]
@staticmethod
def _manifest_paths(manifest_version: str) -> tuple[str, ...]:
return (
f"Versions/model_names_{manifest_version}.json",
f"model_names_{manifest_version}.json",
)
def _set_model_param_keys(self, model_key: str | None = None, model_name: str | None = None, model_version: str | None = None):
if model_key is not None and model_key != "":
canonical_key = self._canonical_model_key(model_key)
self.params.put("Model", canonical_key)
self.params.put("DrivingModel", canonical_key)
if model_name is not None and model_name != "":
self.params.put("DrivingModelName", model_name)
if model_version is not None and model_version != "":
self.params.put("ModelVersion", model_version)
self.params.put("DrivingModelVersion", model_version)
def _ensure_model_params(self):
selected_model = self._selected_model()
current_version = self._resolve_mirrored_param("ModelVersion", "DrivingModelVersion")
if not current_version:
current_version = self._default_param_text("ModelVersion") or self._default_param_text("DrivingModelVersion") or "v11"
selected_name = self._param_text("DrivingModelName")
if not selected_name and selected_model in self.available_models:
selected_index = self.available_models.index(selected_model)
if selected_index < len(self.available_model_names):
selected_name = self.available_model_names[selected_index]
self._set_model_param_keys(selected_model, selected_name, current_version)
def _model_key_aliases(self, model_key: str) -> list[str]:
return model_key_aliases(model_key)
def _model_version_map(self) -> dict[str, str]:
return {
model_key: self.model_versions[index]
for index, model_key in enumerate(self.available_models)
if index < len(self.model_versions) and model_key
}
def _blacklisted_model_keys(self) -> set[str]:
return {
self._canonical_model_key(entry)
for entry in self._param_text("BlacklistedModels").split(",")
if entry.strip()
}
def _selected_model(self) -> str:
selected = self._resolve_mirrored_param("Model", "DrivingModel")
if selected:
return self._canonical_model_key(selected)
default_value = self._default_param_text("Model") or self._default_param_text("DrivingModel")
if default_value:
return self._canonical_model_key(default_value)
return DEFAULT_MODEL_KEY
def _required_files(self, model_key: str, model_version: str) -> list[str]:
if not is_tinygrad_model_version(model_version):
return []
if uses_combined_driving_artifacts(model_version):
return [f"{model_key}_driving_tinygrad.pkl"]
filenames = [
f"{model_key}_driving_policy_tinygrad.pkl",
f"{model_key}_driving_vision_tinygrad.pkl",
f"{model_key}_driving_policy_metadata.pkl",
f"{model_key}_driving_vision_metadata.pkl",
]
if uses_split_off_policy_artifacts(model_version):
filenames += [
f"{model_key}_driving_off_policy_tinygrad.pkl",
f"{model_key}_driving_off_policy_metadata.pkl",
]
return filenames
@staticmethod
def _artifact_urls_cache_path() -> Path:
return MODELS_PATH / ARTIFACT_URLS_CACHE
def _load_artifact_url_map(self) -> dict[str, dict[str, str]]:
try:
cache_path = self._artifact_urls_cache_path()
if not cache_path.is_file():
return {}
payload = json.loads(cache_path.read_text())
if not isinstance(payload, dict):
return {}
normalized: dict[str, dict[str, str]] = {}
for model_key, urls in payload.items():
if not isinstance(urls, dict):
continue
normalized[str(model_key)] = {
str(filename): str(url)
for filename, url in urls.items()
if filename and url
}
return normalized
except Exception as error:
print(f"Failed to load artifact URL cache: {error}")
return {}
def _build_artifact_url_map(self, model_info: list[dict]) -> dict[str, dict[str, str]]:
artifact_url_map: dict[str, dict[str, str]] = {}
for model in model_info:
model_key = self._canonical_model_key(str(model.get("id") or "").strip())
model_version = str(model.get("version") or "").strip()
required_files = self._required_files(model_key, model_version)
if not model_key or not required_files:
continue
urls: dict[str, str] = {}
explicit_urls = model.get("artifact_urls") or model.get("download_urls")
if isinstance(explicit_urls, dict):
for filename, url in explicit_urls.items():
if filename and url:
urls[str(filename).strip()] = str(url).strip()
base_url = str(model.get("artifact_base_url") or model.get("download_base_url") or "").strip()
if base_url:
base_url = base_url.rstrip("/")
for filename in required_files:
urls.setdefault(filename, f"{base_url}/{filename}")
direct_url = str(model.get("artifact_url") or model.get("download_url") or "").strip()
if direct_url:
if len(required_files) == 1:
urls.setdefault(required_files[0], direct_url)
else:
matched_filename = next((filename for filename in required_files if Path(filename).name == Path(direct_url).name), None)
if matched_filename is not None:
urls.setdefault(matched_filename, direct_url)
if urls:
artifact_url_map[model_key] = urls
return artifact_url_map
def _is_model_downloaded(self, model_key: str, model_version: str) -> bool:
if is_builtin_model_key(model_key):
return True
required_files = self._required_files(model_key, model_version)
if not required_files:
return False
return all((MODELS_PATH / filename).is_file() for filename in required_files)
def _installed_model_choices(self) -> list[tuple[str, str, str]]:
self._load_catalog_from_params()
version_map = self._model_version_map()
blacklisted_keys = self._blacklisted_model_keys()
choices: list[tuple[str, str, str]] = []
seen_keys: set[str] = set()
for index, model_key in enumerate(self.available_models):
if not model_key:
continue
canonical_key = self._canonical_model_key(model_key)
if canonical_key in blacklisted_keys or canonical_key in seen_keys:
continue
model_version = version_map.get(model_key) or version_map.get(canonical_key) or ""
if not model_version and is_builtin_model_key(canonical_key):
model_version = self._default_param_text("ModelVersion") or self._default_param_text("DrivingModelVersion") or "v11"
if not self._is_model_downloaded(model_key, model_version):
continue
model_name = self.available_model_names[index] if index < len(self.available_model_names) else canonical_key
choices.append((canonical_key, model_name, model_version))
seen_keys.add(canonical_key)
return choices
def randomize_selected_model(self) -> str | None:
if not self._param_bool("ModelRandomizer"):
return None
choices = self._installed_model_choices()
if not choices:
print("Model Randomizer skipped: no installed, non-blacklisted models available.")
return None
selected = self._selected_model()
eligible_choices = [choice for choice in choices if self._canonical_model_key(choice[0]) != selected]
if not eligible_choices:
eligible_choices = choices
model_key, model_name, model_version = random.choice(eligible_choices)
if not model_version:
model_version = self._default_param_text("ModelVersion") or self._default_param_text("DrivingModelVersion") or "v11"
self._set_model_param_keys(model_key, model_name, model_version)
try:
self.params_memory.put_bool("StarPilotTogglesUpdated", True)
except Exception:
pass
print(f"Model Randomizer selected {model_name} ({model_key}, {model_version}).")
return model_key
def _sync_selected_model_version(self):
version_map = self._model_version_map()
name_map = {model_key: model_name for model_key, model_name in zip(self.available_models, self.available_model_names)}
selected = self._selected_model()
version = version_map.get(selected)
if version:
self._set_model_param_keys(selected, name_map.get(selected), version)
return
for alias in self._model_key_aliases(selected):
version = version_map.get(alias)
if version:
selected_name = name_map.get(selected) or name_map.get(alias) or self._param_text("DrivingModelName")
self._set_model_param_keys(selected, selected_name, version)
return
fallback_version = self._resolve_mirrored_param("ModelVersion", "DrivingModelVersion")
if not fallback_version:
fallback_version = self._default_param_text("ModelVersion") or self._default_param_text("DrivingModelVersion") or "v11"
self._set_model_param_keys(selected, name_map.get(selected, ""), fallback_version)
@staticmethod
def _fetch_manifest(url: str) -> list[dict]:
try:
with urllib.request.urlopen(url, timeout=10) as response:
payload = json.loads(response.read().decode("utf-8"))
return payload.get("models", []) if isinstance(payload, dict) else []
except Exception as error:
handle_request_error(error, None, None, None, None)
return []
def _get_manifest(self, repo_url: str) -> tuple[str | None, list[dict]]:
for manifest_version in MANIFEST_CANDIDATES:
for manifest_path in self._manifest_paths(manifest_version):
model_info = self._fetch_manifest(f"{repo_url}/{manifest_path}")
if not model_info:
continue
# Desktop/dev build is tinygrad-only.
filtered = [model for model in model_info if is_tinygrad_model_version(model.get("version"))]
if not filtered:
continue
return manifest_version, filtered
return None, []
def _remove_stale_model_files(self):
valid_keys = set(self.available_models)
for model_file in MODELS_PATH.glob("*_driving_*"):
model_key = model_file.name.split("_driving_", 1)[0]
if model_key not in valid_keys:
delete_file(model_file, print_error=False)
for temp_file in MODELS_PATH.glob("tmp*"):
delete_file(temp_file, print_error=False)
def _enforce_selected_model(self):
if not self.available_models:
return
selected = self._selected_model()
aliases = self._model_key_aliases(selected)
if any(alias in self.available_models for alias in aliases):
self._sync_selected_model_version()
return
try:
default_model = self._default_param_text("Model") or self._default_param_text("DrivingModel")
except Exception:
default_model = DEFAULT_MODEL_KEY
candidates = self._model_key_aliases(default_model) + self._model_key_aliases(DEFAULT_MODEL_KEY) + self.available_models
replacement = next((entry for entry in candidates if entry in self.available_models), self.available_models[0])
replacement_index = self.available_models.index(replacement)
replacement_name = self.available_model_names[replacement_index] if replacement_index < len(self.available_model_names) else replacement
self._set_model_param_keys(replacement, replacement_name, None)
self._sync_selected_model_version()
def update_model_params(self, model_info: list[dict], manifest_version: str):
del manifest_version
self.available_models = [str(model.get("id") or "").strip() for model in model_info]
self.available_model_names = [_clean_model_name(model.get("name")) for model in model_info]
self.model_versions = [str(model.get("version") or "").strip() for model in model_info]
self.model_series = [str(model.get("series") or "Custom Series").strip() for model in model_info]
released_dates = [str(model.get("released") or "2023-01-01").strip() for model in model_info]
community_favorites = [model_key for model_key, model in zip(self.available_models, model_info) if model.get("community_favorite", False)]
self.params.put("AvailableModels", ",".join(self.available_models))
self.params.put("AvailableModelNames", ",".join(self.available_model_names))
self.params.put("AvailableModelSeries", ",".join(self.model_series))
self.params.put("ModelReleasedDates", ",".join(released_dates))
self.params.put("ModelVersions", ",".join(self.model_versions))
self.params.put("CommunityFavorites", ",".join(community_favorites))
self._sync_selected_model_version()
try:
version_map = {model_key: version for model_key, version in zip(self.available_models, self.model_versions)}
versions_file = MODELS_PATH / ".model_versions.json"
versions_file.parent.mkdir(parents=True, exist_ok=True)
versions_file.write_text(json.dumps(version_map))
artifact_urls_file = self._artifact_urls_cache_path()
artifact_urls_file.write_text(json.dumps(self._build_artifact_url_map(model_info)))
except Exception as error:
print(f"Failed to write model versions cache: {error}")
def check_models(self, boot_run: bool):
del boot_run # Not currently needed, retained for call-site parity.
self._remove_stale_model_files()
self._enforce_selected_model()
def update_models(self, boot_run=False):
if self.downloading_model:
return
repo_url = get_repository_url()
if repo_url is None:
print("GitHub and GitLab are offline...")
return
manifest_version, model_info = self._get_manifest(repo_url)
if not model_info:
print("No compatible tinygrad manifest found.")
return
self.update_model_params(model_info, manifest_version or "unknown")
self.check_models(boot_run)
def download_model(self, model_to_download: str):
self.downloading_model = True
if is_builtin_model_key(model_to_download):
self.params_memory.put(DOWNLOAD_PROGRESS_PARAM, "Built-in model already downloaded.")
self.params_memory.remove(MODEL_DOWNLOAD_PARAM)
self.downloading_model = False
return
repo_url = get_repository_url()
if not repo_url:
handle_error(None, "GitHub and GitLab are offline...", "Repository unavailable", MODEL_DOWNLOAD_PARAM, DOWNLOAD_PROGRESS_PARAM, self.params_memory)
self.downloading_model = False
return
# Refresh from params so long-lived workers pick up manifest refreshes done by
# a separate ModelManager instance before we validate the requested model.
self._load_catalog_from_params()
version_map = self._model_version_map()
model_version = version_map.get(model_to_download)
model_artifact_urls = self._load_artifact_url_map()
artifact_urls = model_artifact_urls.get(self._canonical_model_key(model_to_download)) or model_artifact_urls.get(model_to_download) or {}
required_files = self._required_files(model_to_download, model_version or "")
if not required_files:
handle_error(None, f"Unsupported model format for {model_to_download}", "Model download failed", MODEL_DOWNLOAD_PARAM, DOWNLOAD_PROGRESS_PARAM, self.params_memory)
self.downloading_model = False
return
for filename in required_files:
file_path = MODELS_PATH / filename
candidate_urls: list[tuple[str, bool]] = []
custom_url = artifact_urls.get(filename, "").strip()
if custom_url:
candidate_urls.append((custom_url, True))
file_url = f"{repo_url}/Models/{filename}"
candidate_urls.append((file_url, False))
fallback_url = f"{GITLAB_URL}/Models/{filename}"
if fallback_url != file_url:
candidate_urls.append((fallback_url, False))
download_succeeded = False
for candidate_url, allow_unknown_size in candidate_urls:
download_file(
CANCEL_DOWNLOAD_PARAM,
file_path,
DOWNLOAD_PROGRESS_PARAM,
candidate_url,
MODEL_DOWNLOAD_PARAM,
self.params_memory,
allow_unknown_size=allow_unknown_size,
)
if self.params_memory.get_bool(CANCEL_DOWNLOAD_PARAM):
handle_error(None, "Download cancelled...", "Download cancelled...", MODEL_DOWNLOAD_PARAM, DOWNLOAD_PROGRESS_PARAM, self.params_memory)
self.downloading_model = False
return
if verify_download(file_path, candidate_url, allow_unknown_size=allow_unknown_size):
download_succeeded = True
break
if not download_succeeded:
handle_error(file_path, "Verification failed...", f"Verification failed for {filename}", MODEL_DOWNLOAD_PARAM, DOWNLOAD_PROGRESS_PARAM, self.params_memory)
self.downloading_model = False
return
self.params_memory.put(DOWNLOAD_PROGRESS_PARAM, "Downloaded!")
self.params_memory.remove(MODEL_DOWNLOAD_PARAM)
self.downloading_model = False
def download_all_models(self):
repo_url = get_repository_url()
if not repo_url:
handle_error(None, "GitHub and GitLab are offline...", "Repository unavailable", MODEL_DOWNLOAD_ALL_PARAM, DOWNLOAD_PROGRESS_PARAM, self.params_memory)
return
manifest_version, model_info = self._get_manifest(repo_url)
if not model_info:
handle_error(None, "Unable to fetch models...", "Model list unavailable", MODEL_DOWNLOAD_ALL_PARAM, DOWNLOAD_PROGRESS_PARAM, self.params_memory)
return
self.update_model_params(model_info, manifest_version or "unknown")
for model_key, model_name in zip(self.available_models, self.available_model_names):
if self.params_memory.get_bool(CANCEL_DOWNLOAD_PARAM):
handle_error(None, "Download cancelled...", "Download cancelled...", MODEL_DOWNLOAD_ALL_PARAM, DOWNLOAD_PROGRESS_PARAM, self.params_memory)
return
model_version = self._model_version_map().get(model_key, "")
if self._is_model_downloaded(model_key, model_version):
continue
self.params_memory.put(DOWNLOAD_PROGRESS_PARAM, f"Downloading \"{model_name}\"...")
self.download_model(model_key)
if self.params_memory.get_bool(CANCEL_DOWNLOAD_PARAM):
return
self.params_memory.put(DOWNLOAD_PROGRESS_PARAM, "All models downloaded!")
self.params_memory.remove(MODEL_DOWNLOAD_ALL_PARAM)
self.randomize_selected_model()
def update_tinygrad(self):
# This branch ships tinygrad runtime in-tree. "Update" here refreshes local model files.
self.params_memory.put(DOWNLOAD_PROGRESS_PARAM, "Updating...")
for model_file in MODELS_PATH.glob("*_driving_*"):
if model_file.is_file():
delete_file(model_file, print_error=False)
model_versions_file = MODELS_PATH / ".model_versions.json"
if model_versions_file.is_file():
delete_file(model_versions_file, print_error=False)
artifact_urls_file = self._artifact_urls_cache_path()
if artifact_urls_file.is_file():
delete_file(artifact_urls_file, print_error=False)
self.params.put_bool("TinygradUpdateAvailable", False)
self.params_memory.remove(UPDATE_TINYGRAD_PARAM)
self.params_memory.remove(CANCEL_DOWNLOAD_PARAM)
if self.params.get_bool("AutomaticallyDownloadModels"):
self.params_memory.put_bool(MODEL_DOWNLOAD_ALL_PARAM, True)
self.params_memory.put(DOWNLOAD_PROGRESS_PARAM, "Downloading...")
else:
self.params_memory.put(DOWNLOAD_PROGRESS_PARAM, "Updated!")