diff --git a/test/unit/test_function.py b/test/unit/test_function.py index 127c89076a..6fa5b50fee 100644 --- a/test/unit/test_function.py +++ b/test/unit/test_function.py @@ -2,7 +2,7 @@ import numpy as np import unittest from tinygrad.function import function from tinygrad import Tensor, GlobalCounters -from tinygrad.uop.ops import UOp, KernelInfo +from tinygrad.uop.ops import UOp, Ops, KernelInfo class TestFunction(unittest.TestCase): def test_simple(self): @@ -466,6 +466,35 @@ class TestFunctionTuple(unittest.TestCase): Tensor.realize(a.grad) np.testing.assert_allclose(a.grad.numpy(), [2., 2., 2., 2.]) + def test_custom_kernel_precompile_no_copy_kernel(self): + def my_kernel(C:UOp, A:UOp) -> UOp: + i = UOp.range(A.shape[0], 0) + return C[i].store(A[i] * 2.0).end(i).sink(arg=KernelInfo(name="my_kernel")) + + def my_grad(d_c:UOp, call:UOp): + return (None, (Tensor(d_c) * 2.0).uop) + + @function(precompile=True, precompile_backward=True) + def f(a:Tensor): + c = Tensor.invalids(*a.shape, dtype=a.dtype, device=a.device) + c = Tensor.custom_kernel(c, a, fxn=my_kernel, grad_fxn=my_grad)[0] + return c + + def count_kernels(t:Tensor): + linear, _ = t.linear_with_vars() + return sum((len(call.device) if isinstance(call.device, tuple) else 1) + for call in linear.src if call.src[0].op is Ops.SINK) + + a = Tensor([1., 2., 3., 4.], requires_grad=True).contiguous() + Tensor.realize(a) + c = f(a) + + self.assertEqual(count_kernels(c), 1) + + c.sum().backward() + Tensor.realize(a.grad) + np.testing.assert_allclose(a.grad.numpy(), [2., 2., 2., 2.]) + class TestFunctionGrad(unittest.TestCase): def test_function_grad_ops(self, precompile=False, precompile_backward=False): N = 64 diff --git a/tinygrad/callify.py b/tinygrad/callify.py index 13b85cbf0b..79566aebc1 100644 --- a/tinygrad/callify.py +++ b/tinygrad/callify.py @@ -99,7 +99,18 @@ def transform_precompiled_call(c:UOp) -> UOp|None: resolved = [c.gettuple(i) for i in range(len(srcs))] outs = tuple(r.empty_like() for r in resolved) targets = [o.param_like(len(c.src)-1+i).shrink_to(s.shape) for i,(o,s) in enumerate(zip(outs, srcs))] - fxn = UOp.sink(*[t.after(t.store(s)) for t,s in zip(targets, srcs)]) + + subs:dict[UOp, UOp] = {} + items:list[UOp] = [] + for s, t in zip(srcs, targets): + while s.op is Ops.AFTER: s = s.src[0] + base = s.base + if base.op in {Ops.CONTIGUOUS, Ops.BUFFER} and base.shape == t.shape and base not in subs: + subs[base] = t.after(t.store(base.src[0])) if base.op is Ops.CONTIGUOUS else t + items.append(s) + else: + items.append(t.after(t.store(s))) + fxn = UOp.sink(*(x.substitute(subs) for x in items)) # body switches from TUPLE to SINK, so the node becomes an opaque CALL (not FUNCTION) new_call = UOp(Ops.CALL, c.dtype, (fxn, *input_buffers, *outs), c.arg)