mirror of
https://github.com/sunnypilot/sunnypilot.git
synced 2026-06-24 03:32:05 +08:00
This was actually quite revealing! This also has correlation with behaviors I saw when porting to tinygrad!
This commit is contained in:
+17
-31
@@ -103,44 +103,30 @@ class ModelState:
|
||||
return None
|
||||
|
||||
self.model.execute()
|
||||
outputs = self.parser.parse_outputs(self.slice_outputs(self.output))
|
||||
outputs = self.parser.parse_outputs(self.slice_outputs(self.output), self.inputs.keys())
|
||||
|
||||
self.full_features_20Hz[:-1] = self.full_features_20Hz[1:]
|
||||
self.full_features_20Hz[-1] = outputs['hidden_state'][0, :]
|
||||
|
||||
self.inputs['features_buffer'][:] = self.full_features_20Hz[self.feature_buffer_idxs].flatten()
|
||||
# Code below needs to be adjusted because the inputs in legacy models were received as flattened arrays
|
||||
# if "desired_curvature" in outputs:
|
||||
# input_name_prev = None
|
||||
#
|
||||
# if "prev_desired_curvs" in self.inputs.keys():
|
||||
# input_name_prev = 'prev_desired_curvs'
|
||||
# elif "prev_desired_curv" in self.inputs.keys():
|
||||
# input_name_prev = 'prev_desired_curv'
|
||||
#
|
||||
# if input_name_prev is not None:
|
||||
# len = outputs['desired_curvature'][0].size
|
||||
# self.inputs[input_name_prev][0, :-len, 0] = self.inputs[input_name_prev][0, len:, 0]
|
||||
# self.inputs[input_name_prev][0, -len:, 0] = outputs['desired_curvature'][0]
|
||||
#
|
||||
#
|
||||
# if "lat_planner_solution" in outputs:
|
||||
# if "lat_planner_state" in self.inputs.keys():
|
||||
# self.inputs['lat_planner_state'][2] = interp(DT_MDL, ModelConstants.T_IDXS, outputs['lat_planner_solution'][0, :, 2])
|
||||
# self.inputs['lat_planner_state'][3] = interp(DT_MDL, ModelConstants.T_IDXS, outputs['lat_planner_solution'][0, :, 3])
|
||||
#
|
||||
# if "desired_curvature" in outputs:
|
||||
# input_name_prev = None
|
||||
# if "prev_desired_curvs" in self.inputs.keys():
|
||||
# input_name_prev = 'prev_desired_curvs'
|
||||
# elif "prev_desired_curv" in self.inputs.keys():
|
||||
# input_name_prev = 'prev_desired_curv'
|
||||
#
|
||||
# if input_name_prev is not None:
|
||||
# len = outputs['desired_curvature'][0].size
|
||||
# self.inputs[input_name_prev][:-len] = self.inputs[input_name_prev][len:]
|
||||
# self.inputs[input_name_prev][-len:] = outputs['desired_curvature'][0, :len]
|
||||
if "desired_curvature" in outputs:
|
||||
input_name_prev = None
|
||||
|
||||
if "prev_desired_curvs" in self.inputs.keys():
|
||||
input_name_prev = 'prev_desired_curvs'
|
||||
elif "prev_desired_curv" in self.inputs.keys():
|
||||
input_name_prev = 'prev_desired_curv'
|
||||
|
||||
if input_name_prev is not None:
|
||||
len = outputs['desired_curvature'][0].size
|
||||
self.inputs[input_name_prev][:-len] = self.inputs[input_name_prev][len:]
|
||||
self.inputs[input_name_prev][-len:] = outputs['desired_curvature'][0, :]
|
||||
|
||||
if "lat_planner_solution" in outputs:
|
||||
if "lat_planner_state" in self.inputs.keys():
|
||||
self.inputs['lat_planner_state'][2] = interp(DT_MDL, ModelConstants.T_IDXS, outputs['lat_planner_solution'][0, :, 2])
|
||||
self.inputs['lat_planner_state'][3] = interp(DT_MDL, ModelConstants.T_IDXS, outputs['lat_planner_solution'][0, :, 3])
|
||||
return outputs
|
||||
|
||||
|
||||
|
||||
@@ -84,7 +84,8 @@ class Parser:
|
||||
outs[name] = pred_mu_final.reshape(final_shape)
|
||||
outs[name + '_stds'] = pred_std_final.reshape(final_shape)
|
||||
|
||||
def parse_outputs(self, outs: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
|
||||
def parse_outputs(self, outs: dict[str, np.ndarray], input_keys: [str]) -> dict[str, np.ndarray]:
|
||||
""" Parse the model outputs into a dictionary of numpy arrays. The input_keys are used to determine how the output should be parsed. """
|
||||
self.parse_mdn('plan', outs, in_N=ModelConstants.PLAN_MHP_N, out_N=ModelConstants.PLAN_MHP_SELECTION,
|
||||
out_shape=(ModelConstants.IDX_N,ModelConstants.PLAN_WIDTH))
|
||||
self.parse_mdn('lane_lines', outs, in_N=0, out_N=0, out_shape=(ModelConstants.NUM_LANE_LINES,ModelConstants.IDX_N,ModelConstants.LANE_LINES_WIDTH))
|
||||
@@ -96,6 +97,8 @@ class Parser:
|
||||
out_shape=(ModelConstants.LEAD_TRAJ_LEN,ModelConstants.LEAD_WIDTH))
|
||||
if 'lat_planner_solution' in outs:
|
||||
self.parse_mdn('lat_planner_solution', outs, in_N=0, out_N=0, out_shape=(ModelConstants.IDX_N,ModelConstants.LAT_PLANNER_SOLUTION_WIDTH))
|
||||
if 'desired_curvature' in outs and "prev_desired_curv" in input_keys:
|
||||
self.parse_mdn('desired_curvature', outs, in_N=0, out_N=0, out_shape=(ModelConstants.DESIRED_CURV_WIDTH,))
|
||||
for k in ['lead_prob', 'lane_lines_prob', 'meta']:
|
||||
self.parse_binary_crossentropy(k, outs)
|
||||
self.parse_categorical_crossentropy('desire_state', outs, out_shape=(ModelConstants.DESIRE_PRED_WIDTH,))
|
||||
|
||||
Reference in New Issue
Block a user