diff --git a/extra/jit.py b/extra/jit.py index ec5163a8c5..9cebf0f567 100644 --- a/extra/jit.py +++ b/extra/jit.py @@ -1,5 +1,6 @@ from typing import Callable, List, Tuple import itertools +from tinygrad.lazy import Device from tinygrad.tensor import Tensor from tinygrad.ops import DEBUG, GlobalCounters @@ -12,20 +13,24 @@ class TinyJit: self.input_replace = {} def __call__(self, *args, **kwargs): + if Device.DEFAULT != "GPU": return self.fxn(*args, **kwargs) # only jit on the GPU input_tensors = {k:v.realize().lazydata.realized._buf for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)} + assert len(input_tensors) != 0, "no inputs to JIT" if self.cnt >= 2: for a,idx in self.input_replace.items(): a._buf = input_tensors[idx] for prg, args in self.jit_cache: prg(*args) - else: - if self.cnt == 1: GlobalCounters.cache = [] - self.ret = self.fxn(*args, **kwargs).realize() - if self.cnt == 1: - self.jit_cache = GlobalCounters.cache - GlobalCounters.cache = None + elif self.cnt == 1: + GlobalCounters.cache = [] + self.ret = self.fxn(*args, **kwargs) + self.jit_cache = GlobalCounters.cache + GlobalCounters.cache = None + assert len(self.jit_cache) != 0, "didn't JIT anything!" - # get the inputs for replacement - for prg, args in self.jit_cache: # pylint: disable=E1133 - self.input_replace.update({a:[k for k,v in input_tensors.items() if v == a._buf][0] for a in args if a._buf in input_tensors.values()}) - assert set(self.input_replace.values()) == set(input_tensors.keys()), "some input tensors not found" + # get the inputs for replacement + for prg, args in self.jit_cache: # pylint: disable=E1133 + self.input_replace.update({a:[k for k,v in input_tensors.items() if v == a._buf][0] for a in args if a._buf in input_tensors.values()}) + assert set(self.input_replace.values()) == set(input_tensors.keys()), "some input tensors not found" + elif self.cnt == 0: + self.ret = self.fxn(*args, **kwargs) self.cnt += 1 return self.ret diff --git a/openpilot/compile.py b/openpilot/compile.py index fabcbb61b0..43463a86e1 100644 --- a/openpilot/compile.py +++ b/openpilot/compile.py @@ -63,7 +63,7 @@ def model_exec(run_onnx, using_graph, **inputs): ret = next(iter(run_onnx(inputs).values())) GlobalCounters.cache = [] # don't cache pre-realize if using_graph: graph.GRAPH = True - return ret + return ret.realize() def compile(dat, output_fn): Tensor.no_grad = True diff --git a/test/test_jit.py b/test/test_jit.py index 7d809ad834..04d25a3a47 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -8,7 +8,7 @@ from extra.jit import TinyJit class TestJit(unittest.TestCase): def test_simple_jit(self): @TinyJit - def add(a, b): return a+b + def add(a, b): return (a+b).realize() for _ in range(3): a = Tensor.randn(10, 10) b = Tensor.randn(10, 10) @@ -17,7 +17,7 @@ class TestJit(unittest.TestCase): def test_kwargs_jit(self): @TinyJit - def add_kwargs(first, second): return first+second + def add_kwargs(first, second): return (first+second).realize() for _ in range(3): a = Tensor.randn(10, 10) b = Tensor.randn(10, 10) @@ -26,12 +26,12 @@ class TestJit(unittest.TestCase): def test_array_jit(self): @TinyJit - def add_array(arr): return arr[0]+arr[1] + def add_array(a, arr): return (a+arr[0]).realize() for i in range(3): a = Tensor.randn(10, 10) b = Tensor.randn(10, 10) a.realize(), b.realize() - c = add_array([a,b]) + c = add_array(a, [b]) if i == 2: # should fail once jitted since jit can't handle arrays np.testing.assert_equal(np.any(np.not_equal(c.numpy(),a.numpy()+b.numpy())), True) diff --git a/test/test_speed_v_torch.py b/test/test_speed_v_torch.py index b949831cae..0b319c632c 100644 --- a/test/test_speed_v_torch.py +++ b/test/test_speed_v_torch.py @@ -11,6 +11,7 @@ from tinygrad.ops import GlobalCounters from tinygrad.tensor import Tensor from tinygrad.nn import Conv2d from tinygrad.helpers import colored, getenv +from extra.jit import TinyJit try: from tinygrad.runtime.opencl import CL except ImportError: @@ -39,6 +40,7 @@ def helper_test_speed(f1, *args): del ret GlobalCounters.global_ops = 0 GlobalCounters.global_mem = 0 + args = [(x+1).realize() if isinstance(x,Tensor) else (None if x is None else (x+1)) for x in args] # cache defeats st = time.monotonic() ret = f1(*args) if CL is not None and ret.device in ["GPU"]: @@ -52,22 +54,22 @@ def helper_test_speed(f1, *args): save_ops, save_mem = GlobalCounters.global_ops, GlobalCounters.global_mem return ret.cpu().numpy(), np.min(ets) -def helper_test_generic_square(name, N, f1, f2): +def helper_test_generic_square(name, N, f1, f2, onearg=False): torch.manual_seed(0) torch_a = (torch.rand(N, N) - 0.5).to(torch_device) - torch_b = (torch.rand(N, N) - 0.5).to(torch_device) + torch_b = (torch.rand(N, N) - 0.5).to(torch_device) if not onearg else None tiny_a = Tensor(torch_a.cpu().numpy()) - tiny_b = Tensor(torch_b.cpu().numpy()) + tiny_b = Tensor(torch_b.cpu().numpy()) if not onearg else None - helper_test_generic(f"{name:30s} {N:4d}x{N:4d}", partial(f1, torch_a, torch_b), partial(f2, tiny_a, tiny_b)) + helper_test_generic(f"{name:30s} {N:4d}x{N:4d}", f1, (torch_a, torch_b), TinyJit(lambda a,b:f2(a,b).realize()), (tiny_a, tiny_b)) prefix = None -def helper_test_generic(name, f1, f2): +def helper_test_generic(name, f1, f1_args, f2, f2_args): global prefix with torch.no_grad(): - val_torch, et_torch = helper_test_speed(f1) - val_tinygrad, et_tinygrad = helper_test_speed(lambda: f2().realize()) + val_torch, et_torch = helper_test_speed(f1, *f1_args) + val_tinygrad, et_tinygrad = helper_test_speed(f2, *f2_args) desc = "faster" if et_torch > et_tinygrad else "slower" flops = save_ops*1e-6 @@ -92,24 +94,24 @@ class TestSpeed(unittest.TestCase): def test_sum(self): def f(a, b): return a.sum() - helper_test_generic_square('sum', 4096, f, f) + helper_test_generic_square('sum', 4096, f, f, onearg=True) def test_partial_sum(self): R = 256 def f(a, b): return a.reshape(int(4096//R), int(4096*R)).sum(axis=1) - helper_test_generic_square('partial_sum', 4096, f, f) + helper_test_generic_square('partial_sum', 4096, f, f, onearg=True) def test_array_packing(self): N = 2048 def f(a, b): return a.reshape(N, N // 32, 32).permute(1,0,2).contiguous() - helper_test_generic_square('array_packing', N, f, f) + helper_test_generic_square('array_packing', N, f, f, onearg=True) def test_permute(self): for N in [1024, 4096]: # this is a 64MB tensor, M1 L1 cache is 128kB # to fit easily in L1, rotations should be 128x128 chunks. 128x128 is also the AMX size def f(a, b): return a.permute(1,0).contiguous() - helper_test_generic_square('permute', N, f, f) + helper_test_generic_square('permute', N, f, f, onearg=True) def test_double_permute(self): N = 64 @@ -117,23 +119,23 @@ class TestSpeed(unittest.TestCase): torch_a = (torch.rand(N, N, N, N) - 0.5).to(torch_device) tiny_a = Tensor(torch_a.cpu().numpy()) def f(a): return a.permute(1,0,3,2).contiguous() - helper_test_generic(f"double_permute {tiny_a.shape}", partial(f, torch_a), partial(f, tiny_a)) + helper_test_generic(f"double_permute {tiny_a.shape}", f, (torch_a,), TinyJit(lambda a: f(a).realize()), (tiny_a,)) def test_neg(self): def f(a, b): return -a - helper_test_generic_square('neg', 4096, f, f) + helper_test_generic_square('neg', 4096, f, f, onearg=True) def test_exp(self): def f(a, b): return a.exp() - helper_test_generic_square('exp', 2048, f, f) + helper_test_generic_square('exp', 2048, f, f, onearg=True) def test_relu(self): def f(a, b): return a.relu() - helper_test_generic_square('relu', 4096, f, f) + helper_test_generic_square('relu', 4096, f, f, onearg=True) def test_max(self): def f(a, b): return a.max() - helper_test_generic_square('max', 4096, f, f) + helper_test_generic_square('max', 4096, f, f, onearg=True) def test_mul_sum(self): def f(a, b): return (a*b).sum() @@ -146,11 +148,11 @@ class TestSpeed(unittest.TestCase): def test_add_constant(self): def f(a, b): return a+2.0 - helper_test_generic_square('add_constant', 4096, f, f) + helper_test_generic_square('add_constant', 4096, f, f, onearg=True) def test_add_constant_zero(self): def f(a, b): return a+0.0 - helper_test_generic_square('add_constant_zero', 4096, f, f) + helper_test_generic_square('add_constant_zero', 4096, f, f, onearg=True) def test_add_sq(self): def f(a, b): return a*a + b*b @@ -194,9 +196,9 @@ class TestSpeed(unittest.TestCase): tiny_conv = Conv2d(in_chans, out_chans, 3, bias=None, padding=1) tiny_conv.weight = Tensor(torch_conv.weight.detach().cpu().numpy()) - def f1(): return torch_conv(torch_dat.permute(0,3,1,2)) - def f2(): return tiny_conv(tiny_dat.permute(0,3,1,2)).realize() - helper_test_generic(f"conv bs:{bs:3d} chans:{in_chans:3d} -> {out_chans:3d}", f1, f2) + def f1(torch_dat): return torch_conv(torch_dat.permute(0,3,1,2)) + def f2(tiny_dat): return tiny_conv(tiny_dat.permute(0,3,1,2)).realize() + helper_test_generic(f"conv bs:{bs:3d} chans:{in_chans:3d} -> {out_chans:3d}", f1, (torch_dat,), TinyJit(f2), (tiny_dat,)) def test_conv2d(self): torch.manual_seed(0) @@ -211,9 +213,9 @@ class TestSpeed(unittest.TestCase): tiny_conv = Conv2d(in_chans, out_chans, 3, bias=None) tiny_conv.weight = Tensor(torch_conv.weight.detach().cpu().numpy()) - def f1(): return torch_conv(torch_dat) - def f2(): return tiny_conv(tiny_dat).realize() - helper_test_generic(f"conv bs:{bs:3d} chans:{in_chans:3d} -> {out_chans:3d}", f1, f2) + def f1(torch_dat): return torch_conv(torch_dat) + def f2(tiny_dat): return tiny_conv(tiny_dat).realize() + helper_test_generic(f"conv bs:{bs:3d} chans:{in_chans:3d} -> {out_chans:3d}", f1, (torch_dat,), TinyJit(f2), (tiny_dat,)) if __name__ == '__main__': unittest.main()