diff --git a/test/test_ops.py b/test/test_ops.py index 49b873f4b5..cf7b0f5307 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -69,7 +69,7 @@ class TestOps(unittest.TestCase): (torch.div, Tensor.div), (torch.pow, Tensor.pow)]: for shapes in [((5,13,24,16), (5,1,24,1)), ((1,3,1,7,1), (2,1,5,1,8))]: with self.subTest(op=torch_op.__name__, shapes=shapes): - helper_test_op(shapes, torch_op, tinygrad_op, gpu=self.gpu, forward_only=True) + helper_test_op(shapes, torch_op, tinygrad_op, gpu=self.gpu, forward_only=self.gpu) def test_broadcast_partial(self): diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 62203df291..f825441818 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -4,25 +4,31 @@ import numpy as np from .tensor import Function, register # ************* basic ops ************* +def adBC(out, in_sh): #adjoint operation to broadcast is sum. Need to sum all axis with 1 = in_sh[i] < out.shape[i] + return out.sum(axis=tuple([i for i in range(len(in_sh)) if in_sh[i]==1 and out.shape[i]>1])).reshape(in_sh) class Add(Function): @staticmethod def forward(ctx, x, y): + ctx.save_for_backward(x.shape,y.shape) return x+y @staticmethod def backward(ctx, grad_output): - return grad_output, grad_output + shape_x, shape_y = ctx.saved_tensors + return adBC(grad_output, shape_x), adBC(grad_output, shape_y) register('add', Add) class Sub(Function): @staticmethod def forward(ctx, x, y): + ctx.save_for_backward(x.shape,y.shape) return x-y @staticmethod def backward(ctx, grad_output): - return grad_output, -grad_output + shape_x, shape_y = ctx.saved_tensors + return adBC(grad_output, shape_x), adBC( -grad_output, shape_y) register('sub', Sub) class Mul(Function): @@ -34,7 +40,7 @@ class Mul(Function): @staticmethod def backward(ctx, grad_output): x,y = ctx.saved_tensors - return y*grad_output, x*grad_output + return adBC(y*grad_output, x.shape), adBC(x*grad_output, y.shape) register('mul', Mul) class Div(Function): @@ -46,7 +52,7 @@ class Div(Function): @staticmethod def backward(ctx, grad_output): x,y = ctx.saved_tensors - return grad_output / y, -x * grad_output / y**2 + return adBC(grad_output / y, x.shape), adBC(-x * grad_output / y**2, y.shape) # TODO: registering this breaks the default div on the GPU #register('div', Div) @@ -59,7 +65,7 @@ class Pow(Function): @staticmethod def backward(ctx, grad_output): x,y = ctx.saved_tensors - return y * (x**(y-1.0)) * grad_output, (x**y) * np.log(x) * grad_output + return adBC(y * (x**(y-1.0)) * grad_output,x.shape), adBC((x**y) * np.log(x) * grad_output,y.shape) register('pow', Pow) class Sum(Function): @@ -206,7 +212,6 @@ class Conv2D(Function): ret[:,g] += np.tensordot(tx[:,g], tw[g], ((1,4,5),(1,2,3))) return np.moveaxis(ret,4,2).reshape(bs, cout, oy, ox) - @staticmethod def backward(ctx, grad_output): bs,_,oy,ox = grad_output.shape