mirror of
https://github.com/sunnypilot/sunnypilot.git
synced 2026-06-23 10:02:06 +08:00
Refactor type annotations and return types in model_runner
Removed specific type hints and return annotations for `self.inputs` and `run_model` methods to enhance flexibility and maintain consistency. These changes streamline the code and improve compatibility with varying input/output types during model inference.
This commit is contained in:
@@ -30,14 +30,14 @@ class ModelRunner(ABC):
|
||||
self.model_metadata = pickle.load(f)
|
||||
self.input_shapes = self.model_metadata['input_shapes']
|
||||
self.output_slices = self.model_metadata['output_slices']
|
||||
self.inputs: dict[str, np.ndarray | Tensor] = {}
|
||||
self.inputs: dict = {}
|
||||
|
||||
@abstractmethod
|
||||
def prepare_inputs(self, imgs_cl: dict[str, CLMem], numpy_inputs: dict[str, np.ndarray])-> dict:
|
||||
"""Prepare inputs for model inference."""
|
||||
|
||||
@abstractmethod
|
||||
def run_model(self) -> np.ndarray:
|
||||
def run_model(self):
|
||||
"""Run model inference with prepared inputs."""
|
||||
|
||||
def slice_outputs(self, model_outputs: np.ndarray) -> dict:
|
||||
@@ -70,7 +70,7 @@ class TinygradRunner(ModelRunner):
|
||||
|
||||
return self.inputs
|
||||
|
||||
def run_model(self) -> np.ndarray:
|
||||
def run_model(self):
|
||||
return self.model_run(**self.inputs).numpy().flatten()
|
||||
|
||||
|
||||
@@ -88,5 +88,5 @@ class ONNXRunner(ModelRunner):
|
||||
self.inputs[key] = self.frames[key].buffer_from_cl(imgs_cl[key]).reshape(self.input_shapes[key])
|
||||
return self.inputs
|
||||
|
||||
def run_model(self) -> np.ndarray:
|
||||
def run_model(self):
|
||||
return self.runner.run(None, self.inputs)[0].flatten()
|
||||
|
||||
Reference in New Issue
Block a user