diff --git a/starpilot/assets/model_manager.py b/starpilot/assets/model_manager.py index f41febbc..26de0c0a 100644 --- a/starpilot/assets/model_manager.py +++ b/starpilot/assets/model_manager.py @@ -66,10 +66,12 @@ class ModelManager: self.params_memory = params_memory self.downloading_model = False - self.available_models = [entry for entry in (self.params.get("AvailableModels") or "").split(",") if entry] - self.model_versions = [entry for entry in (self.params.get("ModelVersions") or "").split(",") if entry] - self.model_series = [entry for entry in (self.params.get("AvailableModelSeries") or "").split(",") if entry] - self.available_model_names = [entry for entry in (self.params.get("AvailableModelNames") or "").split(",") if entry] + self.available_models: list[str] = [] + self.model_versions: list[str] = [] + self.model_series: list[str] = [] + self.available_model_names: list[str] = [] + + self._load_catalog_from_params() self._ensure_model_params() if boot_run: @@ -98,6 +100,12 @@ class ModelManager: return default_value.decode("utf-8", errors="ignore").strip() return str(default_value).strip() + def _load_catalog_from_params(self): + self.available_models = [entry for entry in self._param_text("AvailableModels").split(",") if entry] + self.model_versions = [entry for entry in self._param_text("ModelVersions").split(",") if entry] + self.model_series = [entry for entry in self._param_text("AvailableModelSeries").split(",") if entry] + self.available_model_names = [entry for entry in self._param_text("AvailableModelNames").split(",") if entry] + def _set_model_param_keys(self, model_key: str | None = None, model_name: str | None = None, model_version: str | None = None): if model_key is not None and model_key != "": canonical_key = self._canonical_model_key(model_key) @@ -313,6 +321,9 @@ class ModelManager: self.downloading_model = False return + # Refresh from params so long-lived workers pick up manifest refreshes done by + # a separate ModelManager instance before we validate the requested model. + self._load_catalog_from_params() version_map = self._model_version_map() model_version = version_map.get(model_to_download) required_files = self._required_files(model_to_download, model_version or "")