Files
firestar5683 d97100bd14 tiny my BUTT
2026-06-23 12:01:44 -05:00

553 lines
21 KiB
Python

#!/usr/bin/env python3
import argparse
import atexit
import math
import os
import pickle
import tempfile
import time
from collections import namedtuple
from functools import partial
import numpy as np
def _patch_tinygrad_fetch_fw():
import hashlib
import pathlib
import zstandard
from tinygrad import helpers
original_fetch_fw = getattr(helpers, "fetch_fw", None)
if original_fetch_fw is None:
return
def fetch_fw(path, name, sha256):
firmware_path = pathlib.Path(f"/lib/firmware/{path}/{name}.zst")
if firmware_path.is_file():
blob = zstandard.ZstdDecompressor().stream_reader(firmware_path.read_bytes()).read()
if hashlib.sha256(blob).hexdigest() == sha256:
return blob
return original_fetch_fw(path, name, sha256)
helpers.fetch_fw = fetch_fw
_patch_tinygrad_fetch_fw()
from tinygrad.device import Device
from tinygrad.engine.jit import TinyJit
from tinygrad.helpers import Context
from tinygrad.tensor import Tensor
ARTIFACT_FORMAT_VERSION = 1
MODEL_TYPES = ("vision_policy", "vision_multi_policy", "supercombo")
NV12Frame = namedtuple("NV12Frame", ["width", "height", "stride", "y_height", "uv_height", "size"])
WARP_INPUTS = ("img_q", "big_img_q", "tfm", "big_tfm")
SPLIT_POLICY_INPUTS = ("feat_q", "desire_q", "packed_npy_inputs")
SUPERCOMBO_POLICY_INPUTS = ("feat_q", "desire_q", "packed_npy_inputs")
WARP_DEV = os.getenv("WARP_DEV")
def _detect_desire_key(input_shapes):
return next((key for key in input_shapes if key.startswith("desire")), None)
def _detect_vision_keys(input_shapes):
image_keys = sorted(key for key in input_shapes if "img" in key)
road_key = next((key for key in image_keys if "big" not in key), None)
wide_key = next((key for key in image_keys if "big" in key), None)
if road_key is None or wide_key is None:
raise ValueError(f"Cannot determine road/wide image keys from {list(input_shapes)}")
return road_key, wide_key
def derive_frame_skip(input_shapes):
features_shape = input_shapes.get("features_buffer")
if features_shape is None:
return 1
return 1 if features_shape[1] >= 99 else 4
def make_random_images(keys, shape, device=None):
return {key: Tensor.randint(shape, low=0, high=256, dtype="uint8", device=device).realize() for key in keys}
def make_random_blob_images(keys, size, device=None):
keepalive: list[np.ndarray] = []
def make_inputs():
nonlocal keepalive
keepalive = []
tensors = {}
for key in keys:
frame = (32 * np.random.randn(size).astype(np.float32) + 128).clip(0, 255).astype(np.uint8)
keepalive.append(frame)
tensors[key] = Tensor.from_blob(frame.ctypes.data, (size,), dtype="uint8", device=device).realize()
return tensors
return make_inputs
def warp_perspective_tinygrad(src_flat, matrix_inverse, dst_shape, src_shape, stride_pad, border_fill_val=None):
width_dst, height_dst = dst_shape
height_src, width_src = src_shape
x = Tensor.arange(width_dst).reshape(1, width_dst).expand(height_dst, width_dst).reshape(-1)
y = Tensor.arange(height_dst).reshape(height_dst, 1).expand(height_dst, width_dst).reshape(-1)
src_x = matrix_inverse[0, 0] * x + matrix_inverse[0, 1] * y + matrix_inverse[0, 2]
src_y = matrix_inverse[1, 0] * x + matrix_inverse[1, 1] * y + matrix_inverse[1, 2]
src_w = matrix_inverse[2, 0] * x + matrix_inverse[2, 1] * y + matrix_inverse[2, 2]
src_x = src_x / src_w
src_y = src_y / src_w
x_round = Tensor.round(src_x)
y_round = Tensor.round(src_y)
x_nn_clipped = x_round.clip(0, width_src - 1).cast("int")
y_nn_clipped = y_round.clip(0, height_src - 1).cast("int")
sampled = src_flat[y_nn_clipped * (width_src + stride_pad) + x_nn_clipped]
if border_fill_val is None:
return sampled
in_bounds = ((x_round >= 0) & (x_round <= width_src - 1) &
(y_round >= 0) & (y_round <= height_src - 1)).cast(sampled.dtype)
return sampled * in_bounds + Tensor(border_fill_val, dtype=sampled.dtype) * (1 - in_bounds)
def frames_to_tensor(frames):
height = (frames.shape[0] * 2) // 3
width = frames.shape[1]
return Tensor.cat(
frames[0:height:2, 0::2],
frames[1:height:2, 0::2],
frames[0:height:2, 1::2],
frames[1:height:2, 1::2],
frames[height:height + height // 4].reshape((height // 2, width // 2)),
frames[height + height // 4:height + height // 2].reshape((height // 2, width // 2)),
dim=0,
).reshape((6, height // 2, width // 2))
def make_frame_prepare(nv12: NV12Frame, model_w, model_h):
cam_w, cam_h, stride, y_height, uv_height, _ = nv12
uv_offset = stride * y_height
stride_pad = stride - cam_w
def frame_prepare(input_frame, matrix_inverse):
matrix_inverse_uv = matrix_inverse * Tensor(
[[1.0, 1.0, 0.5], [1.0, 1.0, 0.5], [2.0, 2.0, 1.0]],
device=WARP_DEV,
)
uv = input_frame[uv_offset:uv_offset + uv_height * stride].reshape(uv_height, stride)
with Context(SPLIT_REDUCEOP=0):
y = warp_perspective_tinygrad(
input_frame[:cam_h * stride], matrix_inverse, (model_w, model_h), (cam_h, cam_w), stride_pad,
).realize()
u = warp_perspective_tinygrad(
uv[:cam_h // 2, :cam_w:2].flatten(), matrix_inverse_uv,
(model_w // 2, model_h // 2), (cam_h // 2, cam_w // 2), 0,
).realize()
v = warp_perspective_tinygrad(
uv[:cam_h // 2, 1:cam_w:2].flatten(), matrix_inverse_uv,
(model_w // 2, model_h // 2), (cam_h // 2, cam_w // 2), 0,
).realize()
return frames_to_tensor(y.cat(u).cat(v).reshape((model_h * 3 // 2, model_w)))
return frame_prepare
def make_warp_input_queues(vision_input_shapes, frame_skip, device):
road_key, _ = _detect_vision_keys(vision_input_shapes)
image_shape = vision_input_shapes[road_key]
frame_count = image_shape[1] // 6
image_buffer_shape = (frame_skip * (frame_count - 1) + 1, 6, image_shape[2], image_shape[3])
npy = {
"tfm": np.zeros((3, 3), dtype=np.float32),
"big_tfm": np.zeros((3, 3), dtype=np.float32),
}
queues = {
"img_q": Tensor(np.zeros(image_buffer_shape, dtype=np.uint8), device=device).contiguous().realize(),
"big_img_q": Tensor(np.zeros(image_buffer_shape, dtype=np.uint8), device=device).contiguous().realize(),
**{key: Tensor(value, device="NPY").realize() for key, value in npy.items()},
}
return queues, npy
def _packed_policy_shapes(input_shapes, include_prev_feature=False):
desire_key = _detect_desire_key(input_shapes)
if desire_key is None:
raise ValueError(f"No desire input found in {list(input_shapes)}")
shapes = {"desire": (input_shapes[desire_key][2],)}
for key, shape in input_shapes.items():
if key in ("features_buffer", desire_key) or "img" in key:
continue
shapes[key] = tuple(shape)
if include_prev_feature:
features_shape = input_shapes["features_buffer"]
shapes["prev_feat"] = (features_shape[0], features_shape[2])
return shapes, [math.prod(shape) for shape in shapes.values()]
def make_split_input_queues(vision_input_shapes, policy_input_shapes, frame_skip, device):
queues, npy = make_warp_input_queues(vision_input_shapes, frame_skip, device)
features_shape = policy_input_shapes["features_buffer"]
desire_key = _detect_desire_key(policy_input_shapes)
desire_shape = policy_input_shapes[desire_key]
packed_shapes, packed_sizes = _packed_policy_shapes(policy_input_shapes)
packed_inputs = np.zeros(sum(packed_sizes), dtype=np.float32)
npy.update({
key: value.reshape(shape)
for (key, shape), value in zip(
packed_shapes.items(), np.split(packed_inputs, np.cumsum(packed_sizes[:-1])), strict=True,
)
})
queues.update({
"feat_q": Tensor(
np.zeros((frame_skip * (features_shape[1] - 1) + 1, features_shape[0], features_shape[2]), dtype=np.float32),
device=device,
).contiguous().realize(),
"desire_q": Tensor(
np.zeros((frame_skip * desire_shape[1], desire_shape[0], desire_shape[2]), dtype=np.float32),
device=device,
).contiguous().realize(),
"packed_npy_inputs": Tensor(packed_inputs, device="NPY").realize(),
})
return queues, npy
def make_supercombo_input_queues(input_shapes, frame_skip, device):
queues, npy = make_warp_input_queues(input_shapes, frame_skip, device)
features_shape = input_shapes["features_buffer"]
desire_key = _detect_desire_key(input_shapes)
desire_shape = input_shapes[desire_key]
packed_shapes, packed_sizes = _packed_policy_shapes(input_shapes, include_prev_feature=True)
packed_inputs = np.zeros(sum(packed_sizes), dtype=np.float32)
npy.update({
key: value.reshape(shape)
for (key, shape), value in zip(
packed_shapes.items(), np.split(packed_inputs, np.cumsum(packed_sizes[:-1])), strict=True,
)
})
queues.update({
"feat_q": Tensor(
np.zeros((frame_skip * features_shape[1], features_shape[0], features_shape[2]), dtype=np.float32),
device=device,
).contiguous().realize(),
"desire_q": Tensor(
np.zeros((frame_skip * desire_shape[1], desire_shape[0], desire_shape[2]), dtype=np.float32),
device=device,
).contiguous().realize(),
"packed_npy_inputs": Tensor(packed_inputs, device="NPY").realize(),
})
return queues, npy
def shift_and_sample(buffer, new_value, sample_fn):
buffer.assign(buffer[1:].cat(new_value, dim=0).contiguous())
return sample_fn(buffer)
def sample_skip(buffer, frame_skip):
return buffer[::frame_skip].contiguous().flatten(0, 1).unsqueeze(0)
def sample_desire(buffer, frame_skip):
return buffer.reshape(-1, frame_skip, *buffer.shape[1:]).max(1).flatten(0, 1).unsqueeze(0)
def make_warp(nv12, model_w, model_h, frame_skip):
frame_prepare = make_frame_prepare(nv12, model_w, model_h)
sample_skip_fn = partial(sample_skip, frame_skip=frame_skip)
def warp_enqueue(img_q, big_img_q, tfm, big_tfm, frame, big_frame):
tfm = tfm.to(WARP_DEV)
big_tfm = big_tfm.to(WARP_DEV)
Tensor.realize(tfm, big_tfm)
warped = Tensor.cat(
frame_prepare(frame, tfm).unsqueeze(0),
frame_prepare(big_frame, big_tfm).unsqueeze(0),
).to(Device.DEFAULT)
img = shift_and_sample(img_q, warped[0:1], sample_skip_fn)
big_img = shift_and_sample(big_img_q, warped[1:2], sample_skip_fn)
return img, big_img
return warp_enqueue
def make_run_split_policy(vision_runner, policy_runners, metadata, policy_order, frame_skip):
sample_desire_fn = partial(sample_desire, frame_skip=frame_skip)
sample_skip_fn = partial(sample_skip, frame_skip=frame_skip)
vision_metadata = metadata["vision"]
policy_metadata = metadata[policy_order[0]]
vision_features_slice = vision_metadata["output_slices"]["hidden_state"]
desire_key = _detect_desire_key(policy_metadata["input_shapes"])
packed_shapes, packed_sizes = _packed_policy_shapes(policy_metadata["input_shapes"])
road_key, wide_key = _detect_vision_keys(vision_metadata["input_shapes"])
def run_policy(img, big_img, feat_q, desire_q, packed_npy_inputs):
packed_npy_inputs = packed_npy_inputs.to(Device.DEFAULT).realize()
unpacked = {
key: tensor.reshape(shape)
for (key, shape), tensor in zip(
packed_shapes.items(), packed_npy_inputs.split(packed_sizes), strict=True,
)
}
desire_buffer = shift_and_sample(
desire_q, unpacked.pop("desire").reshape(1, 1, -1), sample_desire_fn,
)
vision_output = next(iter(vision_runner({road_key: img, wide_key: big_img}).values())).cast("float32")
new_feature = vision_output[:, vision_features_slice].reshape(1, -1).unsqueeze(0)
features_buffer = shift_and_sample(feat_q, new_feature, sample_skip_fn)
policy_inputs = {
"features_buffer": features_buffer,
desire_key: desire_buffer,
**unpacked,
}
policy_outputs = [
next(iter(policy_runners[key](policy_inputs).values())).cast("float32")
for key in policy_order
]
return (vision_output, *policy_outputs)
return run_policy
def make_run_supercombo(model_runner, metadata, frame_skip):
input_shapes = metadata["model"]["input_shapes"]
output_slices = metadata["model"]["output_slices"]
sample_desire_fn = partial(sample_desire, frame_skip=frame_skip)
sample_skip_fn = partial(sample_skip, frame_skip=frame_skip)
desire_key = _detect_desire_key(input_shapes)
packed_shapes, packed_sizes = _packed_policy_shapes(input_shapes, include_prev_feature=True)
road_key, wide_key = _detect_vision_keys(input_shapes)
def run_policy(img, big_img, feat_q, desire_q, packed_npy_inputs):
packed_npy_inputs = packed_npy_inputs.to(Device.DEFAULT).realize()
unpacked = {
key: tensor.reshape(shape)
for (key, shape), tensor in zip(
packed_shapes.items(), packed_npy_inputs.split(packed_sizes), strict=True,
)
}
desire_buffer = shift_and_sample(
desire_q, unpacked.pop("desire").reshape(1, 1, -1), sample_desire_fn,
)
previous_feature = unpacked.pop("prev_feat")
features_buffer = shift_and_sample(
feat_q, previous_feature.reshape(1, 1, -1), sample_skip_fn,
)
model_inputs = {
road_key: img,
wide_key: big_img,
"features_buffer": features_buffer,
desire_key: desire_buffer,
**unpacked,
}
model_output = next(iter(model_runner(model_inputs).values())).cast("float32")
return model_output,
return run_policy
def compile_jit(jit, make_random_inputs, input_keys, make_queues):
seed = 42
def random_inputs_run(fn, current_seed, test_values=None, test_buffers=None, expect_match=True):
input_queues, npy = make_queues(Device.DEFAULT)
np.random.seed(current_seed)
Tensor.manual_seed(current_seed)
testing = test_values is not None or test_buffers is not None
run_count = 1 if testing else 3
for index in range(run_count):
for value in npy.values():
value[:] = np.random.randn(*value.shape).astype(value.dtype)
Device.default.synchronize()
random_inputs = make_random_inputs()
start = time.perf_counter()
outputs = fn(**{key: input_queues[key] for key in input_keys}, **random_inputs)
mid = time.perf_counter()
Device.default.synchronize()
end = time.perf_counter()
print(f" [{index + 1}/{run_count}] enqueue {(mid - start) * 1e3:6.2f} ms -- total {(end - start) * 1e3:6.2f} ms")
if index == 0:
values = [np.copy(value.numpy()) for value in outputs]
buffers = [np.copy(value.numpy()) for value in input_queues.values()]
if not all(np.isfinite(value).all() for value in values):
raise ValueError("Compiled JIT produced non-finite outputs")
if test_values is not None:
match = all(np.array_equal(lhs, rhs) for lhs, rhs in zip(values, test_values, strict=True))
assert match == expect_match, f"outputs {'differ from' if expect_match else 'match'} baseline (seed={current_seed})"
if test_buffers is not None:
match = all(np.array_equal(lhs, rhs) for lhs, rhs in zip(buffers, test_buffers, strict=True))
assert match == expect_match, f"buffers {'differ from' if expect_match else 'match'} baseline (seed={current_seed})"
return values, buffers
print("capture + replay")
test_values, test_buffers = random_inputs_run(jit, seed)
print("pickle round trip")
jit = pickle.loads(pickle.dumps(jit))
random_inputs_run(jit, seed, test_values, test_buffers, expect_match=True)
random_inputs_run(jit, seed + 1, test_values, test_buffers, expect_match=False)
return jit
def _parse_size(value):
width, height = value.lower().split("x")
return int(width), int(height)
def read_file_chunked_to_shm(path):
from openpilot.common.file_chunker import read_file_chunked
from openpilot.system.hardware.hw import Paths
with tempfile.NamedTemporaryFile(prefix="compile_modeld_", dir=Paths.shm_path(), delete=False) as output:
output.write(read_file_chunked(path))
temporary_path = output.name
atexit.register(lambda: os.path.exists(temporary_path) and os.remove(temporary_path))
return temporary_path
def validate_metadata(metadata):
output_shapes = metadata.get("output_shapes", {})
output_shape = output_shapes.get("outputs")
if not output_shape or len(output_shape) < 2:
raise ValueError(f"Invalid model output shape metadata: {output_shapes}")
output_size = output_shape[-1]
for name, output_slice in metadata.get("output_slices", {}).items():
start, stop, step = output_slice.indices(output_size)
if step != 1 or start < 0 or stop < start or stop > output_size:
raise ValueError(f"Invalid output slice {name}={output_slice} for output size {output_size}")
def main():
from tinygrad.nn.onnx import OnnxRunner
from openpilot.selfdrive.modeld.get_model_metadata import make_metadata_dict
from openpilot.system.camerad.cameras.nv12_info import get_nv12_info
parser = argparse.ArgumentParser()
parser.add_argument("--model-type", choices=MODEL_TYPES, required=True)
parser.add_argument("--model-size", type=_parse_size, required=True)
parser.add_argument("--camera-resolutions", type=_parse_size, nargs="+", required=True)
parser.add_argument("--frame-skip", type=int)
parser.add_argument("--behavior-version")
parser.add_argument("--output", required=True)
parser.add_argument("--vision-onnx")
parser.add_argument("--policy-onnx")
parser.add_argument("--off-policy-onnx")
parser.add_argument("--on-policy-onnx")
parser.add_argument("--supercombo-onnx")
args = parser.parse_args()
output = {
"format_version": ARTIFACT_FORMAT_VERSION,
"model_type": args.model_type,
"metadata": {},
}
if args.behavior_version:
output["behavior_version"] = args.behavior_version
if args.model_type == "supercombo":
if not args.supercombo_onnx:
parser.error("--supercombo-onnx is required for supercombo")
model_path = read_file_chunked_to_shm(args.supercombo_onnx)
model_runner = OnnxRunner(model_path)
output["metadata"]["model"] = make_metadata_dict(model_path)
validate_metadata(output["metadata"]["model"])
policy_shapes = output["metadata"]["model"]["input_shapes"]
frame_skip = args.frame_skip or derive_frame_skip(policy_shapes)
make_policy_queues = partial(make_supercombo_input_queues, policy_shapes, frame_skip)
run_policy = make_run_supercombo(model_runner, output["metadata"], frame_skip)
image_shapes = policy_shapes
policy_input_keys = SUPERCOMBO_POLICY_INPUTS
else:
if not args.vision_onnx:
parser.error("--vision-onnx is required for split models")
policy_paths = {}
if args.policy_onnx:
policy_paths["policy"] = args.policy_onnx
if args.off_policy_onnx:
policy_paths["off_policy"] = args.off_policy_onnx
if args.on_policy_onnx:
policy_paths["on_policy"] = args.on_policy_onnx
if args.model_type == "vision_policy" and set(policy_paths) != {"policy"}:
parser.error("vision_policy requires --policy-onnx")
if args.model_type == "vision_multi_policy" and not policy_paths:
parser.error("vision_multi_policy requires at least one policy ONNX")
vision_path = read_file_chunked_to_shm(args.vision_onnx)
resolved_policy_paths = {key: read_file_chunked_to_shm(path) for key, path in policy_paths.items()}
vision_runner = OnnxRunner(vision_path)
policy_runners = {key: OnnxRunner(path) for key, path in resolved_policy_paths.items()}
output["metadata"]["vision"] = make_metadata_dict(vision_path)
validate_metadata(output["metadata"]["vision"])
for key, path in resolved_policy_paths.items():
output["metadata"][key] = make_metadata_dict(path)
validate_metadata(output["metadata"][key])
policy_order = [key for key in ("on_policy", "off_policy", "policy") if key in policy_runners]
output["policy_order"] = policy_order
first_policy_shapes = output["metadata"][policy_order[0]]["input_shapes"]
for key in policy_order[1:]:
if output["metadata"][key]["input_shapes"] != first_policy_shapes:
raise ValueError(f"Policy input shapes differ between {policy_order[0]} and {key}")
frame_skip = args.frame_skip or derive_frame_skip(first_policy_shapes)
make_policy_queues = partial(
make_split_input_queues,
output["metadata"]["vision"]["input_shapes"],
first_policy_shapes,
frame_skip,
)
run_policy = make_run_split_policy(
vision_runner, policy_runners, output["metadata"], policy_order, frame_skip,
)
image_shapes = output["metadata"]["vision"]["input_shapes"]
policy_input_keys = SPLIT_POLICY_INPUTS
output["frame_skip"] = frame_skip
output["policy_input_keys"] = policy_input_keys
run_policy_jit = TinyJit(run_policy, prune=True)
road_key, wide_key = _detect_vision_keys(image_shapes)
make_random_model_inputs = partial(
make_random_images,
keys=[road_key, wide_key],
shape=image_shapes[road_key],
)
output["run_policy"] = compile_jit(
run_policy_jit, make_random_model_inputs, policy_input_keys, make_policy_queues,
)
model_w, model_h = args.model_size
for cam_w, cam_h in args.camera_resolutions:
nv12 = NV12Frame(cam_w, cam_h, *get_nv12_info(cam_w, cam_h))
warp_enqueue = TinyJit(make_warp(nv12, model_w, model_h, frame_skip), prune=True)
make_random_warp_inputs = make_random_blob_images(
keys=["frame", "big_frame"], size=nv12.size, device=WARP_DEV,
)
make_warp_queues = partial(make_warp_input_queues, image_shapes, frame_skip)
output[(cam_w, cam_h)] = compile_jit(
warp_enqueue, make_random_warp_inputs, WARP_INPUTS, make_warp_queues,
)
with open(args.output, "wb") as artifact_file:
pickle.dump(output, artifact_file)
print(f"Saved JITs to {args.output} ({os.path.getsize(args.output) / 1e6:.2f} MB)")
return 0
if __name__ == "__main__":
raise SystemExit(main())