diff --git a/examples/custom_function.py b/examples/custom_function.py deleted file mode 100644 index 77d37a20cf..0000000000 --- a/examples/custom_function.py +++ /dev/null @@ -1,89 +0,0 @@ -# this is an example of how you can write terrible DSP compute breaking ops like warpPerspective -# here we use a CUSTOM op to write atan2 -import numpy as np -from tinygrad.helpers import prod - -# *** first, we implement the atan2 op at the lowest level *** -# `atan2_op` can handle both GPUBuffers and CPUBuffers - -from tinygrad.ops import ASTRunner, DeviceBuffer -from tinygrad.runtime.ops_gpu import GPUBuffer -from tinygrad.runtime.ops_cpu import CPUBuffer - -def atan2_op(a:DeviceBuffer, b:DeviceBuffer) -> DeviceBuffer: - assert prod(a.shape) == prod(b.shape) and type(a) == type(b), "shape or type mismatch" - if isinstance(a, GPUBuffer): - ret = GPUBuffer(a.shape) - ASTRunner("atan2", """ - __kernel void atan2(global float *c, global float *a, global float *b) { - int idx = get_global_id(0); - c[idx] = atan2(a[idx], b[idx]); - }""", global_size=[prod(ret.shape)]).build(GPUBuffer.runtime_type).exec([ret, a.contiguous(), b.contiguous()]) - return ret - elif isinstance(a, CPUBuffer): - return CPUBuffer(np.arctan2(a._buf, b._buf)) - else: - raise NotImplementedError(f"no atan2 implemented for {type(a)}") - -# *** second, we write the ATan2 mlop *** -# NOTE: The derivative of atan2 doesn't need a custom op! https://www.liquisearch.com/atan2/derivative -# In general, it is also optional to write a backward function, just your backward pass won't work without it - -from tinygrad.ops import ASTRunner, LazyOp, LoadOps, BinaryOps, UnaryOps -from tinygrad.lazy import LazyBuffer -from tinygrad.tensor import Function - -class ATan2(Function): - def forward(self, a, b): - self.a, self.b = a, b - ast = LazyOp(LoadOps.CUSTOM, (a, b), atan2_op) - return LazyBuffer(a.device, a.shape, LoadOps, ast) - def backward(self, grad_output): - denom = (self.a.binary_op(BinaryOps.MUL, self.a)).binary_op(BinaryOps.ADD, self.b.binary_op(BinaryOps.MUL, self.b)) - return grad_output.binary_op(BinaryOps.MUL, self.b.binary_op(BinaryOps.DIV, denom)) if self.needs_input_grad[0] else None, \ - grad_output.binary_op(BinaryOps.MUL, self.a.unary_op(UnaryOps.NEG).binary_op(BinaryOps.DIV, denom)) if self.needs_input_grad[1] else None - -# *** third, we use our lovely new mlop *** - -from tinygrad.tensor import Tensor - -if __name__ == "__main__": - # create some random Tensors, permute them just because we can - a = Tensor.randn(4,4,requires_grad=True).permute(1,0) - b = Tensor.randn(4,4,requires_grad=True).permute(1,0) - - # run the forward pass. note: up until the .numpy(), it's all lazy - c = ATan2.apply(a, b) - print(c.numpy()) - - # check the forward pass (in numpy) - np.testing.assert_allclose(c.numpy(), np.arctan2(a.numpy(), b.numpy()), atol=1e-5) - - # run the backward pass - c.mean().backward() - assert a.grad is not None and b.grad is not None, "tinygrad didn't compute gradients" - print(a.grad.numpy()) - print(b.grad.numpy()) - - # check the backward pass (in torch) - import torch - ta, tb = torch.tensor(a.numpy(), requires_grad=True), torch.tensor(b.numpy(), requires_grad=True) - tc = torch.atan2(ta, tb) - tc.mean().backward() - assert ta.grad is not None and tb.grad is not None, "torch didn't compute gradients" - np.testing.assert_allclose(a.grad.numpy(), ta.grad.numpy(), atol=1e-5) - np.testing.assert_allclose(b.grad.numpy(), tb.grad.numpy(), atol=1e-5) - - # custom ops even work in the JIT! - from tinygrad.jit import TinyJit - - @TinyJit - def jitted_atan2(a, b): - return ATan2.apply(a, b).realize() - - for i in range(5): - a = Tensor.randn(4,4,requires_grad=True).permute(1,0) - b = Tensor.randn(4,4,requires_grad=True).permute(1,0) - c = jitted_atan2(a, b) - np.testing.assert_allclose(c.numpy(), np.arctan2(a.numpy(), b.numpy()), atol=1e-5) - diff --git a/test/test_custom_function.py b/test/test_custom_function.py new file mode 100644 index 0000000000..2838c080d7 --- /dev/null +++ b/test/test_custom_function.py @@ -0,0 +1,107 @@ +# this is an example of how you can write terrible DSP compute breaking ops like warpPerspective +# here we use a CUSTOM op to write atan2 + +import unittest +import numpy as np +from typing import Optional +from tinygrad.helpers import prod + +# *** first, we implement the atan2 op at the lowest level *** +# `atan2_op` can handle both GPUBuffers and CPUBuffers + +from tinygrad.ops import ASTRunner +from tinygrad.runtime.ops_gpu import GPUBuffer +from tinygrad.runtime.ops_cpu import CPUBuffer + +def atan2_gpu(a:GPUBuffer, b:GPUBuffer) -> GPUBuffer: + ret = GPUBuffer(a.shape) + ASTRunner("atan2", """ + __kernel void atan2(global float *c, global float *a, global float *b) { + int idx = get_global_id(0); + c[idx] = atan2(a[idx], b[idx]); + }""", global_size=[prod(ret.shape)]).build(GPUBuffer.runtime_type).exec([ret, a.contiguous(), b.contiguous()]) + return ret + +def atan2_cpu(a:CPUBuffer, b:CPUBuffer) -> CPUBuffer: + return CPUBuffer(np.arctan2(a._buf, b._buf)) + +def atan2_dispatch(a, b): + assert prod(a.shape) == prod(b.shape) and type(a) == type(b), "shape or type mismatch" + if isinstance(a, GPUBuffer): return atan2_gpu(a, b) + elif isinstance(a, CPUBuffer): return atan2_cpu(a, b) + else: raise NotImplementedError(f"no atan2 implemented for {type(a)}") + +# *** second, we write the ATan2 mlop *** +# NOTE: The derivative of atan2 doesn't need a custom op! https://www.liquisearch.com/atan2/derivative +# In general, it is also optional to write a backward function, just your backward pass won't work without it + +from tinygrad.ops import ASTRunner, LazyOp, LoadOps, BinaryOps, UnaryOps +from tinygrad.lazy import LazyBuffer +from tinygrad.tensor import Function + +class ATan2(Function): + def forward(self, a:LazyBuffer, b:LazyBuffer) -> LazyBuffer: + self.a, self.b = a, b + ast = LazyOp(LoadOps.CUSTOM, (a, b), atan2_dispatch) + return LazyBuffer(a.device, a.shape, LoadOps, ast) + def backward(self, grad_output:LazyBuffer) -> tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: + denom = (self.a.binary_op(BinaryOps.MUL, self.a)).binary_op(BinaryOps.ADD, self.b.binary_op(BinaryOps.MUL, self.b)) + return grad_output.binary_op(BinaryOps.MUL, self.b.binary_op(BinaryOps.DIV, denom)) if self.needs_input_grad[0] else None, \ + grad_output.binary_op(BinaryOps.MUL, self.a.unary_op(UnaryOps.NEG).binary_op(BinaryOps.DIV, denom)) if self.needs_input_grad[1] else None + +# *** third, we use our lovely new mlop in some tests *** + +from tinygrad.tensor import Tensor, Device + +@unittest.skipUnless(Device.DEFAULT in ["CPU", "GPU"], "atan2 is only implemented for CPU and GPU") +class TestCustomFunction(unittest.TestCase): + def test_atan2_forward(self): + # create some random Tensors, permute them just because we can + a = Tensor.randn(4,4,requires_grad=True).permute(1,0) + b = Tensor.randn(4,4,requires_grad=True).permute(1,0) + + # run the forward pass. note: up until the .numpy(), it's all lazy + c = ATan2.apply(a, b) + print(c.numpy()) + + # check the forward pass (in numpy) + np.testing.assert_allclose(c.numpy(), np.arctan2(a.numpy(), b.numpy()), atol=1e-5) + + # fun fact, this never actually calls forward, so it works in all the backends + def test_atan2_backward(self): + # have to go forward before we can go backward + a = Tensor.randn(4,4,requires_grad=True).permute(1,0) + b = Tensor.randn(4,4,requires_grad=True).permute(1,0) + c = ATan2.apply(a, b) + + # run the backward pass + c.mean().backward() + assert a.grad is not None and b.grad is not None, "tinygrad didn't compute gradients" + print(a.grad.numpy()) + print(b.grad.numpy()) + + # check the backward pass (in torch) + import torch + ta, tb = torch.tensor(a.numpy(), requires_grad=True), torch.tensor(b.numpy(), requires_grad=True) + tc = torch.atan2(ta, tb) + tc.mean().backward() + assert ta.grad is not None and tb.grad is not None, "torch didn't compute gradients" + np.testing.assert_allclose(a.grad.numpy(), ta.grad.numpy(), atol=1e-5) + np.testing.assert_allclose(b.grad.numpy(), tb.grad.numpy(), atol=1e-5) + + def test_atan2_jit(self): + # custom ops even work in the JIT! + from tinygrad.jit import TinyJit + + @TinyJit + def jitted_atan2(a:Tensor, b:Tensor) -> Tensor: + return ATan2.apply(a, b).realize() + + for _ in range(5): + a = Tensor.randn(4,4,requires_grad=True).permute(1,0) + b = Tensor.randn(4,4,requires_grad=True).permute(1,0) + c = jitted_atan2(a, b) + np.testing.assert_allclose(c.numpy(), np.arctan2(a.numpy(), b.numpy()), atol=1e-5) + +if __name__ == "__main__": + unittest.main()