mirror of
https://github.com/sunnypilot/sunnypilot.git
synced 2026-06-10 14:44:21 +08:00
Compare commits
17 Commits
compile-mo
...
deep-rl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6e7d9e5e52 | ||
|
|
a4a7c2335d | ||
|
|
7d7b6ee306 | ||
|
|
6a4c59c3e0 | ||
|
|
fb5cb7a1cc | ||
|
|
049dfd2eaa | ||
|
|
be20848487 | ||
|
|
cdd232b606 | ||
|
|
b21c70b1ba | ||
|
|
67e5bd3c1e | ||
|
|
74692d0b5f | ||
|
|
dd35c27981 | ||
|
|
159140e64e | ||
|
|
f1ab6c8dfb | ||
|
|
fba521dcff | ||
|
|
a8ef55bfaa | ||
|
|
a232f54e2d |
@@ -12,11 +12,11 @@ on:
|
||||
required: false
|
||||
type: string
|
||||
recompiled_dir:
|
||||
description: 'Existing recompiled directory number (e.g. 3 for recompiled3)'
|
||||
description: 'Existing recompiled directory number (e.g. 1 for recompiled1)'
|
||||
required: true
|
||||
type: string
|
||||
json_version:
|
||||
description: 'driving_models version number to update (e.g. 5 for driving_models_v5.json)'
|
||||
description: 'driving_models version number to update (e.g. 18 for driving_models_v18.json)'
|
||||
required: true
|
||||
type: string
|
||||
artifact_suffix:
|
||||
@@ -63,12 +63,11 @@ on:
|
||||
default: 'None'
|
||||
options:
|
||||
- None
|
||||
- Simple Plan Models
|
||||
- Space Lab Models
|
||||
- TR Models
|
||||
- DTR Models
|
||||
- Master Models
|
||||
- Release Models
|
||||
- 2025 World Models
|
||||
- 2026 World Models
|
||||
- Custom Merge Models
|
||||
- FOF series models
|
||||
- Other
|
||||
custom_model_folder:
|
||||
description: 'Custom model folder name (if "Other" selected)'
|
||||
|
||||
29
.github/workflows/sunnypilot-build-model.yaml
vendored
29
.github/workflows/sunnypilot-build-model.yaml
vendored
@@ -30,6 +30,11 @@ on:
|
||||
required: false
|
||||
type: string
|
||||
default: ''
|
||||
target_hardware:
|
||||
description: 'Hardware target to compile for (qcom or usbgpu)'
|
||||
required: false
|
||||
type: string
|
||||
default: 'qcom'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
upstream_branch:
|
||||
@@ -46,6 +51,14 @@ on:
|
||||
required: false
|
||||
type: boolean
|
||||
default: true
|
||||
target_hardware:
|
||||
description: 'Hardware target to compile for'
|
||||
required: true
|
||||
type: choice
|
||||
options:
|
||||
- qcom
|
||||
- usbgpu
|
||||
default: 'qcom'
|
||||
|
||||
|
||||
run-name: Build model [${{ inputs.custom_name || inputs.upstream_branch }}] from ref [${{ inputs.upstream_branch }}]
|
||||
@@ -169,7 +182,17 @@ jobs:
|
||||
COMPILE_MODELD="${{ github.workspace }}/sunnypilot/modeld_v2/compile_modeld.py"
|
||||
MODEL_SIZE=$(python3 -c "from openpilot.common.transformations.model import MEDMODEL_INPUT_SIZE as s; print(f'{s[0]}x{s[1]}')")
|
||||
CAMERA_RES=$(python3 -c "from openpilot.common.transformations.camera import _ar_ox_fisheye as a, _os_fisheye as o; print(f'{a.width}x{a.height} {o.width}x{o.height}')")
|
||||
TG_FLAGS="DEV=QCOM IMAGE=1 FLOAT16=1 NOLOCALS=1 JIT_BATCH_SIZE=0 OPENPILOT_HACKS=1"
|
||||
|
||||
if [ "${{ inputs.target_hardware }}" == "usbgpu" ]; then
|
||||
echo "USBGPU build"
|
||||
export USBGPU=1
|
||||
TG_FLAGS="DEV=AMD USBGPU=1 IMAGE=1 FLOAT16=1 NOLOCALS=1 JIT_BATCH_SIZE=0 OPENPILOT_HACKS=1"
|
||||
OUTPUT_PKL="${{ env.MODELS_DIR }}/big_driving_tinygrad.pkl"
|
||||
else
|
||||
echo "QCOM build"
|
||||
TG_FLAGS="DEV=QCOM IMAGE=1 FLOAT16=1 NOLOCALS=1 JIT_BATCH_SIZE=0 OPENPILOT_HACKS=1"
|
||||
OUTPUT_PKL="${{ env.MODELS_DIR }}/driving_tinygrad.pkl"
|
||||
fi
|
||||
|
||||
# Generate metadata for all ONNX files
|
||||
find "${{ env.MODELS_DIR }}" -maxdepth 1 -name '*.onnx' | while IFS= read -r onnx_file; do
|
||||
@@ -203,13 +226,13 @@ jobs:
|
||||
fi
|
||||
|
||||
if [ -n "$MODEL_TYPE" ]; then
|
||||
echo "Detected: $MODEL_TYPE -> driving_tinygrad.pkl"
|
||||
echo "Detected: $MODEL_TYPE -> $OUTPUT_PKL"
|
||||
env ${TG_FLAGS} python3 "$COMPILE_MODELD" \
|
||||
--model-type $MODEL_TYPE \
|
||||
--model-size $MODEL_SIZE \
|
||||
--camera-resolutions $CAMERA_RES \
|
||||
$ONNX_ARGS \
|
||||
--output "${{ env.MODELS_DIR }}/driving_tinygrad.pkl"
|
||||
--output "$OUTPUT_PKL"
|
||||
fi
|
||||
|
||||
- name: Validate Model Outputs
|
||||
|
||||
@@ -137,10 +137,16 @@ struct ModelManagerSP @0xaedffd8f31e7b55d {
|
||||
eta @2 :UInt32;
|
||||
}
|
||||
|
||||
struct Chunk {
|
||||
fileName @0 :Text;
|
||||
sha256 @1 :Text;
|
||||
}
|
||||
|
||||
struct Artifact {
|
||||
fileName @0 :Text;
|
||||
downloadUri @1 :DownloadUri;
|
||||
downloadProgress @2 :DownloadProgress;
|
||||
chunks @3 :List(Chunk);
|
||||
}
|
||||
|
||||
struct Model {
|
||||
|
||||
@@ -53,7 +53,7 @@ def validate_model_outputs(metadata_paths: list[Path]) -> None:
|
||||
print(f"Optional output keys detected: {detected_optional}")
|
||||
|
||||
|
||||
def create_short_name(full_name):
|
||||
def create_short_name(full_name: str) -> str:
|
||||
# Remove parentheses and extract alphanumeric words
|
||||
clean_name = re.sub(r'\([^)]*\)', '', full_name)
|
||||
words = [re.sub(r'[^a-zA-Z0-9]', '', word) for word in clean_name.split() if re.sub(r'[^a-zA-Z0-9]', '', word)]
|
||||
@@ -121,7 +121,7 @@ def _rename_pkl_with_chunks(old_pkl: Path, new_pkl: Path) -> Path:
|
||||
return old_pkl.rename(new_pkl)
|
||||
|
||||
|
||||
def generate_metadata(model_path: Path, output_dir: Path, short_name: str, driving_pkl: Path):
|
||||
def generate_metadata(model_path: Path, output_dir: Path, short_name: str, driving_pkl: Path) -> dict | None:
|
||||
base = model_path.stem
|
||||
metadata_file = output_dir / f"{base}_metadata.pkl"
|
||||
|
||||
@@ -134,7 +134,7 @@ def generate_metadata(model_path: Path, output_dir: Path, short_name: str, drivi
|
||||
|
||||
if not metadata_file.exists():
|
||||
print(f"Warning: Missing metadata for {base} ({metadata_file}), skipping", file=sys.stderr)
|
||||
return
|
||||
return None
|
||||
|
||||
tinygrad_hash = hashlib.sha256(_read_pkl_bytes(driving_pkl)).hexdigest()
|
||||
|
||||
@@ -143,15 +143,33 @@ def generate_metadata(model_path: Path, output_dir: Path, short_name: str, drivi
|
||||
|
||||
model_type = "offPolicy" if "off_policy" in base else "onPolicy" if "on_policy" in base else base.split("_")[-1]
|
||||
|
||||
chunks_config = []
|
||||
manifest_file = Path(f"{driving_pkl}.chunkmanifest")
|
||||
if manifest_file.exists():
|
||||
num_chunks = int(manifest_file.read_text().strip())
|
||||
for i in range(num_chunks):
|
||||
chunk_path = Path(f"{driving_pkl}.chunk{i + 1:02d}of{num_chunks:02d}")
|
||||
if chunk_path.exists():
|
||||
chunk_hash = hashlib.sha256(chunk_path.read_bytes()).hexdigest()
|
||||
chunks_config.append({
|
||||
"file_name": chunk_path.name,
|
||||
"sha256": chunk_hash
|
||||
})
|
||||
|
||||
artifact_data = {
|
||||
"file_name": driving_pkl.name,
|
||||
"download_uri": {
|
||||
"url": "https://gitlab.com/sunnypilot/public/docs.sunnypilot.ai/-/raw/main/",
|
||||
"sha256": tinygrad_hash
|
||||
}
|
||||
}
|
||||
|
||||
if chunks_config:
|
||||
artifact_data["chunks"] = chunks_config
|
||||
|
||||
return {
|
||||
"type": model_type,
|
||||
"artifact": {
|
||||
"file_name": driving_pkl.name,
|
||||
"download_uri": {
|
||||
"url": "https://gitlab.com/sunnypilot/public/docs.sunnypilot.ai/-/raw/main/",
|
||||
"sha256": tinygrad_hash
|
||||
}
|
||||
},
|
||||
"artifact": artifact_data,
|
||||
"metadata": {
|
||||
"file_name": metadata_file.name,
|
||||
"download_uri": {
|
||||
@@ -162,8 +180,8 @@ def generate_metadata(model_path: Path, output_dir: Path, short_name: str, drivi
|
||||
}
|
||||
|
||||
|
||||
def create_metadata_json(models: list, output_dir: Path, custom_name=None, short_name=None, is_20hz=False, upstream_branch="unknown"):
|
||||
metadata_json = {
|
||||
def create_metadata_json(models: list, output_dir: Path, custom_name=None, short_name=None, is_20hz=False, upstream_branch="unknown") -> None:
|
||||
bundle_json = {
|
||||
"short_name": short_name,
|
||||
"display_name": custom_name or upstream_branch,
|
||||
"is_20hz": is_20hz,
|
||||
@@ -179,6 +197,10 @@ def create_metadata_json(models: list, output_dir: Path, custom_name=None, short
|
||||
}
|
||||
|
||||
# Write metadata to output_dir
|
||||
metadata_json = {
|
||||
"bundles": [bundle_json]
|
||||
}
|
||||
|
||||
with open(output_dir / "metadata.json", "w") as f:
|
||||
json.dump(metadata_json, f, indent=2)
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ See the LICENSE.md file in the root directory for more details.
|
||||
"""
|
||||
|
||||
import os
|
||||
os.environ['GMMU'] = '0'
|
||||
from openpilot.system.hardware import TICI
|
||||
os.environ['DEV'] = 'QCOM' if TICI else 'CPU'
|
||||
USBGPU = "USBGPU" in os.environ
|
||||
@@ -109,6 +110,8 @@ class ModelState(ModelStateBase):
|
||||
jits = pickle.loads(read_file_chunked(pkl_path))
|
||||
|
||||
self.DEV = Device.DEFAULT
|
||||
self.WARP_DEV = 'CPU' if USBGPU else self.DEV
|
||||
self.QUEUE_DEV = self.DEV
|
||||
|
||||
metadata = jits['metadata']
|
||||
if 'model' in metadata:
|
||||
@@ -120,7 +123,7 @@ class ModelState(ModelStateBase):
|
||||
self._vision_input_names = [k for k in model_metadata['input_shapes'] if 'img' in k]
|
||||
from openpilot.sunnypilot.modeld_v2.compile_modeld import make_supercombo_input_queues
|
||||
frame_skip = derive_frame_skip({}, model_metadata['input_shapes'])
|
||||
self.input_queues, self.numpy_inputs = make_supercombo_input_queues(model_metadata['input_shapes'], frame_skip, device=self.DEV)
|
||||
self.input_queues, self.numpy_inputs = make_supercombo_input_queues(model_metadata['input_shapes'], frame_skip, device=self.QUEUE_DEV)
|
||||
else:
|
||||
vision_metadata = metadata['vision']
|
||||
policy_keys = [k for k in metadata if k != 'vision']
|
||||
@@ -138,7 +141,11 @@ class ModelState(ModelStateBase):
|
||||
policy_input_shapes = first_policy_metadata['input_shapes']
|
||||
self._vision_input_names = [k for k in vision_input_shapes if 'img' in k]
|
||||
frame_skip = derive_frame_skip(vision_input_shapes, policy_input_shapes)
|
||||
self.input_queues, self.numpy_inputs = make_split_input_queues(vision_input_shapes, policy_input_shapes, frame_skip, device=self.DEV)
|
||||
self.input_queues, self.numpy_inputs = make_split_input_queues(vision_input_shapes, policy_input_shapes, frame_skip, device=self.QUEUE_DEV)
|
||||
|
||||
self._desire_key = next(key for key in self.numpy_inputs if key.startswith('desire'))
|
||||
self._road_key = next(key for key in self._vision_input_names if 'big' not in key)
|
||||
self._wide_key = next(key for key in self._vision_input_names if 'big' in key)
|
||||
|
||||
from openpilot.sunnypilot.modeld_v2.parse_model_outputs_split import Parser as SplitParser
|
||||
from openpilot.sunnypilot.modeld_v2.parse_model_outputs import Parser as CombinedParser
|
||||
@@ -160,12 +167,11 @@ class ModelState(ModelStateBase):
|
||||
|
||||
self._run_policy = jits[(cam_w, cam_h)]['run_policy']
|
||||
self._warp_enqueue = jits[(cam_w, cam_h)]['warp_enqueue']
|
||||
road_name = next(k for k in self._vision_input_names if 'big' not in k)
|
||||
yuv_size = self.frame_buf_params[road_name][3]
|
||||
yuv_size = self.frame_buf_params[self._road_key][3]
|
||||
self._warp_enqueue(
|
||||
**self.input_queues,
|
||||
frame=Tensor(np.zeros(yuv_size, dtype=np.uint8), device=self.DEV).contiguous().realize(),
|
||||
big_frame=Tensor(np.zeros(yuv_size, dtype=np.uint8), device=self.DEV).contiguous().realize())
|
||||
frame=Tensor(np.zeros(yuv_size, dtype=np.uint8), device=self.WARP_DEV).contiguous().realize(),
|
||||
big_frame=Tensor(np.zeros(yuv_size, dtype=np.uint8), device=self.WARP_DEV).contiguous().realize())
|
||||
|
||||
|
||||
@property
|
||||
@@ -178,7 +184,7 @@ class ModelState(ModelStateBase):
|
||||
|
||||
@property
|
||||
def desire_key(self) -> str:
|
||||
return next(k for k in self.numpy_inputs if k.startswith('desire'))
|
||||
return self._desire_key
|
||||
|
||||
def run(self, bufs: dict[str, VisionBuf], transforms: dict[str, np.ndarray],
|
||||
inputs: dict[str, np.ndarray], prepare_only: bool) -> dict[str, np.ndarray] | None:
|
||||
@@ -189,19 +195,19 @@ class ModelState(ModelStateBase):
|
||||
yuv_size = self.frame_buf_params[key][3]
|
||||
cache_key = (key, ptr)
|
||||
if cache_key not in self._blob_cache:
|
||||
self._blob_cache[cache_key] = Tensor.from_blob(ptr, (yuv_size,), dtype='uint8', device=self.DEV)
|
||||
self._blob_cache[cache_key] = Tensor.from_blob(ptr, (yuv_size,), dtype='uint8', device=self.WARP_DEV)
|
||||
self.full_frames[key] = self._blob_cache[cache_key]
|
||||
|
||||
desire_key = self.desire_key
|
||||
inputs[desire_key][0] = 0
|
||||
self.numpy_inputs[desire_key][:] = np.where(inputs[desire_key] - self.prev_desire > .99, inputs[desire_key], 0)
|
||||
self.prev_desire[:] = inputs[desire_key]
|
||||
for key in ('traffic_convention', 'lateral_control_params'):
|
||||
for key in ('traffic_convention', 'lateral_control_params', 'action_t'):
|
||||
if key in self.numpy_inputs and key in inputs:
|
||||
self.numpy_inputs[key][:] = inputs[key]
|
||||
|
||||
road_key = next(n for n in bufs if 'big' not in n)
|
||||
wide_key = next(n for n in bufs if 'big' in n)
|
||||
road_key = self._road_key
|
||||
wide_key = self._wide_key
|
||||
self.numpy_inputs['tfm'][:, :] = transforms[road_key].reshape(3, 3)
|
||||
self.numpy_inputs['big_tfm'][:, :] = transforms[wide_key].reshape(3, 3)
|
||||
|
||||
@@ -240,13 +246,20 @@ class ModelState(ModelStateBase):
|
||||
|
||||
def get_action_from_model(self, model_output: dict[str, np.ndarray], prev_action: log.ModelDataV2.Action,
|
||||
lat_action_t: float, long_action_t: float, v_ego: float) -> log.ModelDataV2.Action:
|
||||
plan = model_output['plan'][0]
|
||||
desired_accel, should_stop = get_accel_from_plan(plan[:, Plan.VELOCITY][:, 0], plan[:, Plan.ACCELERATION][:, 0], self.constants.T_IDXS,
|
||||
action_t=long_action_t)
|
||||
desired_accel = smooth_value(desired_accel, prev_action.desiredAcceleration, self.LONG_SMOOTH_SECONDS)
|
||||
if 'action' not in model_output:
|
||||
plan = model_output['plan'][0]
|
||||
desired_accel, should_stop = get_accel_from_plan(plan[:, Plan.VELOCITY][:, 0], plan[:, Plan.ACCELERATION][:, 0], self.constants.T_IDXS,
|
||||
action_t=long_action_t)
|
||||
desired_accel = smooth_value(desired_accel, prev_action.desiredAcceleration, self.LONG_SMOOTH_SECONDS)
|
||||
|
||||
curvature_plan = (plan + (self.PLANPLUS_CONTROL - 1.0) * model_output['planplus'][0]
|
||||
if 'planplus' in model_output and self.PLANPLUS_CONTROL != 1.0 else plan)
|
||||
desired_curvature = get_curvature_from_output(model_output, curvature_plan, v_ego, lat_action_t, self.mlsim)
|
||||
else:
|
||||
desired_accel = model_output['action'][0, 1]
|
||||
desired_curvature = model_output['action'][0, 0] / (max(1.0, v_ego))**2
|
||||
should_stop = (v_ego < 0.3 and desired_accel < 0.1)
|
||||
|
||||
curvature_plan = plan + (self.PLANPLUS_CONTROL - 1.0) * model_output['planplus'][0] if 'planplus' in model_output and self.PLANPLUS_CONTROL != 1.0 else plan
|
||||
desired_curvature = get_curvature_from_output(model_output, curvature_plan, v_ego, lat_action_t, self.mlsim)
|
||||
if self.generation is not None and self.generation >= 10: # smooth curvature for post FOF models
|
||||
if v_ego > self.MIN_LAT_CONTROL_SPEED:
|
||||
desired_curvature = smooth_value(desired_curvature, prev_action.desiredCurvature, self.LAT_SMOOTH_SECONDS)
|
||||
@@ -399,6 +412,12 @@ def main(demo=False):
|
||||
|
||||
bufs = {name: buf_extra if 'big' in name else buf_main for name in model.vision_input_names}
|
||||
transforms = {name: model_transform_extra if 'big' in name else model_transform_main for name in model.vision_input_names}
|
||||
|
||||
frame_delay = DT_MDL # compensate for time passed since the frame was captured: current_time - timestamp_eof is 50ms on average
|
||||
action_delay = DT_MDL / 2 # middle of the interval between model output (current state) and next frame (expected state)
|
||||
lat_action_t = lat_delay + frame_delay + action_delay
|
||||
long_action_t = long_delay + frame_delay + action_delay
|
||||
|
||||
inputs:dict[str, np.ndarray] = {
|
||||
model.desire_key: vec_desire,
|
||||
'traffic_convention': traffic_convention,
|
||||
@@ -407,6 +426,9 @@ def main(demo=False):
|
||||
if 'lateral_control_params' in model.numpy_inputs:
|
||||
inputs['lateral_control_params'] = np.array([v_ego, lat_delay], dtype=np.float32)
|
||||
|
||||
if 'action_t' in model.numpy_inputs:
|
||||
inputs['action_t'] = np.array([lat_action_t, long_action_t], dtype=np.float32)
|
||||
|
||||
mt1 = time.perf_counter()
|
||||
model_output = model.run(bufs, transforms, inputs, prepare_only)
|
||||
mt2 = time.perf_counter()
|
||||
@@ -418,7 +440,7 @@ def main(demo=False):
|
||||
posenet_send = messaging.new_message('cameraOdometry')
|
||||
mdv2sp_send = messaging.new_message('modelDataV2SP')
|
||||
|
||||
action = model.get_action_from_model(model_output, prev_action, lat_delay + DT_MDL, long_delay + DT_MDL, v_ego)
|
||||
action = model.get_action_from_model(model_output, prev_action, lat_action_t, long_action_t, v_ego)
|
||||
prev_action = action
|
||||
fill_model_msg(drivingdata_send, modelv2_send, model_output, action,
|
||||
publish_state, meta_main.frame_id, meta_extra.frame_id, frame_id,
|
||||
|
||||
@@ -134,6 +134,8 @@ class Parser:
|
||||
out_shape=(SplitModelConstants.NUM_ROAD_EDGES,SplitModelConstants.IDX_N,SplitModelConstants.LANE_LINES_WIDTH))
|
||||
if 'sim_pose' in outs:
|
||||
self.parse_mdn('sim_pose', outs, in_N=0, out_N=0, out_shape=(SplitModelConstants.POSE_WIDTH,))
|
||||
if 'action' in outs:
|
||||
self.parse_mdn('action', outs, in_N=0, out_N=0, out_shape=(SplitModelConstants.ACTION_WIDTH,))
|
||||
|
||||
def parse_vision_outputs(self, outs: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
|
||||
self.parse_mdn('pose', outs, in_N=0, out_N=0, out_shape=(SplitModelConstants.POSE_WIDTH,))
|
||||
|
||||
@@ -26,11 +26,22 @@ class ModelParser:
|
||||
download_uri.sha256 = download_uri_data.get("sha256")
|
||||
return download_uri
|
||||
|
||||
@staticmethod
|
||||
def _parse_chunk(chunk_data) -> custom.ModelManagerSP.Chunk:
|
||||
chunk = custom.ModelManagerSP.Chunk()
|
||||
chunk.fileName = chunk_data.get("file_name")
|
||||
chunk.sha256 = chunk_data.get("sha256")
|
||||
return chunk
|
||||
|
||||
@staticmethod
|
||||
def _parse_artifact(artifact_data) -> custom.ModelManagerSP.Artifact:
|
||||
artifact = custom.ModelManagerSP.Artifact()
|
||||
artifact.fileName = artifact_data.get("file_name")
|
||||
artifact.downloadUri = ModelParser._parse_download_uri(artifact_data.get("download_uri", {}))
|
||||
|
||||
if "chunks" in artifact_data:
|
||||
artifact.chunks = [ModelParser._parse_chunk(chunk_data) for chunk_data in artifact_data["chunks"]]
|
||||
|
||||
return artifact
|
||||
|
||||
@staticmethod
|
||||
@@ -116,7 +127,7 @@ class ModelCache:
|
||||
|
||||
class ModelFetcher:
|
||||
"""Handles fetching and caching of model data from remote source"""
|
||||
MODEL_URL = "https://raw.githubusercontent.com/sunnypilot/sunnypilot-models/refs/heads/gh-pages/docs/driving_models_v17.json"
|
||||
MODEL_URL = "https://raw.githubusercontent.com/sunnypilot/sunnypilot-models/refs/heads/gh-pages/docs/driving_models_v18.json"
|
||||
|
||||
def __init__(self, params: Params):
|
||||
self.params = params
|
||||
@@ -184,4 +195,6 @@ if __name__ == "__main__":
|
||||
# Print artifact details
|
||||
print(f"Artifact: {model.artifact.fileName}, Download URI: {model.artifact.downloadUri.uri}")
|
||||
# Print metadata details
|
||||
if model.artifact.chunks:
|
||||
print(f"Contains {len(model.artifact.chunks)} chunks.")
|
||||
print(f"Metadata: {model.metadata.fileName}, Download URI: {model.metadata.downloadUri.uri}")
|
||||
|
||||
@@ -89,20 +89,16 @@ class ModelManagerSP:
|
||||
del self._download_start_times[model.fileName]
|
||||
|
||||
async def _download_chunked(self, base_url: str, base_path: str, artifact) -> None:
|
||||
from openpilot.common.file_chunker import get_manifest_path, get_chunk_name
|
||||
manifest_url = get_manifest_path(base_url)
|
||||
from openpilot.common.file_chunker import get_chunk_name, get_manifest_path
|
||||
|
||||
num_chunks = len(artifact.chunks)
|
||||
if num_chunks == 0:
|
||||
raise ValueError("No chunks defined in artifact")
|
||||
|
||||
manifest_path = get_manifest_path(base_path)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(manifest_url) as resp:
|
||||
if resp.status == 404:
|
||||
raise FileNotFoundError
|
||||
resp.raise_for_status()
|
||||
num_chunks = int((await resp.read()).strip())
|
||||
|
||||
self._download_start_times[artifact.fileName] = time.monotonic()
|
||||
|
||||
for i in range(num_chunks):
|
||||
for i, _ in enumerate(artifact.chunks):
|
||||
chunk_url = get_chunk_name(base_url, i, num_chunks)
|
||||
chunk_path = get_chunk_name(base_path, i, num_chunks)
|
||||
chunk_downloaded = 0
|
||||
@@ -117,7 +113,7 @@ class ModelManagerSP:
|
||||
if self.params.get("ModelManager_DownloadIndex") is None:
|
||||
raise Exception("Download cancelled")
|
||||
intra = chunk_downloaded / max(chunk_size, 1)
|
||||
progress = min(99, (i + intra) / num_chunks * 100)
|
||||
progress = min(99.0, ((i + intra) / num_chunks) * 100)
|
||||
artifact.downloadProgress.status = custom.ModelManagerSP.DownloadStatus.downloading
|
||||
artifact.downloadProgress.progress = progress
|
||||
artifact.downloadProgress.eta = self._calculate_eta(artifact.fileName, progress)
|
||||
@@ -148,9 +144,9 @@ class ModelManagerSP:
|
||||
self._report_status()
|
||||
return
|
||||
|
||||
try:
|
||||
if len(artifact.chunks) > 0:
|
||||
await self._download_chunked(url, full_path, artifact)
|
||||
except (FileNotFoundError, aiohttp.ClientResponseError):
|
||||
else:
|
||||
await self._download_file(url, full_path, artifact)
|
||||
|
||||
if not await verify_file(full_path, expected_hash):
|
||||
@@ -170,18 +166,16 @@ class ModelManagerSP:
|
||||
artifact.downloadProgress.status = custom.ModelManagerSP.DownloadStatus.failed
|
||||
artifact.downloadProgress.eta = 0
|
||||
self._sync_artifact_progress(artifact)
|
||||
self.selected_bundle.status = custom.ModelManagerSP.DownloadStatus.failed
|
||||
if self.selected_bundle:
|
||||
self.selected_bundle.status = custom.ModelManagerSP.DownloadStatus.failed
|
||||
self._report_status()
|
||||
self._download_start_times.pop(artifact.fileName, None)
|
||||
raise
|
||||
|
||||
async def _process_model(self, model, destination_path: str) -> None:
|
||||
"""Processes a single model download including verification"""
|
||||
model_artifact = model.artifact
|
||||
metadata_artifact = model.metadata
|
||||
|
||||
await self._process_artifact(metadata_artifact, destination_path)
|
||||
await self._process_artifact(model_artifact, destination_path)
|
||||
await self._process_artifact(model.metadata, destination_path)
|
||||
await self._process_artifact(model.artifact, destination_path)
|
||||
|
||||
def _report_status(self) -> None:
|
||||
"""Reports current status through messaging system"""
|
||||
@@ -222,7 +216,8 @@ class ModelManagerSP:
|
||||
self.selected_bundle = None
|
||||
|
||||
except Exception:
|
||||
self.selected_bundle.status = custom.ModelManagerSP.DownloadStatus.failed
|
||||
if self.selected_bundle:
|
||||
self.selected_bundle.status = custom.ModelManagerSP.DownloadStatus.failed
|
||||
raise
|
||||
|
||||
finally:
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
from openpilot.sunnypilot.models.helpers import get_active_bundle
|
||||
from openpilot.sunnypilot.models.runners.model_runner import ModelRunner
|
||||
from openpilot.sunnypilot.models.runners.tinygrad.tinygrad_runner import TinygradRunner, TinygradSplitRunner
|
||||
from openpilot.sunnypilot.models.runners.constants import ModelType
|
||||
|
||||
|
||||
def get_model_runner() -> ModelRunner:
|
||||
"""
|
||||
Factory function to create and return the appropriate ModelRunner instance.
|
||||
|
||||
Selects TinygradRunner, choosing TinygradSplitRunner if separate vision/policy
|
||||
models are detected in the active bundle.
|
||||
|
||||
:return: An instance of a ModelRunner subclass (ONNXRunner, TinygradRunner, or TinygradSplitRunner).
|
||||
"""
|
||||
bundle = get_active_bundle()
|
||||
if bundle and bundle.models:
|
||||
model_types = {m.type.raw for m in bundle.models}
|
||||
# Check if the bundle uses separate vision and policy models (legacy or new split format)
|
||||
split_types = {ModelType.vision, ModelType.policy, ModelType.offPolicy, ModelType.onPolicy}
|
||||
if model_types & split_types:
|
||||
return TinygradSplitRunner()
|
||||
# Otherwise, assume a single model (likely supercombo)
|
||||
if bundle.models:
|
||||
return TinygradRunner(bundle.models[0].type.raw)
|
||||
|
||||
# Default fallback to TinygradRunner with the supercombo type if bundle info is missing/incomplete
|
||||
return TinygradRunner(ModelType.supercombo)
|
||||
@@ -1,174 +0,0 @@
|
||||
from abc import abstractmethod, ABC
|
||||
|
||||
import numpy as np
|
||||
from openpilot.sunnypilot.models.helpers import get_active_bundle
|
||||
from openpilot.sunnypilot.models.runners.constants import NumpyDict, ShapeDict, Model, SliceDict, SEND_RAW_PRED
|
||||
from openpilot.system.hardware.hw import Paths
|
||||
import pickle
|
||||
|
||||
CUSTOM_MODEL_PATH = Paths.model_root()
|
||||
|
||||
|
||||
class ModelData:
|
||||
"""
|
||||
Stores metadata and configuration for a specific machine learning model.
|
||||
|
||||
This class loads model metadata (like input shapes and output slices)
|
||||
from a pickle file associated with a model instance.
|
||||
|
||||
:param model: The machine learning model object containing metadata.
|
||||
"""
|
||||
def __init__(self, model: Model):
|
||||
self.model = model
|
||||
self.metadata = model.metadata
|
||||
self.input_shapes: ShapeDict = {}
|
||||
self.output_slices: SliceDict = {}
|
||||
if self.metadata:
|
||||
self._load_metadata()
|
||||
|
||||
def _load_metadata(self) -> None:
|
||||
"""Loads input shapes and output slices from the model's metadata pickle file."""
|
||||
metadata_path = f"{CUSTOM_MODEL_PATH}/{self.metadata.fileName}"
|
||||
with open(metadata_path, 'rb') as f:
|
||||
model_metadata = pickle.load(f)
|
||||
self.input_shapes = model_metadata.get('input_shapes', {})
|
||||
self.output_slices = model_metadata.get('output_slices', {})
|
||||
|
||||
|
||||
class ModularRunner(ABC):
|
||||
"""
|
||||
Represents a modular runner for handling and slicing model outputs.
|
||||
|
||||
This abstract base class is designed to provide an interface for modular
|
||||
parsing and processing of model outputs. Classes inheriting from it must
|
||||
implement the specified abstract methods, defining how model outputs
|
||||
should be handled and stored. The primary goal is to enable structured
|
||||
parsing of outputs through a dictionary-based method mapping.
|
||||
|
||||
:ivar parser_method_dict: Mapping dictionary containing parser methods
|
||||
for handling specific types of outputs.
|
||||
:type parser_method_dict: dict
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def parser_method_dict(self) -> dict:
|
||||
pass
|
||||
|
||||
@parser_method_dict.setter
|
||||
@abstractmethod
|
||||
def parser_method_dict(self, value: dict) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _slice_outputs(self, model_outputs: np.ndarray) -> NumpyDict:
|
||||
pass
|
||||
|
||||
|
||||
class ModelRunner(ModularRunner):
|
||||
"""
|
||||
Abstract base class for managing and executing machine learning models.
|
||||
|
||||
Provides a common interface for loading models, preparing inputs, running
|
||||
inference, and slicing/parsing outputs based on model metadata. Derived
|
||||
classes implement the specifics of input preparation and model execution
|
||||
for different frameworks (e.g., Tinygrad, ONNX).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initializes the model runner, loading the active model bundle."""
|
||||
self.is_20hz: bool | None = None
|
||||
self.is_20hz_3d: bool | None = None
|
||||
self.models: dict[int, ModelData] = {}
|
||||
self._model_data: ModelData | None = None # Active model data for current operation
|
||||
self._parser_method_dict: dict = {}
|
||||
self.inputs: dict = {}
|
||||
self._parser = None
|
||||
self._load_models()
|
||||
self._constants = None
|
||||
|
||||
@property
|
||||
def constants(self):
|
||||
return self._constants
|
||||
|
||||
@property
|
||||
def parser_method_dict(self) -> dict:
|
||||
"""Returns the dictionary mapping model types to their respective parsing methods."""
|
||||
return self._parser_method_dict
|
||||
|
||||
@parser_method_dict.setter
|
||||
def parser_method_dict(self, value: dict) -> None:
|
||||
"""Sets the dictionary mapping model types to their respective parsing methods."""
|
||||
self._parser_method_dict = value
|
||||
|
||||
def _load_models(self) -> None:
|
||||
"""Loads the active model bundle configuration and sets up ModelData."""
|
||||
bundle = get_active_bundle()
|
||||
if not bundle:
|
||||
raise ValueError("No active model bundle found, why are we being executed?")
|
||||
|
||||
self.models = {model.type.raw: ModelData(model) for model in bundle.models}
|
||||
self.is_20hz = bundle.is20hz
|
||||
self.is_20hz_3d = False
|
||||
|
||||
@property
|
||||
def input_shapes(self) -> ShapeDict:
|
||||
"""Returns the input shapes for the currently active model."""
|
||||
if self._model_data:
|
||||
return self._model_data.input_shapes
|
||||
raise ValueError("Model data is not available. Ensure the model is loaded correctly.")
|
||||
|
||||
@property
|
||||
def output_slices(self) -> SliceDict:
|
||||
"""Returns the output slices for the currently active model."""
|
||||
if self._model_data:
|
||||
return self._model_data.output_slices
|
||||
raise ValueError("Model data is not available. Ensure the model is loaded correctly.")
|
||||
|
||||
@property
|
||||
def vision_input_names(self) -> list[str]:
|
||||
"""Returns the list of vision input names from the input shapes."""
|
||||
if self._model_data:
|
||||
return list(self._model_data.input_shapes.keys())
|
||||
raise ValueError("Model data is not available. Ensure the model is loaded correctly.")
|
||||
|
||||
@abstractmethod
|
||||
def prepare_inputs(self, numpy_inputs: NumpyDict) -> dict:
|
||||
"""
|
||||
Abstract method to prepare inputs for model inference.
|
||||
|
||||
:param numpy_inputs: Dictionary of numpy arrays for non-image inputs.
|
||||
:return: Dictionary of prepared inputs ready for the model.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _run_model(self) -> NumpyDict:
|
||||
"""
|
||||
Abstract method to execute model inference with prepared inputs.
|
||||
|
||||
:return: Dictionary containing the model's raw output arrays.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _slice_outputs(self, model_outputs: np.ndarray) -> NumpyDict:
|
||||
"""
|
||||
Slices the raw model output array based on the output_slices metadata.
|
||||
|
||||
:param model_outputs: The raw numpy array output from the model.
|
||||
:return: A dictionary where keys are output names and values are sliced numpy arrays.
|
||||
"""
|
||||
if not self._model_data:
|
||||
raise ValueError("Model data is not available. Ensure the model is loaded correctly.")
|
||||
sliced_outputs = {k: model_outputs[np.newaxis, v] for k, v in self._model_data.output_slices.items()}
|
||||
if SEND_RAW_PRED:
|
||||
sliced_outputs['raw_pred'] = model_outputs.copy() # Optionally include the full raw output
|
||||
return sliced_outputs
|
||||
|
||||
def run_model(self) -> NumpyDict:
|
||||
"""
|
||||
Executes the model inference pipeline: runs the model and parses outputs.
|
||||
|
||||
:return: Dictionary containing the final parsed model outputs.
|
||||
"""
|
||||
return self._run_model() # Parsing is handled within specific runner implementations
|
||||
@@ -1,91 +0,0 @@
|
||||
import os
|
||||
from abc import ABC
|
||||
|
||||
import numpy as np
|
||||
from openpilot.sunnypilot.modeld_v2.parse_model_outputs import Parser as CombinedParser
|
||||
from openpilot.sunnypilot.modeld_v2.parse_model_outputs_split import Parser as SplitParser
|
||||
from openpilot.sunnypilot.models.runners.constants import ModelType, NumpyDict
|
||||
from openpilot.sunnypilot.models.runners.model_runner import ModularRunner
|
||||
from openpilot.system.hardware.hw import Paths
|
||||
|
||||
|
||||
SEND_RAW_PRED = os.getenv('SEND_RAW_PRED')
|
||||
CUSTOM_MODEL_PATH = Paths.model_root()
|
||||
|
||||
|
||||
class OffPolicyTinygrad(ModularRunner, ABC):
|
||||
"""
|
||||
A TinygradRunner specialized for off-policy models.
|
||||
|
||||
Uses a SplitParser to handle outputs specific to the off-policy part of a split model setup.
|
||||
"""
|
||||
def __init__(self):
|
||||
self._off_policy_parser = SplitParser()
|
||||
self.parser_method_dict[ModelType.offPolicy] = self._parse_off_policy_outputs
|
||||
|
||||
def _parse_off_policy_outputs(self, model_outputs: np.ndarray) -> NumpyDict:
|
||||
"""Parses off-policy model outputs using SplitParser."""
|
||||
result: NumpyDict = self._off_policy_parser.parse_policy_outputs(self._slice_outputs(model_outputs))
|
||||
return result
|
||||
|
||||
|
||||
class OnPolicyTinygrad(ModularRunner, ABC):
|
||||
"""
|
||||
A TinygradRunner specialized for on-policy models.
|
||||
|
||||
Uses a SplitParser to handle outputs specific to the on-policy part of a split model setup.
|
||||
"""
|
||||
def __init__(self):
|
||||
self._on_policy_parser = SplitParser()
|
||||
self.parser_method_dict[ModelType.onPolicy] = self._parse_on_policy_outputs
|
||||
|
||||
def _parse_on_policy_outputs(self, model_outputs: np.ndarray) -> NumpyDict:
|
||||
"""Parses on-policy model outputs using SplitParser."""
|
||||
result: NumpyDict = self._on_policy_parser.parse_policy_outputs(self._slice_outputs(model_outputs))
|
||||
return result
|
||||
|
||||
|
||||
class PolicyTinygrad(ModularRunner, ABC):
|
||||
"""
|
||||
A TinygradRunner specialized for policy-only models.
|
||||
|
||||
Uses a SplitParser to handle outputs specific to the policy part of a split model setup.
|
||||
"""
|
||||
def __init__(self):
|
||||
self._policy_parser = SplitParser()
|
||||
self.parser_method_dict[ModelType.policy] = self._parse_policy_outputs
|
||||
|
||||
def _parse_policy_outputs(self, model_outputs: np.ndarray) -> NumpyDict:
|
||||
"""Parses policy model outputs using SplitParser."""
|
||||
result: NumpyDict = self._policy_parser.parse_policy_outputs(self._slice_outputs(model_outputs))
|
||||
return result
|
||||
|
||||
class VisionTinygrad(ModularRunner, ABC):
|
||||
"""
|
||||
A TinygradRunner specialized for vision-only models.
|
||||
|
||||
Uses a SplitParser to handle outputs specific to the vision part of a split model setup.
|
||||
"""
|
||||
def __init__(self):
|
||||
self._vision_parser = SplitParser()
|
||||
self.parser_method_dict[ModelType.vision] = self._parse_vision_outputs
|
||||
|
||||
def _parse_vision_outputs(self, model_outputs: np.ndarray) -> NumpyDict:
|
||||
"""Parses vision model outputs using SplitParser."""
|
||||
result: NumpyDict = self._vision_parser.parse_vision_outputs(self._slice_outputs(model_outputs))
|
||||
return result
|
||||
|
||||
class SupercomboTinygrad(ModularRunner, ABC):
|
||||
"""
|
||||
A TinygradRunner specialized for vision-only models.
|
||||
|
||||
Uses a SplitParser to handle outputs specific to the vision part of a split model setup.
|
||||
"""
|
||||
def __init__(self):
|
||||
self._supercombo_parser = CombinedParser()
|
||||
self.parser_method_dict[ModelType.supercombo] = self._parse_supercombo_outputs
|
||||
|
||||
def _parse_supercombo_outputs(self, model_outputs: np.ndarray) -> NumpyDict:
|
||||
"""Parses vision model outputs using SplitParser."""
|
||||
result: NumpyDict = self._supercombo_parser.parse_outputs(self._slice_outputs(model_outputs))
|
||||
return result
|
||||
@@ -1,179 +0,0 @@
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
from openpilot.sunnypilot.models.runners.constants import NumpyDict, ModelType, ShapeDict, CUSTOM_MODEL_PATH, SliceDict
|
||||
from openpilot.sunnypilot.models.runners.model_runner import ModelRunner
|
||||
from openpilot.sunnypilot.models.runners.tinygrad.model_types import PolicyTinygrad, VisionTinygrad, SupercomboTinygrad, OffPolicyTinygrad, OnPolicyTinygrad
|
||||
from openpilot.sunnypilot.models.split_model_constants import SplitModelConstants
|
||||
from openpilot.sunnypilot.modeld_v2.constants import ModelConstants
|
||||
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
|
||||
class TinygradRunner(ModelRunner, SupercomboTinygrad, PolicyTinygrad, VisionTinygrad, OffPolicyTinygrad, OnPolicyTinygrad):
|
||||
"""
|
||||
A ModelRunner implementation for executing Tinygrad models.
|
||||
|
||||
Handles loading Tinygrad model artifacts (.pkl), preparing inputs as Tinygrad
|
||||
Tensors (potentially using QCOM extensions on TICI), running inference,
|
||||
and parsing the outputs.
|
||||
|
||||
:param model_type: The type of model (e.g., supercombo) to load and run.
|
||||
"""
|
||||
def __init__(self, model_type: int = ModelType.supercombo):
|
||||
ModelRunner.__init__(self)
|
||||
SupercomboTinygrad.__init__(self)
|
||||
PolicyTinygrad.__init__(self)
|
||||
VisionTinygrad.__init__(self)
|
||||
OffPolicyTinygrad.__init__(self)
|
||||
OnPolicyTinygrad.__init__(self)
|
||||
self._constants = ModelConstants
|
||||
self._model_data = self.models.get(model_type)
|
||||
if not self._model_data or not self._model_data.model:
|
||||
raise ValueError(f"Model data for type {model_type} not available.")
|
||||
|
||||
artifact_filename = self._model_data.model.artifact.fileName
|
||||
assert artifact_filename.endswith('_tinygrad.pkl'), \
|
||||
f"Invalid model file {artifact_filename} for TinygradRunner"
|
||||
|
||||
model_pkl_path = f"{CUSTOM_MODEL_PATH}/{artifact_filename}"
|
||||
with open(model_pkl_path, "rb") as f:
|
||||
try:
|
||||
# Load the compiled Tinygrad model runner function
|
||||
self.model_run = pickle.load(f)
|
||||
except FileNotFoundError as e:
|
||||
# Provide a helpful error message if the model was built for a different platform
|
||||
assert "/dev/kgsl-3d0" not in str(e), "Model was built on C3 or C3X, but is being loaded on PC"
|
||||
raise
|
||||
|
||||
# Map input names to their required dtype and device from the loaded model
|
||||
self.input_to_dtype = {}
|
||||
self.input_to_device = {}
|
||||
for idx, name in enumerate(self.model_run.captured.expected_names):
|
||||
info = self.model_run.captured.expected_input_info[idx]
|
||||
self.input_to_dtype[name] = info[2] # dtype
|
||||
self.input_to_device[name] = info[3] # device
|
||||
self._policy_cached = False
|
||||
|
||||
@property
|
||||
def vision_input_names(self) -> list[str]:
|
||||
"""Returns the list of vision input names from the input shapes."""
|
||||
return [name for name in self.input_shapes.keys() if 'img' in name]
|
||||
|
||||
|
||||
def prepare_policy_inputs(self, numpy_inputs: NumpyDict):
|
||||
if not self._policy_cached:
|
||||
for key, value in numpy_inputs.items():
|
||||
self.inputs[key] = Tensor(value, device='NPY').realize()
|
||||
self._policy_cached = True
|
||||
|
||||
def prepare_inputs(self, numpy_inputs: NumpyDict) -> dict:
|
||||
"""Prepares all vision and policy inputs for the model."""
|
||||
self.prepare_policy_inputs(numpy_inputs)
|
||||
for key in self.vision_input_names:
|
||||
if key in self.inputs:
|
||||
self.inputs[key] = self.inputs[key].cast(self.input_to_dtype[key])
|
||||
return self.inputs
|
||||
|
||||
def _run_model(self) -> NumpyDict:
|
||||
"""Runs the Tinygrad model inference and parses the outputs."""
|
||||
outputs = self.model_run(**self.inputs).contiguous().realize().uop.base.buffer.numpy().flatten()
|
||||
return self._parse_outputs(outputs)
|
||||
|
||||
def _parse_outputs(self, model_outputs: np.ndarray) -> NumpyDict:
|
||||
"""Parses the raw model outputs using the standard Parser."""
|
||||
if self._model_data is None:
|
||||
raise ValueError("Model data is not available. Ensure the model is loaded correctly.")
|
||||
|
||||
result: NumpyDict = self.parser_method_dict[self._model_data.model.type.raw](model_outputs)
|
||||
return result
|
||||
|
||||
|
||||
class TinygradSplitRunner(ModelRunner):
|
||||
"""
|
||||
A ModelRunner that coordinates separate TinygradVisionRunner and TinygradPolicyRunner instances.
|
||||
|
||||
Manages the execution of split vision and policy models, combining their inputs and outputs.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.is_20hz_3d = True
|
||||
self.vision_runner = TinygradRunner(ModelType.vision)
|
||||
self.policy_runner = TinygradRunner(ModelType.policy) if self.models.get(ModelType.policy) else None
|
||||
self.off_policy_runner = TinygradRunner(ModelType.offPolicy) if self.models.get(ModelType.offPolicy) else None
|
||||
self.on_policy_runner = TinygradRunner(ModelType.onPolicy) if self.models.get(ModelType.onPolicy) else None
|
||||
self._constants = SplitModelConstants
|
||||
|
||||
def _run_model(self) -> NumpyDict:
|
||||
"""Runs both vision and policy models and merges their parsed outputs."""
|
||||
vision_output = self.vision_runner.run_model()
|
||||
outputs = {**vision_output}
|
||||
|
||||
if self.policy_runner:
|
||||
policy_output = self.policy_runner.run_model()
|
||||
outputs.update(policy_output)
|
||||
|
||||
if self.off_policy_runner:
|
||||
off_policy_output = self.off_policy_runner.run_model()
|
||||
if self.on_policy_runner:
|
||||
off_policy_output.pop('plan', None)
|
||||
outputs.update(off_policy_output)
|
||||
|
||||
if self.on_policy_runner:
|
||||
on_policy_output = self.on_policy_runner.run_model()
|
||||
outputs.update(on_policy_output)
|
||||
|
||||
if 'planplus' in outputs and 'plan' in outputs:
|
||||
outputs['plan'] = outputs['plan'] + outputs['planplus']
|
||||
|
||||
return outputs
|
||||
|
||||
@property
|
||||
def vision_input_names(self) -> list[str]:
|
||||
"""Returns the list of vision input names from the vision runner."""
|
||||
return list(self.vision_runner.vision_input_names)
|
||||
|
||||
@property
|
||||
def input_shapes(self) -> ShapeDict:
|
||||
"""Returns the combined input shapes from both vision and policy models."""
|
||||
shapes = {**self.vision_runner.input_shapes}
|
||||
if self.policy_runner:
|
||||
shapes.update(self.policy_runner.input_shapes)
|
||||
if self.off_policy_runner:
|
||||
shapes.update(self.off_policy_runner.input_shapes)
|
||||
if self.on_policy_runner:
|
||||
shapes.update(self.on_policy_runner.input_shapes)
|
||||
return shapes
|
||||
|
||||
@property
|
||||
def output_slices(self) -> SliceDict:
|
||||
"""Returns the combined output slices from both vision and policy models."""
|
||||
slices = {**self.vision_runner.output_slices}
|
||||
if self.policy_runner:
|
||||
slices.update(self.policy_runner.output_slices)
|
||||
if self.off_policy_runner:
|
||||
slices.update(self.off_policy_runner.output_slices)
|
||||
if self.on_policy_runner:
|
||||
slices.update(self.on_policy_runner.output_slices)
|
||||
return slices
|
||||
|
||||
def prepare_inputs(self, numpy_inputs: NumpyDict) -> dict:
|
||||
"""Prepares inputs for both vision and policy models."""
|
||||
if self.policy_runner:
|
||||
self.policy_runner.prepare_policy_inputs(numpy_inputs)
|
||||
|
||||
for key in self.vision_input_names:
|
||||
if key in self.inputs:
|
||||
self.vision_runner.inputs[key] = self.inputs[key].cast(self.vision_runner.input_to_dtype[key])
|
||||
|
||||
inputs = {**self.vision_runner.inputs}
|
||||
if self.policy_runner:
|
||||
inputs.update(self.policy_runner.inputs)
|
||||
|
||||
if self.off_policy_runner:
|
||||
self.off_policy_runner.prepare_policy_inputs(numpy_inputs)
|
||||
inputs.update(self.off_policy_runner.inputs)
|
||||
if self.on_policy_runner:
|
||||
self.on_policy_runner.prepare_policy_inputs(numpy_inputs)
|
||||
inputs.update(self.on_policy_runner.inputs)
|
||||
return inputs
|
||||
@@ -43,6 +43,7 @@ class SplitModelConstants:
|
||||
LANE_LINES_WIDTH = 2
|
||||
ROAD_EDGES_WIDTH = 2
|
||||
PLAN_WIDTH = 15
|
||||
ACTION_WIDTH = 2
|
||||
DESIRE_PRED_WIDTH = 8
|
||||
LAT_PLANNER_SOLUTION_WIDTH = 4
|
||||
DESIRED_CURV_WIDTH = 1
|
||||
|
||||
Submodule tinygrad_repo updated: ac1632ab96...2fecac4e4a
Reference in New Issue
Block a user