LayerNorm2d for 2 lines

This commit is contained in:
George Hotz
2023-03-20 16:58:43 -07:00
parent 128ca160ac
commit d6f4219952
3 changed files with 31 additions and 8 deletions

View File

@@ -1,5 +1,5 @@
from tinygrad.tensor import Tensor
from tinygrad.nn import Conv2d, LayerNorm, Linear
from tinygrad.nn import Conv2d, LayerNorm, LayerNorm2d, Linear
class Block:
def __init__(self, dim):
@@ -18,8 +18,8 @@ class Block:
class ConvNeXt:
def __init__(self, in_chans=3, num_classes=1000, depths=[3, 3, 9, 3], dims=[96, 192, 384, 768]):
self.downsample_layers = [
[Conv2d(in_chans, dims[0], kernel_size=4, stride=4), LayerNorm((dims[0], 1, 1), eps=1e-6)],
*[[LayerNorm((dims[i], 1, 1), eps=1e-6), Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2)] for i in range(len(dims)-1)]
[Conv2d(in_chans, dims[0], kernel_size=4, stride=4), LayerNorm2d(dims[0], eps=1e-6)],
*[[LayerNorm2d(dims[i], eps=1e-6), Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2)] for i in range(len(dims)-1)]
]
self.stages = [[Block(dims[i]) for _ in range(depths[i])] for i in range(len(dims))]
self.norm = LayerNorm(dims[-1])

View File

@@ -2,7 +2,7 @@
import unittest
import numpy as np
from tinygrad.tensor import Tensor, Device
from tinygrad.nn import BatchNorm2d, Conv2d, Linear, GroupNorm, LayerNorm
from tinygrad.nn import BatchNorm2d, Conv2d, Linear, GroupNorm, LayerNorm, LayerNorm2d
import torch
class TestNN(unittest.TestCase):
@@ -76,7 +76,7 @@ class TestNN(unittest.TestCase):
def test_conv2d(self):
BS, C1, H, W = 4, 16, 224, 224
C2, K, S, P = 64, 7, 2, 1
# create in tinygrad
layer = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P)
@@ -131,5 +131,24 @@ class TestNN(unittest.TestCase):
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
def test_layernorm_2d(self):
N, C, H, W = 20, 5, 10, 10
# create in tinygrad
layer = LayerNorm2d(C)
# create in torch
with torch.no_grad():
torch_layer = torch.nn.LayerNorm([C]).eval()
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
# test
x = Tensor.randn(N, C, H, W)
z = layer(x)
torch_x = torch.tensor(x.cpu().numpy())
torch_z = torch_layer(torch_x.permute(0,2,3,1)).permute(0,3,1,2)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
if __name__ == '__main__':
unittest.main()

View File

@@ -73,11 +73,15 @@ class GroupNorm:
class LayerNorm:
def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps:float=1e-5, elementwise_affine:bool=True):
normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(normalized_shape))), eps, elementwise_affine
self.weight, self.bias = (Tensor.ones(*normalized_shape), Tensor.zeros(*normalized_shape)) if elementwise_affine else (None, None)
self.normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(self.normalized_shape))), eps, elementwise_affine
self.weight, self.bias = (Tensor.ones(*self.normalized_shape), Tensor.zeros(*self.normalized_shape)) if elementwise_affine else (None, None)
def __call__(self, x:Tensor):
assert self.normalized_shape == x.shape[-len(self.normalized_shape):], f"last dimensions of {x.shape} must match {self.normalized_shape}"
x = x.layernorm(eps=self.eps, axis=self.axis)
if not self.elementwise_affine: return x
return x * self.weight + self.bias
class LayerNorm2d(LayerNorm):
def __call__(self, x): return super().__call__(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)