diff --git a/sunnypilot/modeld_v2/modeld.py b/sunnypilot/modeld_v2/modeld.py index 648cd5ae95..a55cbd8ab2 100755 --- a/sunnypilot/modeld_v2/modeld.py +++ b/sunnypilot/modeld_v2/modeld.py @@ -353,19 +353,15 @@ def main(demo=False): 'traffic_convention': traffic_convention, } - if "lateral_control_params" in model.numpy_inputs.keys(): - inputs['lateral_control_params'] = np.array([v_ego, lat_delay], dtype=np.float32) - - if "driving_style" in model.numpy_inputs.keys(): - inputs['driving_style'] = np.array([1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], dtype=np.float32) - - if "nav_features" in model.numpy_inputs.keys(): - nav_features_shape = model.model_runner.input_shapes.get('nav_features') - inputs['nav_features'] = np.zeros(nav_features_shape[1], dtype=np.float32) - - if "nav_instructions" in model.numpy_inputs.keys(): - nav_instructions_shape = model.model_runner.input_shapes.get('nav_instructions') - inputs['nav_instructions'] = np.zeros(nav_instructions_shape[1], dtype=np.float32) + conditional_inputs = { + "lateral_control_params": lambda: np.array([v_ego, lat_delay], dtype=np.float32), + "driving_style": lambda: np.array([1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], dtype=np.float32), + "nav_features": lambda: np.zeros(model.model_runner.input_shapes.get('nav_features')[1], dtype=np.float32), + "nav_instructions": lambda: np.zeros(model.model_runner.input_shapes.get('nav_instructions')[1], dtype=np.float32), + } + for key, value in conditional_inputs.items(): + if key in model.numpy_inputs: + inputs[key] = value() mt1 = time.perf_counter() model_output = model.run(bufs, transforms, inputs, prepare_only)