diff --git a/examples/hlb_cifar10.py b/examples/hlb_cifar10.py index 1ee27bc6d4..3350e9d716 100644 --- a/examples/hlb_cifar10.py +++ b/examples/hlb_cifar10.py @@ -11,7 +11,7 @@ from extra.lr_scheduler import OneCycleLR from tinygrad import nn, dtypes, Tensor, Device, GlobalCounters, TinyJit from tinygrad.nn.state import get_state_dict, get_parameters from tinygrad.nn import optim -from tinygrad.helpers import Context, BEAM, WINO, getenv, colored +from tinygrad.helpers import Context, BEAM, WINO, getenv, colored, prod from tinygrad.features.multi import MultiLazyBuffer BS, STEPS = getenv("BS", 512), getenv("STEPS", 1000) @@ -28,45 +28,62 @@ else: dtypes.default_float = dtypes.float32 np_dtype = np.float32 -class BatchNorm(nn.BatchNorm2d): +class UnsyncedBatchNorm: + def __init__(self, sz:int, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1, num_devices=len(GPUS)): + self.eps, self.track_running_stats, self.momentum = eps, track_running_stats, momentum + self.num_devices = num_devices + + if affine: self.weight, self.bias = Tensor.ones(sz), Tensor.zeros(sz) + else: self.weight, self.bias = None, None + + self.running_mean, self.running_var = Tensor.zeros(num_devices, sz, requires_grad=False), Tensor.ones(num_devices, sz, requires_grad=False) + self.num_batches_tracked = Tensor.zeros(1, requires_grad=False) + + def __call__(self, x:Tensor): + if isinstance(x.lazydata, MultiLazyBuffer): assert x.lazydata.axis is None or x.lazydata.axis == 0 and len(x.lazydata.lbs) == self.num_devices + + rshape, x = x.shape, x.reshape(self.num_devices, -1, *x.shape[1:]) + batch_mean, batch_invstd = self.calc_stats(x) + ret = x.batchnorm( + self.weight.reshape(1, -1).expand((self.num_devices, -1)), + self.bias.reshape(1, -1).expand((self.num_devices, -1)), + batch_mean, batch_invstd, axis=(0, 2)) + return ret.reshape(rshape) + + def calc_stats(self, x:Tensor): + if Tensor.training: + # This requires two full memory accesses to x + # https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh + # There's "online" algorithms that fix this, like https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm + batch_mean = x.mean(axis=(1,3,4)) + y = (x - batch_mean.reshape(shape=[batch_mean.shape[0], 1, -1, 1, 1])) + batch_var = (y*y).mean(axis=(1,3,4)) + batch_invstd = batch_var.add(self.eps).pow(-0.5) + + # NOTE: wow, this is done all throughout training in most PyTorch models + if self.track_running_stats: + self.running_mean.assign((1-self.momentum) * self.running_mean + self.momentum * batch_mean.detach()) + self.running_var.assign((1-self.momentum) * self.running_var + self.momentum * prod(y.shape[1:])/(prod(y.shape[1:])-y.shape[2]) * batch_var.detach()) + self.num_batches_tracked += 1 + else: + batch_mean = self.running_mean + # NOTE: this can be precomputed for static inference. we expand it here so it fuses + batch_invstd = self.running_var.reshape(self.running_var.shape[0], 1, -1, 1, 1).expand(x.shape).add(self.eps).rsqrt() + return batch_mean, batch_invstd + +class BatchNorm(nn.BatchNorm2d if getenv("SYNCBN") else UnsyncedBatchNorm): def __init__(self, num_features): super().__init__(num_features, track_running_stats=False, eps=1e-12, momentum=0.85, affine=True) self.weight.requires_grad = False self.bias.requires_grad = True -class UnsyncedBatchNorm: - def __init__(self, num_features, num_devices=len(GPUS)): - self.bns:List[BatchNorm] = [] - for _ in range(num_devices): - bn = BatchNorm(num_features) - self.bns.append(bn) - - def __call__(self, x:Tensor): - if len(self.bns) == 1: return self.bns[0](x) - - bn_ts = [] - assert isinstance(x.lazydata, MultiLazyBuffer) - for bound, bn in zip(x.lazydata.bounds, self.bns): - # TODO: __getitem__ does not work - # xi = x[bound] - xi = x.shrink((bound, None, None, None)) - bni = bn(xi) - bn_ts.append(bni) - # TODO: what do we want to do for inference? average weight? pick any one? - # a good start would be to check each mean/std are similar - return bn_ts[0].cat(*bn_ts[1:]) - class ConvGroup: def __init__(self, channels_in, channels_out): self.conv1 = nn.Conv2d(channels_in, channels_out, kernel_size=3, padding=1, bias=False) self.conv2 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding=1, bias=False) - if getenv("SYNCBN"): - self.norm1 = BatchNorm(channels_out) - self.norm2 = BatchNorm(channels_out) - else: - self.norm1 = UnsyncedBatchNorm(channels_out) - self.norm2 = UnsyncedBatchNorm(channels_out) + self.norm1 = BatchNorm(channels_out) + self.norm2 = BatchNorm(channels_out) def __call__(self, x): x = self.conv1(x) @@ -288,8 +305,11 @@ def train_cifar(): X_test, Y_test = X_test.cast(dtypes.default_float), Y_test.cast(dtypes.default_float) if len(GPUS) > 1: - for x in get_parameters(model): - x.to_(GPUS) + for k, x in get_state_dict(model).items(): + if not getenv('SYNCBN') and ('running_mean' in k or 'running_bias' in k): + x.shard_(GPUS, axis=0) + else: + x.to_(GPUS) # parse the training params into bias and non-bias params_dict = get_state_dict(model) diff --git a/test/test_multitensor.py b/test/test_multitensor.py index a0133ecc32..6262f28fa2 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -541,6 +541,30 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase): out.mean().backward() optim.step() + def test_unsynced_backprop_sync_weights(self): + from extra.lr_scheduler import OneCycleLR + from examples.hlb_cifar10 import UnsyncedBatchNorm + from tinygrad.features.multi import MultiLazyBuffer + GPUS = (d1, d2) + + with Tensor.train(): + conv = nn.Conv2d(3, 16, 3) + bn = UnsyncedBatchNorm(16, num_devices=len(GPUS)) + + for p in get_parameters([conv, bn]): + if not isinstance(p.lazydata, MultiLazyBuffer): + p.shard_(GPUS) + optim = nn.optim.Adam(get_parameters([conv, bn])) + lr_sched = OneCycleLR(optim, max_lr=0.1, pct_start=0.1, div_factor=100, final_div_factor=0.1, total_steps=10) + lr_sched.step() + + fake_image = Tensor.rand((8, 3, 32, 32)).shard(GPUS, axis=0) + + out = bn(conv(fake_image)) + optim.zero_grad() + out.mean().backward() + optim.step() + @given(strat.sampled_from((False, True))) def test_batchnorm(self, is_training): devices = [f"{Device.DEFAULT}:{i}" for i in range(4)] @@ -564,13 +588,14 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase): bn_ts[0].cat(*bn_ts[1:]).numpy() def test_synced_vs_unsynced_bn(self): - from examples.hlb_cifar10 import BatchNorm, UnsyncedBatchNorm + from examples.hlb_cifar10 import UnsyncedBatchNorm + from tinygrad.nn import BatchNorm2d devices = [f"{Device.DEFAULT}:{i}" for i in range(4)] x = Tensor.ones(8, 8, 8, 8).contiguous().realize().shard(devices, axis=0) with Tensor.train(): - synced_bn = BatchNorm(8) - unsynced_bn = UnsyncedBatchNorm(8) + synced_bn = BatchNorm2d(8) + unsynced_bn = UnsyncedBatchNorm(8, num_devices=len(devices)) for p in get_parameters([synced_bn, unsynced_bn]): p.shard_(devices) diff --git a/test/test_nn.py b/test/test_nn.py index b2605d5193..215e34cfea 100755 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -62,6 +62,24 @@ class TestNN(unittest.TestCase): def test_batchnorm2d_training(self): self.test_batchnorm2d(True) + def test_batchnorm_axis(self): + sz = (2, 4, 3, 2, 2) + x = Tensor.randn(sz) + weight = Tensor.randn(2, 3) + bias = Tensor.randn(2, 3) + mean = Tensor.randn(2, 3) + invstd = Tensor.randn(2, 3) + a = (x.batchnorm(weight, bias, mean, invstd, axis=(0, 2)) + .permute(1, 0, 2, 3, 4).reshape(4, 6, 2, 2)) + b = (x.permute(1, 0, 2, 3, 4).reshape(4, 6, 2, 2) + .batchnorm(weight.flatten(), bias.flatten(), mean.flatten(), invstd.flatten())) + t_x = torch.tensor(x.permute(1, 0, 2, 3, 4).reshape(4, 6, 2, 2).numpy()) + t_weight, t_bias = torch.tensor(weight.flatten().numpy()), torch.tensor(bias.flatten().numpy()) + t_mean, t_invstd = torch.tensor(mean.flatten().numpy()), torch.tensor(invstd.flatten().numpy()) + torch.nn.functional.batch_norm(t_x, t_mean, 1.0 / t_invstd**2, t_weight, t_bias) + + np.testing.assert_allclose(a.numpy(), b.numpy()) + def test_linear(self): def _test_linear(x, in_dim, out_dim): # create in tinygrad diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 51948cd8c1..e643b52917 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -927,11 +927,12 @@ class Tensor: y = (self - self.mean(axis, keepdim=True)) return y.mul((y*y).mean(axis, keepdim=True).add(eps).rsqrt()) - def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor) -> Tensor: - shape = (1, -1) + (1,) * (self.ndim-2) + def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor, axis:Union[int,Tuple[int,...]]=1) -> Tensor: + axis_ = argfix(axis) + shape = tuple(s if ax in axis_ else 1 for ax, s in enumerate(self.shape)) x = self - mean.reshape(shape) if weight: x = x * weight.reshape(shape) - ret = x.mul(invstd.reshape(shape) if len(invstd.shape) == 1 else invstd) + ret = x.mul(invstd.reshape(shape) if len(invstd.shape) == len(axis_) else invstd) return (ret + bias.reshape(shape)) if bias else ret def dropout(self, p=0.5) -> Tensor: