diff --git a/test/test_schedule.py b/test/test_schedule.py index 89ff9dbd55..c777312622 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -204,7 +204,7 @@ class TestSchedule(unittest.TestCase): def test_fold_conv_batchnorm_optim(self): # this is too high - for optim, cnt in [(nn.optim.Adam, 20), (nn.optim.SGD, 17)]: + for optim, cnt in [(nn.optim.Adam, 19), (nn.optim.SGD, 17)]: with self.subTest(optim=optim.__name__): with Tensor.train(): img = Tensor.ones(1,3,4,4) @@ -609,7 +609,7 @@ class TestSchedule(unittest.TestCase): layer = nn.Linear(768, 768*4) opt = nn.optim.Adam(nn.state.get_parameters(layer), lr=1e-4) layer(x).relu().sum().backward() - check_schedule(opt.schedule_step(), 12) + check_schedule(opt.schedule_step(), 11) def test_adam_conv_fuse(self): with Tensor.train(): @@ -618,7 +618,7 @@ class TestSchedule(unittest.TestCase): opt = nn.optim.Adam(nn.state.get_parameters(c1), lr=1e-4) opt.zero_grad() c1(img).relu().sum().backward() - check_schedule(opt.schedule_step(), 12) + check_schedule(opt.schedule_step(), 11) def test_adam_2convs_fuse(self): with Tensor.train(): @@ -628,7 +628,7 @@ class TestSchedule(unittest.TestCase): opt = nn.optim.Adam(nn.state.get_parameters([c1, c2]), lr=1e-4) opt.zero_grad() c2(c1(img).relu()).relu().sum().backward() - check_schedule(opt.schedule_step(), 14) + check_schedule(opt.schedule_step(), 13) def test_sgd_conv_fuse(self): with Tensor.train(): diff --git a/tinygrad/nn/optim.py b/tinygrad/nn/optim.py index a7991f0bbc..c428dc985b 100644 --- a/tinygrad/nn/optim.py +++ b/tinygrad/nn/optim.py @@ -76,19 +76,20 @@ def Adam(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8): return LAM class LAMB(Optimizer): def __init__(self, params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, wd=0.0, adam=False): super().__init__(params, lr) - self.eps, self.wd, self.adam = eps, wd, adam - self.b1, self.b2, self.t = (Tensor([x], device=self.device, requires_grad=False).realize() for x in [b1, b2, 0]) - self.m = [Tensor.zeros(*t.shape, dtype=t.dtype, device=t.device, requires_grad=False).contiguous() for t in self.params] - self.v = [Tensor.zeros(*t.shape, dtype=t.dtype, device=t.device, requires_grad=False).contiguous() for t in self.params] + self.b1, self.b2, self.eps, self.wd, self.adam = b1, b2, eps, wd, adam + self.b1_t, self.b2_t = (Tensor([1], dtype=dtypes.float32, device=self.device, requires_grad=False).realize() for _ in [b1, b2]) + self.m = [Tensor.zeros(*t.shape, dtype=dtypes.float32, device=t.device, requires_grad=False).contiguous() for t in self.params] + self.v = [Tensor.zeros(*t.shape, dtype=dtypes.float32, device=t.device, requires_grad=False).contiguous() for t in self.params] def _step(self) -> List[Tensor]: - self.t.assign(self.t + 1) + self.b1_t *= self.b1 + self.b2_t *= self.b2 for i, t in enumerate(self.params): assert t.grad is not None self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * t.grad) self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (t.grad * t.grad)) - m_hat = self.m[i] / (1.0 - self.b1 ** self.t) - v_hat = self.v[i] / (1.0 - self.b2 ** self.t) + m_hat = self.m[i] / (1.0 - self.b1_t) + v_hat = self.v[i] / (1.0 - self.b2_t) up = (m_hat / (v_hat.sqrt() + self.eps)) + self.wd * t.detach() if not self.adam: r1 = t.detach().square().sum().sqrt() @@ -97,4 +98,4 @@ class LAMB(Optimizer): else: r = 1.0 t.assign((t.detach() - self.lr * r * up).cast(t.dtype)) - return [self.t] + self.m + self.v + return [self.b1_t, self.b2_t] + self.m + self.v