diff --git a/tinygrad/optim.py b/tinygrad/optim.py index f1ee044fbd..5887767932 100644 --- a/tinygrad/optim.py +++ b/tinygrad/optim.py @@ -7,9 +7,6 @@ class Optimizer: def __init__(self, params): self.params = [x for x in params if x.requires_grad == True] - def num(self, x): - return Tensor([x], gpu=self.params[0].gpu, requires_grad=False) - def zero_grad(self): for param in self.params: param.grad = None @@ -17,7 +14,7 @@ class Optimizer: class SGD(Optimizer): def __init__(self, params, lr=0.001): super(SGD, self).__init__(params) - self.lr = self.num(lr) + self.lr = lr def step(self): for t in self.params: @@ -26,27 +23,27 @@ class SGD(Optimizer): class RMSprop(Optimizer): def __init__(self, params, lr=0.001, decay=0.9, eps=1e-8): super(RMSprop, self).__init__(params) - self.lr, self.decay, self.eps, self.one, self.two = [self.num(x) for x in [lr, decay, eps, 1, 2]] + self.lr, self.decay, self.eps = Tensor([lr], gpu=self.params[0].gpu, requires_grad=False), decay, eps self.v = [Tensor(np.zeros(t.shape, dtype=np.float32), gpu=params[0].gpu, requires_grad=False) for t in self.params] def step(self): for i, t in enumerate(self.params): - self.v[i] = self.decay * self.v[i] + (self.one - self.decay) * t.grad.pow(self.two) + self.v[i] = self.decay * self.v[i] + (1.0 - self.decay) * t.grad.pow(2.0) t -= self.lr.div(self.v[i].sqrt() + self.eps) * t.grad class Adam(Optimizer): def __init__(self, params, lr=0.001, b1=0.9, b2=0.999, eps=1e-8): super(Adam, self).__init__(params) - self.lr, self.b1, self.b2, self.eps, self.t, self.one, self.two = [self.num(x) for x in [lr, b1, b2, eps, 0, 1, 2]] + self.lr, self.b1, self.b2, self.eps, self.t = lr, b1, b2, eps, Tensor([0.0], gpu=self.params[0].gpu, requires_grad=False) self.m = [Tensor(np.zeros(t.shape, dtype=np.float32), gpu=params[0].gpu, requires_grad=False) for t in self.params] self.v = [Tensor(np.zeros(t.shape, dtype=np.float32), gpu=params[0].gpu, requires_grad=False) for t in self.params] def step(self): - self.t = self.t + self.one - a = self.lr * (self.one - self.b2.pow(self.t)).sqrt().div(self.one - self.b1.pow(self.t)) + self.t = self.t + 1.0 + a = self.lr * ((1.0 - self.b2**self.t)**0.5).div(1.0 - self.b1**self.t) for i,t in enumerate(self.params): - self.m[i] = self.b1 * self.m[i] + (self.one - self.b1) * t.grad - self.v[i] = self.b2 * self.v[i] + (self.one - self.b2) * t.grad.pow(self.two) + self.m[i] = self.b1 * self.m[i] + (1.0 - self.b1) * t.grad + self.v[i] = self.b2 * self.v[i] + (1.0 - self.b2) * t.grad.pow(2.0) t -= a * self.m[i].div(self.v[i].sqrt() + self.eps)