diff --git a/sunnypilot/modeld/modeld.py b/sunnypilot/modeld/modeld.py index a37600fc0f..aab07605d1 100755 --- a/sunnypilot/modeld/modeld.py +++ b/sunnypilot/modeld/modeld.py @@ -103,44 +103,30 @@ class ModelState: return None self.model.execute() - outputs = self.parser.parse_outputs(self.slice_outputs(self.output)) + outputs = self.parser.parse_outputs(self.slice_outputs(self.output), self.inputs.keys()) self.full_features_20Hz[:-1] = self.full_features_20Hz[1:] self.full_features_20Hz[-1] = outputs['hidden_state'][0, :] self.inputs['features_buffer'][:] = self.full_features_20Hz[self.feature_buffer_idxs].flatten() # Code below needs to be adjusted because the inputs in legacy models were received as flattened arrays - # if "desired_curvature" in outputs: - # input_name_prev = None - # - # if "prev_desired_curvs" in self.inputs.keys(): - # input_name_prev = 'prev_desired_curvs' - # elif "prev_desired_curv" in self.inputs.keys(): - # input_name_prev = 'prev_desired_curv' - # - # if input_name_prev is not None: - # len = outputs['desired_curvature'][0].size - # self.inputs[input_name_prev][0, :-len, 0] = self.inputs[input_name_prev][0, len:, 0] - # self.inputs[input_name_prev][0, -len:, 0] = outputs['desired_curvature'][0] - # - # - # if "lat_planner_solution" in outputs: - # if "lat_planner_state" in self.inputs.keys(): - # self.inputs['lat_planner_state'][2] = interp(DT_MDL, ModelConstants.T_IDXS, outputs['lat_planner_solution'][0, :, 2]) - # self.inputs['lat_planner_state'][3] = interp(DT_MDL, ModelConstants.T_IDXS, outputs['lat_planner_solution'][0, :, 3]) - # - # if "desired_curvature" in outputs: - # input_name_prev = None - # if "prev_desired_curvs" in self.inputs.keys(): - # input_name_prev = 'prev_desired_curvs' - # elif "prev_desired_curv" in self.inputs.keys(): - # input_name_prev = 'prev_desired_curv' - # - # if input_name_prev is not None: - # len = outputs['desired_curvature'][0].size - # self.inputs[input_name_prev][:-len] = self.inputs[input_name_prev][len:] - # self.inputs[input_name_prev][-len:] = outputs['desired_curvature'][0, :len] + if "desired_curvature" in outputs: + input_name_prev = None + if "prev_desired_curvs" in self.inputs.keys(): + input_name_prev = 'prev_desired_curvs' + elif "prev_desired_curv" in self.inputs.keys(): + input_name_prev = 'prev_desired_curv' + + if input_name_prev is not None: + len = outputs['desired_curvature'][0].size + self.inputs[input_name_prev][:-len] = self.inputs[input_name_prev][len:] + self.inputs[input_name_prev][-len:] = outputs['desired_curvature'][0, :] + + if "lat_planner_solution" in outputs: + if "lat_planner_state" in self.inputs.keys(): + self.inputs['lat_planner_state'][2] = interp(DT_MDL, ModelConstants.T_IDXS, outputs['lat_planner_solution'][0, :, 2]) + self.inputs['lat_planner_state'][3] = interp(DT_MDL, ModelConstants.T_IDXS, outputs['lat_planner_solution'][0, :, 3]) return outputs diff --git a/sunnypilot/modeld/parse_model_outputs.py b/sunnypilot/modeld/parse_model_outputs.py index ad664efc55..1941ab4780 100644 --- a/sunnypilot/modeld/parse_model_outputs.py +++ b/sunnypilot/modeld/parse_model_outputs.py @@ -84,7 +84,8 @@ class Parser: outs[name] = pred_mu_final.reshape(final_shape) outs[name + '_stds'] = pred_std_final.reshape(final_shape) - def parse_outputs(self, outs: dict[str, np.ndarray]) -> dict[str, np.ndarray]: + def parse_outputs(self, outs: dict[str, np.ndarray], input_keys: [str]) -> dict[str, np.ndarray]: + """ Parse the model outputs into a dictionary of numpy arrays. The input_keys are used to determine how the output should be parsed. """ self.parse_mdn('plan', outs, in_N=ModelConstants.PLAN_MHP_N, out_N=ModelConstants.PLAN_MHP_SELECTION, out_shape=(ModelConstants.IDX_N,ModelConstants.PLAN_WIDTH)) self.parse_mdn('lane_lines', outs, in_N=0, out_N=0, out_shape=(ModelConstants.NUM_LANE_LINES,ModelConstants.IDX_N,ModelConstants.LANE_LINES_WIDTH)) @@ -96,6 +97,8 @@ class Parser: out_shape=(ModelConstants.LEAD_TRAJ_LEN,ModelConstants.LEAD_WIDTH)) if 'lat_planner_solution' in outs: self.parse_mdn('lat_planner_solution', outs, in_N=0, out_N=0, out_shape=(ModelConstants.IDX_N,ModelConstants.LAT_PLANNER_SOLUTION_WIDTH)) + if 'desired_curvature' in outs and "prev_desired_curv" in input_keys: + self.parse_mdn('desired_curvature', outs, in_N=0, out_N=0, out_shape=(ModelConstants.DESIRED_CURV_WIDTH,)) for k in ['lead_prob', 'lane_lines_prob', 'meta']: self.parse_binary_crossentropy(k, outs) self.parse_categorical_crossentropy('desire_state', outs, out_shape=(ModelConstants.DESIRE_PRED_WIDTH,))