diff --git a/test/test_nn.py b/test/test_nn.py index 561ef8e93e..6a8d461092 100755 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -5,7 +5,8 @@ import torch from tinygrad import Tensor, Device, TinyJit from tinygrad.helpers import CI, Context from tinygrad.ops import BufferOps -from tinygrad.nn import BatchNorm2d, Conv1d,ConvTranspose1d, Conv2d,ConvTranspose2d, Linear, GroupNorm, LayerNorm,LayerNorm2d, Embedding, InstanceNorm +from tinygrad.nn import Conv1d, ConvTranspose1d, Conv2d, ConvTranspose2d, Linear, Embedding +from tinygrad.nn import BatchNorm2d, LayerNorm, LayerNorm2d, GroupNorm, InstanceNorm from tinygrad.nn.state import load_state_dict from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import run_schedule @@ -227,97 +228,132 @@ class TestNN(unittest.TestCase): def test_groupnorm(self): BS, H, W, C, G = 20, 10, 10, 6, 3 + # create in torch + torch_layer = torch.nn.GroupNorm(G, C).eval() + # create in tinygrad layer = GroupNorm(G, C) + layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True) + layer.bias = Tensor(torch_layer.bias.detach().numpy(), requires_grad=True) - # create in torch - with torch.no_grad(): - torch_layer = torch.nn.GroupNorm(G, C).eval() - torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32) - torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32) + for _ in range(10): + # forward + x = Tensor.randn(BS, C, H, W, requires_grad=True) + z = layer(x) + torch_x = torch.tensor(x.numpy(), requires_grad=True) + torch_z = torch_layer(torch_x) + np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6) - # test - x = Tensor.randn(BS, C, H, W) - z = layer(x) - torch_x = torch.tensor(x.numpy()) - torch_z = torch_layer(torch_x) - np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3) + # backward + z.sum().backward() + torch_z.sum().backward(retain_graph=True) + np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4) + np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=5e-4, rtol=5e-4) + np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=5e-4, rtol=5e-4) def test_layernorm(self): N, C, H, W = 20, 5, 10, 10 + # create in torch + torch_layer = torch.nn.LayerNorm([H, W]).eval() + # create in tinygrad layer = LayerNorm([H, W]) + layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True) + layer.bias = Tensor(torch_layer.bias.detach().numpy(), requires_grad=True) - # create in torch - with torch.no_grad(): - torch_layer = torch.nn.LayerNorm([H, W]).eval() - torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32) - torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32) + for _ in range(10): + # forward + x = Tensor.randn(N, C, H, W, requires_grad=True) + z = layer(x) + torch_x = torch.tensor(x.numpy(), requires_grad=True) + torch_z = torch_layer(torch_x) + np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6) - # test - x = Tensor.randn(N, C, H, W) - z = layer(x) - torch_x = torch.tensor(x.numpy()) - torch_z = torch_layer(torch_x) - np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3) + # backward + z.sum().backward() + torch_z.sum().backward(retain_graph=True) + np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4) + np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=5e-4, rtol=5e-4) + np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=5e-4, rtol=5e-4) def test_layernorm_2d(self): N, C, H, W = 20, 5, 10, 10 + # create in torch + torch_layer = torch.nn.LayerNorm([C]).eval() + # create in tinygrad layer = LayerNorm2d(C) + layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True) + layer.bias = Tensor(torch_layer.bias.detach().numpy(), requires_grad=True) - # create in torch - with torch.no_grad(): - torch_layer = torch.nn.LayerNorm([C]).eval() - torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32) - torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32) + for _ in range(10): + # forward + x = Tensor.randn(N, C, H, W, requires_grad=True) + z = layer(x) + torch_x = torch.tensor(x.numpy(), requires_grad=True) + torch_z = torch_layer(torch_x.permute(0,2,3,1)).permute(0,3,1,2) + np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6) - # test - x = Tensor.randn(N, C, H, W) - z = layer(x) - torch_x = torch.tensor(x.numpy()) - torch_z = torch_layer(torch_x.permute(0,2,3,1)).permute(0,3,1,2) - np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3) + # backward + z.sum().backward() + torch_z.sum().backward(retain_graph=True) + np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4) + np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=5e-4, rtol=5e-4) + np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=5e-4, rtol=5e-4) def test_instancenorm_2d(self): - N, C, H, W = 20, 5, 10, 10 + N, C, H, W = 20, 10, 10, 10 + + # create in torch + torch_layer = torch.nn.InstanceNorm2d(C, affine=True).eval() # create in tinygrad layer = InstanceNorm(C) + layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True) + layer.bias = Tensor(torch_layer.bias.detach().numpy(), requires_grad=True) - # create in torch - with torch.no_grad(): - torch_layer = torch.nn.InstanceNorm2d(C, affine=True).eval() - torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32) - torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32) + for _ in range(10): + # forward + x = Tensor.randn(N, C, H, W, requires_grad=True) + z = layer(x) + torch_x = torch.tensor(x.numpy(), requires_grad=True) + torch_z = torch_layer(torch_x) + np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6) - # test - x = Tensor.randn(N, C, H, W) - z = layer(x) - torch_x = torch.tensor(x.numpy()) - torch_z = torch_layer(torch_x) - np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3) + # backward + z.sum().backward() + torch_z.sum().backward(retain_graph=True) + np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3) + np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=1e-3, rtol=1e-3) + np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=1e-3, rtol=1e-3) def test_instancenorm_3d(self): - N, C, D, H, W = 20, 5, 3, 10, 10 + N, C, D, H, W = 20, 10, 10, 10, 10 + + # create in torch + torch_layer = torch.nn.InstanceNorm3d(C, affine=True).eval() # create in tinygrad layer = InstanceNorm(C) + layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True) + layer.bias = Tensor(torch_layer.bias.detach().numpy(), requires_grad=True) - # create in torch - with torch.no_grad(): - torch_layer = torch.nn.InstanceNorm3d(C, affine=True).eval() - torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32) - torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32) + for _ in range(10): + # forward + x = Tensor.randn(N, C, D, H, W, requires_grad=True) + z = layer(x) + torch_x = torch.tensor(x.numpy(), requires_grad=True) + torch_z = torch_layer(torch_x) + np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6) - # test - x = Tensor.randn(N, C, D, H, W) - z = layer(x) - torch_x = torch.tensor(x.numpy()) - torch_z = torch_layer(torch_x) - np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3) + # backward + z.sum().backward() + torch_z.sum().backward(retain_graph=True) + np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3) + np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=1e-3, rtol=1e-3) + np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=1e-3, rtol=1e-3) def test_embedding(self): B, T, embed_size, vocab_size = 4, 10, 20, 28 diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index d7d557fffe..c48291b8c0 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2600,7 +2600,7 @@ class Tensor: print(t.mean().item(), t.std().item()) ``` """ - y = (self - self.detach().mean(axis, keepdim=True)) + y = (self - self.mean(axis, keepdim=True)) return y.mul((y*y).mean(axis, keepdim=True).add(eps).rsqrt()) def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor, axis:Union[int,Tuple[int,...]]=1) -> Tensor: