mirror of
https://github.com/firestar5683/StarPilot.git
synced 2026-06-28 01:52:06 +08:00
68 lines
2.2 KiB
Python
68 lines
2.2 KiB
Python
#!/usr/bin/env python3
|
|
import argparse
|
|
import pickle
|
|
import time
|
|
|
|
from tinygrad.device import Device
|
|
from tinygrad.engine.jit import TinyJit
|
|
from tinygrad.tensor import Tensor
|
|
|
|
from openpilot.selfdrive.modeld.compile_modeld import NV12Frame, _parse_size, warp_perspective_tinygrad
|
|
from openpilot.system.camerad.cameras.nv12_info import get_nv12_info
|
|
|
|
|
|
def make_warp_dm(nv12: NV12Frame, dm_w: int, dm_h: int):
|
|
cam_w, cam_h, stride, _, _, _ = nv12
|
|
stride_pad = stride - cam_w
|
|
|
|
def warp_dm(input_frame, matrix_inverse):
|
|
matrix_inverse = matrix_inverse.to(Device.DEFAULT).realize()
|
|
return warp_perspective_tinygrad(
|
|
input_frame[:cam_h * stride],
|
|
matrix_inverse,
|
|
(dm_w, dm_h),
|
|
(cam_h, cam_w),
|
|
stride_pad,
|
|
border_fill_val=16,
|
|
).reshape(-1, dm_h * dm_w)
|
|
|
|
return warp_dm
|
|
|
|
|
|
def compile_dm_warp(nv12: NV12Frame, dm_w: int, dm_h: int, pkl_path: str) -> None:
|
|
print(f"Compiling DM warp for {nv12.width}x{nv12.height} -> {dm_w}x{dm_h}...")
|
|
warp_dm_jit = TinyJit(make_warp_dm(nv12, dm_w, dm_h), prune=True)
|
|
|
|
for index in range(10):
|
|
frame = Tensor.randint(nv12.size, low=0, high=256, dtype="uint8").realize()
|
|
matrix_inverse = Tensor(Tensor.randn(3, 3).mul(8).realize().numpy(), device="NPY")
|
|
Device.default.synchronize()
|
|
start = time.perf_counter()
|
|
warp_dm_jit(frame, matrix_inverse).realize()
|
|
queued = time.perf_counter()
|
|
Device.default.synchronize()
|
|
end = time.perf_counter()
|
|
print(f" [{index + 1}/10] enqueue {(queued - start) * 1e3:6.2f} ms -- total {(end - start) * 1e3:6.2f} ms")
|
|
|
|
with open(pkl_path, "wb") as artifact_file:
|
|
pickle.dump(warp_dm_jit, artifact_file)
|
|
print(f" saved {pkl_path}")
|
|
|
|
|
|
def main() -> int:
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--camera-resolution", type=_parse_size, required=True, help="Camera resolution WxH")
|
|
parser.add_argument("--warp-to", type=_parse_size, required=True, help="DM input resolution WxH")
|
|
parser.add_argument("--output", required=True)
|
|
args = parser.parse_args()
|
|
|
|
cam_w, cam_h = args.camera_resolution
|
|
nv12 = NV12Frame(cam_w, cam_h, *get_nv12_info(cam_w, cam_h))
|
|
dm_w, dm_h = args.warp_to
|
|
compile_dm_warp(nv12, dm_w, dm_h, args.output)
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|