mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
fix: don't copy precompiled custom kernel outputs (#16084)
This commit is contained in:
@@ -2,7 +2,7 @@ import numpy as np
|
||||
import unittest
|
||||
from tinygrad.function import function
|
||||
from tinygrad import Tensor, GlobalCounters
|
||||
from tinygrad.uop.ops import UOp, KernelInfo
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo
|
||||
|
||||
class TestFunction(unittest.TestCase):
|
||||
def test_simple(self):
|
||||
@@ -466,6 +466,35 @@ class TestFunctionTuple(unittest.TestCase):
|
||||
Tensor.realize(a.grad)
|
||||
np.testing.assert_allclose(a.grad.numpy(), [2., 2., 2., 2.])
|
||||
|
||||
def test_custom_kernel_precompile_no_copy_kernel(self):
|
||||
def my_kernel(C:UOp, A:UOp) -> UOp:
|
||||
i = UOp.range(A.shape[0], 0)
|
||||
return C[i].store(A[i] * 2.0).end(i).sink(arg=KernelInfo(name="my_kernel"))
|
||||
|
||||
def my_grad(d_c:UOp, call:UOp):
|
||||
return (None, (Tensor(d_c) * 2.0).uop)
|
||||
|
||||
@function(precompile=True, precompile_backward=True)
|
||||
def f(a:Tensor):
|
||||
c = Tensor.invalids(*a.shape, dtype=a.dtype, device=a.device)
|
||||
c = Tensor.custom_kernel(c, a, fxn=my_kernel, grad_fxn=my_grad)[0]
|
||||
return c
|
||||
|
||||
def count_kernels(t:Tensor):
|
||||
linear, _ = t.linear_with_vars()
|
||||
return sum((len(call.device) if isinstance(call.device, tuple) else 1)
|
||||
for call in linear.src if call.src[0].op is Ops.SINK)
|
||||
|
||||
a = Tensor([1., 2., 3., 4.], requires_grad=True).contiguous()
|
||||
Tensor.realize(a)
|
||||
c = f(a)
|
||||
|
||||
self.assertEqual(count_kernels(c), 1)
|
||||
|
||||
c.sum().backward()
|
||||
Tensor.realize(a.grad)
|
||||
np.testing.assert_allclose(a.grad.numpy(), [2., 2., 2., 2.])
|
||||
|
||||
class TestFunctionGrad(unittest.TestCase):
|
||||
def test_function_grad_ops(self, precompile=False, precompile_backward=False):
|
||||
N = 64
|
||||
|
||||
@@ -99,7 +99,18 @@ def transform_precompiled_call(c:UOp) -> UOp|None:
|
||||
resolved = [c.gettuple(i) for i in range(len(srcs))]
|
||||
outs = tuple(r.empty_like() for r in resolved)
|
||||
targets = [o.param_like(len(c.src)-1+i).shrink_to(s.shape) for i,(o,s) in enumerate(zip(outs, srcs))]
|
||||
fxn = UOp.sink(*[t.after(t.store(s)) for t,s in zip(targets, srcs)])
|
||||
|
||||
subs:dict[UOp, UOp] = {}
|
||||
items:list[UOp] = []
|
||||
for s, t in zip(srcs, targets):
|
||||
while s.op is Ops.AFTER: s = s.src[0]
|
||||
base = s.base
|
||||
if base.op in {Ops.CONTIGUOUS, Ops.BUFFER} and base.shape == t.shape and base not in subs:
|
||||
subs[base] = t.after(t.store(base.src[0])) if base.op is Ops.CONTIGUOUS else t
|
||||
items.append(s)
|
||||
else:
|
||||
items.append(t.after(t.store(s)))
|
||||
fxn = UOp.sink(*(x.substitute(subs) for x in items))
|
||||
|
||||
# body switches from TUPLE to SINK, so the node becomes an opaque CALL (not FUNCTION)
|
||||
new_call = UOp(Ops.CALL, c.dtype, (fxn, *input_buffers, *outs), c.arg)
|
||||
|
||||
Reference in New Issue
Block a user