diff --git a/extra/training.py b/extra/training.py index 8b1782e31f..e889242881 100644 --- a/extra/training.py +++ b/extra/training.py @@ -20,7 +20,7 @@ def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=sparse_categoric losses, accuracies = [], [] for i in (t := trange(steps, disable=os.getenv('CI') is not None)): samp = np.random.randint(0, X_train.shape[0], size=(BS)) - x = Tensor(transform(X_train[samp])) + x = Tensor(transform(X_train[samp]), requires_grad=False) y = target_transform(Y_train[samp]) # network diff --git a/tinygrad/ops/ops_gpu.py b/tinygrad/ops/ops_gpu.py index 5f8105daff..8da2389e9a 100644 --- a/tinygrad/ops/ops_gpu.py +++ b/tinygrad/ops/ops_gpu.py @@ -158,8 +158,8 @@ class Matmul(Function): def backward(ctx, grad_output): input, weight = ctx.saved_tensors - grad_input = matmul(grad_output, weight, buffer_new(input.shape), transpose_b=True) - grad_weight = matmul(input, grad_output, buffer_new(weight.shape), transpose_a=True) + grad_input = matmul(grad_output, weight, buffer_new(input.shape), transpose_b=True) if ctx.needs_input_grad[0] else None + grad_weight = matmul(input, grad_output, buffer_new(weight.shape), transpose_a=True) if ctx.needs_input_grad[1] else None return grad_input, grad_weight class Conv2D(Function): @@ -191,6 +191,6 @@ class Conv2D(Function): rcout = cout//ctx.groups conv_args = H, W, ctx.groups, rcout, cin, oy, ox, iy, ix, ys, xs, bs - dw = convdw(x, grad_output, buffer_new((cout, cin, H, W)), conv_args) - dx = convdx(w, grad_output, buffer_new((bs, cin_, iy, ix), zero=True), conv_args) + dx = convdx(w, grad_output, buffer_new((bs, cin_, iy, ix), zero=True), conv_args) if ctx.needs_input_grad[0] else None + dw = convdw(x, grad_output, buffer_new((cout, cin, H, W)), conv_args) if ctx.needs_input_grad[1] else None return dx, dw diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index c554e8d872..cd6648b9ff 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -161,8 +161,9 @@ class Tensor: assert (t0.grad is not None) with ProfileOp(t0._ctx, t0._ctx.__class__.__name__, [t0.grad], backward=True) as po: grads = t0._ctx.backward(t0._ctx, t0.grad.data) - po.output = grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None + grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None for g in ([grads] if len(t0._ctx.parents) == 1 else grads)] + po.output = [x for x in grads if x is not None] # backward can return None if no required gradient, don't profile it for t, g in zip(t0._ctx.parents, grads): if g is not None and t.requires_grad: assert g.shape == t.shape, \ @@ -382,9 +383,10 @@ class Function: # overwrite with passed params for k, v in kwargs.items(): setattr(ctx, k, v) + ctx.needs_input_grad = [t.requires_grad for t in x] with ProfileOp(ctx, ctx.__class__.__name__, x) as po: ret = Tensor(self.forward(ctx, *[t.data for t in x], **kwargs), - device=ctx.device, requires_grad=any(t.requires_grad for t in x)) + device=ctx.device, requires_grad=any(ctx.needs_input_grad)) po.output = [ret] if ret.requires_grad: ret._ctx = ctx