mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
only need to check the min_lr for the nan bug
This commit is contained in:
@@ -135,30 +135,18 @@ class TestRealWorld(unittest.TestCase):
|
||||
#Device.DEFAULT = old_default
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault")
|
||||
@unittest.skipIf(Device.DEFAULT in ["LLVM", "CLANG"] and CI, "too long on CI LLVM and CLANG")
|
||||
def test_train_cifar_half(self):
|
||||
# @unittest.skipIf(Device.DEFAULT in ["LLVM", "CLANG"] and CI, "too long on CI LLVM and CLANG")
|
||||
def test_train_cifar_hyp(self):
|
||||
dtypes.default_float = dtypes.float16
|
||||
with Tensor.train():
|
||||
model = SpeedyResNet(Tensor.ones((12,3,2,2)))
|
||||
optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=0.8, nesterov=True, weight_decay=0.15)
|
||||
# optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=hyp['opt']['momentum'], nesterov=True, weight_decay=hyp['opt']['bias_decay'])
|
||||
# initial_div_factor = hyp['opt']['initial_div_factor']
|
||||
# final_lr_ratio = hyp['opt']['final_lr_ratio']
|
||||
# pct_start = hyp['opt']['percent_start']
|
||||
# lr_scheduler = OneCycleLR(optimizer, max_lr=hyp['opt']['bias_lr'], pct_start=pct_start, div_factor=initial_div_factor,
|
||||
# final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=4)
|
||||
BS = 32 if CI else 512
|
||||
|
||||
@TinyJit
|
||||
def train(X):
|
||||
out = model(X)
|
||||
loss = out.mean()
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
# lr_scheduler.step()
|
||||
|
||||
helper_test("train_cifar_half", lambda: (Tensor.randn(BS, 3, 32, 32),), train, (1.0/48)*BS, 145 if CI else 157)
|
||||
optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=hyp['opt']['momentum'], nesterov=True, weight_decay=hyp['opt']['bias_decay'])
|
||||
initial_div_factor = hyp['opt']['initial_div_factor']
|
||||
final_lr_ratio = hyp['opt']['final_lr_ratio']
|
||||
pct_start = hyp['opt']['percent_start']
|
||||
lr_scheduler = OneCycleLR(optimizer, max_lr=hyp['opt']['bias_lr'], pct_start=pct_start, div_factor=initial_div_factor,
|
||||
final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=4)
|
||||
assert not np.isnan(lr_scheduler.min_lr.numpy()), "lr too small or initial_div_facotr too big for half"
|
||||
|
||||
dtypes.default_float = dtypes.float32
|
||||
|
||||
|
||||
Reference in New Issue
Block a user