Refactor type annotations and return types in model_runner

Removed specific type hints and return annotations for `self.inputs` and `run_model` methods to enhance flexibility and maintain consistency. These changes streamline the code and improve compatibility with varying input/output types during model inference.
This commit is contained in:
DevTekVE
2024-12-29 18:04:05 +01:00
parent f739f9f71d
commit 8197d170bc
+4 -4
View File
@@ -30,14 +30,14 @@ class ModelRunner(ABC):
self.model_metadata = pickle.load(f)
self.input_shapes = self.model_metadata['input_shapes']
self.output_slices = self.model_metadata['output_slices']
self.inputs: dict[str, np.ndarray | Tensor] = {}
self.inputs: dict = {}
@abstractmethod
def prepare_inputs(self, imgs_cl: dict[str, CLMem], numpy_inputs: dict[str, np.ndarray])-> dict:
"""Prepare inputs for model inference."""
@abstractmethod
def run_model(self) -> np.ndarray:
def run_model(self):
"""Run model inference with prepared inputs."""
def slice_outputs(self, model_outputs: np.ndarray) -> dict:
@@ -70,7 +70,7 @@ class TinygradRunner(ModelRunner):
return self.inputs
def run_model(self) -> np.ndarray:
def run_model(self):
return self.model_run(**self.inputs).numpy().flatten()
@@ -88,5 +88,5 @@ class ONNXRunner(ModelRunner):
self.inputs[key] = self.frames[key].buffer_from_cl(imgs_cl[key]).reshape(self.input_shapes[key])
return self.inputs
def run_model(self) -> np.ndarray:
def run_model(self):
return self.runner.run(None, self.inputs)[0].flatten()