conditional mapping in a dict to reduce code while maintaining clarity

This commit is contained in:
discountchubbs
2025-08-30 07:43:56 -05:00
parent 90240eb1bc
commit bbdccb019f
+9 -13
View File
@@ -353,19 +353,15 @@ def main(demo=False):
'traffic_convention': traffic_convention,
}
if "lateral_control_params" in model.numpy_inputs.keys():
inputs['lateral_control_params'] = np.array([v_ego, lat_delay], dtype=np.float32)
if "driving_style" in model.numpy_inputs.keys():
inputs['driving_style'] = np.array([1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], dtype=np.float32)
if "nav_features" in model.numpy_inputs.keys():
nav_features_shape = model.model_runner.input_shapes.get('nav_features')
inputs['nav_features'] = np.zeros(nav_features_shape[1], dtype=np.float32)
if "nav_instructions" in model.numpy_inputs.keys():
nav_instructions_shape = model.model_runner.input_shapes.get('nav_instructions')
inputs['nav_instructions'] = np.zeros(nav_instructions_shape[1], dtype=np.float32)
conditional_inputs = {
"lateral_control_params": lambda: np.array([v_ego, lat_delay], dtype=np.float32),
"driving_style": lambda: np.array([1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], dtype=np.float32),
"nav_features": lambda: np.zeros(model.model_runner.input_shapes.get('nav_features')[1], dtype=np.float32),
"nav_instructions": lambda: np.zeros(model.model_runner.input_shapes.get('nav_instructions')[1], dtype=np.float32),
}
for key, value in conditional_inputs.items():
if key in model.numpy_inputs:
inputs[key] = value()
mt1 = time.perf_counter()
model_output = model.run(bufs, transforms, inputs, prepare_only)