so that was a lie

This commit is contained in:
firestar5683
2026-06-05 01:39:46 -05:00
parent 038b83ac4a
commit fee42e4d7b
2 changed files with 4 additions and 65 deletions

View File

@@ -1,5 +1,4 @@
import os, sys, pickle, time, re
from pathlib import Path
import numpy as np
if "FLOAT16" not in os.environ: os.environ["FLOAT16"] = "1"
if "IMAGE" not in os.environ: os.environ["IMAGE"] = "2"
@@ -132,8 +131,7 @@ def bench(run, inputs):
run(**inputs).numpy()
if __name__ == "__main__":
local_onnx = Path(OPENPILOT_MODEL).expanduser()
onnx_file = str(local_onnx if local_onnx.exists() else fetch(OPENPILOT_MODEL))
onnx_file = fetch(OPENPILOT_MODEL)
inputs, outputs = compile(onnx_file)
with open(OUTPUT, "rb") as f: pickle_loaded = pickle.load(f)

View File

@@ -166,48 +166,6 @@ def update_depends(depends:set[Buffer|None], jit_cache:list[ExecItem]):
for ei in jit_cache:
if any(b in depends for b in ei.bufs): depends.update(get_out_buffers_for_ei(ei))
def get_ret_tensors(ret:Any) -> list[Tensor]:
if isinstance(ret, Tensor): return [ret]
if isinstance(ret, (tuple, list)): return flatten([get_ret_tensors(x) for x in ret])
if isinstance(ret, dict): return flatten([get_ret_tensors(x) for x in ret.values()])
return []
def get_jit_outs(jit_cache:list[ExecItem]) -> list[Buffer]:
return flatten([get_out_buffers_for_ei(ei) for ei in jit_cache])
def get_ret_output_map(ret:Any, jit_cache:list[ExecItem]) -> list[int|None]:
output_map = {id(buf): idx for idx, buf in enumerate(get_jit_outs(jit_cache))}
ret_output_map: list[int|None] = []
for t in get_ret_tensors(ret):
realized = t.uop.base.realized
ret_output_map.append(output_map.get(id(realized)) if realized is not None else None)
return ret_output_map
def get_ret_spec(ret:Any, jit_cache:list[ExecItem]) -> Any:
output_map = {id(buf): idx for idx, buf in enumerate(get_jit_outs(jit_cache))}
if isinstance(ret, Tensor):
realized = ret.uop.base.realized
if realized is not None and (out_idx:=output_map.get(id(realized))) is not None:
return ("tensor", out_idx, ret.uop, ret.requires_grad)
return ("value", ret)
if isinstance(ret, tuple): return ("tuple", tuple(get_ret_spec(x, jit_cache) for x in ret))
if isinstance(ret, list): return ("list", [get_ret_spec(x, jit_cache) for x in ret])
if isinstance(ret, dict): return ("dict", [(k, get_ret_spec(v, jit_cache)) for k, v in ret.items()])
return ("value", ret)
def rebuild_ret_from_spec(spec:Any, jit_outs:list[Buffer]) -> Any:
tag, payload = spec[0], spec[1:]
if tag == "tensor":
out_idx, template_uop, requires_grad = payload
target_buf = jit_outs[out_idx]
buf_uop = UOp.new_buffer(target_buf.device, target_buf.size, template_uop.base.dtype)
bound = UOp(buf_uop.op, buf_uop.dtype, buf_uop.src, buf_uop.arg, buf_uop.tag, _buffer=target_buf)
return Tensor(template_uop.substitute({template_uop.base: bound}, name="rebuild captured jit ret"), requires_grad=requires_grad)
if tag == "tuple": return tuple(rebuild_ret_from_spec(x, jit_outs) for x in payload[0])
if tag == "list": return [rebuild_ret_from_spec(x, jit_outs) for x in payload[0]]
if tag == "dict": return {k: rebuild_ret_from_spec(v, jit_outs) for k, v in payload[0]}
return payload[0]
ReturnType = TypeVar('ReturnType')
@dataclass
class CapturedJit(Generic[ReturnType]):
@@ -217,14 +175,10 @@ class CapturedJit(Generic[ReturnType]):
extra_view_inputs: list[tuple[int, int, str, int, DType]]
expected_names: list[int|str]
expected_input_info: list[tuple[UOp, tuple[Variable, ...], DType, str]] # (view, variables, dtype, device) per input
ret_output_map: list[int|None]|None = None
ret_spec: Any = None
def __reduce__(self):
# TODO: free_intermediates here? replan_buffers_memory_layout here?
return self.__class__, (
self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs, self.expected_names, self.expected_input_info, self.ret_output_map, self.ret_spec,
)
return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs, self.expected_names, self.expected_input_info)
def __post_init__(self):
self._jit_cache: list[ExecItem] = self.jit_cache
@@ -235,18 +189,6 @@ class CapturedJit(Generic[ReturnType]):
self._input_to_max_reader: dict[int, int] = {}
for (j, _), idx in self.input_replace.items(): self._input_to_max_reader[idx] = max(self._input_to_max_reader.get(idx, -1), j)
self._clear_inputs()
self._rebind_ret_outputs()
def _rebind_ret_outputs(self):
if self.ret_output_map is None: return
jit_outs = get_jit_outs(self.jit_cache)
for t, out_idx in zip(get_ret_tensors(self.ret), self.ret_output_map):
if out_idx is None or out_idx >= len(jit_outs) or t.uop.base.op is not Ops.BUFFER: continue
target_buf = jit_outs[out_idx]
if t.uop.base.realized is target_buf: continue
buf_uop = UOp.new_buffer(target_buf.device, target_buf.size, t.uop.base.dtype)
new_base = UOp(buf_uop.op, buf_uop.dtype, buf_uop.src, buf_uop.arg, buf_uop.tag, _buffer=target_buf)
t.uop = t.uop.substitute({t.uop.base: new_base}, name="rebind captured jit outputs")
def _clear_inputs(self):
for (j,i) in self._input_replace.keys(): self._jit_cache[j].bufs[i] = None
@@ -302,7 +244,7 @@ class CapturedJit(Generic[ReturnType]):
if DEBUG >= 1 and len(self._jit_cache) >= 10: print(f"jit execs {len(self._jit_cache)} kernels")
for ei in self._jit_cache: ei.run(var_vals, jit=True)
self._clear_inputs()
return rebuild_ret_from_spec(self.ret_spec, get_jit_outs(self.jit_cache)) if self.ret_spec is not None else self.ret
return self.ret
def _prepare_jit_inputs(args, kwargs):
input_tensors: list[tuple[int|str, Tensor]] = [(name,t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if t.__class__ is Tensor]
@@ -419,8 +361,7 @@ class TinyJit(Generic[ReturnType]):
if DEBUG >= 1 and len(set(input_replace.values())) != len(input_buffers): print("WARNING: some input tensors not found")
# set this for next run
self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, expected_input_info,
get_ret_output_map(ret, jit_cache), get_ret_spec(ret, jit_cache))
self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, expected_input_info)
if self.optimize: self.captured.replan_buffers_memory_layout()
elif self.cnt >= 2:
# jit exec