diff --git a/tinygrad_repo/examples/openpilot/compile3.py b/tinygrad_repo/examples/openpilot/compile3.py index 3955c7292..109bf8a3f 100644 --- a/tinygrad_repo/examples/openpilot/compile3.py +++ b/tinygrad_repo/examples/openpilot/compile3.py @@ -1,5 +1,4 @@ import os, sys, pickle, time, re -from pathlib import Path import numpy as np if "FLOAT16" not in os.environ: os.environ["FLOAT16"] = "1" if "IMAGE" not in os.environ: os.environ["IMAGE"] = "2" @@ -132,8 +131,7 @@ def bench(run, inputs): run(**inputs).numpy() if __name__ == "__main__": - local_onnx = Path(OPENPILOT_MODEL).expanduser() - onnx_file = str(local_onnx if local_onnx.exists() else fetch(OPENPILOT_MODEL)) + onnx_file = fetch(OPENPILOT_MODEL) inputs, outputs = compile(onnx_file) with open(OUTPUT, "rb") as f: pickle_loaded = pickle.load(f) diff --git a/tinygrad_repo/tinygrad/engine/jit.py b/tinygrad_repo/tinygrad/engine/jit.py index 2103c87f6..79fe034d3 100644 --- a/tinygrad_repo/tinygrad/engine/jit.py +++ b/tinygrad_repo/tinygrad/engine/jit.py @@ -166,48 +166,6 @@ def update_depends(depends:set[Buffer|None], jit_cache:list[ExecItem]): for ei in jit_cache: if any(b in depends for b in ei.bufs): depends.update(get_out_buffers_for_ei(ei)) -def get_ret_tensors(ret:Any) -> list[Tensor]: - if isinstance(ret, Tensor): return [ret] - if isinstance(ret, (tuple, list)): return flatten([get_ret_tensors(x) for x in ret]) - if isinstance(ret, dict): return flatten([get_ret_tensors(x) for x in ret.values()]) - return [] - -def get_jit_outs(jit_cache:list[ExecItem]) -> list[Buffer]: - return flatten([get_out_buffers_for_ei(ei) for ei in jit_cache]) - -def get_ret_output_map(ret:Any, jit_cache:list[ExecItem]) -> list[int|None]: - output_map = {id(buf): idx for idx, buf in enumerate(get_jit_outs(jit_cache))} - ret_output_map: list[int|None] = [] - for t in get_ret_tensors(ret): - realized = t.uop.base.realized - ret_output_map.append(output_map.get(id(realized)) if realized is not None else None) - return ret_output_map - -def get_ret_spec(ret:Any, jit_cache:list[ExecItem]) -> Any: - output_map = {id(buf): idx for idx, buf in enumerate(get_jit_outs(jit_cache))} - if isinstance(ret, Tensor): - realized = ret.uop.base.realized - if realized is not None and (out_idx:=output_map.get(id(realized))) is not None: - return ("tensor", out_idx, ret.uop, ret.requires_grad) - return ("value", ret) - if isinstance(ret, tuple): return ("tuple", tuple(get_ret_spec(x, jit_cache) for x in ret)) - if isinstance(ret, list): return ("list", [get_ret_spec(x, jit_cache) for x in ret]) - if isinstance(ret, dict): return ("dict", [(k, get_ret_spec(v, jit_cache)) for k, v in ret.items()]) - return ("value", ret) - -def rebuild_ret_from_spec(spec:Any, jit_outs:list[Buffer]) -> Any: - tag, payload = spec[0], spec[1:] - if tag == "tensor": - out_idx, template_uop, requires_grad = payload - target_buf = jit_outs[out_idx] - buf_uop = UOp.new_buffer(target_buf.device, target_buf.size, template_uop.base.dtype) - bound = UOp(buf_uop.op, buf_uop.dtype, buf_uop.src, buf_uop.arg, buf_uop.tag, _buffer=target_buf) - return Tensor(template_uop.substitute({template_uop.base: bound}, name="rebuild captured jit ret"), requires_grad=requires_grad) - if tag == "tuple": return tuple(rebuild_ret_from_spec(x, jit_outs) for x in payload[0]) - if tag == "list": return [rebuild_ret_from_spec(x, jit_outs) for x in payload[0]] - if tag == "dict": return {k: rebuild_ret_from_spec(v, jit_outs) for k, v in payload[0]} - return payload[0] - ReturnType = TypeVar('ReturnType') @dataclass class CapturedJit(Generic[ReturnType]): @@ -217,14 +175,10 @@ class CapturedJit(Generic[ReturnType]): extra_view_inputs: list[tuple[int, int, str, int, DType]] expected_names: list[int|str] expected_input_info: list[tuple[UOp, tuple[Variable, ...], DType, str]] # (view, variables, dtype, device) per input - ret_output_map: list[int|None]|None = None - ret_spec: Any = None def __reduce__(self): # TODO: free_intermediates here? replan_buffers_memory_layout here? - return self.__class__, ( - self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs, self.expected_names, self.expected_input_info, self.ret_output_map, self.ret_spec, - ) + return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs, self.expected_names, self.expected_input_info) def __post_init__(self): self._jit_cache: list[ExecItem] = self.jit_cache @@ -235,18 +189,6 @@ class CapturedJit(Generic[ReturnType]): self._input_to_max_reader: dict[int, int] = {} for (j, _), idx in self.input_replace.items(): self._input_to_max_reader[idx] = max(self._input_to_max_reader.get(idx, -1), j) self._clear_inputs() - self._rebind_ret_outputs() - - def _rebind_ret_outputs(self): - if self.ret_output_map is None: return - jit_outs = get_jit_outs(self.jit_cache) - for t, out_idx in zip(get_ret_tensors(self.ret), self.ret_output_map): - if out_idx is None or out_idx >= len(jit_outs) or t.uop.base.op is not Ops.BUFFER: continue - target_buf = jit_outs[out_idx] - if t.uop.base.realized is target_buf: continue - buf_uop = UOp.new_buffer(target_buf.device, target_buf.size, t.uop.base.dtype) - new_base = UOp(buf_uop.op, buf_uop.dtype, buf_uop.src, buf_uop.arg, buf_uop.tag, _buffer=target_buf) - t.uop = t.uop.substitute({t.uop.base: new_base}, name="rebind captured jit outputs") def _clear_inputs(self): for (j,i) in self._input_replace.keys(): self._jit_cache[j].bufs[i] = None @@ -302,7 +244,7 @@ class CapturedJit(Generic[ReturnType]): if DEBUG >= 1 and len(self._jit_cache) >= 10: print(f"jit execs {len(self._jit_cache)} kernels") for ei in self._jit_cache: ei.run(var_vals, jit=True) self._clear_inputs() - return rebuild_ret_from_spec(self.ret_spec, get_jit_outs(self.jit_cache)) if self.ret_spec is not None else self.ret + return self.ret def _prepare_jit_inputs(args, kwargs): input_tensors: list[tuple[int|str, Tensor]] = [(name,t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if t.__class__ is Tensor] @@ -419,8 +361,7 @@ class TinyJit(Generic[ReturnType]): if DEBUG >= 1 and len(set(input_replace.values())) != len(input_buffers): print("WARNING: some input tensors not found") # set this for next run - self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, expected_input_info, - get_ret_output_map(ret, jit_cache), get_ret_spec(ret, jit_cache)) + self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, expected_input_info) if self.optimize: self.captured.replan_buffers_memory_layout() elif self.cnt >= 2: # jit exec