mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
replace pow in LAMB by tracking b1**t and b2**t per step (#4582)
* replace pow in LAMB by tracking b1**t and b2**t per step * remove t, add [self.b1_t, self.b2_t] to return * adam has one less kernel
This commit is contained in:
@@ -204,7 +204,7 @@ class TestSchedule(unittest.TestCase):
|
||||
|
||||
def test_fold_conv_batchnorm_optim(self):
|
||||
# this is too high
|
||||
for optim, cnt in [(nn.optim.Adam, 20), (nn.optim.SGD, 17)]:
|
||||
for optim, cnt in [(nn.optim.Adam, 19), (nn.optim.SGD, 17)]:
|
||||
with self.subTest(optim=optim.__name__):
|
||||
with Tensor.train():
|
||||
img = Tensor.ones(1,3,4,4)
|
||||
@@ -609,7 +609,7 @@ class TestSchedule(unittest.TestCase):
|
||||
layer = nn.Linear(768, 768*4)
|
||||
opt = nn.optim.Adam(nn.state.get_parameters(layer), lr=1e-4)
|
||||
layer(x).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 12)
|
||||
check_schedule(opt.schedule_step(), 11)
|
||||
|
||||
def test_adam_conv_fuse(self):
|
||||
with Tensor.train():
|
||||
@@ -618,7 +618,7 @@ class TestSchedule(unittest.TestCase):
|
||||
opt = nn.optim.Adam(nn.state.get_parameters(c1), lr=1e-4)
|
||||
opt.zero_grad()
|
||||
c1(img).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 12)
|
||||
check_schedule(opt.schedule_step(), 11)
|
||||
|
||||
def test_adam_2convs_fuse(self):
|
||||
with Tensor.train():
|
||||
@@ -628,7 +628,7 @@ class TestSchedule(unittest.TestCase):
|
||||
opt = nn.optim.Adam(nn.state.get_parameters([c1, c2]), lr=1e-4)
|
||||
opt.zero_grad()
|
||||
c2(c1(img).relu()).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 14)
|
||||
check_schedule(opt.schedule_step(), 13)
|
||||
|
||||
def test_sgd_conv_fuse(self):
|
||||
with Tensor.train():
|
||||
|
||||
@@ -76,19 +76,20 @@ def Adam(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8): return LAM
|
||||
class LAMB(Optimizer):
|
||||
def __init__(self, params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, wd=0.0, adam=False):
|
||||
super().__init__(params, lr)
|
||||
self.eps, self.wd, self.adam = eps, wd, adam
|
||||
self.b1, self.b2, self.t = (Tensor([x], device=self.device, requires_grad=False).realize() for x in [b1, b2, 0])
|
||||
self.m = [Tensor.zeros(*t.shape, dtype=t.dtype, device=t.device, requires_grad=False).contiguous() for t in self.params]
|
||||
self.v = [Tensor.zeros(*t.shape, dtype=t.dtype, device=t.device, requires_grad=False).contiguous() for t in self.params]
|
||||
self.b1, self.b2, self.eps, self.wd, self.adam = b1, b2, eps, wd, adam
|
||||
self.b1_t, self.b2_t = (Tensor([1], dtype=dtypes.float32, device=self.device, requires_grad=False).realize() for _ in [b1, b2])
|
||||
self.m = [Tensor.zeros(*t.shape, dtype=dtypes.float32, device=t.device, requires_grad=False).contiguous() for t in self.params]
|
||||
self.v = [Tensor.zeros(*t.shape, dtype=dtypes.float32, device=t.device, requires_grad=False).contiguous() for t in self.params]
|
||||
|
||||
def _step(self) -> List[Tensor]:
|
||||
self.t.assign(self.t + 1)
|
||||
self.b1_t *= self.b1
|
||||
self.b2_t *= self.b2
|
||||
for i, t in enumerate(self.params):
|
||||
assert t.grad is not None
|
||||
self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * t.grad)
|
||||
self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (t.grad * t.grad))
|
||||
m_hat = self.m[i] / (1.0 - self.b1 ** self.t)
|
||||
v_hat = self.v[i] / (1.0 - self.b2 ** self.t)
|
||||
m_hat = self.m[i] / (1.0 - self.b1_t)
|
||||
v_hat = self.v[i] / (1.0 - self.b2_t)
|
||||
up = (m_hat / (v_hat.sqrt() + self.eps)) + self.wd * t.detach()
|
||||
if not self.adam:
|
||||
r1 = t.detach().square().sum().sqrt()
|
||||
@@ -97,4 +98,4 @@ class LAMB(Optimizer):
|
||||
else:
|
||||
r = 1.0
|
||||
t.assign((t.detach() - self.lr * r * up).cast(t.dtype))
|
||||
return [self.t] + self.m + self.v
|
||||
return [self.b1_t, self.b2_t] + self.m + self.v
|
||||
|
||||
Reference in New Issue
Block a user