fix: don't copy precompiled custom kernel outputs (#16084)

This commit is contained in:
wozeparrot
2026-05-07 17:02:38 -04:00
committed by GitHub
parent f9083cf901
commit 4d1a9dca41
2 changed files with 42 additions and 2 deletions

View File

@@ -2,7 +2,7 @@ import numpy as np
import unittest
from tinygrad.function import function
from tinygrad import Tensor, GlobalCounters
from tinygrad.uop.ops import UOp, KernelInfo
from tinygrad.uop.ops import UOp, Ops, KernelInfo
class TestFunction(unittest.TestCase):
def test_simple(self):
@@ -466,6 +466,35 @@ class TestFunctionTuple(unittest.TestCase):
Tensor.realize(a.grad)
np.testing.assert_allclose(a.grad.numpy(), [2., 2., 2., 2.])
def test_custom_kernel_precompile_no_copy_kernel(self):
def my_kernel(C:UOp, A:UOp) -> UOp:
i = UOp.range(A.shape[0], 0)
return C[i].store(A[i] * 2.0).end(i).sink(arg=KernelInfo(name="my_kernel"))
def my_grad(d_c:UOp, call:UOp):
return (None, (Tensor(d_c) * 2.0).uop)
@function(precompile=True, precompile_backward=True)
def f(a:Tensor):
c = Tensor.invalids(*a.shape, dtype=a.dtype, device=a.device)
c = Tensor.custom_kernel(c, a, fxn=my_kernel, grad_fxn=my_grad)[0]
return c
def count_kernels(t:Tensor):
linear, _ = t.linear_with_vars()
return sum((len(call.device) if isinstance(call.device, tuple) else 1)
for call in linear.src if call.src[0].op is Ops.SINK)
a = Tensor([1., 2., 3., 4.], requires_grad=True).contiguous()
Tensor.realize(a)
c = f(a)
self.assertEqual(count_kernels(c), 1)
c.sum().backward()
Tensor.realize(a.grad)
np.testing.assert_allclose(a.grad.numpy(), [2., 2., 2., 2.])
class TestFunctionGrad(unittest.TestCase):
def test_function_grad_ops(self, precompile=False, precompile_backward=False):
N = 64

View File

@@ -99,7 +99,18 @@ def transform_precompiled_call(c:UOp) -> UOp|None:
resolved = [c.gettuple(i) for i in range(len(srcs))]
outs = tuple(r.empty_like() for r in resolved)
targets = [o.param_like(len(c.src)-1+i).shrink_to(s.shape) for i,(o,s) in enumerate(zip(outs, srcs))]
fxn = UOp.sink(*[t.after(t.store(s)) for t,s in zip(targets, srcs)])
subs:dict[UOp, UOp] = {}
items:list[UOp] = []
for s, t in zip(srcs, targets):
while s.op is Ops.AFTER: s = s.src[0]
base = s.base
if base.op in {Ops.CONTIGUOUS, Ops.BUFFER} and base.shape == t.shape and base not in subs:
subs[base] = t.after(t.store(base.src[0])) if base.op is Ops.CONTIGUOUS else t
items.append(s)
else:
items.append(t.after(t.store(s)))
fxn = UOp.sink(*(x.substitute(subs) for x in items))
# body switches from TUPLE to SINK, so the node becomes an opaque CALL (not FUNCTION)
new_call = UOp(Ops.CALL, c.dtype, (fxn, *input_buffers, *outs), c.arg)