mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
* support custom UOp kernels * no number * multioutput works * backward kernel runs * move kernel class * grad later * work * no tags in kernel graph * test arange * arange + contig * delete comment
66 lines
2.2 KiB
Python
66 lines
2.2 KiB
Python
import unittest
|
|
from typing import Callable
|
|
from tinygrad import Tensor, UOp
|
|
from tinygrad.uop.ops import KernelInfo
|
|
|
|
def custom_arange_kernel(C:UOp):
|
|
i = UOp.range(C.size, 0)
|
|
return C[i].store(i.cast(C.dtype.base)).end(i).sink(arg=KernelInfo(name=f"custom_arange_{C.size}"))
|
|
|
|
def custom_add_one_kernel(B:UOp, A:UOp):
|
|
assert B.size == A.size
|
|
i = UOp.range(A.size, 0)
|
|
return B[i].store(A[i] + 1).end(i).sink(arg=KernelInfo(name=f"add_one_{A.size}"))
|
|
|
|
def custom_elementwise_add_kernel(C:UOp, A:UOp, B:UOp):
|
|
i = UOp.range(C.size, 0)
|
|
return C[i].store(A[i]+B[i]).end(i).sink(arg=KernelInfo(name=f"custom_add_kernel_{C.size}")).simplify()
|
|
|
|
def custom_elementwise_addmul_kernel(C:UOp, D:UOp, A:UOp, B:UOp):
|
|
assert C.size == D.size
|
|
i = UOp.range(C.size, 0)
|
|
store_c = C[i].store(A[i]+B[i])
|
|
store_d = D[i].store(A[i]*B[i])
|
|
return UOp.group(store_c, store_d).end(i).sink(arg=KernelInfo(name=f"custom_addmul_kernel_{C.size}")).simplify()
|
|
|
|
def _kernel(tensors:list[Tensor], fxn:Callable) -> list[Tensor]: return [Tensor(u) for u in UOp.custom_kernel(*[t.uop for t in tensors], fxn=fxn)]
|
|
|
|
class TestCustomKernel(unittest.TestCase):
|
|
def test_simple(self):
|
|
a = Tensor.ones(16, 16).contiguous()
|
|
b = Tensor.ones(16, 16).contiguous()
|
|
c = Tensor.empty(16, 16)
|
|
|
|
c = _kernel([c,a,b], fxn=custom_elementwise_add_kernel)[0]
|
|
|
|
out = c.flatten().tolist()
|
|
assert all(x == 2 for x in out), "all 2"
|
|
|
|
def test_multioutput(self):
|
|
a = Tensor.full((16, 16), 3.).contiguous()
|
|
b = Tensor.full((16, 16), 3.).contiguous()
|
|
c = Tensor.empty(16, 16)
|
|
d = Tensor.empty(16, 16)
|
|
|
|
c,d = _kernel([c,d,a,b], custom_elementwise_addmul_kernel)[:2]
|
|
Tensor.realize(c,d)
|
|
|
|
assert all(x == 6 for x in c.flatten().tolist()), "all 6"
|
|
assert all(x == 9 for x in d.flatten().tolist()), "all 9"
|
|
|
|
def test_arange(self):
|
|
ref = Tensor.arange(100)
|
|
tst = Tensor.empty_like(ref)
|
|
tst = _kernel([tst], custom_arange_kernel)[0]
|
|
self.assertTrue((ref == tst).all().item())
|
|
|
|
def test_noncontig(self):
|
|
a = Tensor.ones(16, 16).contiguous()
|
|
tst = Tensor.empty_like(a)
|
|
b = a+1
|
|
b_p1 = _kernel([tst, b], custom_add_one_kernel)[0]
|
|
self.assertTrue((b_p1 == 3).all().item())
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|