zero_grad there to match readme

This commit is contained in:
George Hotz
2020-12-07 23:12:18 -08:00
parent c63f950348
commit 97fd9c1237

View File

@@ -56,7 +56,6 @@ class TinyConvNet:
def train(model, optim, steps, BS=128, gpu=False):
losses, accuracies = [], []
for i in (t := trange(steps, disable=os.getenv('CI') is not None)):
optim.zero_grad()
samp = np.random.randint(0, X_train.shape[0], size=(BS))
x = Tensor(X_train[samp].reshape((-1, 28*28)).astype(np.float32), gpu=gpu)
@@ -71,6 +70,7 @@ def train(model, optim, steps, BS=128, gpu=False):
# NLL loss function
loss = out.mul(y).mean()
optim.zero_grad()
loss.backward()
optim.step()