mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-14 08:57:24 +08:00
69 lines
2.7 KiB
Python
69 lines
2.7 KiB
Python
import unittest
|
|
from tinygrad import Tensor, TinyJit
|
|
from tinygrad.nn.state import get_parameters
|
|
from examples.mlperf.models.flat_llama import apply_grad
|
|
|
|
class FlatModel:
|
|
def __init__(self, n_layers:int, dim:int, hidden:int):
|
|
self.n_layers = n_layers
|
|
self.w1 = Tensor.uniform(n_layers, dim, hidden, low=-0.1, high=0.1)
|
|
self.w2 = Tensor.uniform(n_layers, hidden, dim, low=-0.1, high=0.1)
|
|
self.scale = Tensor.uniform(dim, low=0.9, high=1.1)
|
|
self.bias = Tensor.zeros(dim).contiguous()
|
|
|
|
def __call__(self, x:Tensor) -> Tensor:
|
|
h = x
|
|
for i in range(self.n_layers):
|
|
h = (h @ self.w1[i]).relu() @ self.w2[i] + h
|
|
return (h * self.scale + self.bias).sum()
|
|
|
|
class TestApplyGradE2E(unittest.TestCase):
|
|
def _run_with_apply_grad(self, model, xs):
|
|
grads = {p: Tensor.zeros(p.shape, dtype=p.dtype).contiguous().realize() for p in get_parameters(model)}
|
|
for x in xs:
|
|
loss = model(x)
|
|
for p, g in zip(grads, loss.gradient(*grads)):
|
|
apply_grad(grads[p], g.uop)
|
|
Tensor.realize(loss, *grads.values())
|
|
return [grads[p] for p in get_parameters(model)]
|
|
|
|
def _run_reference(self, model, xs):
|
|
for x in xs: model(x).backward()
|
|
return [p.grad for p in get_parameters(model)]
|
|
|
|
def _assert_close(self, got, expected, atol, rtol):
|
|
for g, e in zip(got, expected):
|
|
self.assertTrue(g.allclose(e, atol=atol, rtol=rtol).item(), f"grad mismatch (max abs diff {(g - e).abs().max().item()})")
|
|
|
|
def _assert_match(self, model, xs, atol, rtol):
|
|
self._assert_close(self._run_with_apply_grad(model, xs), self._run_reference(model, xs), atol, rtol)
|
|
|
|
def test_e2e_single_step(self):
|
|
model = FlatModel(n_layers=3, dim=8, hidden=16)
|
|
Tensor.realize(*get_parameters(model))
|
|
self._assert_match(model, [Tensor.randn(2, 8).realize()], atol=1e-4, rtol=1e-4)
|
|
|
|
def test_e2e_multi_step_accumulation(self):
|
|
model = FlatModel(n_layers=4, dim=8, hidden=16)
|
|
Tensor.realize(*get_parameters(model))
|
|
self._assert_match(model, [Tensor.randn(2, 8).realize() for _ in range(3)], atol=1e-4, rtol=1e-4)
|
|
|
|
def test_e2e_jit(self):
|
|
model = FlatModel(n_layers=3, dim=8, hidden=16)
|
|
Tensor.realize(*get_parameters(model))
|
|
grads = {p: Tensor.zeros(p.shape, dtype=p.dtype).contiguous().realize() for p in get_parameters(model)}
|
|
|
|
@TinyJit
|
|
def fwd_bwd(x:Tensor):
|
|
loss = model(x)
|
|
for p, g in zip(grads, loss.gradient(*grads)): apply_grad(grads[p], g.uop)
|
|
Tensor.realize(loss, *grads.values())
|
|
|
|
xs = [Tensor.randn(2, 8).realize() for _ in range(3)]
|
|
for x in xs: fwd_bwd(x)
|
|
self._assert_close([grads[p] for p in get_parameters(model)], self._run_reference(model, xs), atol=1e-3, rtol=1e-3)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|