mirror of
https://github.com/sunnypilot/sunnypilot.git
synced 2026-06-19 16:52:06 +08:00
Refactor modeld for flexible input rates and output parsing.
Add support for both 20Hz and variable input rates in `modeld.py` by introducing conditional data updates based on a new `is_20hz` flag. Additionally, enhance `parse_outputs` in `parse_model_outputs.py` to handle `desired_curvature` parsing when relevant input keys are present.
This commit is contained in:
@@ -50,8 +50,9 @@ class ModelState:
|
||||
self.prev_desire = np.zeros(ModelConstants.DESIRE_LEN, dtype=np.float32)
|
||||
self.full_features_20Hz = np.zeros((ModelConstants.FULL_HISTORY_BUFFER_LEN, ModelConstants.FEATURE_LEN), dtype=np.float32)
|
||||
self.desire_20Hz = np.zeros((ModelConstants.FULL_HISTORY_BUFFER_LEN + 1, ModelConstants.DESIRE_LEN), dtype=np.float32)
|
||||
self.is_20hz = False
|
||||
# Initialize model runner
|
||||
self.model_runner = TinygradRunner(self.frames)# if TICI else ONNXRunner(self.frames)
|
||||
self.model_runner = TinygradRunner(self.frames) if TICI else ONNXRunner(self.frames)
|
||||
|
||||
# img buffers are managed in openCL transform code
|
||||
self.numpy_inputs = {}
|
||||
@@ -77,9 +78,14 @@ class ModelState:
|
||||
new_desire = np.where(inputs['desire'] - self.prev_desire > .99, inputs['desire'], 0)
|
||||
self.prev_desire[:] = inputs['desire']
|
||||
|
||||
self.desire_20Hz[:-1] = self.desire_20Hz[1:]
|
||||
self.desire_20Hz[-1] = new_desire
|
||||
self.numpy_inputs['desire'][:] = self.desire_20Hz.reshape(self.desire_reshape_dims).max(axis=2)
|
||||
if self.is_20hz:
|
||||
self.desire_20Hz[:-1] = self.desire_20Hz[1:]
|
||||
self.desire_20Hz[-1] = new_desire
|
||||
self.numpy_inputs['desire'][:] = self.desire_20Hz.reshape(self.desire_reshape_dims).max(axis=2)
|
||||
else:
|
||||
len = inputs['desire'].shape[0]
|
||||
self.numpy_inputs['desire'][0, :-1] = self.numpy_inputs['desire'][0, 1:]
|
||||
self.numpy_inputs['desire'][0, -1, :len] = new_desire[:len]
|
||||
|
||||
for key in self.numpy_inputs:
|
||||
if key in inputs and key not in ['desire']:
|
||||
@@ -96,12 +102,17 @@ class ModelState:
|
||||
|
||||
# Run model inference
|
||||
self.output = self.model_runner.run_model()
|
||||
outputs = self.parser.parse_outputs(self.model_runner.slice_outputs(self.output))
|
||||
outputs = self.parser.parse_outputs(self.model_runner.slice_outputs(self.output), self.numpy_inputs.keys())
|
||||
|
||||
self.full_features_20Hz[:-1] = self.full_features_20Hz[1:]
|
||||
self.full_features_20Hz[-1] = outputs['hidden_state'][0, :]
|
||||
if self.is_20hz:
|
||||
self.full_features_20Hz[:-1] = self.full_features_20Hz[1:]
|
||||
self.full_features_20Hz[-1] = outputs['hidden_state'][0, :]
|
||||
self.numpy_inputs['features_buffer'][:] = self.full_features_20Hz[self.full_features_20Hz_idxs]
|
||||
else:
|
||||
feature_len = outputs['hidden_state'].shape[1]
|
||||
self.numpy_inputs['features_buffer'][0, :-1] = self.numpy_inputs['features_buffer'][0, 1:]
|
||||
self.numpy_inputs['features_buffer'][0, -1, :feature_len] = outputs['hidden_state'][0, :feature_len]
|
||||
|
||||
self.numpy_inputs['features_buffer'][:] = self.full_features_20Hz[self.full_features_20Hz_idxs]
|
||||
|
||||
if "desired_curvature" in outputs:
|
||||
input_name_prev = None
|
||||
|
||||
@@ -84,7 +84,8 @@ class Parser:
|
||||
outs[name] = pred_mu_final.reshape(final_shape)
|
||||
outs[name + '_stds'] = pred_std_final.reshape(final_shape)
|
||||
|
||||
def parse_outputs(self, outs: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
|
||||
def parse_outputs(self, outs: dict[str, np.ndarray], input_keys: [str]) -> dict[str, np.ndarray]:
|
||||
""" Parse the model outputs into a dictionary of numpy arrays. The input_keys are used to determine how the output should be parsed. """
|
||||
self.parse_mdn('plan', outs, in_N=ModelConstants.PLAN_MHP_N, out_N=ModelConstants.PLAN_MHP_SELECTION,
|
||||
out_shape=(ModelConstants.IDX_N,ModelConstants.PLAN_WIDTH))
|
||||
self.parse_mdn('lane_lines', outs, in_N=0, out_N=0, out_shape=(ModelConstants.NUM_LANE_LINES,ModelConstants.IDX_N,ModelConstants.LANE_LINES_WIDTH))
|
||||
@@ -96,6 +97,8 @@ class Parser:
|
||||
out_shape=(ModelConstants.LEAD_TRAJ_LEN,ModelConstants.LEAD_WIDTH))
|
||||
if 'lat_planner_solution' in outs:
|
||||
self.parse_mdn('lat_planner_solution', outs, in_N=0, out_N=0, out_shape=(ModelConstants.IDX_N,ModelConstants.LAT_PLANNER_SOLUTION_WIDTH))
|
||||
if 'desired_curvature' in outs and "prev_desired_curv" in input_keys:
|
||||
self.parse_mdn('desired_curvature', outs, in_N=0, out_N=0, out_shape=(ModelConstants.DESIRED_CURV_WIDTH,))
|
||||
for k in ['lead_prob', 'lane_lines_prob', 'meta']:
|
||||
self.parse_binary_crossentropy(k, outs)
|
||||
self.parse_categorical_crossentropy('desire_state', outs, out_shape=(ModelConstants.DESIRE_PRED_WIDTH,))
|
||||
|
||||
Reference in New Issue
Block a user