replace pow in LAMB by tracking b1**t and b2**t per step (#4582)

* replace pow in LAMB by tracking b1**t and b2**t per step

* remove t, add [self.b1_t, self.b2_t] to return

* adam has one less kernel
This commit is contained in:
chenyu
2024-05-14 13:08:22 -04:00
committed by GitHub
parent 9b02aef45a
commit 7afca52796
2 changed files with 13 additions and 12 deletions

View File

@@ -204,7 +204,7 @@ class TestSchedule(unittest.TestCase):
def test_fold_conv_batchnorm_optim(self):
# this is too high
for optim, cnt in [(nn.optim.Adam, 20), (nn.optim.SGD, 17)]:
for optim, cnt in [(nn.optim.Adam, 19), (nn.optim.SGD, 17)]:
with self.subTest(optim=optim.__name__):
with Tensor.train():
img = Tensor.ones(1,3,4,4)
@@ -609,7 +609,7 @@ class TestSchedule(unittest.TestCase):
layer = nn.Linear(768, 768*4)
opt = nn.optim.Adam(nn.state.get_parameters(layer), lr=1e-4)
layer(x).relu().sum().backward()
check_schedule(opt.schedule_step(), 12)
check_schedule(opt.schedule_step(), 11)
def test_adam_conv_fuse(self):
with Tensor.train():
@@ -618,7 +618,7 @@ class TestSchedule(unittest.TestCase):
opt = nn.optim.Adam(nn.state.get_parameters(c1), lr=1e-4)
opt.zero_grad()
c1(img).relu().sum().backward()
check_schedule(opt.schedule_step(), 12)
check_schedule(opt.schedule_step(), 11)
def test_adam_2convs_fuse(self):
with Tensor.train():
@@ -628,7 +628,7 @@ class TestSchedule(unittest.TestCase):
opt = nn.optim.Adam(nn.state.get_parameters([c1, c2]), lr=1e-4)
opt.zero_grad()
c2(c1(img).relu()).relu().sum().backward()
check_schedule(opt.schedule_step(), 14)
check_schedule(opt.schedule_step(), 13)
def test_sgd_conv_fuse(self):
with Tensor.train():

View File

@@ -76,19 +76,20 @@ def Adam(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8): return LAM
class LAMB(Optimizer):
def __init__(self, params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, wd=0.0, adam=False):
super().__init__(params, lr)
self.eps, self.wd, self.adam = eps, wd, adam
self.b1, self.b2, self.t = (Tensor([x], device=self.device, requires_grad=False).realize() for x in [b1, b2, 0])
self.m = [Tensor.zeros(*t.shape, dtype=t.dtype, device=t.device, requires_grad=False).contiguous() for t in self.params]
self.v = [Tensor.zeros(*t.shape, dtype=t.dtype, device=t.device, requires_grad=False).contiguous() for t in self.params]
self.b1, self.b2, self.eps, self.wd, self.adam = b1, b2, eps, wd, adam
self.b1_t, self.b2_t = (Tensor([1], dtype=dtypes.float32, device=self.device, requires_grad=False).realize() for _ in [b1, b2])
self.m = [Tensor.zeros(*t.shape, dtype=dtypes.float32, device=t.device, requires_grad=False).contiguous() for t in self.params]
self.v = [Tensor.zeros(*t.shape, dtype=dtypes.float32, device=t.device, requires_grad=False).contiguous() for t in self.params]
def _step(self) -> List[Tensor]:
self.t.assign(self.t + 1)
self.b1_t *= self.b1
self.b2_t *= self.b2
for i, t in enumerate(self.params):
assert t.grad is not None
self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * t.grad)
self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (t.grad * t.grad))
m_hat = self.m[i] / (1.0 - self.b1 ** self.t)
v_hat = self.v[i] / (1.0 - self.b2 ** self.t)
m_hat = self.m[i] / (1.0 - self.b1_t)
v_hat = self.v[i] / (1.0 - self.b2_t)
up = (m_hat / (v_hat.sqrt() + self.eps)) + self.wd * t.detach()
if not self.adam:
r1 = t.detach().square().sum().sqrt()
@@ -97,4 +98,4 @@ class LAMB(Optimizer):
else:
r = 1.0
t.assign((t.detach() - self.lr * r * up).cast(t.dtype))
return [self.t] + self.m + self.v
return [self.b1_t, self.b2_t] + self.m + self.v