mirror of
https://github.com/sunnypilot/sunnypilot.git
synced 2026-06-08 11:25:51 +08:00
modeld: split warp (#38079)
* compiles * runs * dedupe compiling model * always build for both res * fix does not bind loop variable * rm size multiplier
This commit is contained in:
committed by
GitHub
parent
aff9f9ffae
commit
a3cc9c7ac3
@@ -190,7 +190,7 @@ else:
|
||||
np_version = SCons.Script.Value(np.__version__)
|
||||
Export('envCython', 'np_version')
|
||||
|
||||
Export('env', 'arch', 'acados', 'release')
|
||||
Export('env', 'arch', 'acados')
|
||||
|
||||
# Setup cache dir
|
||||
cache_dir = '/data/scons_cache' if arch == "larch64" else '/tmp/scons_cache'
|
||||
|
||||
@@ -7,24 +7,15 @@ from openpilot.common.file_chunker import chunk_file, get_chunk_targets, get_exi
|
||||
from openpilot.common.transformations.camera import _ar_ox_fisheye, _os_fisheye
|
||||
from openpilot.common.transformations.model import MEDMODEL_INPUT_SIZE, DM_INPUT_SIZE
|
||||
from openpilot.selfdrive.modeld.constants import ModelConstants
|
||||
from openpilot.system.hardware import HARDWARE, PC
|
||||
from openpilot.selfdrive.modeld.helpers import TG_INPUT_DEVICES_PATH, usbgpu_present, modeld_pkl_path
|
||||
|
||||
Import('env', 'arch', 'release')
|
||||
|
||||
CAMERA_CONFIGS = [
|
||||
(_ar_ox_fisheye.width, _ar_ox_fisheye.height), # tici: 1928x1208
|
||||
(_os_fisheye.width, _os_fisheye.height), # mici: 1344x760
|
||||
]
|
||||
|
||||
def get_camera_configs():
|
||||
DEVICE_RESOLUTIONS = {
|
||||
"tici": (_ar_ox_fisheye.width, _ar_ox_fisheye.height),
|
||||
"tizi": (_ar_ox_fisheye.width, _ar_ox_fisheye.height),
|
||||
"mici": (_os_fisheye.width, _os_fisheye.height),
|
||||
}
|
||||
if release or PC or 'CI' in os.environ:
|
||||
return set(DEVICE_RESOLUTIONS.values())
|
||||
return [DEVICE_RESOLUTIONS[HARDWARE.get_device_type()]]
|
||||
|
||||
CAMERA_CONFIGS = get_camera_configs()
|
||||
|
||||
Import('env', 'arch')
|
||||
chunker_file = File("#common/file_chunker.py")
|
||||
lenv = env.Clone()
|
||||
|
||||
@@ -101,8 +92,7 @@ for usbgpu in [False, True] if USBGPU else [False]:
|
||||
f'--policy-onnx {File(f"models/{file_prefix}driving_policy.onnx").abspath} '
|
||||
f'--output {target_pkl_path} --frame-skip {frame_skip}')
|
||||
onnx_sizes_sum = sum(os.path.getsize(f) for f in driving_onnx_deps)
|
||||
size_multiplier = 1 if usbgpu else 2 # TODO make weight dedupe work on QCOM
|
||||
chunk_targets = get_chunk_targets(target_pkl_path, estimate_pickle_max_size(onnx_sizes_sum)*size_multiplier)
|
||||
chunk_targets = get_chunk_targets(target_pkl_path, estimate_pickle_max_size(onnx_sizes_sum))
|
||||
def do_chunk(target, source, env, pkl=target_pkl_path, chunks=chunk_targets):
|
||||
chunk_file(pkl, chunks)
|
||||
node = lenv.Command(
|
||||
|
||||
@@ -30,11 +30,10 @@ from tinygrad.helpers import Context
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.engine.jit import TinyJit
|
||||
|
||||
from openpilot.common.file_chunker import read_file_chunked
|
||||
from openpilot.system.hardware.hw import Paths
|
||||
|
||||
|
||||
NV12Frame = namedtuple("NV12Frame", ['width', 'height', 'stride', 'y_height', 'uv_height', 'size'])
|
||||
WARP_INPUTS = ['img_q', 'big_img_q', 'tfm', 'big_tfm']
|
||||
POLICY_INPUTS = ['feat_q', 'desire_q', 'desire', 'traffic_convention']
|
||||
|
||||
UV_SCALE_MATRIX = np.array([[0.5, 0, 0], [0, 0.5, 0], [0, 0, 1]], dtype=np.float32)
|
||||
UV_SCALE_MATRIX_INV = np.linalg.inv(UV_SCALE_MATRIX)
|
||||
@@ -42,6 +41,10 @@ UV_SCALE_MATRIX_INV = np.linalg.inv(UV_SCALE_MATRIX)
|
||||
WARP_DEV = os.getenv('WARP_DEV')
|
||||
|
||||
|
||||
def make_random_images(keys, shape, device=None):
|
||||
return {k: Tensor.randint(shape, low=0, high=256, dtype='uint8', device=device).realize() for k in keys}
|
||||
|
||||
|
||||
def warp_perspective_tinygrad(src_flat, M_inv, dst_shape, src_shape, stride_pad, border_fill_val=None):
|
||||
w_dst, h_dst = dst_shape
|
||||
h_src, w_src = src_shape
|
||||
@@ -148,55 +151,49 @@ def sample_desire(buf, frame_skip):
|
||||
return buf.reshape(-1, frame_skip, *buf.shape[1:]).max(1).flatten(0, 1).unsqueeze(0)
|
||||
|
||||
|
||||
def make_run_policy(vision_runner, policy_runner, nv12: NV12Frame, model_w, model_h,
|
||||
vision_features_slice, frame_skip, prepare_only=False):
|
||||
def make_warp(nv12, model_w, model_h, frame_skip):
|
||||
frame_prepare = make_frame_prepare(nv12, model_w, model_h)
|
||||
sample_skip_fn = partial(sample_skip, frame_skip=frame_skip)
|
||||
sample_desire_fn = partial(sample_desire, frame_skip=frame_skip)
|
||||
|
||||
def run_policy(img_q, big_img_q, feat_q, desire_q, desire, traffic_convention, tfm, big_tfm, frame, big_frame):
|
||||
def warp_enqueue(img_q, big_img_q, tfm, big_tfm, frame, big_frame):
|
||||
tfm = tfm.to(WARP_DEV)
|
||||
big_tfm = big_tfm.to(WARP_DEV)
|
||||
desire = desire.to(Device.DEFAULT)
|
||||
traffic_convention = traffic_convention.to(Device.DEFAULT)
|
||||
Tensor.realize(tfm, big_tfm, desire, traffic_convention)
|
||||
Tensor.realize(tfm, big_tfm)
|
||||
|
||||
warped_frame = frame_prepare(frame, tfm).unsqueeze(0).to(Device.DEFAULT)
|
||||
warped_big_frame = frame_prepare(big_frame, big_tfm).unsqueeze(0).to(Device.DEFAULT)
|
||||
img = shift_and_sample(img_q, warped_frame, sample_skip_fn)
|
||||
big_img = shift_and_sample(big_img_q, warped_big_frame, sample_skip_fn)
|
||||
return img, big_img
|
||||
return warp_enqueue
|
||||
|
||||
if prepare_only:
|
||||
return img, big_img
|
||||
|
||||
def make_run_policy(vision_runner, policy_runner, vision_features_slice, frame_skip):
|
||||
sample_desire_fn = partial(sample_desire, frame_skip=frame_skip)
|
||||
sample_skip_fn = partial(sample_skip, frame_skip=frame_skip)
|
||||
|
||||
def run_policy(img, big_img, feat_q, desire_q, desire, traffic_convention):
|
||||
desire = desire.to(Device.DEFAULT)
|
||||
traffic_convention = traffic_convention.to(Device.DEFAULT)
|
||||
Tensor.realize(desire, traffic_convention)
|
||||
desire_buf = shift_and_sample(desire_q, desire.reshape(1, 1, -1), sample_desire_fn)
|
||||
vision_out = next(iter(vision_runner({'img': img, 'big_img': big_img}).values())).cast('float32')
|
||||
|
||||
new_feat = vision_out[:, vision_features_slice].reshape(1, -1).unsqueeze(0)
|
||||
feat_buf = shift_and_sample(feat_q, new_feat, sample_skip_fn)
|
||||
desire_buf = shift_and_sample(desire_q, desire.reshape(1, 1, -1), sample_desire_fn)
|
||||
|
||||
inputs = {'features_buffer': feat_buf, 'desire_pulse': desire_buf, 'traffic_convention': traffic_convention}
|
||||
policy_out = next(iter(policy_runner(inputs).values())).cast('float32')
|
||||
|
||||
return vision_out, policy_out
|
||||
return run_policy
|
||||
|
||||
|
||||
def compile_modeld(nv12: NV12Frame, model_w, model_h, prepare_only, frame_skip,
|
||||
vision_runner, policy_runner, vision_metadata, policy_metadata):
|
||||
print(f"Compiling combined policy JIT for {nv12.width}x{nv12.height} (prepare_only={prepare_only})...")
|
||||
|
||||
vision_features_slice = vision_metadata['output_slices']['hidden_state']
|
||||
def compile_jit(jit, make_random_inputs, input_keys, frame_skip, vision_metadata, policy_metadata):
|
||||
vision_input_shapes = vision_metadata['input_shapes']
|
||||
policy_input_shapes = policy_metadata['input_shapes']
|
||||
|
||||
_run = make_run_policy(vision_runner, policy_runner, nv12, model_w, model_h,
|
||||
vision_features_slice, frame_skip, prepare_only)
|
||||
run_policy_jit = TinyJit(_run, prune=True)
|
||||
|
||||
SEED = 42
|
||||
|
||||
def random_inputs_run_fn(fn, seed, test_val=None, test_buffers=None, expect_match=True):
|
||||
def random_inputs_run(fn, seed, test_val=None, test_buffers=None, expect_match=True):
|
||||
input_queues, npy = make_input_queues(vision_input_shapes, policy_input_shapes, frame_skip, Device.DEFAULT)
|
||||
np.random.seed(seed)
|
||||
Tensor.manual_seed(seed)
|
||||
@@ -205,13 +202,11 @@ def compile_modeld(nv12: NV12Frame, model_w, model_h, prepare_only, frame_skip,
|
||||
n_runs = 1 if testing else 3
|
||||
|
||||
for i in range(n_runs):
|
||||
frame = Tensor.randint(nv12.size, low=0, high=256, dtype='uint8', device=WARP_DEV).realize()
|
||||
big_frame = Tensor.randint(nv12.size, low=0, high=256, dtype='uint8', device=WARP_DEV).realize()
|
||||
for v in npy.values():
|
||||
v[:] = np.random.randn(*v.shape).astype(v.dtype)
|
||||
Device.default.synchronize()
|
||||
st = time.perf_counter()
|
||||
outs = fn(**input_queues, frame=frame, big_frame=big_frame)
|
||||
outs = fn(**{k: input_queues[k] for k in input_keys}, **make_random_inputs())
|
||||
mt = time.perf_counter()
|
||||
Device.default.synchronize()
|
||||
et = time.perf_counter()
|
||||
@@ -227,16 +222,15 @@ def compile_modeld(nv12: NV12Frame, model_w, model_h, prepare_only, frame_skip,
|
||||
if test_buffers is not None:
|
||||
match = all(np.array_equal(a, b) for a, b in zip(buffers, test_buffers, strict=True))
|
||||
assert match == expect_match, f"buffers {'differ from' if expect_match else 'match'} baseline (seed={seed})"
|
||||
return fn, val, buffers
|
||||
return val, buffers
|
||||
|
||||
print('capture + replay')
|
||||
run_policy_jit, test_val, test_buffers = random_inputs_run_fn(run_policy_jit, SEED)
|
||||
|
||||
test_val, test_buffers = random_inputs_run(jit, SEED)
|
||||
print('pickle round trip')
|
||||
run_policy_jit = pickle.loads(pickle.dumps(run_policy_jit))
|
||||
random_inputs_run_fn(run_policy_jit, SEED, test_val, test_buffers, expect_match=True)
|
||||
random_inputs_run_fn(run_policy_jit, SEED+1, test_val, test_buffers, expect_match=False)
|
||||
return run_policy_jit
|
||||
jit = pickle.loads(pickle.dumps(jit))
|
||||
random_inputs_run(jit, SEED, test_val, test_buffers, expect_match=True)
|
||||
random_inputs_run(jit, SEED+1, test_val, test_buffers, expect_match=False)
|
||||
return jit
|
||||
|
||||
|
||||
def _parse_size(s):
|
||||
@@ -245,6 +239,8 @@ def _parse_size(s):
|
||||
|
||||
|
||||
def read_file_chunked_to_shm(path):
|
||||
from openpilot.common.file_chunker import read_file_chunked
|
||||
from openpilot.system.hardware.hw import Paths
|
||||
shm_path = os.path.join(Paths.shm_path(), os.path.basename(path))
|
||||
atexit.register(lambda: os.path.exists(shm_path) and os.remove(shm_path))
|
||||
with open(shm_path, 'wb') as f:
|
||||
@@ -255,6 +251,7 @@ def read_file_chunked_to_shm(path):
|
||||
if __name__ == "__main__":
|
||||
from tinygrad.nn.onnx import OnnxRunner
|
||||
from openpilot.system.camerad.cameras.nv12_info import get_nv12_info
|
||||
from openpilot.selfdrive.modeld.get_model_metadata import make_metadata_dict
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument('--model-size', type=_parse_size, required=True, help='model input WxH')
|
||||
p.add_argument('--camera-resolutions', type=_parse_size, nargs='+', required=True,
|
||||
@@ -266,23 +263,26 @@ if __name__ == "__main__":
|
||||
args = p.parse_args()
|
||||
|
||||
out = defaultdict(dict)
|
||||
# init runners once so weights are shared
|
||||
from get_model_metadata import make_metadata_dict
|
||||
vision_path, policy_path = read_file_chunked_to_shm(args.vision_onnx), read_file_chunked_to_shm(args.policy_onnx)
|
||||
model_w, model_h = args.model_size
|
||||
|
||||
vision_runner = OnnxRunner(vision_path)
|
||||
policy_runner = OnnxRunner(policy_path)
|
||||
out['metadata']['vision'] = make_metadata_dict(vision_path)
|
||||
out['metadata']['policy'] = make_metadata_dict(policy_path)
|
||||
vision_metadata, policy_metadata = make_metadata_dict(vision_path), make_metadata_dict(policy_path)
|
||||
|
||||
run_policy_jit = TinyJit(make_run_policy(vision_runner, policy_runner, vision_metadata['output_slices']['hidden_state'], args.frame_skip), prune=True)
|
||||
|
||||
out['metadata']['vision'], out['metadata']['policy'] = vision_metadata, policy_metadata
|
||||
|
||||
make_random_model_inputs = partial(make_random_images, keys=['img', 'big_img'], shape=vision_metadata['input_shapes']['img'])
|
||||
out['run_policy'] = compile_jit(run_policy_jit, make_random_model_inputs, POLICY_INPUTS, args.frame_skip, vision_metadata, policy_metadata)
|
||||
|
||||
for cam_w, cam_h in args.camera_resolutions:
|
||||
nv12 = NV12Frame(cam_w, cam_h, *get_nv12_info(cam_w, cam_h))
|
||||
model_w, model_h = args.model_size
|
||||
out[(cam_w,cam_h)] = {
|
||||
name: compile_modeld(nv12, model_w, model_h, prepare_only, args.frame_skip,
|
||||
vision_runner, policy_runner, out['metadata']['vision'], out['metadata']['policy'])
|
||||
for name, prepare_only in [('warp_enqueue', True), ('run_policy', False)]
|
||||
}
|
||||
make_random_warp_inputs = partial(make_random_images, keys=['frame', 'big_frame'], shape=nv12.size, device=WARP_DEV)
|
||||
warp_enqueue = TinyJit(make_warp(nv12, model_w, model_h, args.frame_skip), prune=True)
|
||||
out[(cam_w,cam_h)] = compile_jit(warp_enqueue, make_random_warp_inputs, WARP_INPUTS, args.frame_skip, vision_metadata, policy_metadata)
|
||||
|
||||
with open(args.output, "wb") as f:
|
||||
pickle.dump(out, f)
|
||||
print(f"Saved combined JIT to {args.output} ({os.path.getsize(args.output) / 1e6:.2f} MB)")
|
||||
print(f"Saved JITs to {args.output} ({os.path.getsize(args.output) / 1e6:.2f} MB)")
|
||||
|
||||
@@ -20,7 +20,7 @@ from openpilot.common.transformations.model import get_warp_matrix
|
||||
from openpilot.selfdrive.controls.lib.desire_helper import DesireHelper
|
||||
from openpilot.selfdrive.controls.lib.drive_helpers import get_accel_from_plan, smooth_value, get_curvature_from_plan
|
||||
from openpilot.selfdrive.modeld.parse_model_outputs import Parser
|
||||
from openpilot.selfdrive.modeld.compile_modeld import make_input_queues
|
||||
from openpilot.selfdrive.modeld.compile_modeld import make_input_queues, WARP_INPUTS, POLICY_INPUTS
|
||||
from openpilot.selfdrive.modeld.fill_model_msg import fill_model_msg, fill_pose_msg, PublishState
|
||||
from openpilot.common.file_chunker import read_file_chunked, get_manifest_path
|
||||
from openpilot.selfdrive.modeld.constants import ModelConstants, Plan
|
||||
@@ -93,12 +93,8 @@ class ModelState:
|
||||
self._blob_cache : dict[int, Tensor] = {}
|
||||
self.parser = Parser()
|
||||
self.frame_buf_params = {k: get_nv12_info(cam_w, cam_h) for k in ('img', 'big_img')}
|
||||
self.run_policy = jits[(cam_w,cam_h)]['run_policy']
|
||||
self.warp_enqueue = jits[(cam_w,cam_h)]['warp_enqueue']
|
||||
self.warp_enqueue(
|
||||
**self.input_queues,
|
||||
frame=Tensor(np.zeros(self.frame_buf_params['img'][3], dtype=np.uint8), device=self.WARP_DEV).contiguous().realize(),
|
||||
big_frame=Tensor(np.zeros(self.frame_buf_params['big_img'][3], dtype=np.uint8), device=self.WARP_DEV).contiguous().realize())
|
||||
self.run_policy = jits['run_policy']
|
||||
self.warp_enqueue = jits[(cam_w,cam_h)]
|
||||
|
||||
def slice_outputs(self, model_outputs: np.ndarray, output_slices: dict[str, slice]) -> dict[str, np.ndarray]:
|
||||
parsed_model_outputs = {k: model_outputs[np.newaxis, v] for k,v in output_slices.items()}
|
||||
@@ -123,12 +119,13 @@ class ModelState:
|
||||
self.npy['tfm'][:,:] = transforms['img'][:,:]
|
||||
self.npy['big_tfm'][:,:] = transforms['big_img'][:,:]
|
||||
|
||||
img, big_img = self.warp_enqueue(**{k: self.input_queues[k] for k in WARP_INPUTS}, frame=self.full_frames['img'], big_frame=self.full_frames['big_img'])
|
||||
|
||||
if prepare_only:
|
||||
self.warp_enqueue(**self.input_queues, frame=self.full_frames['img'], big_frame=self.full_frames['big_img'])
|
||||
return None
|
||||
|
||||
vision_output, policy_output = self.run_policy(
|
||||
**self.input_queues, frame=self.full_frames['img'], big_frame=self.full_frames['big_img']
|
||||
**{k: self.input_queues[k] for k in POLICY_INPUTS}, img=img, big_img=big_img
|
||||
)
|
||||
|
||||
vision_output = vision_output.numpy().flatten()
|
||||
|
||||
Reference in New Issue
Block a user