From 8a928ed2f308d2e779403f45c87cdc047e57951a Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 1 Jun 2023 21:24:11 -0700 Subject: [PATCH] nn init matches torch (#901) --- examples/hlb_cifar10.py | 6 +++++- examples/hlb_cifar10_torch.py | 3 +++ test/test_randomness.py | 13 ++++++------- tinygrad/nn/__init__.py | 20 ++++++++++++-------- tinygrad/tensor.py | 6 +++--- 5 files changed, 29 insertions(+), 19 deletions(-) diff --git a/examples/hlb_cifar10.py b/examples/hlb_cifar10.py index 2fcc0a02a4..7c9c120c06 100644 --- a/examples/hlb_cifar10.py +++ b/examples/hlb_cifar10.py @@ -85,14 +85,18 @@ def train_cifar(): model = SpeedyResNet() # init weights with torch + # TODO: it doesn't learn with the tinygrad weights, likely since kaiming init if getenv("TORCHWEIGHTS"): from examples.hlb_cifar10_torch import SpeedyResNet as SpeedyResNetTorch torch_model = SpeedyResNetTorch() model_state_dict = optim.get_state_dict(model) torch_state_dict = torch_model.state_dict() for k,v in torch_state_dict.items(): - print(f"initting {k} from torch") + old_mean_std = model_state_dict[k].mean().numpy(), model_state_dict[k].std().numpy() model_state_dict[k].assign(Tensor(v.detach().numpy())).realize() + new_mean_std = model_state_dict[k].mean().numpy(), model_state_dict[k].std().numpy() + print(f"initted {k:40s} {str(model_state_dict[k].shape):20s} from torch mean:{old_mean_std[0]:8.5f} -> {new_mean_std[0]:8.5f} std:{old_mean_std[1]:8.5f} -> {new_mean_std[1]:8.5f}") + exit(0) if getenv("ADAM"): optimizer = optim.Adam(optim.get_parameters(model), lr=Tensor([0.001]).realize()) diff --git a/examples/hlb_cifar10_torch.py b/examples/hlb_cifar10_torch.py index 9a2a9afafd..b68f1c0638 100644 --- a/examples/hlb_cifar10_torch.py +++ b/examples/hlb_cifar10_torch.py @@ -8,6 +8,9 @@ from torch import optim from datasets import fetch_cifar from tinygrad.helpers import getenv +# allow TF32 +torch.set_float32_matmul_precision('high') + OSX = platform.system() == "Darwin" device = 'mps' if OSX else 'cuda' diff --git a/test/test_randomness.py b/test/test_randomness.py index d28e789647..203f15577a 100644 --- a/test/test_randomness.py +++ b/test/test_randomness.py @@ -43,14 +43,14 @@ def normal_test(func, shape=(20, 23), alpha=0.05): y = np.random.randn(*shape).flatten() return kstest(x, y) >= alpha -def equal_distribution(tiny_func, torch_func, numpy_func, shape=(20, 23), alpha=0.05): +def equal_distribution(tiny_func, torch_func, numpy_func=None, shape=(20, 23), alpha=0.05): Tensor.manual_seed(1337) torch.manual_seed(1337) np.random.seed(1337) x = tiny_func(*shape).cpu().numpy().flatten() - y = numpy_func(shape).flatten() + if numpy_func is not None: y = numpy_func(shape).flatten() z = torch_func(shape).numpy().flatten() - return kstest(x, y) >= alpha and kstest(x, z) >= alpha + return (numpy_func is None or kstest(x, y) >= alpha) and kstest(x, z) >= alpha class TestRandomness(unittest.TestCase): def test_rand(self): @@ -73,13 +73,12 @@ class TestRandomness(unittest.TestCase): self.assertFalse(normal_test(Tensor.glorot_uniform)) self.assertTrue(equal_distribution(Tensor.glorot_uniform, lambda x: torch.nn.init.xavier_uniform_(torch.empty(x)), lambda x: (np.random.rand(*x) * 2 - 1) * math.sqrt(6 / (x[0] + math.prod(x[1:]))))) - def test_kaiming_uniform(self, shape=(20, 23), a=0.01): + def test_kaiming_uniform(self): Tensor.manual_seed(1337) torch.manual_seed(1337) np.random.seed(1337) - - bound = (math.sqrt(3.0) * (math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(shape[1] * np.prod(shape[2:])))) - self.assertTrue(equal_distribution(Tensor.kaiming_uniform, lambda x: torch.nn.init.kaiming_uniform_(torch.empty(x)), lambda x: np.random.uniform(low=-bound, high=bound, size=shape))) + for shape in [(128, 64, 3, 3), (20, 24)]: + self.assertTrue(equal_distribution(Tensor.kaiming_uniform, lambda x: torch.nn.init.kaiming_uniform_(torch.empty(x)), shape=shape)) if __name__ == "__main__": unittest.main() diff --git a/tinygrad/nn/__init__.py b/tinygrad/nn/__init__.py index b54d7b1fe0..5f1c7336a0 100644 --- a/tinygrad/nn/__init__.py +++ b/tinygrad/nn/__init__.py @@ -1,3 +1,4 @@ +import math from typing import Optional, Union, Tuple from tinygrad.tensor import Tensor from tinygrad.helpers import prod @@ -35,11 +36,12 @@ class BatchNorm2d: return x.batchnorm(self.weight, self.bias, batch_mean, batch_invstd) class Conv2d: - def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, initialization: str='kaiming_uniform'): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size) self.stride, self.padding, self.dilation, self.groups = stride, padding, dilation, groups - self.weight = getattr(Tensor, initialization)(out_channels, in_channels//groups, *self.kernel_size) - self.bias = Tensor.zeros(out_channels) if bias else None + self.weight = Tensor.kaiming_uniform(out_channels, in_channels//groups, *self.kernel_size, a=math.sqrt(5)) + bound = 1 / math.sqrt(prod(self.weight.shape[1:])) + self.bias = Tensor.uniform(out_channels, low=-bound, high=bound) if bias else None def __call__(self, x): return x.conv2d(self.weight, self.bias, padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups) @@ -48,16 +50,18 @@ class ConvTranspose2d: def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True): self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size) self.stride, self.padding, self.output_padding, self.dilation, self.groups = stride, padding, output_padding, dilation, groups - self.weight = Tensor.glorot_uniform(in_channels, out_channels//groups, *self.kernel_size) - self.bias = Tensor.zeros(out_channels) if bias else None + self.weight = Tensor.kaiming_uniform(in_channels, out_channels//groups, *self.kernel_size, a=math.sqrt(5)) + bound = 1 / math.sqrt(prod(self.weight.shape[1:])) + self.bias = Tensor.uniform(out_channels, low=-bound, high=bound) if bias else None def __call__(self, x): return x.conv_transpose2d(self.weight, self.bias, padding=self.padding, output_padding=self.output_padding, stride=self.stride, dilation=self.dilation, groups=self.groups) class Linear: - def __init__(self, in_features, out_features, bias=True, initialization: str='kaiming_uniform'): - self.weight = getattr(Tensor, initialization)(out_features, in_features) - self.bias = Tensor.zeros(out_features) if bias else None + def __init__(self, in_features, out_features, bias=True): + self.weight = Tensor.kaiming_uniform(out_features, in_features, a=math.sqrt(5)) + bound = 1 / math.sqrt(self.weight.shape[1]) + self.bias = Tensor.uniform(out_features, low=-bound, high=bound) if bias else None def __call__(self, x): return x.linear(self.weight.transpose(), self.bias) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index ac5cd2de92..03caa132fa 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -185,7 +185,7 @@ class Tensor: # https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_ @staticmethod def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor: - bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(shape[1] * prod(shape[2:])) + bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(shape[1:])) return Tensor.uniform(*shape, low=-bound, high=bound) # ***** toposort and backward pass ***** @@ -251,7 +251,7 @@ class Tensor: # - Strides > 1 and < 0 are now allowed!: # - This works by applying Shrink -> [[Flip -> ] Pad -> Reshape -> Shrink] -> Reshape (ops in brackets are optional) # - Idea of stride < 0 support: - # - Do the slice first, flip the axes were slice.step is negative, do slice.step -> -slice.step. Go to steps below. + # - Do the slice first, flip the axes were slice.step is negative, do slice.step -> -slice.step. Go to steps below. # - Idea of stride `s` > 1 support (Pad -> Reshape -> Shrink): # - Instead of doing [::s] on axis [dim_sz], do [:, 0] on axes [dim_sz_padded // s, s]. # - So pad dim_sz with as many zeros as needed (dim_sz -> dim_sz_padded) so that reshape to [dim_sz_padded // s, s] @@ -368,7 +368,7 @@ class Tensor: out = self.sum(axis=axis, keepdim=keepdim) return out * (prod(out.shape)/prod(self.shape)) def std(self, axis=None, keepdim=False, correction=1): - square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim) + square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim) return (square_sum / (prod(self.shape)/prod(square_sum.shape)-correction)).sqrt() def _softmax(self, axis): m = self - self.max(axis=axis, keepdim=True)