Refactor prepare_inputs to use explicit CLMem type.

This update replaces the generic `any` type with the more explicit `CLMem` type for better type safety and clarity. It ensures consistency across the `prepare_inputs` method implementations in derived classes, improving code readability and robustness.
This commit is contained in:
DevTekVE
2024-12-29 17:44:23 +01:00
parent 5d1b403015
commit f67e8aca47
+4 -4
View File
@@ -14,7 +14,7 @@ import pickle
import numpy as np
from pathlib import Path
from abc import ABC, abstractmethod
from openpilot.selfdrive.modeld.models.commonmodel_pyx import DrivingModelFrame
from openpilot.selfdrive.modeld.models.commonmodel_pyx import DrivingModelFrame, CLMem
SEND_RAW_PRED = os.getenv('SEND_RAW_PRED')
MODEL_PATH = Path(__file__).parent / '../models/supercombo.onnx'
@@ -34,7 +34,7 @@ class ModelRunner(ABC):
self.inputs = {}
@abstractmethod
def prepare_inputs(self, imgs_cl: dict[str, any], numpy_inputs: dict[str, np.ndarray]) -> dict[str, any]:
def prepare_inputs(self, imgs_cl: dict[str, CLMem], numpy_inputs: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
"""Prepare inputs for model inference."""
@abstractmethod
@@ -58,7 +58,7 @@ class TinyGradRunner(ModelRunner):
with open(MODEL_PKL_PATH, "rb") as f:
self.model_run = pickle.load(f)
def prepare_inputs(self, imgs_cl: dict[str, any], numpy_inputs: dict[str, np.ndarray]) -> dict[str, any]:
def prepare_inputs(self, imgs_cl: dict[str, CLMem], numpy_inputs: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
# Initialize image tensors if not already done
for key in imgs_cl:
if key not in self.inputs:
@@ -83,7 +83,7 @@ class ONNXRunner(ModelRunner):
self.runner = make_onnx_cpu_runner(MODEL_PATH)
self.frames = frames
def prepare_inputs(self, imgs_cl: dict[str, any], numpy_inputs: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
def prepare_inputs(self, imgs_cl: dict[str, CLMem], numpy_inputs: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
for key in imgs_cl:
numpy_inputs[key] = self.frames[key].buffer_from_cl(imgs_cl[key]).reshape(self.input_shapes[key])
self.inputs = numpy_inputs