diff --git a/test/test_custom_function.py b/test/test_custom_function.py index 2838c080d7..4cccf82863 100644 --- a/test/test_custom_function.py +++ b/test/test_custom_function.py @@ -3,17 +3,19 @@ import unittest import numpy as np -from typing import Optional +from typing import Optional, Tuple from tinygrad.helpers import prod # *** first, we implement the atan2 op at the lowest level *** -# `atan2_op` can handle both GPUBuffers and CPUBuffers +# `atan2_gpu` for GPUBuffers and `atan2_cpu` for CPUBuffers -from tinygrad.ops import ASTRunner -from tinygrad.runtime.ops_gpu import GPUBuffer +from tinygrad.ops import ASTRunner, CompiledBuffer from tinygrad.runtime.ops_cpu import CPUBuffer -def atan2_gpu(a:GPUBuffer, b:GPUBuffer) -> GPUBuffer: +# we don't always have GPU support, so the type signature is the abstract CompiledBuffer instead of GPUBuffer +def atan2_gpu(a:CompiledBuffer, b:CompiledBuffer) -> CompiledBuffer: + from tinygrad.runtime.ops_gpu import GPUBuffer + assert type(a) == GPUBuffer and type(b) == GPUBuffer, "gpu function requires GPUBuffers" ret = GPUBuffer(a.shape) ASTRunner("atan2", """ __kernel void atan2(global float *c, global float *a, global float *b) { @@ -25,12 +27,6 @@ def atan2_gpu(a:GPUBuffer, b:GPUBuffer) -> GPUBuffer: def atan2_cpu(a:CPUBuffer, b:CPUBuffer) -> CPUBuffer: return CPUBuffer(np.arctan2(a._buf, b._buf)) -def atan2_dispatch(a, b): - assert prod(a.shape) == prod(b.shape) and type(a) == type(b), "shape or type mismatch" - if isinstance(a, GPUBuffer): return atan2_gpu(a, b) - elif isinstance(a, CPUBuffer): return atan2_cpu(a, b) - else: raise NotImplementedError(f"no atan2 implemented for {type(a)}") - # *** second, we write the ATan2 mlop *** # NOTE: The derivative of atan2 doesn't need a custom op! https://www.liquisearch.com/atan2/derivative # In general, it is also optional to write a backward function, just your backward pass won't work without it @@ -41,10 +37,11 @@ from tinygrad.tensor import Function class ATan2(Function): def forward(self, a:LazyBuffer, b:LazyBuffer) -> LazyBuffer: + assert prod(a.shape) == prod(b.shape) and a.device == b.device, "shape or device mismatch" self.a, self.b = a, b - ast = LazyOp(LoadOps.CUSTOM, (a, b), atan2_dispatch) + ast = LazyOp(LoadOps.CUSTOM, (a, b), {"GPU": atan2_gpu, "CPU": atan2_cpu}[a.device]) return LazyBuffer(a.device, a.shape, LoadOps, ast) - def backward(self, grad_output:LazyBuffer) -> tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: + def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: denom = (self.a.binary_op(BinaryOps.MUL, self.a)).binary_op(BinaryOps.ADD, self.b.binary_op(BinaryOps.MUL, self.b)) return grad_output.binary_op(BinaryOps.MUL, self.b.binary_op(BinaryOps.DIV, denom)) if self.needs_input_grad[0] else None, \ grad_output.binary_op(BinaryOps.MUL, self.a.unary_op(UnaryOps.NEG).binary_op(BinaryOps.DIV, denom)) if self.needs_input_grad[1] else None