mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
revert the .detach() in layernorm (#4904)
* revert the .detach() in layernorm it's only correct in LayerNorm where input is the data, and not correct in GroupNorm and InstanceNorm that reused layernorm. Added backward tests for weights, bias and input for these norms. * bigger atol for llvm * relax backward more
This commit is contained in:
152
test/test_nn.py
152
test/test_nn.py
@@ -5,7 +5,8 @@ import torch
|
||||
from tinygrad import Tensor, Device, TinyJit
|
||||
from tinygrad.helpers import CI, Context
|
||||
from tinygrad.ops import BufferOps
|
||||
from tinygrad.nn import BatchNorm2d, Conv1d,ConvTranspose1d, Conv2d,ConvTranspose2d, Linear, GroupNorm, LayerNorm,LayerNorm2d, Embedding, InstanceNorm
|
||||
from tinygrad.nn import Conv1d, ConvTranspose1d, Conv2d, ConvTranspose2d, Linear, Embedding
|
||||
from tinygrad.nn import BatchNorm2d, LayerNorm, LayerNorm2d, GroupNorm, InstanceNorm
|
||||
from tinygrad.nn.state import load_state_dict
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
@@ -227,97 +228,132 @@ class TestNN(unittest.TestCase):
|
||||
def test_groupnorm(self):
|
||||
BS, H, W, C, G = 20, 10, 10, 6, 3
|
||||
|
||||
# create in torch
|
||||
torch_layer = torch.nn.GroupNorm(G, C).eval()
|
||||
|
||||
# create in tinygrad
|
||||
layer = GroupNorm(G, C)
|
||||
layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True)
|
||||
layer.bias = Tensor(torch_layer.bias.detach().numpy(), requires_grad=True)
|
||||
|
||||
# create in torch
|
||||
with torch.no_grad():
|
||||
torch_layer = torch.nn.GroupNorm(G, C).eval()
|
||||
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
|
||||
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
|
||||
for _ in range(10):
|
||||
# forward
|
||||
x = Tensor.randn(BS, C, H, W, requires_grad=True)
|
||||
z = layer(x)
|
||||
torch_x = torch.tensor(x.numpy(), requires_grad=True)
|
||||
torch_z = torch_layer(torch_x)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
|
||||
|
||||
# test
|
||||
x = Tensor.randn(BS, C, H, W)
|
||||
z = layer(x)
|
||||
torch_x = torch.tensor(x.numpy())
|
||||
torch_z = torch_layer(torch_x)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
|
||||
# backward
|
||||
z.sum().backward()
|
||||
torch_z.sum().backward(retain_graph=True)
|
||||
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
|
||||
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
|
||||
np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
|
||||
|
||||
def test_layernorm(self):
|
||||
N, C, H, W = 20, 5, 10, 10
|
||||
|
||||
# create in torch
|
||||
torch_layer = torch.nn.LayerNorm([H, W]).eval()
|
||||
|
||||
# create in tinygrad
|
||||
layer = LayerNorm([H, W])
|
||||
layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True)
|
||||
layer.bias = Tensor(torch_layer.bias.detach().numpy(), requires_grad=True)
|
||||
|
||||
# create in torch
|
||||
with torch.no_grad():
|
||||
torch_layer = torch.nn.LayerNorm([H, W]).eval()
|
||||
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
|
||||
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
|
||||
for _ in range(10):
|
||||
# forward
|
||||
x = Tensor.randn(N, C, H, W, requires_grad=True)
|
||||
z = layer(x)
|
||||
torch_x = torch.tensor(x.numpy(), requires_grad=True)
|
||||
torch_z = torch_layer(torch_x)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
|
||||
|
||||
# test
|
||||
x = Tensor.randn(N, C, H, W)
|
||||
z = layer(x)
|
||||
torch_x = torch.tensor(x.numpy())
|
||||
torch_z = torch_layer(torch_x)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
|
||||
# backward
|
||||
z.sum().backward()
|
||||
torch_z.sum().backward(retain_graph=True)
|
||||
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
|
||||
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
|
||||
np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
|
||||
|
||||
def test_layernorm_2d(self):
|
||||
N, C, H, W = 20, 5, 10, 10
|
||||
|
||||
# create in torch
|
||||
torch_layer = torch.nn.LayerNorm([C]).eval()
|
||||
|
||||
# create in tinygrad
|
||||
layer = LayerNorm2d(C)
|
||||
layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True)
|
||||
layer.bias = Tensor(torch_layer.bias.detach().numpy(), requires_grad=True)
|
||||
|
||||
# create in torch
|
||||
with torch.no_grad():
|
||||
torch_layer = torch.nn.LayerNorm([C]).eval()
|
||||
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
|
||||
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
|
||||
for _ in range(10):
|
||||
# forward
|
||||
x = Tensor.randn(N, C, H, W, requires_grad=True)
|
||||
z = layer(x)
|
||||
torch_x = torch.tensor(x.numpy(), requires_grad=True)
|
||||
torch_z = torch_layer(torch_x.permute(0,2,3,1)).permute(0,3,1,2)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
|
||||
|
||||
# test
|
||||
x = Tensor.randn(N, C, H, W)
|
||||
z = layer(x)
|
||||
torch_x = torch.tensor(x.numpy())
|
||||
torch_z = torch_layer(torch_x.permute(0,2,3,1)).permute(0,3,1,2)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
|
||||
# backward
|
||||
z.sum().backward()
|
||||
torch_z.sum().backward(retain_graph=True)
|
||||
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
|
||||
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
|
||||
np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
|
||||
|
||||
def test_instancenorm_2d(self):
|
||||
N, C, H, W = 20, 5, 10, 10
|
||||
N, C, H, W = 20, 10, 10, 10
|
||||
|
||||
# create in torch
|
||||
torch_layer = torch.nn.InstanceNorm2d(C, affine=True).eval()
|
||||
|
||||
# create in tinygrad
|
||||
layer = InstanceNorm(C)
|
||||
layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True)
|
||||
layer.bias = Tensor(torch_layer.bias.detach().numpy(), requires_grad=True)
|
||||
|
||||
# create in torch
|
||||
with torch.no_grad():
|
||||
torch_layer = torch.nn.InstanceNorm2d(C, affine=True).eval()
|
||||
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
|
||||
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
|
||||
for _ in range(10):
|
||||
# forward
|
||||
x = Tensor.randn(N, C, H, W, requires_grad=True)
|
||||
z = layer(x)
|
||||
torch_x = torch.tensor(x.numpy(), requires_grad=True)
|
||||
torch_z = torch_layer(torch_x)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
|
||||
|
||||
# test
|
||||
x = Tensor.randn(N, C, H, W)
|
||||
z = layer(x)
|
||||
torch_x = torch.tensor(x.numpy())
|
||||
torch_z = torch_layer(torch_x)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
|
||||
# backward
|
||||
z.sum().backward()
|
||||
torch_z.sum().backward(retain_graph=True)
|
||||
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
|
||||
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
|
||||
np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
|
||||
|
||||
def test_instancenorm_3d(self):
|
||||
N, C, D, H, W = 20, 5, 3, 10, 10
|
||||
N, C, D, H, W = 20, 10, 10, 10, 10
|
||||
|
||||
# create in torch
|
||||
torch_layer = torch.nn.InstanceNorm3d(C, affine=True).eval()
|
||||
|
||||
# create in tinygrad
|
||||
layer = InstanceNorm(C)
|
||||
layer.weight = Tensor(torch_layer.weight.detach().numpy(), requires_grad=True)
|
||||
layer.bias = Tensor(torch_layer.bias.detach().numpy(), requires_grad=True)
|
||||
|
||||
# create in torch
|
||||
with torch.no_grad():
|
||||
torch_layer = torch.nn.InstanceNorm3d(C, affine=True).eval()
|
||||
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
|
||||
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
|
||||
for _ in range(10):
|
||||
# forward
|
||||
x = Tensor.randn(N, C, D, H, W, requires_grad=True)
|
||||
z = layer(x)
|
||||
torch_x = torch.tensor(x.numpy(), requires_grad=True)
|
||||
torch_z = torch_layer(torch_x)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
|
||||
|
||||
# test
|
||||
x = Tensor.randn(N, C, D, H, W)
|
||||
z = layer(x)
|
||||
torch_x = torch.tensor(x.numpy())
|
||||
torch_z = torch_layer(torch_x)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
|
||||
# backward
|
||||
z.sum().backward()
|
||||
torch_z.sum().backward(retain_graph=True)
|
||||
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
|
||||
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
|
||||
np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
|
||||
|
||||
def test_embedding(self):
|
||||
B, T, embed_size, vocab_size = 4, 10, 20, 28
|
||||
|
||||
@@ -2600,7 +2600,7 @@ class Tensor:
|
||||
print(t.mean().item(), t.std().item())
|
||||
```
|
||||
"""
|
||||
y = (self - self.detach().mean(axis, keepdim=True))
|
||||
y = (self - self.mean(axis, keepdim=True))
|
||||
return y.mul((y*y).mean(axis, keepdim=True).add(eps).rsqrt())
|
||||
|
||||
def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor, axis:Union[int,Tuple[int,...]]=1) -> Tensor:
|
||||
|
||||
Reference in New Issue
Block a user