mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-11 15:35:51 +08:00
LayerNorm2d for 2 lines
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import Conv2d, LayerNorm, Linear
|
||||
from tinygrad.nn import Conv2d, LayerNorm, LayerNorm2d, Linear
|
||||
|
||||
class Block:
|
||||
def __init__(self, dim):
|
||||
@@ -18,8 +18,8 @@ class Block:
|
||||
class ConvNeXt:
|
||||
def __init__(self, in_chans=3, num_classes=1000, depths=[3, 3, 9, 3], dims=[96, 192, 384, 768]):
|
||||
self.downsample_layers = [
|
||||
[Conv2d(in_chans, dims[0], kernel_size=4, stride=4), LayerNorm((dims[0], 1, 1), eps=1e-6)],
|
||||
*[[LayerNorm((dims[i], 1, 1), eps=1e-6), Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2)] for i in range(len(dims)-1)]
|
||||
[Conv2d(in_chans, dims[0], kernel_size=4, stride=4), LayerNorm2d(dims[0], eps=1e-6)],
|
||||
*[[LayerNorm2d(dims[i], eps=1e-6), Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2)] for i in range(len(dims)-1)]
|
||||
]
|
||||
self.stages = [[Block(dims[i]) for _ in range(depths[i])] for i in range(len(dims))]
|
||||
self.norm = LayerNorm(dims[-1])
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
from tinygrad.nn import BatchNorm2d, Conv2d, Linear, GroupNorm, LayerNorm
|
||||
from tinygrad.nn import BatchNorm2d, Conv2d, Linear, GroupNorm, LayerNorm, LayerNorm2d
|
||||
import torch
|
||||
|
||||
class TestNN(unittest.TestCase):
|
||||
@@ -76,7 +76,7 @@ class TestNN(unittest.TestCase):
|
||||
def test_conv2d(self):
|
||||
BS, C1, H, W = 4, 16, 224, 224
|
||||
C2, K, S, P = 64, 7, 2, 1
|
||||
|
||||
|
||||
# create in tinygrad
|
||||
layer = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P)
|
||||
|
||||
@@ -131,5 +131,24 @@ class TestNN(unittest.TestCase):
|
||||
torch_z = torch_layer(torch_x)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
|
||||
|
||||
def test_layernorm_2d(self):
|
||||
N, C, H, W = 20, 5, 10, 10
|
||||
|
||||
# create in tinygrad
|
||||
layer = LayerNorm2d(C)
|
||||
|
||||
# create in torch
|
||||
with torch.no_grad():
|
||||
torch_layer = torch.nn.LayerNorm([C]).eval()
|
||||
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
|
||||
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
|
||||
|
||||
# test
|
||||
x = Tensor.randn(N, C, H, W)
|
||||
z = layer(x)
|
||||
torch_x = torch.tensor(x.cpu().numpy())
|
||||
torch_z = torch_layer(torch_x.permute(0,2,3,1)).permute(0,3,1,2)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -73,11 +73,15 @@ class GroupNorm:
|
||||
|
||||
class LayerNorm:
|
||||
def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps:float=1e-5, elementwise_affine:bool=True):
|
||||
normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
|
||||
self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(normalized_shape))), eps, elementwise_affine
|
||||
self.weight, self.bias = (Tensor.ones(*normalized_shape), Tensor.zeros(*normalized_shape)) if elementwise_affine else (None, None)
|
||||
self.normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
|
||||
self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(self.normalized_shape))), eps, elementwise_affine
|
||||
self.weight, self.bias = (Tensor.ones(*self.normalized_shape), Tensor.zeros(*self.normalized_shape)) if elementwise_affine else (None, None)
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
assert self.normalized_shape == x.shape[-len(self.normalized_shape):], f"last dimensions of {x.shape} must match {self.normalized_shape}"
|
||||
x = x.layernorm(eps=self.eps, axis=self.axis)
|
||||
if not self.elementwise_affine: return x
|
||||
return x * self.weight + self.bias
|
||||
|
||||
class LayerNorm2d(LayerNorm):
|
||||
def __call__(self, x): return super().__call__(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
|
||||
Reference in New Issue
Block a user