Files
onepilot/tinygrad_repo/test/test_optim.py
T
carrot 77a8919349 TR16 Model, fix radar routine (#211)
* UV+DTR model

* DTR model.. again.

* fix naviGPS

* fix radar...

* fix..

* test

* fix..

* carrot serv

* fix..

* fix.. fleet

* fix.. radar

* fix atc

* Steam Powered model..

* fix.. radarLatFactor range.. 200->500

* fix.. dbc..

* side

* SP v2

* brake light

* fix brakelight

* fix..

* add datetime...

* fix..

* fix..

* fix..

* fix..

* blind spot

* fix tz

* fix..

* ff

* radarLatFactor

* fix.. bsd

* Revert "fix.. bsd"

This reverts commit 1d0d1434470e1b92c65eaffaeb8dd7cd779f85ee.

* fix.. bsd side..

* test

* fix.. e2e conditions

* Revert "test"

This reverts commit 0ce791dbd66c17260366ed1a4df2626c602dbb7d.

* TR16

* fix cut-in detect threshold  3.4 -> 2.6

* fix.. jerk_l limit 5->10

* fix..

* fix.. gm

* fix.. OPTIMA_H mass

* fix.. radar..

* fix radar..

* fix..

* Radar...

* fix..

* fix..

* fix..

* fix.. radartrack 3

* fix..

* fix..

* fix..

* merge..

* fix.. canfd

* fix..

* fix..

* fix..

* fix.. radard

* new cut_in

* Revert "new cut_in"

This reverts commit b9b6e9b33318fe1ce7d626468139b17848efcdcd.

* fix..

* new cut_in detect...

* fix.. disp..

* fix..

* fix..

* fix.. center radar..

* fix.. radar y_sane..

* fix..

* fix..

* hkg jerk 10 -> 5

* fix..

* fix..

* fix.. radar dbc..

* fix..

* fix.. jLead filter..

* test new radar interface..

* fix..

* fix..

* test time...

* Revert "test time..."

This reverts commit 63e9187736985c4dc4b4f3736674ba7cda6adc3f.

* fix radar..

* fix..

* FireHose model..

* tinygrad

* Update interface.py

* fix..

* fix.. nff toyota corolla_tss2

* fix..

* fix..

* fix.. radar

* fix..

* fix.. radar, y_gate

* fix.. radar..

* fix.. for clone..

* scc radar enable at low speed..

* fix.. settings..

* fix.

* fix..

* fix.. radarTimeStep.

* TR16 model again..

* RELEASE.md

* fix cut-in detection...

* fix.. registeration timeout 15sec..

* fix..

* fix.. radar processing.

* fix..

* fix..

* fix..

* fix..

* fix..

* fix..
2025-09-05 15:43:10 +09:00

163 lines
7.9 KiB
Python

import numpy as np
import torch
import unittest
from tinygrad import Tensor, Device, dtypes
from tinygrad.nn.optim import Adam, SGD, AdamW, Muon
from tinygrad.helpers import CI
from tinygrad.device import is_dtype_supported
from extra.torch_muon import SingleDeviceMuon as TorchMuon
np.random.seed(1337)
x_init = np.random.randn(1,4).astype(np.float32)
W_init = np.random.randn(4,4).astype(np.float32)
m_init = np.random.randn(1,4).astype(np.float32)
class TeenyNet:
def __init__(self, tensor):
self.x = tensor(x_init.copy(), requires_grad=True)
self.W = tensor(W_init.copy(), requires_grad=True)
def forward(self):
return (self.x * self.W).sum()
class TinyNet:
def __init__(self, tensor):
self.x = tensor(x_init.copy(), requires_grad=True)
self.W = tensor(W_init.copy(), requires_grad=True)
self.m = tensor(m_init.copy())
def forward(self):
out = self.x.matmul(self.W).relu()
# print(out.detach().numpy())
out = out.log_softmax(1)
out = out.mul(self.m).add(self.m).sum()
return out
def step(tensor, optim, steps=1, teeny=False, **kwargs):
net = TeenyNet(tensor) if teeny else TinyNet(tensor)
optim = optim([net.x, net.W], **kwargs)
for _ in range(steps):
out = net.forward()
optim.zero_grad()
out.backward()
optim.step()
return net.x.detach().numpy(), net.W.detach().numpy()
@unittest.skipIf(CI and Device.DEFAULT in {"CUDA", "NV"}, "slow")
class TestOptim(unittest.TestCase):
def setUp(self):
self.old_training = Tensor.training
Tensor.training = True
def tearDown(self):
Tensor.training = self.old_training
def _test_optim(self, tinygrad_optim, torch_optim, steps, opts, atol, rtol):
for x,y in zip(step(Tensor, tinygrad_optim, steps, **opts),
step(torch.tensor, torch_optim, steps, **opts)):
np.testing.assert_allclose(x, y, atol=atol, rtol=rtol)
def _test_sgd(self, steps, opts, atol, rtol): self._test_optim(SGD, torch.optim.SGD, steps, opts, atol, rtol)
def _test_adam(self, steps, opts, atol, rtol): self._test_optim(Adam, torch.optim.Adam, steps, opts, atol, rtol)
def _test_adamw(self, steps, opts, atol, rtol): self._test_optim(AdamW, torch.optim.AdamW, steps, opts, atol, rtol)
#TODO: use torch.muon when it comes out
def _test_muon(self, steps, opts, atol, rtol): self._test_optim(Muon, TorchMuon, steps, opts, atol, rtol)
def test_multistep_sgd_high_lr_teeny(self): self._test_sgd(2, {'lr': 1.1, 'teeny': True}, 1e-6, 1e-5)
def test_multistep_adam_high_lr_teeny(self): self._test_adam(2, {'lr': 1.1, 'teeny': True}, 2e-4, 5e-4)
def test_multistep_muon_high_lr_teeny(self): self._test_muon(2, {'lr': 1.1, 'teeny': True}, 2e-4, 5e-4)
def test_sgd(self): self._test_sgd(1, {'lr': 0.001}, 1e-6, 0)
def test_sgd_high_lr(self): self._test_sgd(1, {'lr': 10}, 1e-6, 1e-5)
def test_sgd_wd(self): self._test_sgd(1, {'lr': 0.001, 'weight_decay': 0.1}, 1e-6, 0)
def test_sgd_high_lr_wd(self): self._test_sgd(1, {'lr': 10, 'weight_decay': 0.1}, 1e-6, 1e-5)
def test_multistep_sgd(self): self._test_sgd(10, {'lr': 0.001}, 1e-6, 0)
def test_multistep_sgd_high_lr(self): self._test_sgd(10, {'lr': 10}, 1e-6, 3e-4)
def test_multistep_sgd_wd(self): self._test_sgd(10, {'lr': 0.001, 'weight_decay': 0.1}, 1e-6, 0)
def test_multistep_sgd_high_lr_wd(self): self._test_sgd(10, {'lr': 9, 'weight_decay': 0.1}, 1e-6, 3e-4)
def test_multistep_sgd_momentum(self): self._test_sgd(10, {'lr': 0.001, 'momentum': 0.9}, 1e-6, 0)
def test_multistep_sgd_high_lr_momentum(self): self._test_sgd(10, {'lr': 10, 'momentum': 0.9}, 1e-5, 3e-4)
def test_multistep_sgd_momentum_wd(self): self._test_sgd(10, {'lr': 0.001, 'momentum': 0.9, 'weight_decay': 0.1}, 1e-6, 0)
def test_multistep_sgd_high_lr_momentum_wd(self): self._test_sgd(10, {'lr': 10, 'momentum': 0.9, 'weight_decay': 0.1}, 1e-5, 3e-4)
def test_multistep_sgd_nesterov_momentum(self): self._test_sgd(10, {'lr': 0.001, 'momentum': 0.9, 'nesterov': True}, 1e-5, 0)
def test_multistep_sgd_high_lr_nesterov_momentum(self): self._test_sgd(10, {'lr': 10, 'momentum': 0.9, 'nesterov': True}, 1e-5, 3e-4)
def test_multistep_sgd_nesterov_momentum_wd(self):
self._test_sgd(10, {'lr': 0.001, 'momentum': 0.9, 'nesterov': True, 'weight_decay': 0.1}, 1e-5, 0)
def test_multistep_sgd_high_lr_nesterov_momentum_wd(self):
self._test_sgd(10, {'lr': 9, 'momentum': 0.9, 'nesterov': True, 'weight_decay': 0.1}, 1e-5, 3e-4)
def test_muon(self): self._test_muon(1, {'lr': 0.001}, 1e-6, 0)
def test_muon_high_lr(self): self._test_muon(1, {'lr': 10}, 1e-6, 3e-4)
def test_muon_wd(self): self._test_muon(1, {'lr': 0.001, 'weight_decay': 0.01}, 1e-6, 0)
def test_muon_high_lr_wd(self): self._test_muon(1, {'lr': 10, 'weight_decay': 0.01}, 1e-6, 3e-4)
# NOTE: momentum set to 0.95 by default, nesterov set to True by default
def test_multistep_muon_momentum_wd(self): self._test_muon(10, {'lr': 0.001, 'weight_decay': 0.01}, 1e-5, 0)
# ns defaults are numerically unstable, but it is tolerable in real training (see nsteps/nparam tests)
def test_multistep_muon_high_lr_momentum_wd(self): self._test_muon(10, {'lr': 10, 'weight_decay': 0.01}, 1e-1, 3e-4)
def test_multistep_muon_no_nesterov_momentum(self): self._test_muon(10, {'lr': 0.001, 'nesterov': False}, 1e-5, 0)
def test_multistep_muon_high_lr_no_nesterov_momentum(self): self._test_muon(10, {'lr': 10, 'nesterov': False}, 0.5e-1, 1e-1)
def test_muon_ns_steps(self): self._test_muon(1, {'lr': 0.001, 'ns_steps': 3}, 1e-6, 0)
def test_muon_high_lr_ns_steps(self): self._test_muon(1, {'lr': 10, 'ns_steps': 3}, 1e-5, 3e-4)
def test_muon_ns_params(self): self._test_muon(1, {'lr': 0.001,'ns_params': (2.0,-1.5,0.5)}, 1e-6, 0)
def test_muon_high_lr_ns_params(self): self._test_muon(1, {'lr': 10,'ns_params': (2.0,-1.5,0.5)}, 1e-5, 3e-4)
def test_muon_momentum_wd_ns_steps_ns_params(self):
self._test_muon(10, {'lr': 0.001, 'momentum': 0.90, 'weight_decay': 0.01, 'ns_steps': 3, 'ns_params': (2.0,-1.5,0.5)}, 1e-5, 0)
def test_multistep_muon_high_lr_momentum_wd_ns_steps_ns_params(self):
self._test_muon(10, {'lr': 10, 'momentum': 0.90, 'weight_decay': 0.01, 'ns_steps': 3, 'ns_params': (2.0,-1.5,0.5)}, 1e-5, 3e-4)
def test_adam(self): self._test_adam(1, {'lr': 0.001}, 1e-5, 0)
def test_adam_high_lr(self): self._test_adam(1, {'lr': 10}, 1e-4, 1e-4)
def test_adamw(self): self._test_adamw(1, {'lr': 0.001}, 1e-5, 0)
def test_adamw_high_lr(self): self._test_adamw(1, {'lr': 10}, 1e-4, 1e-4)
def test_multistep_adam(self): self._test_adam(10, {'lr': 0.001}, 1e-5, 0)
def test_multistep_adam_high_lr(self): self._test_adam(10, {'lr': 10}, 2e-3, 5e-4)
def test_multistep_adamw(self): self._test_adamw(10, {'lr': 0.001}, 1e-5, 0)
def test_multistep_adamw_high_lr(self): self._test_adamw(10, {'lr': 10}, 5e-4, 2e-3)
def test_duped_weights(self):
for Opt in [Adam, AdamW, SGD]:
losses = []
for i in range(2):
w = Tensor(x_init.copy())
opt = Opt([w], lr=0.1) if i == 0 else Opt([w, w], lr=0.1)
loss = None
for _ in range(3):
loss = w.sum()
opt.zero_grad()
loss.backward()
opt.step()
losses.append(loss.numpy())
np.testing.assert_allclose(losses[0], losses[1], atol=1e-4, rtol=0)
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_mixed_precision(self):
old_default_float, dtypes.default_float = dtypes.default_float, dtypes.half
# weight update would overflow without upcasting
self._test_sgd(10, {'lr': 1e10}, 1e-6, 3e-4)
self._test_adam(1, {'lr': 1e10}, 1e-4, 1e-4)
self._test_adamw(1, {'lr': 1e10}, 1e-4, 1e-4)
dtypes.default_float = old_default_float
def test_assert_tensor_train(self):
t = Tensor.ones((1,1), requires_grad=True)
optimizer = Adam([t])
optimizer.zero_grad()
old_state = Tensor.training
t.sum().backward()
Tensor.training = False
self.assertRaises(RuntimeError, optimizer.step)
Tensor.training = True
optimizer.step()
Tensor.training = old_state
if __name__ == '__main__':
unittest.main()