mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-10 06:56:13 +08:00
451 lines
16 KiB
Python
451 lines
16 KiB
Python
#!/usr/bin/env python
|
|
import unittest
|
|
import numpy as np
|
|
|
|
from test.helpers import assert_jit_cache_len, call_is_graph, not_support_multi_device, needs_second_gpu
|
|
from test.unit.test_jit import _simple_test
|
|
from tinygrad import Tensor, Variable, TinyJit, Device, dtypes
|
|
from tinygrad.engine.jit import graph_class
|
|
from tinygrad.helpers import JIT, DEV, GlobalCounters
|
|
from tinygrad.uop.ops import Ops
|
|
from tinygrad.renderer.isa.x86 import X86Renderer
|
|
|
|
class TestJit(unittest.TestCase):
|
|
def test_simple_jit(self):
|
|
@TinyJit
|
|
def add(a, b): return (a+b).realize()
|
|
_simple_test(add)
|
|
|
|
@unittest.skipUnless(Device.DEFAULT == "CPU", "core_id is a CPU runtimevar")
|
|
def test_hcq_core_id_runtimevar_merge(self):
|
|
N = 262144
|
|
@TinyJit
|
|
def f(x, st):
|
|
y = (x + 1).contiguous().realize()
|
|
z = x.shrink(((st, st + N),)).contiguous().realize()
|
|
return y, z
|
|
x = Tensor.arange(2*N).clone().realize()
|
|
for _ in range(3): y, z = f(x, Variable("a", 0, N).bind(0))
|
|
self.assertEqual(y.shape, (2*N,))
|
|
self.assertEqual(z.shape, (N,))
|
|
|
|
def test_jit_input_view(self):
|
|
@TinyJit
|
|
def f(x): return (x[2:5].contiguous() + 1).realize()
|
|
for i in range(5):
|
|
x = (Tensor.arange(10).float() + i * 10).clone().realize()
|
|
np.testing.assert_allclose(f(x).numpy(), x.numpy()[2:5] + 1)
|
|
|
|
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, X86Renderer), "estimates are wrong for x86")
|
|
def test_global_counters_jit(self):
|
|
@TinyJit
|
|
def f(a, b):
|
|
c = (a + b).realize()
|
|
d = (c * 2).realize()
|
|
return (d - a).realize()
|
|
a, b = Tensor.randn(64, 64).realize(), Tensor.randn(64, 64).realize()
|
|
for _ in range(4):
|
|
GlobalCounters.reset()
|
|
f(a, b)
|
|
Device[a.device].synchronize()
|
|
self.assertGreater(GlobalCounters.global_mem, 0)
|
|
self.assertGreater(GlobalCounters.global_ops, 0)
|
|
|
|
def test_jit_assign(self, dtype=dtypes.float32):
|
|
@TinyJit
|
|
def add(a):
|
|
a += 1
|
|
a.realize()
|
|
a = Tensor.zeros(1, dtype=dtype).contiguous().realize()
|
|
for _ in range(5): add(a)
|
|
self.assertEqual(a.item(), 5)
|
|
|
|
def test_jit_assign_int8(self): self.test_jit_assign(dtypes.int8)
|
|
|
|
def test_jit_copyin(self):
|
|
@TinyJit
|
|
def f(a):
|
|
return a + Tensor([1,2,3])
|
|
for _ in range(5):
|
|
b = Tensor.randn(3)
|
|
c = f(b)
|
|
np.testing.assert_allclose(c.numpy(), b.numpy()+[1,2,3], atol=1e-4, rtol=1e-5)
|
|
|
|
def test_jit_batch_split(self):
|
|
if Device[Device.DEFAULT].graph is None or JIT >= 2: raise unittest.SkipTest("only test graphs")
|
|
|
|
# Create long jit with 83 kernels.
|
|
def f(a, b, c, d, e):
|
|
for _ in range(80):
|
|
a = (a+b).realize()
|
|
y = (a*c).realize()
|
|
z = (y*d).realize()
|
|
w = (z*e)
|
|
return w.realize()
|
|
|
|
a = Tensor.randn(10, 10).realize()
|
|
b = Tensor.randn(10, 10).realize()
|
|
c = Tensor.randn(10, 10).realize()
|
|
d = Tensor.randn(10, 10).realize()
|
|
e = Tensor.randn(10, 10).realize()
|
|
|
|
jf = TinyJit(f)
|
|
prev = None
|
|
for _ in range(5):
|
|
o = jf(a, b, c, d, e).numpy()
|
|
if prev is not None: np.testing.assert_allclose(o, prev, atol=1e-4, rtol=1e-5)
|
|
prev = o
|
|
|
|
# Checking that 2 graphs are inited.
|
|
assert len(jf.captured.linear.src) == 2
|
|
for si in jf.captured.linear.src:
|
|
assert call_is_graph(si)
|
|
|
|
def test_jitted_clone(self):
|
|
def f(a): return a.clone().realize()
|
|
jf = TinyJit(f)
|
|
for _ in range(5):
|
|
a = Tensor.randn(10, 10, device=Device.DEFAULT).realize()
|
|
ja = jf(a)
|
|
np.testing.assert_allclose(a.numpy(), ja.numpy(), atol=1e-4, rtol=1e-5)
|
|
|
|
@needs_second_gpu
|
|
@unittest.skipIf(not_support_multi_device(), "no multi")
|
|
def test_jitted_transfers(self):
|
|
d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1"
|
|
|
|
def f(a, b):
|
|
x = a.to(d1)
|
|
y = b.to(d1)
|
|
return x.realize(), y.realize()
|
|
|
|
jf = TinyJit(f)
|
|
for _ in range(5):
|
|
a = Tensor.randn(10, 10, device=d0).realize()
|
|
b = Tensor.randn(10, 10, device=d0).realize()
|
|
xc, yc = jf(a, b)
|
|
np.testing.assert_allclose(a.numpy(), xc.numpy(), atol=1e-4, rtol=1e-5)
|
|
np.testing.assert_allclose(b.numpy(), yc.numpy(), atol=1e-4, rtol=1e-5)
|
|
|
|
def test_jit_several_devs(self):
|
|
d0, d1 = f"{Device.DEFAULT}:0", "CPU"
|
|
|
|
def f(a, b):
|
|
x = a.to(d0).realize()
|
|
y = b.to(d0).realize()
|
|
return x+y.realize(), x*y.realize()
|
|
|
|
jf = TinyJit(f)
|
|
for _ in range(5):
|
|
a = Tensor.randn(10, 10, device=d1).realize()
|
|
b = Tensor.randn(10, 10, device=d1).realize()
|
|
zc, wc = jf(a, b)
|
|
np.testing.assert_allclose((a.numpy()+b.numpy()), zc.numpy(), atol=1e-4, rtol=1e-5)
|
|
np.testing.assert_allclose((a.numpy()*b.numpy()), wc.numpy(), atol=1e-4, rtol=1e-5)
|
|
|
|
@needs_second_gpu
|
|
@unittest.skipIf(not_support_multi_device(), "no multi")
|
|
def test_jitted_view(self):
|
|
d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1"
|
|
|
|
def f(a):
|
|
x1 = a.sum(axis=(1,))
|
|
x = (x1 + 5).bitcast(dtypes.int32)
|
|
y = x.to(d1)
|
|
return y.realize()
|
|
|
|
jf = TinyJit(f)
|
|
for _ in range(5):
|
|
a = Tensor.randn(10, 1000, device=d0).realize()
|
|
xc = jf(a)
|
|
np.testing.assert_allclose((a.numpy().sum(axis=(1,)) + 5).view(np.int32), xc.numpy(), atol=1e-4, rtol=5e-5)
|
|
|
|
@unittest.skip("Pending multioutput implementation #3607")
|
|
class TestMultioutputJit(unittest.TestCase):
|
|
def _test(self, f):
|
|
for _ in range(5):
|
|
a, b = Tensor.randn(10, 10), Tensor.randn(10, 10)
|
|
out0, out1, out2 = f(a, b)
|
|
np.testing.assert_allclose(out0.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
|
|
np.testing.assert_allclose(out1.numpy(), a.numpy()-b.numpy(), atol=1e-4, rtol=1e-5)
|
|
np.testing.assert_allclose(out2.numpy(), a.numpy()*b.numpy(), atol=1e-4, rtol=1e-5)
|
|
|
|
def test_jit_multioutput_realize(self):
|
|
@TinyJit
|
|
def fxn(a, b): return (a+b).realize(), (a-b).realize(), (a*b).realize()
|
|
self._test(fxn)
|
|
assert_jit_cache_len(fxn, 3)
|
|
|
|
def test_jit_multioutput_norealize(self):
|
|
@TinyJit
|
|
def fxn(a, b): return a+b, a-b, a*b
|
|
self._test(fxn)
|
|
assert_jit_cache_len(fxn, 1)
|
|
|
|
def test_jit_multioutput_mix(self):
|
|
@TinyJit
|
|
def fxn(a, b): return a+b, a-b, (a*b).realize()
|
|
self._test(fxn)
|
|
assert_jit_cache_len(fxn, 2)
|
|
|
|
class TestCopyInsideJit(unittest.TestCase):
|
|
def test_copy_inside_jit(self):
|
|
@TinyJit
|
|
def add(x,y) -> Tensor: return x.to(Device.DEFAULT)+y
|
|
for _ in range(5):
|
|
# create a Tensor on CPU
|
|
a = Tensor.rand(16,16,device="CPU").realize()
|
|
b = Tensor.rand(16,16).realize()
|
|
out = add(a,b)
|
|
np.testing.assert_allclose(out.flatten().tolist(), [x+y for x,y in zip(a.flatten().tolist(), b.flatten().tolist())])
|
|
|
|
class TestJitPrune(unittest.TestCase):
|
|
def test_prune_w_copy_correct(self):
|
|
weights = Tensor.rand(16).realize()
|
|
def w2(x) -> Tensor: return (weights*2).contiguous() + x.to(Device.DEFAULT)
|
|
w2_noprune = TinyJit(w2)
|
|
w2_prune = TinyJit(w2, prune=True)
|
|
|
|
for _ in range(3):
|
|
a = Tensor.rand(16, device="CPU").realize()
|
|
out = w2_noprune(a)
|
|
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
|
|
|
|
for _ in range(3):
|
|
a = Tensor.rand(16, device="CPU").realize()
|
|
out = w2_prune(a)
|
|
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
|
|
|
|
def test_prune_w_independent_copy_correct(self):
|
|
weights = Tensor.rand(16, device="CPU").realize()
|
|
def w2(x) -> Tensor: return (weights*2).contiguous().to(Device.DEFAULT) + x
|
|
w2_noprune = TinyJit(w2)
|
|
w2_prune = TinyJit(w2, prune=True)
|
|
|
|
for _ in range(3):
|
|
a = Tensor.rand(16).realize()
|
|
out = w2_noprune(a)
|
|
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
|
|
|
|
for _ in range(3):
|
|
a = Tensor.rand(16).realize()
|
|
out = w2_prune(a)
|
|
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
|
|
|
|
assert_jit_cache_len(w2_prune, 1)
|
|
|
|
class TestJitFree(unittest.TestCase):
|
|
def test_free_intermediates(self):
|
|
ext_tensor = Tensor([1,24,23,45,1])
|
|
@TinyJit
|
|
def fxn(x:Tensor):
|
|
t1 = (x * 2).contiguous().realize()
|
|
t2 = (t1 + ext_tensor).contiguous().realize()
|
|
out = (t2.sum()).contiguous().realize()
|
|
return out
|
|
for i in range(5):
|
|
out = fxn(inp:=Tensor([i,1,2,3,4]))
|
|
self.assertEqual(out.item(), 114+2*i)
|
|
pre_free = GlobalCounters.mem_used
|
|
fxn.captured.free_intermediates()
|
|
savings_after_free = pre_free - GlobalCounters.mem_used
|
|
|
|
expected_savings = (len(inp) * inp.dtype.itemsize * 2) + dtypes.float32.itemsize # (t1 and t2) + out
|
|
|
|
self.assertGreaterEqual(savings_after_free, expected_savings)
|
|
out = fxn(Tensor([11,1,2,3,4]))
|
|
self.assertEqual(out.item(), 136)
|
|
|
|
# Try one more time...
|
|
pre_free = GlobalCounters.mem_used
|
|
fxn.captured.free_intermediates()
|
|
fxn.captured.free_intermediates() # 2nd time to validate
|
|
savings_after_free = pre_free - GlobalCounters.mem_used
|
|
|
|
self.assertGreaterEqual(savings_after_free, expected_savings)
|
|
out = fxn(Tensor([11,1,2,3,4]))
|
|
self.assertEqual(out.item(), 136)
|
|
|
|
def test_updated_not_freed(self):
|
|
x = Tensor([1]).realize()
|
|
@TinyJit
|
|
def fxn(y):
|
|
nonlocal x
|
|
x += y
|
|
return x
|
|
for _ in range(5): fxn(Tensor([1]))
|
|
self.assertEqual(x.item(), 6)
|
|
pre_free = GlobalCounters.mem_used
|
|
fxn.captured.free_intermediates()
|
|
savings_after_free = pre_free - GlobalCounters.mem_used
|
|
self.assertEqual(savings_after_free, 0)
|
|
fxn(Tensor([2]))
|
|
self.assertEqual(x.item(), 8)
|
|
|
|
class TestJitGraphSplit(unittest.TestCase):
|
|
def compute(self, device, inp):
|
|
assert inp.device == device, f"Input device {inp.device} does not match expected {device}"
|
|
return (inp + 1.0).contiguous().realize()
|
|
|
|
def copy(self, device, to_device, inp):
|
|
assert inp.device == device, f"Input device {inp.device} does not match expected {device}"
|
|
return inp.to(to_device).realize()
|
|
|
|
def expect(self, f, *args, graph=None, multigraph=None, hcqgraph=None):
|
|
def _numpies(tpl): return tpl.numpy() if tpl.__class__ is Tensor else tuple([t.numpy() for t in tpl])
|
|
|
|
expected = _numpies(f(*args))
|
|
for i in range(4):
|
|
res = _numpies(f(*args))
|
|
np.testing.assert_allclose(res, expected, atol=1e-4, rtol=1e-5)
|
|
|
|
dev = Device[Device.DEFAULT]
|
|
graph_t = graph_class(dev)
|
|
if graph_t is None: return
|
|
|
|
got = f.captured.linear.src
|
|
from tinygrad.runtime.graph.hcq import HCQGraph
|
|
from tinygrad.engine.jit import MultiGraphRunner
|
|
if graph_t is HCQGraph:
|
|
validate = hcqgraph
|
|
elif issubclass(graph_t, MultiGraphRunner):
|
|
validate = multigraph
|
|
else:
|
|
validate = graph
|
|
|
|
assert len(got) == len(validate), f"Expected {len(validate)} operations, got {len(got)}"
|
|
for expected, si in zip(validate, got):
|
|
ast = si.src[0]
|
|
if expected["type"] == "graph":
|
|
assert call_is_graph(si), f"Expected graph, got {ast.op}"
|
|
inner_cnt = len(ast.src[0].src)
|
|
assert inner_cnt == expected["cnt"], f"Expected {expected['cnt']} operations in graph, got {inner_cnt}"
|
|
elif expected["type"] == "comp":
|
|
assert ast.op in (Ops.SINK, Ops.PROGRAM), f"Expected kernel, got {ast.op}"
|
|
elif expected["type"] in ("copy", "xfer"):
|
|
assert ast.op is Ops.COPY, f"Expected COPY, got {ast.op}"
|
|
|
|
def ji_graph(self, cnt): return {"type": "graph", "cnt": cnt}
|
|
def ji_comp(self): return {"type": "comp"}
|
|
def ji_copy(self): return {"type": "copy"}
|
|
def ji_xfer(self): return {"type": "xfer"}
|
|
|
|
def test_jit_split_simple(self):
|
|
@TinyJit
|
|
def f(inp):
|
|
op0 = self.compute(Device.DEFAULT, inp)
|
|
op1 = self.compute(Device.DEFAULT, op0)
|
|
op2 = self.compute(Device.DEFAULT, op1)
|
|
return op2
|
|
|
|
inp = Tensor.randn(10, 10, device=Device.DEFAULT).realize()
|
|
self.expect(f, inp,
|
|
graph=[self.ji_graph(3)],
|
|
multigraph=[self.ji_graph(3)],
|
|
hcqgraph=[self.ji_graph(3)])
|
|
|
|
def test_jit_cpu_simple(self):
|
|
if Device.DEFAULT == "CPU": raise unittest.SkipTest("CPU is not a valid default device for this test")
|
|
|
|
@TinyJit
|
|
def f(inp, inp_cpu):
|
|
op0 = self.compute(Device.DEFAULT, inp)
|
|
op1 = self.compute(Device.DEFAULT, op0)
|
|
op2 = self.compute("CPU", inp_cpu)
|
|
op3 = self.compute(Device.DEFAULT, op1)
|
|
return op2, op3
|
|
|
|
inp = Tensor.randn(10, 10, device=Device.DEFAULT).realize()
|
|
inp_cpu = Tensor.randn(10, 10, device="CPU").realize()
|
|
self.expect(f, inp, inp_cpu,
|
|
graph=[self.ji_graph(2), self.ji_comp(), self.ji_comp()],
|
|
multigraph=[self.ji_graph(2), self.ji_comp(), self.ji_comp()],
|
|
hcqgraph=[self.ji_graph(4)])
|
|
|
|
def test_jit_cpu_several(self):
|
|
if Device.DEFAULT == "CPU": raise unittest.SkipTest("CPU is not a valid default device for this test")
|
|
|
|
@TinyJit
|
|
def f(inp, inp_cpu):
|
|
op0 = self.compute(Device.DEFAULT, inp)
|
|
op1 = self.compute(Device.DEFAULT, op0)
|
|
op2 = self.compute("CPU", inp_cpu)
|
|
op3 = self.compute("CPU", op2)
|
|
op4 = self.compute(Device.DEFAULT, op1)
|
|
return op3, op4
|
|
|
|
inp = Tensor.randn(10, 10, device=Device.DEFAULT).realize()
|
|
inp_cpu = Tensor.randn(10, 10, device="CPU").realize()
|
|
self.expect(f, inp, inp_cpu,
|
|
graph=[self.ji_graph(2), self.ji_graph(2), self.ji_comp()],
|
|
multigraph=[self.ji_graph(2), self.ji_graph(2), self.ji_comp()],
|
|
hcqgraph=[self.ji_graph(5)])
|
|
|
|
def test_jit_multidev(self):
|
|
if Device.DEFAULT == "CPU": raise unittest.SkipTest("CPU is not a valid default device for this test")
|
|
|
|
try: Device[f"{Device.DEFAULT}:1"]
|
|
except Exception: raise unittest.SkipTest("no multidevice")
|
|
|
|
@TinyJit
|
|
def f(inp, inp_d1):
|
|
op0 = self.compute(Device.DEFAULT, inp)
|
|
op1 = self.compute(Device.DEFAULT, op0)
|
|
op2 = self.compute(f"{Device.DEFAULT}:1", inp_d1)
|
|
op3 = self.compute(f"{Device.DEFAULT}:1", op2)
|
|
op4 = self.compute(Device.DEFAULT, op1)
|
|
return op3, op4
|
|
|
|
inp = Tensor.randn(10, 10, device=Device.DEFAULT).realize()
|
|
inp_d1 = Tensor.randn(10, 10, device=f"{Device.DEFAULT}:1").realize()
|
|
self.expect(f, inp, inp_d1,
|
|
graph=[self.ji_graph(2), self.ji_graph(2), self.ji_comp()],
|
|
multigraph=[self.ji_graph(5)],
|
|
hcqgraph=[self.ji_graph(5)])
|
|
|
|
def test_jit_multidev_xfer(self):
|
|
if Device.DEFAULT in {"CPU"}: raise unittest.SkipTest("CPU is not a valid default device for this test (zero-copies)")
|
|
if Device.DEFAULT == "METAL": raise unittest.SkipTest("Metal is flaky, with multidevice (same as metal llama 4gpu?)")
|
|
|
|
try: Device[f"{Device.DEFAULT}:1"]
|
|
except Exception: raise unittest.SkipTest("no multidevice")
|
|
|
|
@TinyJit
|
|
def f(inp, inp_d1):
|
|
op0 = self.compute(Device.DEFAULT, inp)
|
|
op1 = self.compute(Device.DEFAULT, op0)
|
|
op2 = self.compute(f"{Device.DEFAULT}:1", inp_d1)
|
|
op3 = self.copy(f"{Device.DEFAULT}:1", Device.DEFAULT, op2)
|
|
op4 = self.compute(f"{Device.DEFAULT}:1", op2)
|
|
op5 = self.compute(Device.DEFAULT, op3)
|
|
return op1, op4, op5
|
|
|
|
inp = Tensor.randn(10, 10, device=Device.DEFAULT).realize()
|
|
inp_d1 = Tensor.randn(10, 10, device=f"{Device.DEFAULT}:1").realize()
|
|
self.expect(f, inp, inp_d1,
|
|
graph=[self.ji_graph(2), self.ji_comp(), self.ji_xfer(), self.ji_comp(), self.ji_comp()],
|
|
multigraph=[self.ji_graph(6)],
|
|
hcqgraph=[self.ji_graph(6)])
|
|
|
|
@unittest.skip("this fails if you don't have SDMA or are using AMD_DISABLE_SDMA=1")
|
|
@unittest.skipIf(DEV.interface.startswith("MOCK"), "MockGPU does not support parallel copies")
|
|
def test_jit_multidev_copy(self):
|
|
if Device.DEFAULT in {"CPU"}: raise unittest.SkipTest("CPU/LLVM is not a valid default device for this test (zero-copies)")
|
|
|
|
@TinyJit
|
|
def f(inp):
|
|
op0 = self.compute(Device.DEFAULT, inp)
|
|
op1 = self.compute(Device.DEFAULT, op0)
|
|
op2 = self.copy(Device.DEFAULT, "CPU", op1)
|
|
op3 = self.compute("CPU", op2)
|
|
return op3
|
|
|
|
inp = Tensor.randn(10, 10, device=Device.DEFAULT).realize()
|
|
self.expect(f, inp,
|
|
graph=[self.ji_graph(2), self.ji_copy(), self.ji_comp()],
|
|
multigraph=[self.ji_graph(2), self.ji_copy(), self.ji_comp()],
|
|
hcqgraph=[self.ji_graph(4)])
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|