From 1bcfb0518d720103570ed2802ace1eaf765adcf2 Mon Sep 17 00:00:00 2001 From: firestar5683 <168790843+firestar5683@users.noreply.github.com> Date: Wed, 25 Mar 2026 23:56:26 -0500 Subject: [PATCH] Mici Model Switcher --- selfdrive/ui/mici/layouts/home.py | 88 ++- .../ui/mici/layouts/settings/driving_model.py | 732 ++++++++++++++++++ .../ui/mici/layouts/settings/settings.py | 4 + .../ui/mici/onroad/augmented_road_view.py | 29 +- selfdrive/ui/mici/widgets/dialog.py | 53 +- selfdrive/ui/onroad/augmented_road_view.py | 45 +- system/ui/widgets/selection_dialog.py | 19 +- 7 files changed, 908 insertions(+), 62 deletions(-) create mode 100644 selfdrive/ui/mici/layouts/settings/driving_model.py diff --git a/selfdrive/ui/mici/layouts/home.py b/selfdrive/ui/mici/layouts/home.py index 78d6cc31..44654356 100644 --- a/selfdrive/ui/mici/layouts/home.py +++ b/selfdrive/ui/mici/layouts/home.py @@ -1,4 +1,5 @@ import time +import re from cereal import log import pyray as rl @@ -6,6 +7,7 @@ from collections.abc import Callable from openpilot.system.ui.widgets.label import gui_label, MiciLabel, UnifiedLabel from openpilot.system.ui.widgets import Widget from openpilot.system.ui.lib.application import gui_app, FontWeight, DEFAULT_TEXT_COLOR, MousePos +from openpilot.frogpilot.common.frogpilot_variables import MODELS_PATH from openpilot.selfdrive.ui.ui_state import ui_state from openpilot.system.ui.text import wrap_text from openpilot.system.version import training_version, RELEASE_BRANCHES @@ -92,6 +94,7 @@ class MiciHomeLayout(Widget): self._version_text = None self._experimental_mode = False + self._current_model_name = "default" self._settings_txt = gui_app.texture("icons_mici/settings.png", 48, 48) self._experimental_txt = gui_app.texture("icons_mici/experimental_mode.png", 48, 48) @@ -127,6 +130,80 @@ class MiciHomeLayout(Widget): def _update_params(self): self._experimental_mode = ui_state.params.get_bool("ExperimentalMode") + def _clean_name(value: str) -> str: + return re.sub(r"[πŸ—ΊοΈπŸ‘€πŸ“‘]", "", value).replace("(Default)", "").strip() + + def _decode_default(value) -> str: + if isinstance(value, bytes): + return value.decode("utf-8", errors="ignore").strip() + return str(value or "").strip() + + model_key = (ui_state.params.get("Model", encoding="utf-8") or + ui_state.params.get("DrivingModel", encoding="utf-8") or "").strip() + current_param_name = _clean_name(ui_state.params.get("DrivingModelName", encoding="utf-8") or "") + + available_models = [entry.strip() for entry in (ui_state.params.get("AvailableModels", encoding="utf-8") or "").split(",")] + available_names = [entry.strip() for entry in (ui_state.params.get("AvailableModelNames", encoding="utf-8") or "").split(",")] + model_versions = [entry.strip() for entry in (ui_state.params.get("ModelVersions", encoding="utf-8") or "").split(",")] + model_name_map = { + key: _clean_name(name) + for key, name in zip(available_models, available_names) + if key and _clean_name(name) + } + model_version_map = { + key: version + for key, version in zip(available_models, model_versions) + if key and version + } + + default_key = _decode_default(ui_state.params.get_default_value("DrivingModel") or + ui_state.params.get_default_value("Model")) or "sc" + default_name = _clean_name(_decode_default(ui_state.params.get_default_value("DrivingModelName"))) or "South Carolina" + + def _is_model_installed(key: str) -> bool: + if not key: + return False + + # Built-in default model is always available. + if key == default_key: + return True + + if (MODELS_PATH / f"{key}.thneed").is_file(): + return True + + version = model_version_map.get(key, "") + required = [ + f"{key}_driving_policy_tinygrad.pkl", + f"{key}_driving_vision_tinygrad.pkl", + f"{key}_driving_policy_metadata.pkl", + f"{key}_driving_vision_metadata.pkl", + ] + if version == "v12": + required.extend([ + f"{key}_driving_off_policy_tinygrad.pkl", + f"{key}_driving_off_policy_metadata.pkl", + ]) + return all((MODELS_PATH / filename).is_file() for filename in required) + + # If a stale custom model is selected but not actually installed, show default. + if model_key and not _is_model_installed(model_key): + model_key = default_key + + resolved_name = "" + if model_key in model_name_map: + resolved_name = model_name_map[model_key] + elif model_key.endswith("2") and model_key[:-1] in model_name_map: + resolved_name = model_name_map[model_key[:-1]] + elif model_key == default_key or (model_key.endswith("2") and model_key[:-1] == default_key): + resolved_name = default_name + + if not resolved_name and current_param_name: + resolved_name = current_param_name + if not resolved_name: + resolved_name = default_name if (not model_key or model_key == default_key) else model_key + + self._current_model_name = resolved_name + def _update_state(self): if self.is_pressed and not self._is_pressed_prev: self._mouse_down_t = time.monotonic() @@ -187,8 +264,6 @@ class MiciHomeLayout(Widget): self._openpilot_label.render() if self._version_text is not None: - # release branch - release_branch = self._version_text[1] in RELEASE_BRANCHES version_pos = rl.Rectangle(text_pos.x, text_pos.y + self._openpilot_label.font_size + 16, 100, 44) self._version_label.set_text(self._version_text[0]) self._version_label.set_position(version_pos.x, version_pos.y) @@ -199,15 +274,20 @@ class MiciHomeLayout(Widget): self._date_label.render() self._branch_label.set_max_width(gui_app.width - self._version_label.rect.width - self._date_label.rect.width - 32) - self._branch_label.set_text(" " + ("release" if release_branch else self._version_text[1])) + self._branch_label.set_text(" " + self._current_model_name) self._branch_label.set_position(version_pos.x + self._version_label.rect.width + self._date_label.rect.width + 20, version_pos.y) self._branch_label.render() - if not release_branch: + if self._version_text[1] not in RELEASE_BRANCHES: # 2nd line self._version_commit_label.set_text(self._version_text[2]) self._version_commit_label.set_position(version_pos.x, version_pos.y + self._date_label.font_size + 7) self._version_commit_label.render() + else: + self._branch_label.set_max_width(gui_app.width - 32) + self._branch_label.set_text(self._current_model_name) + self._branch_label.set_position(text_pos.x, text_pos.y + self._openpilot_label.font_size + 16) + self._branch_label.render() self._render_bottom_status_bar() diff --git a/selfdrive/ui/mici/layouts/settings/driving_model.py b/selfdrive/ui/mici/layouts/settings/driving_model.py new file mode 100644 index 00000000..eae6cdd2 --- /dev/null +++ b/selfdrive/ui/mici/layouts/settings/driving_model.py @@ -0,0 +1,732 @@ +from __future__ import annotations + +import re +import threading +import time +from dataclasses import dataclass +from collections.abc import Callable + +from openpilot.common.params import Params +from openpilot.frogpilot.assets.model_manager import ( + CANCEL_DOWNLOAD_PARAM, + DOWNLOAD_PROGRESS_PARAM, + ModelManager, + TINYGRAD_VERSIONS, +) +from openpilot.frogpilot.common.frogpilot_variables import MODELS_PATH +from openpilot.selfdrive.ui.mici.widgets.button import BigButton +from openpilot.selfdrive.ui.mici.widgets.dialog import BigDialog, BigDialogBase, BigMultiOptionDialog +from openpilot.selfdrive.ui.ui_state import ui_state +from openpilot.system.ui.lib.application import gui_app, FontWeight +from openpilot.system.ui.widgets import DialogResult, Widget +from openpilot.system.ui.widgets.label import gui_label +import pyray as rl + +MANIFEST_STALE_SECONDS = 60 * 60 +_PROGRESS_HOLD_SECONDS = 2.5 +_DOWNLOAD_DIALOG_CLOSE_SECONDS = 1.0 +_TERMINAL_PROGRESS_PATTERNS = ( + "downloaded", + "cancelled", + "failed", + "offline", + "verification failed", + "server error", + "download invalid", +) +_SORT_MODE_PARAM = "ModelSortMode" +_SORT_MODE_ALPHABETICAL = "alphabetical" +_SORT_MODE_DATE_NEWEST = "date" +_SORT_MODE_DATE_OLDEST = "date_oldest" +_SORT_MODE_FAVORITES = "favorites" +_SORT_MODES = ( + _SORT_MODE_ALPHABETICAL, + _SORT_MODE_DATE_NEWEST, + _SORT_MODE_DATE_OLDEST, + _SORT_MODE_FAVORITES, +) +_SORT_MODE_LABELS = { + _SORT_MODE_ALPHABETICAL: "alphabetical", + _SORT_MODE_DATE_NEWEST: "newest", + _SORT_MODE_DATE_OLDEST: "oldest", + _SORT_MODE_FAVORITES: "favorites", +} +_LABEL_TO_SORT_MODE = {label: mode for mode, label in _SORT_MODE_LABELS.items()} + + +@dataclass +class ModelEntry: + key: str + name: str + series: str + version: str + released: str + community_favorite: bool + + +def _clean_model_name(name: str) -> str: + cleaned = re.sub(r"[πŸ—ΊοΈπŸ‘€πŸ“‘]", "", str(name or "")).replace("(Default)", "") + return cleaned.strip() + + +def _split_param(param_value: str | None) -> list[str]: + return [item.strip() for item in (param_value or "").split(",") if item.strip()] + + +class DownloadProgressDialog(BigDialogBase): + def __init__(self, params_memory: Params, is_downloading: Callable[[], bool], cancel_callback: Callable[[], None], + is_terminal_progress: Callable[[str], bool]): + super().__init__() + self._params_memory = params_memory + self._is_downloading = is_downloading + self._cancel_callback = cancel_callback + self._is_terminal_progress = is_terminal_progress + + self._progress = 0.0 + self._status = "Downloading..." + self._terminal_progress_since = 0.0 + self._downloading = False + + self._cancel_btn = DownloadActionButton("cancel download") + self._cancel_btn.set_click_callback(self._cancel_callback) + + def show_event(self): + super().show_event() + self._cancel_btn.show_event() + + def hide_event(self): + super().hide_event() + self._cancel_btn.hide_event() + + @staticmethod + def _parse_progress(progress: str) -> float | None: + match = re.search(r"(\d{1,3})\s*%", progress) + if match: + return max(0.0, min(float(match.group(1)) / 100.0, 1.0)) + lowered = progress.lower() + if "verifying" in lowered: + return 1.0 + if "downloaded" in lowered: + return 1.0 + return None + + def _update_state(self): + super()._update_state() + + progress = self._params_memory.get(DOWNLOAD_PROGRESS_PARAM, encoding="utf-8") or "" + is_downloading = self._is_downloading() + self._downloading = is_downloading + parsed_progress = self._parse_progress(progress) + + if parsed_progress is not None: + self._progress = parsed_progress + + if progress: + self._status = progress + elif is_downloading: + self._status = "Downloading..." + else: + self._status = "Downloaded!" + self._progress = max(self._progress, 1.0) + + self._cancel_btn.set_enabled(is_downloading) + + terminal = (not is_downloading) and ( + not progress or + self._is_terminal_progress(progress) or + (parsed_progress is not None and parsed_progress >= 1.0) + ) + if terminal: + if self._terminal_progress_since == 0.0: + self._terminal_progress_since = time.monotonic() + elif time.monotonic() - self._terminal_progress_since >= _DOWNLOAD_DIALOG_CLOSE_SECONDS: + self._ret = DialogResult.CONFIRM + else: + self._terminal_progress_since = 0.0 + + def _render(self, _): + super()._render(_) + + width = int(self._rect.width) + height = int(self._rect.height) + + rl.draw_rectangle(0, 0, width, height, rl.Color(0, 0, 0, 170)) + + panel_margin_x = 40 + panel_margin_y = 28 + panel_width = int(min(width - panel_margin_x * 2, 980)) + panel_height = int(min(height - panel_margin_y * 2, 300)) + panel_x = int((width - panel_width) / 2) + panel_y = int((height - panel_height) / 2) + panel_rect = rl.Rectangle(panel_x, panel_y, panel_width, panel_height) + rl.draw_rectangle_rounded(panel_rect, 0.08, 24, rl.Color(16, 16, 16, 245)) + rl.draw_rectangle_rounded_lines_ex(panel_rect, 0.08, 24, 2, rl.Color(255, 255, 255, 32)) + + cancel_w = int(self._cancel_btn.rect.width) + cancel_h = int(self._cancel_btn.rect.height) + cancel_x = int(panel_x + (panel_width - cancel_w) / 2) + + top_padding = 22 + side_padding = 36 + gap = 16 + bar_h = 34 + status_h = 44 + bottom_padding = 24 + + bar_y = int(panel_y + top_padding) + status_y = int(bar_y + bar_h + gap) + cancel_y = int(status_y + status_h + gap) + max_cancel_y = int(panel_y + panel_height - cancel_h - bottom_padding) + if cancel_y > max_cancel_y: + cancel_y = max_cancel_y + + status_rect = rl.Rectangle(panel_x + side_padding, status_y, panel_width - side_padding * 2, status_h) + gui_label( + status_rect, + self._status, + 32, + font_weight=FontWeight.MEDIUM, + alignment=rl.GuiTextAlignment.TEXT_ALIGN_CENTER, + ) + + bar_width = int(min(panel_width - side_padding * 2, 760)) + bar_x = int(panel_x + (panel_width - bar_width) / 2) + self._cancel_btn.render(rl.Rectangle(cancel_x, cancel_y, cancel_w, cancel_h)) + + # Draw the bar last so it cannot get hidden by other dialog graphics. + bar_rect = rl.Rectangle(bar_x, bar_y, bar_width, bar_h) + rl.draw_rectangle_rounded(bar_rect, 0.35, 16, rl.Color(34, 34, 34, 255)) + rl.draw_rectangle_rounded_lines_ex(bar_rect, 0.35, 16, 2, rl.Color(255, 255, 255, 95)) + + fill_padding = 4 + fill_h = bar_h - fill_padding * 2 + bar_inner_w = max(0, bar_width - fill_padding * 2) + clamped_progress = min(max(self._progress, 0.0), 1.0) + + if self._downloading and clamped_progress <= 0.001: + # If backend reports status text but no percentage yet, show animated activity. + segment_w = max(70.0, bar_inner_w * 0.22) + max_offset = max(1.0, bar_inner_w - segment_w) + phase = (time.monotonic() * 1.35) % 1.0 + fill_x = bar_x + fill_padding + (phase * max_offset) + rl.draw_rectangle_rounded( + rl.Rectangle(fill_x, bar_y + fill_padding, segment_w, fill_h), + 0.35, + 16, + rl.Color(70, 91, 234, 255), + ) + else: + fill_width = max(0.0, bar_inner_w * clamped_progress) + if fill_width > 0: + rl.draw_rectangle_rounded( + rl.Rectangle(bar_x + fill_padding, bar_y + fill_padding, fill_width, fill_h), + 0.35, + 16, + rl.Color(70, 91, 234, 255), + ) + + if clamped_progress > 0.0: + gui_label( + bar_rect, + f"{int(clamped_progress * 100)}%", + 24, + color=rl.Color(255, 255, 255, 210), + font_weight=FontWeight.BOLD, + alignment=rl.GuiTextAlignment.TEXT_ALIGN_CENTER, + ) + + return self._ret + + +class DownloadActionButton(Widget): + def __init__(self, label: str): + super().__init__() + self._label = label + self.set_rect(rl.Rectangle(0, 0, 380, 86)) + + def set_label(self, label: str): + self._label = label + + def _render(self, _): + if not self.enabled: + bg = rl.Color(48, 48, 48, 255) + border = rl.Color(255, 255, 255, 45) + label_alpha = 145 + elif self.is_pressed: + bg = rl.Color(72, 86, 170, 255) + border = rl.Color(255, 255, 255, 85) + label_alpha = 235 + else: + bg = rl.Color(58, 70, 146, 255) + border = rl.Color(255, 255, 255, 70) + label_alpha = 225 + + rl.draw_rectangle_rounded(self._rect, 0.35, 14, bg) + rl.draw_rectangle_rounded_lines_ex(self._rect, 0.35, 14, 2, border) + + gui_label( + self._rect, + self._label, + 34, + color=rl.Color(255, 255, 255, label_alpha), + font_weight=FontWeight.MEDIUM, + alignment=rl.GuiTextAlignment.TEXT_ALIGN_CENTER, + ) + + +class DrivingModelBigButton(BigButton): + def __init__(self): + super().__init__("driving model", "", "icons_mici/settings/device/lkas.png") + self._params = Params() + self._params_memory = Params(memory=True) + self._model_manager = ModelManager(self._params, self._params_memory) + + self._worker_thread: threading.Thread | None = None + self._active_job = "" + self._manifest_last_refresh_mono = 0.0 + self._terminal_progress_since = 0.0 + + self.set_click_callback(self._open_manager_menu) + self.refresh() + + def show_event(self): + super().show_event() + self.refresh() + # Always fetch manifest once when this settings pane opens. + self._maybe_refresh_manifest(force=(self._manifest_last_refresh_mono == 0.0)) + + def refresh(self): + self._update_button_value() + + def _update_state(self): + super()._update_state() + self._process_terminal_progress() + self._update_button_value() + + def _open_manager_menu(self): + options = [ + "set sort mode", + "switch model", + "download model", + "download all missing", + "refresh manifest", + ] + + if not options: + return + + def on_confirm(): + value = option_dialog.get_selected_option() + if value == "set sort mode": + self._open_sort_mode_dialog() + elif value == "switch model": + self._open_switch_dialog() + elif value == "download model": + self._open_download_dialog() + elif value == "download all missing": + self._download_all_missing() + elif value == "refresh manifest": + self._maybe_refresh_manifest(force=True) + + default_option = "switch model" if "switch model" in options else options[0] + option_dialog = BigMultiOptionDialog(options=options, default=default_option, right_btn_callback=on_confirm) + gui_app.set_modal_overlay(option_dialog) + + def _open_switch_dialog(self): + self._maybe_refresh_manifest(force=False) + entries = self._load_model_entries() + if not entries: + message = "Refreshing model list..." if self._active_job == "refresh" else "Refresh manifest and try again." + self._show_message("Model list unavailable", message, return_to_manager=True) + return + + installed = [entry for entry in entries if self._is_model_installed(entry.key, entry.version)] + if not installed: + self._show_message("No downloaded models", "Download a model first.", return_to_manager=True) + return + + current_key = self._get_current_model_key() + self._show_model_dialog("Select Driving Model", installed, current_key, self._switch_model) + + def _open_download_dialog(self): + if ui_state.started: + self._show_message("Downloads blocked while driving", "Try again offroad.", return_to_manager=True) + return + + self._maybe_refresh_manifest(force=False) + entries = self._load_model_entries() + if not entries: + message = "Refreshing model list..." if self._active_job == "refresh" else "Refresh manifest and try again." + self._show_message("Model list unavailable", message, return_to_manager=True) + return + + missing = [entry for entry in entries if not self._is_model_installed(entry.key, entry.version)] + if not missing: + self._show_message("All models downloaded", "No additional models are available.", return_to_manager=True) + return + + self._show_model_dialog("Download Driving Model", missing, "", self._start_model_download) + + def _download_all_missing(self): + if ui_state.started: + self._show_message("Downloads blocked while driving", "Try again offroad.", return_to_manager=True) + return + + self._maybe_refresh_manifest(force=False) + entries = self._load_model_entries() + if not entries: + self._show_message("Model list unavailable", "Refresh manifest and try again.", return_to_manager=True) + return + + missing_exists = any(not self._is_model_installed(entry.key, entry.version) for entry in entries) + if not missing_exists: + self._show_message("All models downloaded", "No additional models are available.", return_to_manager=True) + return + + if not self._start_worker("download_all", self._run_download_all): + self._show_message("Model manager busy", "Please wait for the current task.", return_to_manager=True) + return + + self._show_download_progress_dialog() + + def _cancel_download(self): + if not self._is_download_job_running(): + return + self._params_memory.put_bool(CANCEL_DOWNLOAD_PARAM, True) + + def _open_sort_mode_dialog(self): + options = [_SORT_MODE_LABELS[mode] for mode in _SORT_MODES] + current_mode = self._get_sort_mode() + + def on_confirm(): + selected_label = sort_dialog.get_selected_option() + selected_mode = _LABEL_TO_SORT_MODE.get(selected_label, _SORT_MODE_ALPHABETICAL) + self._params.put(_SORT_MODE_PARAM, selected_mode) + self._open_manager_menu_if_no_overlay() + + sort_dialog = BigMultiOptionDialog(options=options, default=_SORT_MODE_LABELS[current_mode], right_btn_callback=on_confirm) + sort_dialog.set_back_callback(self._open_manager_menu) + gui_app.set_modal_overlay(sort_dialog) + + def _get_sort_mode(self) -> str: + mode = (self._params.get(_SORT_MODE_PARAM, encoding="utf-8") or "").strip() + return mode if mode in _SORT_MODES else _SORT_MODE_ALPHABETICAL + + def _show_model_dialog(self, title: str, entries: list[ModelEntry], current_key: str, + on_selected: Callable[[str], None]): + options, option_to_key, key_to_option = self._build_model_options(entries) + if not options: + self._show_message("No models available", "Refresh manifest and try again.", return_to_manager=True) + return + + default_option = key_to_option.get(current_key, options[0]) + + def on_confirm(): + model_key = option_to_key.get(model_dialog.get_selected_option()) + if model_key: + on_selected(model_key) + self._open_manager_menu_if_no_overlay() + + model_dialog = BigMultiOptionDialog(options=options, default=default_option, right_btn_callback=on_confirm) + model_dialog.set_back_callback(self._open_manager_menu) + gui_app.set_modal_overlay(model_dialog) + + def _build_model_options(self, entries: list[ModelEntry]) -> tuple[list[str], dict[str, str], dict[str, str]]: + # Ensure display names are unique before applying status text (date/favorite). + display_names: dict[str, str] = {} + for entry in entries: + name = entry.name + if name in display_names and display_names[name] != entry.key: + name = f"{entry.name} [{entry.series}]" + if name in display_names and display_names[name] != entry.key: + name = f"{entry.name} [{entry.key}]" + display_names[name] = entry.key + + key_to_display = {key: name for name, key in display_names.items()} + sorted_entries = self._sort_entries(entries) + + options: list[str] = [] + option_to_key: dict[str, str] = {} + key_to_option: dict[str, str] = {} + + for entry in sorted_entries: + display_name = key_to_display.get(entry.key, entry.name) + favorite_prefix = "β™₯ " if entry.community_favorite else "" + option = f"{favorite_prefix}{display_name}" + if option in option_to_key: + option = f"{option} [{entry.key}]" + + options.append(option) + option_to_key[option] = entry.key + key_to_option[entry.key] = option + + return options, option_to_key, key_to_option + + def _sort_entries(self, entries: list[ModelEntry]) -> list[ModelEntry]: + sort_mode = self._get_sort_mode() + entry_list = sorted(entries, key=lambda entry: (entry.series.lower(), entry.name.lower())) + + def normalized_release(entry: ModelEntry) -> str: + return entry.released if entry.released else "0000-00-00" + + if sort_mode == _SORT_MODE_DATE_NEWEST: + return sorted(entry_list, key=normalized_release, reverse=True) + if sort_mode == _SORT_MODE_DATE_OLDEST: + return sorted(entry_list, key=normalized_release) + if sort_mode == _SORT_MODE_FAVORITES: + return sorted(entry_list, key=lambda entry: ( + 0 if entry.community_favorite else 1, + entry.series.lower(), + entry.name.lower(), + )) + return entry_list + + def _start_model_download(self, model_key: str): + if not self._start_worker("download", self._run_download_one, model_key): + self._show_message("Model manager busy", "Please wait for the current task.", return_to_manager=True) + return + + self._show_download_progress_dialog() + + def _run_download_one(self, model_key: str): + self._params_memory.put_bool(CANCEL_DOWNLOAD_PARAM, False) + self._model_manager.download_model(model_key) + + entries = {entry.key: entry for entry in self._load_model_entries()} + entry = entries.get(model_key) + model_version = entry.version if entry else "" + if not self._is_model_installed(model_key, model_version): + self._params_memory.put(DOWNLOAD_PROGRESS_PARAM, "Verification failed...") + + def _run_download_all(self): + self._params_memory.put_bool(CANCEL_DOWNLOAD_PARAM, False) + self._model_manager.download_all_models() + + def _run_manifest_refresh(self): + self._model_manager.update_models() + + def _switch_model(self, model_key: str): + entries = {entry.key: entry for entry in self._load_model_entries()} + entry = entries.get(model_key) + + if entry is None: + self._show_message("Model unavailable", "Refresh manifest and try again.", return_to_manager=True) + return + + if not self._is_model_installed(entry.key, entry.version): + self._show_message("Model not downloaded", "Download this model first.", return_to_manager=True) + return + + self._params.put("Model", entry.key) + self._params.put("DrivingModel", entry.key) + self._params.put("DrivingModelName", entry.name) + + version = entry.version.strip() + if version: + self._params.put("ModelVersion", version) + self._params.put("DrivingModelVersion", version) + + if ui_state.started: + self._params.put_bool("OnroadCycleRequested", True) + self._show_message("Model switched", "Drive-cycle requested for immediate apply.", return_to_manager=True) + + self.refresh() + + def _start_worker(self, job: str, target, *args) -> bool: + if self._worker_thread is not None and self._worker_thread.is_alive(): + return False + + self._active_job = job + + def run_job(): + try: + target(*args) + finally: + if job == "refresh": + self._manifest_last_refresh_mono = time.monotonic() + self._active_job = "" + + self._worker_thread = threading.Thread(target=run_job, daemon=True) + self._worker_thread.start() + return True + + def _maybe_refresh_manifest(self, force: bool): + if ui_state.started: + return + + now = time.monotonic() + has_entries = bool(self._load_model_entries()) + stale = (now - self._manifest_last_refresh_mono) > MANIFEST_STALE_SECONDS + + if force or not has_entries or stale: + self._start_worker("refresh", self._run_manifest_refresh) + + def _is_download_job_running(self) -> bool: + return self._active_job in {"download", "download_all"} + + def _process_terminal_progress(self): + if self._is_download_job_running(): + self._terminal_progress_since = 0.0 + return + + progress = self._params_memory.get(DOWNLOAD_PROGRESS_PARAM, encoding="utf-8") or "" + if not progress: + self._terminal_progress_since = 0.0 + return + + if not self._is_terminal_progress(progress): + return + + if self._terminal_progress_since == 0.0: + self._terminal_progress_since = time.monotonic() + return + + if time.monotonic() - self._terminal_progress_since >= _PROGRESS_HOLD_SECONDS: + self._params_memory.remove(CANCEL_DOWNLOAD_PARAM) + self._params_memory.remove(DOWNLOAD_PROGRESS_PARAM) + self._terminal_progress_since = 0.0 + + def _update_button_value(self): + if self._active_job == "refresh": + value = "refreshing" + elif self._is_download_job_running(): + value = self._params_memory.get(DOWNLOAD_PROGRESS_PARAM, encoding="utf-8") or "downloading..." + else: + progress = self._params_memory.get(DOWNLOAD_PROGRESS_PARAM, encoding="utf-8") or "" + if progress and self._is_terminal_progress(progress): + value = progress + else: + value = self._get_current_model_name() + + if value != self.get_value(): + self.set_value(value) + + def _load_model_entries(self) -> list[ModelEntry]: + available_models = _split_param(self._params.get("AvailableModels", encoding="utf-8")) + model_names = _split_param(self._params.get("AvailableModelNames", encoding="utf-8")) + model_series = [item.strip() for item in (self._params.get("AvailableModelSeries", encoding="utf-8") or "").split(",")] + model_versions = [item.strip() for item in (self._params.get("ModelVersions", encoding="utf-8") or "").split(",")] + released_dates = [item.strip() for item in (self._params.get("ModelReleasedDates", encoding="utf-8") or "").split(",")] + community_favs = set(_split_param(self._params.get("CommunityFavorites", encoding="utf-8"))) + + size = min(len(available_models), len(model_names)) + entries: list[ModelEntry] = [] + + for i in range(size): + key = available_models[i].strip() + name = _clean_model_name(model_names[i]) + if not key or not name: + continue + + series = model_series[i].strip() if i < len(model_series) and model_series[i].strip() else "Custom Series" + version = model_versions[i].strip() if i < len(model_versions) else "" + released = released_dates[i].strip() if i < len(released_dates) else "" + + entries.append(ModelEntry( + key=key, + name=name, + series=series, + version=version, + released=released, + community_favorite=(key in community_favs), + )) + + return entries + + def _get_current_model_key(self) -> str: + model_key = self._params.get("Model", encoding="utf-8") or self._params.get("DrivingModel", encoding="utf-8") or "" + if model_key: + return model_key + + default_key = self._params.get_default_value("Model") or self._params.get_default_value("DrivingModel") + if isinstance(default_key, bytes): + return default_key.decode("utf-8", errors="ignore").strip() + return str(default_key or "").strip() + + def _get_current_model_name(self) -> str: + current_name = _clean_model_name(self._params.get("DrivingModelName", encoding="utf-8") or "") + if current_name: + return current_name + + current_key = self._get_current_model_key() + for entry in self._load_model_entries(): + if entry.key == current_key: + return entry.name + + return "default" + + def _is_model_installed(self, key: str, version: str) -> bool: + if not key: + return False + + if self._is_builtin_default_model(key): + return True + + if (MODELS_PATH / f"{key}.thneed").is_file(): + return True + + required_files = self._required_files_for_version(key, version) + if not required_files: + return False + + return all((MODELS_PATH / filename).is_file() for filename in required_files) + + def _is_builtin_default_model(self, key: str) -> bool: + default_key = self._params.get_default_value("DrivingModel") or self._params.get_default_value("Model") + if isinstance(default_key, bytes): + default_key = default_key.decode("utf-8", errors="ignore") + default_key = str(default_key or "").strip() + if not default_key: + default_key = "sc" + + # Manifest can expose legacy IDs like "sc2" while default remains "sc". + if key == default_key: + return True + if default_key.endswith("2") and key == default_key[:-1]: + return True + if not default_key.endswith("2") and key == f"{default_key}2": + return True + return False + + def _required_files_for_version(self, key: str, version: str) -> list[str]: + if version not in TINYGRAD_VERSIONS: + return [] + + files = [ + f"{key}_driving_policy_tinygrad.pkl", + f"{key}_driving_vision_tinygrad.pkl", + f"{key}_driving_policy_metadata.pkl", + f"{key}_driving_vision_metadata.pkl", + ] + + if version == "v12": + files.extend([ + f"{key}_driving_off_policy_tinygrad.pkl", + f"{key}_driving_off_policy_metadata.pkl", + ]) + + return files + + @staticmethod + def _is_terminal_progress(progress: str) -> bool: + lower = progress.lower() + return any(pattern in lower for pattern in _TERMINAL_PROGRESS_PATTERNS) + + def _open_manager_menu_if_no_overlay(self): + if gui_app._modal_overlay.overlay is None: + self._open_manager_menu() + + def _show_download_progress_dialog(self): + dialog = DownloadProgressDialog( + params_memory=self._params_memory, + is_downloading=self._is_download_job_running, + cancel_callback=self._cancel_download, + is_terminal_progress=self._is_terminal_progress, + ) + gui_app.set_modal_overlay(dialog, callback=lambda _result: self._open_manager_menu()) + + def _show_message(self, title: str, description: str, return_to_manager: bool = False): + dialog = BigDialog(title, description) + if return_to_manager: + dialog.set_back_callback(self._open_manager_menu) + gui_app.set_modal_overlay(dialog) diff --git a/selfdrive/ui/mici/layouts/settings/settings.py b/selfdrive/ui/mici/layouts/settings/settings.py index 3471b348..28da9ad2 100644 --- a/selfdrive/ui/mici/layouts/settings/settings.py +++ b/selfdrive/ui/mici/layouts/settings/settings.py @@ -9,6 +9,7 @@ from openpilot.selfdrive.ui.mici.widgets.button import BigButton, BigMultiToggle from openpilot.selfdrive.ui.mici.layouts.settings.toggles import TogglesLayoutMici from openpilot.selfdrive.ui.mici.layouts.settings.network import NetworkLayoutMici from openpilot.selfdrive.ui.mici.layouts.settings.device import DeviceLayoutMici, PairBigButton, GalaxyBigButton +from openpilot.selfdrive.ui.mici.layouts.settings.driving_model import DrivingModelBigButton from openpilot.selfdrive.ui.mici.layouts.settings.developer import DeveloperLayoutMici from openpilot.selfdrive.ui.mici.layouts.settings.firehose import FirehoseLayout from openpilot.system.ui.lib.application import gui_app, FontWeight @@ -79,11 +80,13 @@ class SettingsLayout(NavWidget): firehose_btn.set_click_callback(lambda: self._set_current_panel(PanelType.FIREHOSE)) self._force_drive_state_btn = ForceDriveStateBigButton() + self._driving_model_btn = DrivingModelBigButton() self._scroller = Scroller([ toggles_btn, network_btn, self._force_drive_state_btn, + self._driving_model_btn, device_btn, PairBigButton(), GalaxyBigButton(), @@ -112,6 +115,7 @@ class SettingsLayout(NavWidget): def show_event(self): super().show_event() self._force_drive_state_btn.refresh() + self._driving_model_btn.refresh() self._set_current_panel(None) self._scroller.show_event() if self._current_panel is not None: diff --git a/selfdrive/ui/mici/onroad/augmented_road_view.py b/selfdrive/ui/mici/onroad/augmented_road_view.py index 83ef6656..9a827dac 100644 --- a/selfdrive/ui/mici/onroad/augmented_road_view.py +++ b/selfdrive/ui/mici/onroad/augmented_road_view.py @@ -6,7 +6,7 @@ from msgq.visionipc import VisionStreamType from openpilot.common.constants import CV from openpilot.selfdrive.ui.ui_state import ui_state from openpilot.selfdrive.ui.mici.onroad import SIDE_PANEL_WIDTH -from openpilot.selfdrive.ui.mici.onroad.alert_renderer import AlertRenderer +from openpilot.selfdrive.ui.mici.onroad.alert_renderer import AlertRenderer, ALERT_COLORS, AlertStatus from openpilot.selfdrive.ui.mici.onroad.driver_state import DriverStateRenderer from openpilot.selfdrive.ui.mici.onroad.hud_renderer import HudRenderer from openpilot.selfdrive.ui.mici.onroad.model_renderer import ModelRenderer @@ -252,18 +252,25 @@ class MinSteerSpeedBanner(Widget): if not self._showing_interval: return - banner_width = min(rect.width - 120, 760) - banner_height = 72 - banner_rect = rl.Rectangle( - rect.x + (rect.width - banner_width) / 2, - rect.y + 22, - banner_width, - banner_height, + color = ALERT_COLORS[AlertStatus.userPrompt] + color = rl.Color(color.r, color.g, color.b, int(255 * 0.9)) + translucent = rl.Color(color.r, color.g, color.b, 0) + dropdown_height = min(170, int(rect.height * 0.7)) + solid_height = max(26, int(dropdown_height * 0.2)) + + rl.draw_rectangle(int(rect.x), int(rect.y), int(rect.width), solid_height, color) + rl.draw_rectangle_gradient_v( + int(rect.x), + int(rect.y + solid_height), + int(rect.width), + int(dropdown_height - solid_height), + color, + translucent, ) - rl.draw_rectangle_rounded(banner_rect, 0.3, 12, rl.Color(0, 0, 0, 185)) - rl.draw_rectangle_rounded_lines_ex(banner_rect, 0.3, 12, 4, rl.Color(218, 111, 37, 255)) - self._label.render(banner_rect) + text_rect = rl.Rectangle(rect.x + 26, rect.y - 2, rect.width - 52, dropdown_height) + self._label.set_text_color(rl.Color(255, 255, 255, 242)) + self._label.render(text_rect) class AugmentedRoadView(CameraView): diff --git a/selfdrive/ui/mici/widgets/dialog.py b/selfdrive/ui/mici/widgets/dialog.py index 5f201766..d64571e3 100644 --- a/selfdrive/ui/mici/widgets/dialog.py +++ b/selfdrive/ui/mici/widgets/dialog.py @@ -73,26 +73,45 @@ class BigDialog(BigDialogBase): if self._right_btn: max_width -= self._right_btn._rect.width - title_wrapped = '\n'.join(wrap_text(gui_app.font(FontWeight.BOLD), self._title, 50, int(max_width))) - title_size = measure_text_cached(gui_app.font(FontWeight.BOLD), title_wrapped, 50) + title_font_size = 50 + desc_font_size = 30 + title_lines = wrap_text(gui_app.font(FontWeight.BOLD), self._title, title_font_size, int(max_width)) + if not title_lines: + title_lines = [""] + title_line_height = max(int(title_font_size * 1.2), int(measure_text_cached(gui_app.font(FontWeight.BOLD), "Ag", title_font_size).y)) text_x_offset = 0 - title_rect = rl.Rectangle(int(self._rect.x + text_x_offset + PADDING), - int(self._rect.y + PADDING), - int(max_width), - int(title_size.y)) - gui_label(title_rect, title_wrapped, 50, font_weight=FontWeight.BOLD, - alignment=rl.GuiTextAlignment.TEXT_ALIGN_CENTER) + title_x = int(self._rect.x + text_x_offset + PADDING) + title_y = int(self._rect.y + PADDING) + for i, line in enumerate(title_lines): + line_rect = rl.Rectangle( + title_x, + title_y + i * title_line_height, + int(max_width), + int(title_line_height), + ) + gui_label(line_rect, line, title_font_size, font_weight=FontWeight.BOLD, + alignment=rl.GuiTextAlignment.TEXT_ALIGN_CENTER, + alignment_vertical=rl.GuiTextAlignmentVertical.TEXT_ALIGN_TOP) # draw description - desc_wrapped = '\n'.join(wrap_text(gui_app.font(FontWeight.MEDIUM), self._description, 30, int(max_width))) - desc_size = measure_text_cached(gui_app.font(FontWeight.MEDIUM), desc_wrapped, 30) - desc_rect = rl.Rectangle(int(self._rect.x + text_x_offset + PADDING), - int(self._rect.y + self._rect.height / 3), - int(max_width), - int(desc_size.y)) - # TODO: text align doesn't seem to work properly with newlines - gui_label(desc_rect, desc_wrapped, 30, font_weight=FontWeight.MEDIUM, - alignment=rl.GuiTextAlignment.TEXT_ALIGN_CENTER) + desc_lines = wrap_text(gui_app.font(FontWeight.MEDIUM), self._description, desc_font_size, int(max_width)) + if not desc_lines: + desc_lines = [""] + desc_line_height = max(int(desc_font_size * 1.25), int(measure_text_cached(gui_app.font(FontWeight.MEDIUM), "Ag", desc_font_size).y)) + desc_y = max( + int(self._rect.y + self._rect.height / 3), + title_y + title_line_height * len(title_lines) + 22, + ) + for i, line in enumerate(desc_lines): + line_rect = rl.Rectangle( + title_x, + desc_y + i * desc_line_height, + int(max_width), + int(desc_line_height), + ) + gui_label(line_rect, line, desc_font_size, font_weight=FontWeight.MEDIUM, + alignment=rl.GuiTextAlignment.TEXT_ALIGN_CENTER, + alignment_vertical=rl.GuiTextAlignmentVertical.TEXT_ALIGN_TOP) return self._ret diff --git a/selfdrive/ui/onroad/augmented_road_view.py b/selfdrive/ui/onroad/augmented_road_view.py index f5ae56ef..2c6f5d1c 100644 --- a/selfdrive/ui/onroad/augmented_road_view.py +++ b/selfdrive/ui/onroad/augmented_road_view.py @@ -6,12 +6,12 @@ from msgq.visionipc import VisionStreamType from openpilot.common.constants import CV from openpilot.selfdrive.ui import UI_BORDER_SIZE from openpilot.selfdrive.ui.ui_state import ui_state, UIStatus -from openpilot.selfdrive.ui.onroad.alert_renderer import AlertRenderer +from openpilot.selfdrive.ui.onroad.alert_renderer import AlertRenderer, ALERT_COLORS, AlertStatus from openpilot.selfdrive.ui.onroad.driver_state import DriverStateRenderer from openpilot.selfdrive.ui.onroad.hud_renderer import HudRenderer from openpilot.selfdrive.ui.onroad.model_renderer import ModelRenderer from openpilot.selfdrive.ui.onroad.cameraview import CameraView -from openpilot.system.ui.lib.application import gui_app +from openpilot.system.ui.lib.application import gui_app, FontWeight from openpilot.common.transformations.camera import DEVICE_CAMERAS, DeviceCameraConfig, view_frame_from_device_frame from openpilot.common.transformations.orientation import rot_from_euler @@ -41,6 +41,7 @@ class MinSteerSpeedBanner: self._has_been_above_min = False self._was_under_min = False self._last_started_frame = -1 + self._font = gui_app.font(FontWeight.BOLD) def _reset(self): self._shown_this_drive = False @@ -98,27 +99,35 @@ class MinSteerSpeedBanner: if min_steer_speed <= 0: return - banner_width = min(rect.width - 120, 760) - banner_height = 84 - banner_rect = rl.Rectangle( - rect.x + (rect.width - banner_width) / 2, - rect.y + 24, - banner_width, - banner_height, - ) + color = ALERT_COLORS[AlertStatus.userPrompt] + color = rl.Color(color.r, color.g, color.b, int(255 * 0.93)) + translucent = rl.Color(color.r, color.g, color.b, 0) + dropdown_height = min(200, int(rect.height * 0.38)) + solid_height = max(34, int(dropdown_height * 0.22)) - rl.draw_rectangle_rounded(banner_rect, 0.3, 12, rl.Color(0, 0, 0, 185)) - rl.draw_rectangle_rounded_lines_ex(banner_rect, 0.3, 12, 4, rl.Color(218, 111, 37, 255)) + rl.draw_rectangle(int(rect.x), int(rect.y), int(rect.width), solid_height, color) + rl.draw_rectangle_gradient_v( + int(rect.x), + int(rect.y + solid_height), + int(rect.width), + int(dropdown_height - solid_height), + color, + translucent, + ) text = self._get_message(min_steer_speed) - font = gui_app.font() - font_size = 44 - text_size = rl.measure_text_ex(font, text, font_size, 0) + font_size = 52 + max_text_width = rect.width - 100 + text_size = rl.measure_text_ex(self._font, text, font_size, 0) + while font_size > 36 and text_size.x > max_text_width: + font_size -= 2 + text_size = rl.measure_text_ex(self._font, text, font_size, 0) + text_pos = rl.Vector2( - banner_rect.x + (banner_rect.width - text_size.x) / 2, - banner_rect.y + (banner_rect.height - text_size.y) / 2, + rect.x + (rect.width - text_size.x) / 2, + rect.y + max(12, (dropdown_height * 0.34) - (text_size.y / 2)), ) - rl.draw_text_ex(font, text, text_pos, font_size, 0, rl.WHITE) + rl.draw_text_ex(self._font, text, text_pos, font_size, 0, rl.Color(255, 255, 255, 242)) class AugmentedRoadView(CameraView): diff --git a/system/ui/widgets/selection_dialog.py b/system/ui/widgets/selection_dialog.py index 3720236c..f535547f 100644 --- a/system/ui/widgets/selection_dialog.py +++ b/system/ui/widgets/selection_dialog.py @@ -129,16 +129,6 @@ class SelectionItem(Widget): self._pressed = False self._fav_pressed = False - def _handle_mouse_press(self, mouse_pos): - if rl.check_collision_point_rec(mouse_pos, self._hit_rect): - self._pressed = True - - def _handle_mouse_release(self, mouse_pos): - if self._pressed and rl.check_collision_point_rec(mouse_pos, self._hit_rect): - if self._callback: - self._callback(self._text) - self._pressed = False - class SelectionDialog(Widget): def __init__(self, title: str, options, current_selection: str = "", on_close: Callable[[DialogResult, str], None] | None = None, @@ -146,7 +136,8 @@ class SelectionDialog(Widget): model_file_to_name: dict[str, str] | None = None, user_favorites: list[str] | None = None, community_favorites: list[str] | None = None, - on_favorite_toggled: Callable[[str], None] | None = None): + on_favorite_toggled: Callable[[str], None] | None = None, + favorites_editable: bool = True): super().__init__() self._title = title self._options_raw = options @@ -157,6 +148,7 @@ class SelectionDialog(Widget): self._user_favorites = user_favorites or [] self._community_favorites = community_favorites or [] self._on_favorite_toggled = on_favorite_toggled + self._favorites_editable = favorites_editable self._sort_mode = SortMode.ALPHABETICAL self._expanded_series = {s: True for s in (options.keys() if isinstance(options, dict) else [])} @@ -228,7 +220,7 @@ class SelectionDialog(Widget): is_selected=is_selected, is_favorite=is_fav, callback=self._on_item_selected, - fav_callback=self._toggle_favorite + fav_callback=self._toggle_favorite if self._favorites_editable else None )) else: for option in self._options_raw: @@ -243,6 +235,9 @@ class SelectionDialog(Widget): self._scroller.show_event() def _toggle_favorite(self, model_name: str): + if not self._favorites_editable: + return + key = self._name_to_file.get(model_name, model_name) if self._on_favorite_toggled: self._on_favorite_toggled(key)