revert the .detach() in layernorm (#4904)

* revert the .detach() in layernorm

it's only correct in LayerNorm where input is the data, and not correct in GroupNorm and InstanceNorm that reused layernorm.
Added backward tests for weights, bias and input for these norms.

* bigger atol for llvm

* relax backward more
This commit is contained in:
chenyu
2024-06-10 18:02:05 -04:00
committed by GitHub
parent 8b5bcf309a
commit 97b05f567e
2 changed files with 95 additions and 59 deletions

View File

@@ -5,7 +5,8 @@ import torch
from tinygrad import Tensor, Device, TinyJit
from tinygrad.helpers import CI, Context
from tinygrad.ops import BufferOps
from tinygrad.nn import BatchNorm2d, Conv1d,ConvTranspose1d, Conv2d,ConvTranspose2d, Linear, GroupNorm, LayerNorm,LayerNorm2d, Embedding, InstanceNorm
from tinygrad.nn import Conv1d, ConvTranspose1d, Conv2d, ConvTranspose2d, Linear, Embedding
from tinygrad.nn import BatchNorm2d, LayerNorm, LayerNorm2d, GroupNorm, InstanceNorm
from tinygrad.nn.state import load_state_dict
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import run_schedule
@@ -227,97 +228,132 @@ class TestNN(unittest.TestCase):
def test_groupnorm(self):
BS, H, W, C, G = 20, 10, 10, 6, 3
# create in torch
torch_layer = torch.nn.GroupNorm(G, C).eval()
# create in tinygrad
layer = GroupNorm(G, C)
layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True)
layer.bias = Tensor(torch_layer.bias.detach().numpy(), requires_grad=True)
# create in torch
with torch.no_grad():
torch_layer = torch.nn.GroupNorm(G, C).eval()
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
for _ in range(10):
# forward
x = Tensor.randn(BS, C, H, W, requires_grad=True)
z = layer(x)
torch_x = torch.tensor(x.numpy(), requires_grad=True)
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
# test
x = Tensor.randn(BS, C, H, W)
z = layer(x)
torch_x = torch.tensor(x.numpy())
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
# backward
z.sum().backward()
torch_z.sum().backward(retain_graph=True)
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
def test_layernorm(self):
N, C, H, W = 20, 5, 10, 10
# create in torch
torch_layer = torch.nn.LayerNorm([H, W]).eval()
# create in tinygrad
layer = LayerNorm([H, W])
layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True)
layer.bias = Tensor(torch_layer.bias.detach().numpy(), requires_grad=True)
# create in torch
with torch.no_grad():
torch_layer = torch.nn.LayerNorm([H, W]).eval()
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
for _ in range(10):
# forward
x = Tensor.randn(N, C, H, W, requires_grad=True)
z = layer(x)
torch_x = torch.tensor(x.numpy(), requires_grad=True)
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
# test
x = Tensor.randn(N, C, H, W)
z = layer(x)
torch_x = torch.tensor(x.numpy())
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
# backward
z.sum().backward()
torch_z.sum().backward(retain_graph=True)
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
def test_layernorm_2d(self):
N, C, H, W = 20, 5, 10, 10
# create in torch
torch_layer = torch.nn.LayerNorm([C]).eval()
# create in tinygrad
layer = LayerNorm2d(C)
layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True)
layer.bias = Tensor(torch_layer.bias.detach().numpy(), requires_grad=True)
# create in torch
with torch.no_grad():
torch_layer = torch.nn.LayerNorm([C]).eval()
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
for _ in range(10):
# forward
x = Tensor.randn(N, C, H, W, requires_grad=True)
z = layer(x)
torch_x = torch.tensor(x.numpy(), requires_grad=True)
torch_z = torch_layer(torch_x.permute(0,2,3,1)).permute(0,3,1,2)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
# test
x = Tensor.randn(N, C, H, W)
z = layer(x)
torch_x = torch.tensor(x.numpy())
torch_z = torch_layer(torch_x.permute(0,2,3,1)).permute(0,3,1,2)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
# backward
z.sum().backward()
torch_z.sum().backward(retain_graph=True)
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
def test_instancenorm_2d(self):
N, C, H, W = 20, 5, 10, 10
N, C, H, W = 20, 10, 10, 10
# create in torch
torch_layer = torch.nn.InstanceNorm2d(C, affine=True).eval()
# create in tinygrad
layer = InstanceNorm(C)
layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True)
layer.bias = Tensor(torch_layer.bias.detach().numpy(), requires_grad=True)
# create in torch
with torch.no_grad():
torch_layer = torch.nn.InstanceNorm2d(C, affine=True).eval()
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
for _ in range(10):
# forward
x = Tensor.randn(N, C, H, W, requires_grad=True)
z = layer(x)
torch_x = torch.tensor(x.numpy(), requires_grad=True)
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
# test
x = Tensor.randn(N, C, H, W)
z = layer(x)
torch_x = torch.tensor(x.numpy())
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
# backward
z.sum().backward()
torch_z.sum().backward(retain_graph=True)
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
def test_instancenorm_3d(self):
N, C, D, H, W = 20, 5, 3, 10, 10
N, C, D, H, W = 20, 10, 10, 10, 10
# create in torch
torch_layer = torch.nn.InstanceNorm3d(C, affine=True).eval()
# create in tinygrad
layer = InstanceNorm(C)
layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True)
layer.bias = Tensor(torch_layer.bias.detach().numpy(), requires_grad=True)
# create in torch
with torch.no_grad():
torch_layer = torch.nn.InstanceNorm3d(C, affine=True).eval()
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
for _ in range(10):
# forward
x = Tensor.randn(N, C, D, H, W, requires_grad=True)
z = layer(x)
torch_x = torch.tensor(x.numpy(), requires_grad=True)
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
# test
x = Tensor.randn(N, C, D, H, W)
z = layer(x)
torch_x = torch.tensor(x.numpy())
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
# backward
z.sum().backward()
torch_z.sum().backward(retain_graph=True)
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
def test_embedding(self):
B, T, embed_size, vocab_size = 4, 10, 20, 28

View File

@@ -2600,7 +2600,7 @@ class Tensor:
print(t.mean().item(), t.std().item())
```
"""
y = (self - self.detach().mean(axis, keepdim=True))
y = (self - self.mean(axis, keepdim=True))
return y.mul((y*y).mean(axis, keepdim=True).add(eps).rsqrt())
def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor, axis:Union[int,Tuple[int,...]]=1) -> Tensor: