mirror of
https://github.com/firestar5683/StarPilot.git
synced 2026-06-13 02:54:37 +08:00
RL code
This commit is contained in:
@@ -229,7 +229,7 @@ class LongitudinalPlanner:
|
||||
|
||||
@property
|
||||
def mlsim(self):
|
||||
return self.generation in ("v8", "v10", "v11", "v12")
|
||||
return self.generation in ("v8", "v10", "v11", "v12", "v13")
|
||||
|
||||
def get_mpc_mode(self) -> str:
|
||||
if not self.mlsim:
|
||||
|
||||
@@ -105,7 +105,7 @@ def make_toggles(model_version: str = "v11"):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_version", ["v11", "v12"])
|
||||
@pytest.mark.parametrize("model_version", ["v11", "v12", "v13"])
|
||||
def test_experimental_mlsim_uses_vehicle_min_accel_floor(model_version):
|
||||
v_ego = 18.0
|
||||
desired_accel = -1.0
|
||||
@@ -126,7 +126,7 @@ def test_experimental_mlsim_uses_vehicle_min_accel_floor(model_version):
|
||||
assert planner.output_a_target < comfort_min_accel
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_version", ["v11", "v12"])
|
||||
@pytest.mark.parametrize("model_version", ["v11", "v12", "v13"])
|
||||
def test_acc_mode_uses_close_raw_lead_when_tracking_lead_is_debounced(model_version):
|
||||
v_ego = 5.0
|
||||
|
||||
@@ -151,7 +151,7 @@ def test_acc_mode_uses_close_raw_lead_when_tracking_lead_is_debounced(model_vers
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_version", ["v11", "v12"])
|
||||
@pytest.mark.parametrize("model_version", ["v11", "v12", "v13"])
|
||||
def test_acc_mode_matches_no_lead_baseline_for_far_vision_only_lead_without_tracking(model_version):
|
||||
v_ego = 29.0
|
||||
|
||||
@@ -259,7 +259,7 @@ def test_vision_slow_stopped_lead_cap_ignores_far_high_speed_stop_candidate():
|
||||
assert planner.get_vision_slow_stopped_lead_cap(lead, v_ego, -1.0, 1.45) is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_version", ["v11", "v12"])
|
||||
@pytest.mark.parametrize("model_version", ["v11", "v12", "v13"])
|
||||
def test_dynamic_t_follow_increases_modestly_for_closing_lead(model_version):
|
||||
v_ego = 21.535
|
||||
|
||||
@@ -283,7 +283,7 @@ def test_dynamic_t_follow_increases_modestly_for_closing_lead(model_version):
|
||||
assert planner.effective_t_follow < sm["starpilotPlan"].tFollow + 0.45
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_version", ["v11", "v12"])
|
||||
@pytest.mark.parametrize("model_version", ["v11", "v12", "v13"])
|
||||
def test_dynamic_t_follow_stays_near_base_for_far_highway_lead(model_version):
|
||||
v_ego = 29.26
|
||||
|
||||
@@ -305,7 +305,7 @@ def test_dynamic_t_follow_stays_near_base_for_far_highway_lead(model_version):
|
||||
assert planner.effective_t_follow == pytest.approx(sm["starpilotPlan"].tFollow, abs=0.02)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_version", ["v11", "v12"])
|
||||
@pytest.mark.parametrize("model_version", ["v11", "v12", "v13"])
|
||||
def test_dynamic_t_follow_releases_toward_base_after_lead_opens(model_version):
|
||||
v_ego = 21.535
|
||||
|
||||
@@ -333,7 +333,7 @@ def test_dynamic_t_follow_releases_toward_base_after_lead_opens(model_version):
|
||||
assert planner.effective_t_follow == pytest.approx(sm["starpilotPlan"].tFollow, abs=0.02)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_version", ["v11", "v12"])
|
||||
@pytest.mark.parametrize("model_version", ["v11", "v12", "v13"])
|
||||
def test_acc_mode_vision_lead_approach_cap_smooths_before_close_brake(model_version):
|
||||
approach_v_ego = 21.535
|
||||
close_v_ego = 21.435
|
||||
@@ -370,7 +370,7 @@ def test_acc_mode_vision_lead_approach_cap_smooths_before_close_brake(model_vers
|
||||
assert planner_close.output_a_target < planner_approach.output_a_target - 0.25
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_version", ["v11", "v12"])
|
||||
@pytest.mark.parametrize("model_version", ["v11", "v12", "v13"])
|
||||
def test_acc_mode_low_speed_vision_stop_buffer_sets_should_stop_before_tiny_gap(model_version):
|
||||
v_ego = 3.8
|
||||
|
||||
@@ -393,7 +393,7 @@ def test_acc_mode_low_speed_vision_stop_buffer_sets_should_stop_before_tiny_gap(
|
||||
assert planner.output_a_target < -1.0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_version", ["v11", "v12"])
|
||||
@pytest.mark.parametrize("model_version", ["v11", "v12", "v13"])
|
||||
def test_acc_mode_damps_far_radar_mild_lead_brake_more_than_close_brake(model_version):
|
||||
far_v_ego = 29.26
|
||||
far_v_cruise = 32.22
|
||||
|
||||
@@ -44,6 +44,7 @@ class ModelConstants:
|
||||
LANE_LINES_WIDTH = 2
|
||||
ROAD_EDGES_WIDTH = 2
|
||||
PLAN_WIDTH = 15
|
||||
ACTION_WIDTH = 2
|
||||
DESIRE_PRED_WIDTH = 8
|
||||
LAT_PLANNER_SOLUTION_WIDTH = 4
|
||||
DESIRED_CURV_WIDTH = 1
|
||||
|
||||
@@ -177,7 +177,10 @@ def fill_model_msg(base_msg: capnp._DynamicStructBuilder, extended_msg: capnp._D
|
||||
|
||||
# meta
|
||||
meta = modelV2.meta
|
||||
meta.desireState = net_output_data['desire_state'][0].reshape(-1).tolist()
|
||||
if 'desire_state' in net_output_data:
|
||||
meta.desireState = net_output_data['desire_state'][0].reshape(-1).tolist()
|
||||
else:
|
||||
meta.desireState = [0.0] * ModelConstants.DESIRE_PRED_WIDTH
|
||||
meta.desirePrediction = net_output_data['desire_pred'][0].reshape(-1).tolist()
|
||||
meta.engagedProb = net_output_data['meta'][0,Meta.ENGAGED].item()
|
||||
meta.init('disengagePredictions')
|
||||
|
||||
@@ -156,6 +156,8 @@ class ModelState:
|
||||
numpy_inputs['traffic_convention'] = np.zeros((1, ModelConstants.TRAFFIC_CONVENTION_LEN), dtype=np.float32)
|
||||
if 'features_buffer' in input_shapes:
|
||||
numpy_inputs['features_buffer'] = np.zeros((1, ModelConstants.INPUT_HISTORY_BUFFER_LEN, ModelConstants.FEATURE_LEN), dtype=np.float32)
|
||||
if 'action_t' in input_shapes:
|
||||
numpy_inputs['action_t'] = np.zeros(input_shapes['action_t'], dtype=np.float32)
|
||||
|
||||
# Optional inputs for non-v11 (and some v10/v9 variants)
|
||||
# Lateral control params
|
||||
@@ -282,8 +284,10 @@ class ModelState:
|
||||
self.is_v11 = (self.policy_generation == "v11")
|
||||
self.is_v10 = (self.policy_generation == "v10")
|
||||
self.is_v12 = (self.policy_generation == "v12")
|
||||
self.is_v13 = (self.policy_generation == "v13")
|
||||
self.is_v9 = (self.policy_generation == "v9")
|
||||
self.mlsim = (self.policy_generation in ("v8", "v10", "v11", "v12"))
|
||||
self.mlsim = (self.policy_generation in ("v8", "v10", "v11", "v12", "v13"))
|
||||
self.policy_has_plan = 'plan' in self.policy_output_slices
|
||||
|
||||
self.frames = {name: DrivingModelFrame(context, ModelConstants.TEMPORAL_SKIP) for name in self.vision_input_names}
|
||||
self.prev_desire = np.zeros(ModelConstants.DESIRE_LEN, dtype=np.float32)
|
||||
@@ -307,7 +311,7 @@ class ModelState:
|
||||
self.off_policy_output: np.ndarray | None = None
|
||||
|
||||
off_policy_metadata = None
|
||||
if self.policy_generation == "v12" or OFF_POLICY_METADATA_PATH.is_file() or OFF_POLICY_PKL_PATH.is_file():
|
||||
if self.policy_generation in ("v12", "v13") or OFF_POLICY_METADATA_PATH.is_file() or OFF_POLICY_PKL_PATH.is_file():
|
||||
resolved_off_policy_meta = ensure_artifact(OFF_POLICY_METADATA_PATH, "driving_off_policy_metadata.pkl", optional=True)
|
||||
if resolved_off_policy_meta is not None:
|
||||
with open(resolved_off_policy_meta, 'rb') as f:
|
||||
@@ -316,6 +320,7 @@ class ModelState:
|
||||
if off_policy_metadata is not None:
|
||||
self.off_policy_input_shapes = off_policy_metadata['input_shapes']
|
||||
self.off_policy_output_slices = off_policy_metadata['output_slices']
|
||||
self.off_policy_has_plan = 'plan' in self.off_policy_output_slices
|
||||
off_policy_output_size = off_policy_metadata['output_shapes']['outputs'][1]
|
||||
self.off_policy_numpy_inputs, self.off_policy_prev_desired_curv_key = self._build_policy_inputs(self.off_policy_input_shapes)
|
||||
self.off_policy_desire_key = next((k for k in self.off_policy_numpy_inputs if k.startswith('desire')), None)
|
||||
@@ -326,6 +331,8 @@ class ModelState:
|
||||
with open(resolved_off_policy_pkl, "rb") as f:
|
||||
self.off_policy_run = pickle.load(f)
|
||||
self.off_policy_enabled = True
|
||||
else:
|
||||
self.off_policy_has_plan = False
|
||||
|
||||
# Optional temporal buffer for previous desired curvature (allocate only if any model expects it)
|
||||
if self.prev_desired_curv_key is not None or self.off_policy_prev_desired_curv_key is not None:
|
||||
@@ -374,6 +381,11 @@ class ModelState:
|
||||
if self.off_policy_enabled and 'traffic_convention' in self.off_policy_numpy_inputs:
|
||||
self.off_policy_numpy_inputs['traffic_convention'][:] = inputs['traffic_convention']
|
||||
|
||||
if 'action_t' in self.numpy_inputs:
|
||||
self.numpy_inputs['action_t'][:] = inputs['action_t']
|
||||
if self.off_policy_enabled and 'action_t' in self.off_policy_numpy_inputs:
|
||||
self.off_policy_numpy_inputs['action_t'][:] = inputs['action_t']
|
||||
|
||||
if 'lateral_control_params' in self.numpy_inputs:
|
||||
self.numpy_inputs['lateral_control_params'][:] = inputs['lateral_control_params']
|
||||
if self.off_policy_enabled and 'lateral_control_params' in self.off_policy_numpy_inputs:
|
||||
@@ -413,14 +425,14 @@ class ModelState:
|
||||
self.full_prev_desired_curv[0,-1,:] = policy_outputs_dict['desired_curvature'][0, :]
|
||||
|
||||
if self.prev_desired_curv_key is not None:
|
||||
# v9/v10/v11/v12 models expect zeros for prev_desired_curv(s); others use history
|
||||
if self.is_v9 or self.is_v10 or self.is_v11 or self.is_v12:
|
||||
# v9/v10/v11/v12/v13 models expect zeros for prev_desired_curv(s); others use history
|
||||
if self.is_v9 or self.is_v10 or self.is_v11 or self.is_v12 or self.is_v13:
|
||||
self.numpy_inputs[self.prev_desired_curv_key][:] = 0 * self.full_prev_desired_curv[0, self.temporal_idxs]
|
||||
else:
|
||||
self.numpy_inputs[self.prev_desired_curv_key][:] = self.full_prev_desired_curv[0, self.temporal_idxs]
|
||||
|
||||
if self.off_policy_enabled and self.off_policy_prev_desired_curv_key is not None:
|
||||
if self.is_v9 or self.is_v12:
|
||||
if self.is_v9 or self.is_v12 or self.is_v13:
|
||||
self.off_policy_numpy_inputs[self.off_policy_prev_desired_curv_key][:] = 0 * self.full_prev_desired_curv[0, self.temporal_idxs]
|
||||
else:
|
||||
self.off_policy_numpy_inputs[self.off_policy_prev_desired_curv_key][:] = self.full_prev_desired_curv[0, self.temporal_idxs]
|
||||
@@ -431,7 +443,8 @@ class ModelState:
|
||||
off_policy_outputs_dict = self.off_policy_parser.parse_policy_outputs(
|
||||
self.slice_outputs(self.off_policy_output, self.off_policy_output_slices)
|
||||
)
|
||||
off_policy_outputs_dict.pop('plan', None)
|
||||
if self.policy_has_plan:
|
||||
off_policy_outputs_dict.pop('plan', None)
|
||||
combined_outputs_dict = {**combined_outputs_dict, **off_policy_outputs_dict, **policy_outputs_dict}
|
||||
else:
|
||||
combined_outputs_dict = {**combined_outputs_dict, **policy_outputs_dict}
|
||||
@@ -587,10 +600,17 @@ def main(demo=False):
|
||||
bufs = {name: buf_extra if 'big' in name else buf_main for name in model.vision_input_names}
|
||||
transforms = {name: model_transform_extra if 'big' in name else model_transform_main for name in model.vision_input_names}
|
||||
|
||||
frame_delay = DT_MDL # Average time elapsed since the current frame finished exposing.
|
||||
action_delay = DT_MDL / 2 # Target the midpoint between current output and the next model step.
|
||||
lat_action_t = lat_delay + frame_delay + action_delay
|
||||
long_action_t = long_delay + frame_delay + action_delay
|
||||
|
||||
inputs:dict[str, np.ndarray] = {
|
||||
model.desire_key: vec_desire,
|
||||
'traffic_convention': traffic_convention,
|
||||
}
|
||||
if 'action_t' in model.numpy_inputs or (model.off_policy_enabled and 'action_t' in model.off_policy_numpy_inputs):
|
||||
inputs['action_t'] = np.array([lat_action_t, long_action_t], dtype=np.float32)
|
||||
# Include optional inputs only if the loaded model expects them
|
||||
if 'lateral_control_params' in model.numpy_inputs:
|
||||
inputs['lateral_control_params'] = lateral_control_params
|
||||
@@ -606,12 +626,10 @@ def main(demo=False):
|
||||
drivingdata_send = messaging.new_message('drivingModelData')
|
||||
posenet_send = messaging.new_message('cameraOdometry')
|
||||
|
||||
frame_delay = DT_MDL # Average time elapsed since the current frame finished exposing.
|
||||
action_delay = DT_MDL / 2 # Target the midpoint between current output and the next model step.
|
||||
action = get_action_from_model(
|
||||
model_output, prev_action,
|
||||
lat_delay + frame_delay + action_delay,
|
||||
long_delay + frame_delay + action_delay,
|
||||
lat_action_t,
|
||||
long_action_t,
|
||||
v_ego, model.mlsim, model.is_v9, starpilot_toggles,
|
||||
)
|
||||
prev_action = action
|
||||
|
||||
@@ -131,11 +131,14 @@ class Parser:
|
||||
|
||||
def parse_policy_outputs(self, outs: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
|
||||
self.split_outputs(outs)
|
||||
if 'action' in outs:
|
||||
self.parse_mdn('action', outs, in_N=0, out_N=0, out_shape=(ModelConstants.ACTION_WIDTH,))
|
||||
if 'lat_planner_solution' in outs:
|
||||
self.parse_mdn('lat_planner_solution', outs, in_N=0, out_N=0, out_shape=(ModelConstants.IDX_N, ModelConstants.LAT_PLANNER_SOLUTION_WIDTH))
|
||||
if 'desired_curvature' in outs:
|
||||
self.parse_mdn('desired_curvature', outs, in_N=0, out_N=0, out_shape=(ModelConstants.DESIRED_CURV_WIDTH,))
|
||||
self.parse_categorical_crossentropy('desire_state', outs, out_shape=(ModelConstants.DESIRE_PRED_WIDTH,))
|
||||
if 'desire_state' in outs:
|
||||
self.parse_categorical_crossentropy('desire_state', outs, out_shape=(ModelConstants.DESIRE_PRED_WIDTH,))
|
||||
return outs
|
||||
|
||||
def parse_outputs(self, outs: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
|
||||
|
||||
@@ -772,7 +772,7 @@ class StarPilotDrivingModelLayout(StarPilotPanel):
|
||||
f"{key}_driving_vision_metadata.pkl",
|
||||
]
|
||||
|
||||
if version == "v12":
|
||||
if version in {"v12", "v13"}:
|
||||
files.extend(
|
||||
[
|
||||
f"{key}_driving_off_policy_tinygrad.pkl",
|
||||
|
||||
@@ -725,7 +725,7 @@ class DrivingModelBigButton(BigButton):
|
||||
f"{key}_driving_vision_metadata.pkl",
|
||||
]
|
||||
|
||||
if version == "v12":
|
||||
if version in {"v12", "v13"}:
|
||||
files.extend([
|
||||
f"{key}_driving_off_policy_tinygrad.pkl",
|
||||
f"{key}_driving_off_policy_metadata.pkl",
|
||||
|
||||
@@ -18,7 +18,7 @@ from openpilot.starpilot.common.starpilot_utilities import delete_file
|
||||
from openpilot.starpilot.common.starpilot_variables import MODELS_PATH
|
||||
|
||||
MANIFEST_CANDIDATES = ("v21",)
|
||||
TINYGRAD_VERSIONS = {"v8", "v9", "v10", "v11", "v12"}
|
||||
TINYGRAD_VERSIONS = {"v8", "v9", "v10", "v11", "v12", "v13"}
|
||||
DEFAULT_MODEL_KEY = "sc2"
|
||||
MODEL_KEY_CANONICAL_MAP = {
|
||||
"sc": DEFAULT_MODEL_KEY,
|
||||
@@ -192,7 +192,7 @@ class ModelManager:
|
||||
f"{model_key}_driving_vision_metadata.pkl",
|
||||
]
|
||||
|
||||
if model_version == "v12":
|
||||
if model_version in {"v12", "v13"}:
|
||||
filenames += [
|
||||
f"{model_key}_driving_off_policy_tinygrad.pkl",
|
||||
f"{model_key}_driving_off_policy_metadata.pkl",
|
||||
|
||||
@@ -972,7 +972,7 @@ class StarPilotVariables:
|
||||
if isinstance(toggle.model_version, bytes):
|
||||
toggle.model_version = toggle.model_version.decode("utf-8", "ignore")
|
||||
toggle.classic_model = toggle.model_version in {"v1", "v2", "v3", "v4"}
|
||||
toggle.tinygrad_model = toggle.model_version in {"v8", "v9", "v10", "v11", "v12"}
|
||||
toggle.tinygrad_model = toggle.model_version in {"v8", "v9", "v10", "v11", "v12", "v13"}
|
||||
toggle.tomb_raider = toggle.model == "space-lab"
|
||||
|
||||
toggle.model_ui = self.get_value("ModelUI")
|
||||
|
||||
@@ -4340,14 +4340,14 @@ def setup(app):
|
||||
if f"{model_key}.thneed" in on_disk_files:
|
||||
return True
|
||||
|
||||
if model_version in ("v8", "v9", "v10", "v11", "v12"):
|
||||
if model_version in ("v8", "v9", "v10", "v11", "v12", "v13"):
|
||||
required_files = {
|
||||
f"{model_key}_driving_policy_tinygrad.pkl",
|
||||
f"{model_key}_driving_vision_tinygrad.pkl",
|
||||
f"{model_key}_driving_policy_metadata.pkl",
|
||||
f"{model_key}_driving_vision_metadata.pkl",
|
||||
}
|
||||
if model_version == "v12":
|
||||
if model_version in ("v12", "v13"):
|
||||
required_files |= {
|
||||
f"{model_key}_driving_off_policy_tinygrad.pkl",
|
||||
f"{model_key}_driving_off_policy_metadata.pkl",
|
||||
|
||||
Reference in New Issue
Block a user