mirror of
https://github.com/firestar5683/StarPilot.git
synced 2026-06-28 01:52:06 +08:00
modeld input queues class (#36072)
* move from xx * no get_single * stupid name * thats fine * desire_pulse * 1less * desire->desire_pulse * simplify * reduce copies * more less
This commit is contained in:
@@ -13,12 +13,9 @@ class ModelConstants:
|
||||
META_T_IDXS = [2., 4., 6., 8., 10.]
|
||||
|
||||
# model inputs constants
|
||||
MODEL_FREQ = 20
|
||||
HISTORY_FREQ = 5
|
||||
HISTORY_LEN_SECONDS = 5
|
||||
TEMPORAL_SKIP = MODEL_FREQ // HISTORY_FREQ
|
||||
FULL_HISTORY_BUFFER_LEN = MODEL_FREQ * HISTORY_LEN_SECONDS
|
||||
INPUT_HISTORY_BUFFER_LEN = HISTORY_FREQ * HISTORY_LEN_SECONDS
|
||||
N_FRAMES = 2
|
||||
MODEL_RUN_FREQ = 20
|
||||
MODEL_CONTEXT_FREQ = 5 # "model_trained_fps"
|
||||
|
||||
FEATURE_LEN = 512
|
||||
|
||||
|
||||
@@ -149,7 +149,7 @@ def fill_model_msg(base_msg: capnp._DynamicStructBuilder, extended_msg: capnp._D
|
||||
meta.hardBrakePredicted = hard_brake_predicted.item()
|
||||
|
||||
# confidence
|
||||
if vipc_frame_id % (2*ModelConstants.MODEL_FREQ) == 0:
|
||||
if vipc_frame_id % (2*ModelConstants.MODEL_RUN_FREQ) == 0:
|
||||
# any disengage prob
|
||||
brake_disengage_probs = net_output_data['meta'][0,Meta.BRAKE_DISENGAGE]
|
||||
gas_disengage_probs = net_output_data['meta'][0,Meta.GAS_DISENGAGE]
|
||||
|
||||
+69
-19
@@ -77,6 +77,64 @@ class FrameMeta:
|
||||
if vipc is not None:
|
||||
self.frame_id, self.timestamp_sof, self.timestamp_eof = vipc.frame_id, vipc.timestamp_sof, vipc.timestamp_eof
|
||||
|
||||
class InputQueues:
|
||||
def __init__ (self, model_fps, env_fps, n_frames_input):
|
||||
assert env_fps % model_fps == 0
|
||||
assert env_fps >= model_fps
|
||||
self.model_fps = model_fps
|
||||
self.env_fps = env_fps
|
||||
self.n_frames_input = n_frames_input
|
||||
|
||||
self.dtypes = {}
|
||||
self.shapes = {}
|
||||
self.q = {}
|
||||
|
||||
def update_dtypes_and_shapes(self, input_dtypes, input_shapes) -> None:
|
||||
self.dtypes.update(input_dtypes)
|
||||
if self.env_fps == self.model_fps:
|
||||
self.shapes.update(input_shapes)
|
||||
else:
|
||||
for k in input_shapes:
|
||||
shape = list(input_shapes[k])
|
||||
if 'img' in k:
|
||||
n_channels = shape[1] // self.n_frames_input
|
||||
shape[1] = (self.env_fps // self.model_fps + (self.n_frames_input - 1)) * n_channels
|
||||
else:
|
||||
shape[1] = (self.env_fps // self.model_fps) * shape[1]
|
||||
self.shapes[k] = tuple(shape)
|
||||
|
||||
def reset(self) -> None:
|
||||
self.q = {k: np.zeros(self.shapes[k], dtype=self.dtypes[k]) for k in self.dtypes.keys()}
|
||||
|
||||
def enqueue(self, inputs:dict[str, np.ndarray]) -> None:
|
||||
for k in inputs.keys():
|
||||
if inputs[k].dtype != self.dtypes[k]:
|
||||
raise ValueError(f'supplied input <{k}({inputs[k].dtype})> has wrong dtype, expected {self.dtypes[k]}')
|
||||
input_shape = list(self.shapes[k])
|
||||
input_shape[1] = -1
|
||||
single_input = inputs[k].reshape(tuple(input_shape))
|
||||
sz = single_input.shape[1]
|
||||
self.q[k][:,:-sz] = self.q[k][:,sz:]
|
||||
self.q[k][:,-sz:] = single_input
|
||||
|
||||
def get(self, *names) -> dict[str, np.ndarray]:
|
||||
if self.env_fps == self.model_fps:
|
||||
return {k: self.q[k] for k in names}
|
||||
else:
|
||||
out = {}
|
||||
for k in names:
|
||||
shape = self.shapes[k]
|
||||
if 'img' in k:
|
||||
n_channels = shape[1] // (self.env_fps // self.model_fps + (self.n_frames_input - 1))
|
||||
out[k] = np.concatenate([self.q[k][:, s:s+n_channels] for s in np.linspace(0, shape[1] - n_channels, self.n_frames_input, dtype=int)], axis=1)
|
||||
elif 'pulse' in k:
|
||||
# any pulse within interval counts
|
||||
out[k] = self.q[k].reshape((shape[0], shape[1] * self.model_fps // self.env_fps, self.env_fps // self.model_fps, -1)).max(axis=2)
|
||||
else:
|
||||
idxs = np.arange(-1, -shape[1], -self.env_fps // self.model_fps)[::-1]
|
||||
out[k] = self.q[k][:, idxs]
|
||||
return out
|
||||
|
||||
class ModelState:
|
||||
frames: dict[str, DrivingModelFrame]
|
||||
inputs: dict[str, np.ndarray]
|
||||
@@ -97,19 +155,15 @@ class ModelState:
|
||||
self.policy_output_slices = policy_metadata['output_slices']
|
||||
policy_output_size = policy_metadata['output_shapes']['outputs'][1]
|
||||
|
||||
self.frames = {name: DrivingModelFrame(context, ModelConstants.TEMPORAL_SKIP) for name in self.vision_input_names}
|
||||
self.frames = {name: DrivingModelFrame(context, ModelConstants.MODEL_RUN_FREQ//ModelConstants.MODEL_CONTEXT_FREQ) for name in self.vision_input_names}
|
||||
self.prev_desire = np.zeros(ModelConstants.DESIRE_LEN, dtype=np.float32)
|
||||
|
||||
self.full_features_buffer = np.zeros((1, ModelConstants.FULL_HISTORY_BUFFER_LEN, ModelConstants.FEATURE_LEN), dtype=np.float32)
|
||||
self.full_desire = np.zeros((1, ModelConstants.FULL_HISTORY_BUFFER_LEN, ModelConstants.DESIRE_LEN), dtype=np.float32)
|
||||
self.temporal_idxs = slice(-1-(ModelConstants.TEMPORAL_SKIP*(ModelConstants.INPUT_HISTORY_BUFFER_LEN-1)), None, ModelConstants.TEMPORAL_SKIP)
|
||||
|
||||
# policy inputs
|
||||
self.numpy_inputs = {
|
||||
'desire_pulse': np.zeros((1, ModelConstants.INPUT_HISTORY_BUFFER_LEN, ModelConstants.DESIRE_LEN), dtype=np.float32),
|
||||
'traffic_convention': np.zeros((1, ModelConstants.TRAFFIC_CONVENTION_LEN), dtype=np.float32),
|
||||
'features_buffer': np.zeros((1, ModelConstants.INPUT_HISTORY_BUFFER_LEN, ModelConstants.FEATURE_LEN), dtype=np.float32),
|
||||
}
|
||||
self.numpy_inputs = {k: np.zeros(self.policy_input_shapes[k], dtype=np.float32) for k in self.policy_input_shapes}
|
||||
self.full_input_queues = InputQueues(ModelConstants.MODEL_CONTEXT_FREQ, ModelConstants.MODEL_RUN_FREQ, ModelConstants.N_FRAMES)
|
||||
for k in ['desire_pulse', 'features_buffer']:
|
||||
self.full_input_queues.update_dtypes_and_shapes({k: self.numpy_inputs[k].dtype}, {k: self.numpy_inputs[k].shape})
|
||||
self.full_input_queues.reset()
|
||||
|
||||
# img buffers are managed in openCL transform code
|
||||
self.vision_inputs: dict[str, Tensor] = {}
|
||||
@@ -135,11 +189,6 @@ class ModelState:
|
||||
new_desire = np.where(inputs['desire_pulse'] - self.prev_desire > .99, inputs['desire_pulse'], 0)
|
||||
self.prev_desire[:] = inputs['desire_pulse']
|
||||
|
||||
self.full_desire[0,:-1] = self.full_desire[0,1:]
|
||||
self.full_desire[0,-1] = new_desire
|
||||
self.numpy_inputs['desire_pulse'][:] = self.full_desire.reshape((1,ModelConstants.INPUT_HISTORY_BUFFER_LEN,ModelConstants.TEMPORAL_SKIP,-1)).max(axis=2)
|
||||
|
||||
self.numpy_inputs['traffic_convention'][:] = inputs['traffic_convention']
|
||||
imgs_cl = {name: self.frames[name].prepare(bufs[name], transforms[name].flatten()) for name in self.vision_input_names}
|
||||
|
||||
if TICI and not USBGPU:
|
||||
@@ -158,9 +207,10 @@ class ModelState:
|
||||
self.vision_output = self.vision_run(**self.vision_inputs).contiguous().realize().uop.base.buffer.numpy()
|
||||
vision_outputs_dict = self.parser.parse_vision_outputs(self.slice_outputs(self.vision_output, self.vision_output_slices))
|
||||
|
||||
self.full_features_buffer[0,:-1] = self.full_features_buffer[0,1:]
|
||||
self.full_features_buffer[0,-1] = vision_outputs_dict['hidden_state'][0, :]
|
||||
self.numpy_inputs['features_buffer'][:] = self.full_features_buffer[0, self.temporal_idxs]
|
||||
self.full_input_queues.enqueue({'features_buffer': vision_outputs_dict['hidden_state'], 'desire_pulse': new_desire})
|
||||
for k in ['desire_pulse', 'features_buffer']:
|
||||
self.numpy_inputs[k][:] = self.full_input_queues.get(k)[k]
|
||||
self.numpy_inputs['traffic_convention'][:] = inputs['traffic_convention']
|
||||
|
||||
self.policy_output = self.policy_run(**self.policy_inputs).contiguous().realize().uop.base.buffer.numpy()
|
||||
policy_outputs_dict = self.parser.parse_policy_outputs(self.slice_outputs(self.policy_output, self.policy_output_slices))
|
||||
@@ -218,7 +268,7 @@ def main(demo=False):
|
||||
params = Params()
|
||||
|
||||
# setup filter to track dropped frames
|
||||
frame_dropped_filter = FirstOrderFilter(0., 10., 1. / ModelConstants.MODEL_FREQ)
|
||||
frame_dropped_filter = FirstOrderFilter(0., 10., 1. / ModelConstants.MODEL_RUN_FREQ)
|
||||
frame_id = 0
|
||||
last_vipc_frame_id = 0
|
||||
run_count = 0
|
||||
|
||||
Reference in New Issue
Block a user