fix model manager

This commit is contained in:
firestar5683
2026-04-07 11:50:50 -05:00
parent d59aca4fb6
commit f6601873a2
+15 -4
View File
@@ -66,10 +66,12 @@ class ModelManager:
self.params_memory = params_memory
self.downloading_model = False
self.available_models = [entry for entry in (self.params.get("AvailableModels") or "").split(",") if entry]
self.model_versions = [entry for entry in (self.params.get("ModelVersions") or "").split(",") if entry]
self.model_series = [entry for entry in (self.params.get("AvailableModelSeries") or "").split(",") if entry]
self.available_model_names = [entry for entry in (self.params.get("AvailableModelNames") or "").split(",") if entry]
self.available_models: list[str] = []
self.model_versions: list[str] = []
self.model_series: list[str] = []
self.available_model_names: list[str] = []
self._load_catalog_from_params()
self._ensure_model_params()
if boot_run:
@@ -98,6 +100,12 @@ class ModelManager:
return default_value.decode("utf-8", errors="ignore").strip()
return str(default_value).strip()
def _load_catalog_from_params(self):
self.available_models = [entry for entry in self._param_text("AvailableModels").split(",") if entry]
self.model_versions = [entry for entry in self._param_text("ModelVersions").split(",") if entry]
self.model_series = [entry for entry in self._param_text("AvailableModelSeries").split(",") if entry]
self.available_model_names = [entry for entry in self._param_text("AvailableModelNames").split(",") if entry]
def _set_model_param_keys(self, model_key: str | None = None, model_name: str | None = None, model_version: str | None = None):
if model_key is not None and model_key != "":
canonical_key = self._canonical_model_key(model_key)
@@ -313,6 +321,9 @@ class ModelManager:
self.downloading_model = False
return
# Refresh from params so long-lived workers pick up manifest refreshes done by
# a separate ModelManager instance before we validate the requested model.
self._load_catalog_from_params()
version_map = self._model_version_map()
model_version = version_map.get(model_to_download)
required_files = self._required_files(model_to_download, model_version or "")