From 97fd9c1237c65b215903eb6c1064a1784827efe8 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Mon, 7 Dec 2020 23:12:18 -0800 Subject: [PATCH] zero_grad there to match readme --- test/test_mnist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_mnist.py b/test/test_mnist.py index 41aeae17d5..481abfa674 100644 --- a/test/test_mnist.py +++ b/test/test_mnist.py @@ -56,7 +56,6 @@ class TinyConvNet: def train(model, optim, steps, BS=128, gpu=False): losses, accuracies = [], [] for i in (t := trange(steps, disable=os.getenv('CI') is not None)): - optim.zero_grad() samp = np.random.randint(0, X_train.shape[0], size=(BS)) x = Tensor(X_train[samp].reshape((-1, 28*28)).astype(np.float32), gpu=gpu) @@ -71,6 +70,7 @@ def train(model, optim, steps, BS=128, gpu=False): # NLL loss function loss = out.mul(y).mean() + optim.zero_grad() loss.backward() optim.step()