mirror of
https://github.com/firestar5683/StarPilot.git
synced 2026-06-08 08:34:49 +08:00
so that was a lie
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
import os, sys, pickle, time, re
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
if "FLOAT16" not in os.environ: os.environ["FLOAT16"] = "1"
|
||||
if "IMAGE" not in os.environ: os.environ["IMAGE"] = "2"
|
||||
@@ -132,8 +131,7 @@ def bench(run, inputs):
|
||||
run(**inputs).numpy()
|
||||
|
||||
if __name__ == "__main__":
|
||||
local_onnx = Path(OPENPILOT_MODEL).expanduser()
|
||||
onnx_file = str(local_onnx if local_onnx.exists() else fetch(OPENPILOT_MODEL))
|
||||
onnx_file = fetch(OPENPILOT_MODEL)
|
||||
inputs, outputs = compile(onnx_file)
|
||||
|
||||
with open(OUTPUT, "rb") as f: pickle_loaded = pickle.load(f)
|
||||
|
||||
@@ -166,48 +166,6 @@ def update_depends(depends:set[Buffer|None], jit_cache:list[ExecItem]):
|
||||
for ei in jit_cache:
|
||||
if any(b in depends for b in ei.bufs): depends.update(get_out_buffers_for_ei(ei))
|
||||
|
||||
def get_ret_tensors(ret:Any) -> list[Tensor]:
|
||||
if isinstance(ret, Tensor): return [ret]
|
||||
if isinstance(ret, (tuple, list)): return flatten([get_ret_tensors(x) for x in ret])
|
||||
if isinstance(ret, dict): return flatten([get_ret_tensors(x) for x in ret.values()])
|
||||
return []
|
||||
|
||||
def get_jit_outs(jit_cache:list[ExecItem]) -> list[Buffer]:
|
||||
return flatten([get_out_buffers_for_ei(ei) for ei in jit_cache])
|
||||
|
||||
def get_ret_output_map(ret:Any, jit_cache:list[ExecItem]) -> list[int|None]:
|
||||
output_map = {id(buf): idx for idx, buf in enumerate(get_jit_outs(jit_cache))}
|
||||
ret_output_map: list[int|None] = []
|
||||
for t in get_ret_tensors(ret):
|
||||
realized = t.uop.base.realized
|
||||
ret_output_map.append(output_map.get(id(realized)) if realized is not None else None)
|
||||
return ret_output_map
|
||||
|
||||
def get_ret_spec(ret:Any, jit_cache:list[ExecItem]) -> Any:
|
||||
output_map = {id(buf): idx for idx, buf in enumerate(get_jit_outs(jit_cache))}
|
||||
if isinstance(ret, Tensor):
|
||||
realized = ret.uop.base.realized
|
||||
if realized is not None and (out_idx:=output_map.get(id(realized))) is not None:
|
||||
return ("tensor", out_idx, ret.uop, ret.requires_grad)
|
||||
return ("value", ret)
|
||||
if isinstance(ret, tuple): return ("tuple", tuple(get_ret_spec(x, jit_cache) for x in ret))
|
||||
if isinstance(ret, list): return ("list", [get_ret_spec(x, jit_cache) for x in ret])
|
||||
if isinstance(ret, dict): return ("dict", [(k, get_ret_spec(v, jit_cache)) for k, v in ret.items()])
|
||||
return ("value", ret)
|
||||
|
||||
def rebuild_ret_from_spec(spec:Any, jit_outs:list[Buffer]) -> Any:
|
||||
tag, payload = spec[0], spec[1:]
|
||||
if tag == "tensor":
|
||||
out_idx, template_uop, requires_grad = payload
|
||||
target_buf = jit_outs[out_idx]
|
||||
buf_uop = UOp.new_buffer(target_buf.device, target_buf.size, template_uop.base.dtype)
|
||||
bound = UOp(buf_uop.op, buf_uop.dtype, buf_uop.src, buf_uop.arg, buf_uop.tag, _buffer=target_buf)
|
||||
return Tensor(template_uop.substitute({template_uop.base: bound}, name="rebuild captured jit ret"), requires_grad=requires_grad)
|
||||
if tag == "tuple": return tuple(rebuild_ret_from_spec(x, jit_outs) for x in payload[0])
|
||||
if tag == "list": return [rebuild_ret_from_spec(x, jit_outs) for x in payload[0]]
|
||||
if tag == "dict": return {k: rebuild_ret_from_spec(v, jit_outs) for k, v in payload[0]}
|
||||
return payload[0]
|
||||
|
||||
ReturnType = TypeVar('ReturnType')
|
||||
@dataclass
|
||||
class CapturedJit(Generic[ReturnType]):
|
||||
@@ -217,14 +175,10 @@ class CapturedJit(Generic[ReturnType]):
|
||||
extra_view_inputs: list[tuple[int, int, str, int, DType]]
|
||||
expected_names: list[int|str]
|
||||
expected_input_info: list[tuple[UOp, tuple[Variable, ...], DType, str]] # (view, variables, dtype, device) per input
|
||||
ret_output_map: list[int|None]|None = None
|
||||
ret_spec: Any = None
|
||||
|
||||
def __reduce__(self):
|
||||
# TODO: free_intermediates here? replan_buffers_memory_layout here?
|
||||
return self.__class__, (
|
||||
self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs, self.expected_names, self.expected_input_info, self.ret_output_map, self.ret_spec,
|
||||
)
|
||||
return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs, self.expected_names, self.expected_input_info)
|
||||
|
||||
def __post_init__(self):
|
||||
self._jit_cache: list[ExecItem] = self.jit_cache
|
||||
@@ -235,18 +189,6 @@ class CapturedJit(Generic[ReturnType]):
|
||||
self._input_to_max_reader: dict[int, int] = {}
|
||||
for (j, _), idx in self.input_replace.items(): self._input_to_max_reader[idx] = max(self._input_to_max_reader.get(idx, -1), j)
|
||||
self._clear_inputs()
|
||||
self._rebind_ret_outputs()
|
||||
|
||||
def _rebind_ret_outputs(self):
|
||||
if self.ret_output_map is None: return
|
||||
jit_outs = get_jit_outs(self.jit_cache)
|
||||
for t, out_idx in zip(get_ret_tensors(self.ret), self.ret_output_map):
|
||||
if out_idx is None or out_idx >= len(jit_outs) or t.uop.base.op is not Ops.BUFFER: continue
|
||||
target_buf = jit_outs[out_idx]
|
||||
if t.uop.base.realized is target_buf: continue
|
||||
buf_uop = UOp.new_buffer(target_buf.device, target_buf.size, t.uop.base.dtype)
|
||||
new_base = UOp(buf_uop.op, buf_uop.dtype, buf_uop.src, buf_uop.arg, buf_uop.tag, _buffer=target_buf)
|
||||
t.uop = t.uop.substitute({t.uop.base: new_base}, name="rebind captured jit outputs")
|
||||
|
||||
def _clear_inputs(self):
|
||||
for (j,i) in self._input_replace.keys(): self._jit_cache[j].bufs[i] = None
|
||||
@@ -302,7 +244,7 @@ class CapturedJit(Generic[ReturnType]):
|
||||
if DEBUG >= 1 and len(self._jit_cache) >= 10: print(f"jit execs {len(self._jit_cache)} kernels")
|
||||
for ei in self._jit_cache: ei.run(var_vals, jit=True)
|
||||
self._clear_inputs()
|
||||
return rebuild_ret_from_spec(self.ret_spec, get_jit_outs(self.jit_cache)) if self.ret_spec is not None else self.ret
|
||||
return self.ret
|
||||
|
||||
def _prepare_jit_inputs(args, kwargs):
|
||||
input_tensors: list[tuple[int|str, Tensor]] = [(name,t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if t.__class__ is Tensor]
|
||||
@@ -419,8 +361,7 @@ class TinyJit(Generic[ReturnType]):
|
||||
if DEBUG >= 1 and len(set(input_replace.values())) != len(input_buffers): print("WARNING: some input tensors not found")
|
||||
|
||||
# set this for next run
|
||||
self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, expected_input_info,
|
||||
get_ret_output_map(ret, jit_cache), get_ret_spec(ret, jit_cache))
|
||||
self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, expected_input_info)
|
||||
if self.optimize: self.captured.replan_buffers_memory_layout()
|
||||
elif self.cnt >= 2:
|
||||
# jit exec
|
||||
|
||||
Reference in New Issue
Block a user