From 97b05f567e8e42a2475f8a063fb080b200f6f033 Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 10 Jun 2024 18:02:05 -0400 Subject: [PATCH] revert the .detach() in layernorm (#4904) * revert the .detach() in layernorm it's only correct in LayerNorm where input is the data, and not correct in GroupNorm and InstanceNorm that reused layernorm. Added backward tests for weights, bias and input for these norms. * bigger atol for llvm * relax backward more --- test/test_nn.py | 152 ++++++++++++++++++++++++++++----------------- tinygrad/tensor.py | 2 +- 2 files changed, 95 insertions(+), 59 deletions(-) diff --git a/test/test_nn.py b/test/test_nn.py index 561ef8e93e..6a8d461092 100755 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -5,7 +5,8 @@ import torch from tinygrad import Tensor, Device, TinyJit from tinygrad.helpers import CI, Context from tinygrad.ops import BufferOps -from tinygrad.nn import BatchNorm2d, Conv1d,ConvTranspose1d, Conv2d,ConvTranspose2d, Linear, GroupNorm, LayerNorm,LayerNorm2d, Embedding, InstanceNorm +from tinygrad.nn import Conv1d, ConvTranspose1d, Conv2d, ConvTranspose2d, Linear, Embedding +from tinygrad.nn import BatchNorm2d, LayerNorm, LayerNorm2d, GroupNorm, InstanceNorm from tinygrad.nn.state import load_state_dict from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import run_schedule @@ -227,97 +228,132 @@ class TestNN(unittest.TestCase): def test_groupnorm(self): BS, H, W, C, G = 20, 10, 10, 6, 3 + # create in torch + torch_layer = torch.nn.GroupNorm(G, C).eval() + # create in tinygrad layer = GroupNorm(G, C) + layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True) + layer.bias = Tensor(torch_layer.bias.detach().numpy(), requires_grad=True) - # create in torch - with torch.no_grad(): - torch_layer = torch.nn.GroupNorm(G, C).eval() - torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32) - torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32) + for _ in range(10): + # forward + x = Tensor.randn(BS, C, H, W, requires_grad=True) + z = layer(x) + torch_x = torch.tensor(x.numpy(), requires_grad=True) + torch_z = torch_layer(torch_x) + np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6) - # test - x = Tensor.randn(BS, C, H, W) - z = layer(x) - torch_x = torch.tensor(x.numpy()) - torch_z = torch_layer(torch_x) - np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3) + # backward + z.sum().backward() + torch_z.sum().backward(retain_graph=True) + np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4) + np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=5e-4, rtol=5e-4) + np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=5e-4, rtol=5e-4) def test_layernorm(self): N, C, H, W = 20, 5, 10, 10 + # create in torch + torch_layer = torch.nn.LayerNorm([H, W]).eval() + # create in tinygrad layer = LayerNorm([H, W]) + layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True) + layer.bias = Tensor(torch_layer.bias.detach().numpy(), requires_grad=True) - # create in torch - with torch.no_grad(): - torch_layer = torch.nn.LayerNorm([H, W]).eval() - torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32) - torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32) + for _ in range(10): + # forward + x = Tensor.randn(N, C, H, W, requires_grad=True) + z = layer(x) + torch_x = torch.tensor(x.numpy(), requires_grad=True) + torch_z = torch_layer(torch_x) + np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6) - # test - x = Tensor.randn(N, C, H, W) - z = layer(x) - torch_x = torch.tensor(x.numpy()) - torch_z = torch_layer(torch_x) - np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3) + # backward + z.sum().backward() + torch_z.sum().backward(retain_graph=True) + np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4) + np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=5e-4, rtol=5e-4) + np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=5e-4, rtol=5e-4) def test_layernorm_2d(self): N, C, H, W = 20, 5, 10, 10 + # create in torch + torch_layer = torch.nn.LayerNorm([C]).eval() + # create in tinygrad layer = LayerNorm2d(C) + layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True) + layer.bias = Tensor(torch_layer.bias.detach().numpy(), requires_grad=True) - # create in torch - with torch.no_grad(): - torch_layer = torch.nn.LayerNorm([C]).eval() - torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32) - torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32) + for _ in range(10): + # forward + x = Tensor.randn(N, C, H, W, requires_grad=True) + z = layer(x) + torch_x = torch.tensor(x.numpy(), requires_grad=True) + torch_z = torch_layer(torch_x.permute(0,2,3,1)).permute(0,3,1,2) + np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6) - # test - x = Tensor.randn(N, C, H, W) - z = layer(x) - torch_x = torch.tensor(x.numpy()) - torch_z = torch_layer(torch_x.permute(0,2,3,1)).permute(0,3,1,2) - np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3) + # backward + z.sum().backward() + torch_z.sum().backward(retain_graph=True) + np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4) + np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=5e-4, rtol=5e-4) + np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=5e-4, rtol=5e-4) def test_instancenorm_2d(self): - N, C, H, W = 20, 5, 10, 10 + N, C, H, W = 20, 10, 10, 10 + + # create in torch + torch_layer = torch.nn.InstanceNorm2d(C, affine=True).eval() # create in tinygrad layer = InstanceNorm(C) + layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True) + layer.bias = Tensor(torch_layer.bias.detach().numpy(), requires_grad=True) - # create in torch - with torch.no_grad(): - torch_layer = torch.nn.InstanceNorm2d(C, affine=True).eval() - torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32) - torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32) + for _ in range(10): + # forward + x = Tensor.randn(N, C, H, W, requires_grad=True) + z = layer(x) + torch_x = torch.tensor(x.numpy(), requires_grad=True) + torch_z = torch_layer(torch_x) + np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6) - # test - x = Tensor.randn(N, C, H, W) - z = layer(x) - torch_x = torch.tensor(x.numpy()) - torch_z = torch_layer(torch_x) - np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3) + # backward + z.sum().backward() + torch_z.sum().backward(retain_graph=True) + np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3) + np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=1e-3, rtol=1e-3) + np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=1e-3, rtol=1e-3) def test_instancenorm_3d(self): - N, C, D, H, W = 20, 5, 3, 10, 10 + N, C, D, H, W = 20, 10, 10, 10, 10 + + # create in torch + torch_layer = torch.nn.InstanceNorm3d(C, affine=True).eval() # create in tinygrad layer = InstanceNorm(C) + layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True) + layer.bias = Tensor(torch_layer.bias.detach().numpy(), requires_grad=True) - # create in torch - with torch.no_grad(): - torch_layer = torch.nn.InstanceNorm3d(C, affine=True).eval() - torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32) - torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32) + for _ in range(10): + # forward + x = Tensor.randn(N, C, D, H, W, requires_grad=True) + z = layer(x) + torch_x = torch.tensor(x.numpy(), requires_grad=True) + torch_z = torch_layer(torch_x) + np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6) - # test - x = Tensor.randn(N, C, D, H, W) - z = layer(x) - torch_x = torch.tensor(x.numpy()) - torch_z = torch_layer(torch_x) - np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3) + # backward + z.sum().backward() + torch_z.sum().backward(retain_graph=True) + np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3) + np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=1e-3, rtol=1e-3) + np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=1e-3, rtol=1e-3) def test_embedding(self): B, T, embed_size, vocab_size = 4, 10, 20, 28 diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index d7d557fffe..c48291b8c0 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2600,7 +2600,7 @@ class Tensor: print(t.mean().item(), t.std().item()) ``` """ - y = (self - self.detach().mean(axis, keepdim=True)) + y = (self - self.mean(axis, keepdim=True)) return y.mul((y*y).mean(axis, keepdim=True).add(eps).rsqrt()) def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor, axis:Union[int,Tuple[int,...]]=1) -> Tensor: