UnsyncedBatchNorm with synced trainable weights for hlb cifar (#3472)

* UnsyncedBatchNorm with synced trainable weights for hlb cifar

* multitensor reshape tests

* test mlb assign change axis

* E501

* argfix axis

* don't import batchnorm from hlb_cifar in test_multitensor

* pass num_devices to UnsyncedBatchNorm in test, allow UnsyncedBatchNorm to be used with LB

* add backprop test for UnsyncedBatchNorm

* break out MLB assign and reshape changes

* manually shard running mean and running var

* don't shard unless syncbn=0

* replace nn.BatchNorm2d with UnsyncedBatchNorm

* don't increment num_batches_tracked if not tracking running stats

* update tests

* oops

* Revert "oops"

This reverts commit 5e8a67a535.

* Revert "update tests"

This reverts commit 7ebf65d89a.

* Revert "don't increment num_batches_tracked if not tracking running stats"

This reverts commit 78de0ea9ee.

* Revert "replace nn.BatchNorm2d with UnsyncedBatchNorm"

This reverts commit d03da53da7.

* don't increment num_batched_tracked if not tracking running stats

* oops

* test_batchnorm_axis

* compare against torch

* types

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
David Hou
2024-02-29 19:52:07 -08:00
committed by GitHub
parent 5a6e151844
commit e5385eecfc
4 changed files with 102 additions and 38 deletions

View File

@@ -11,7 +11,7 @@ from extra.lr_scheduler import OneCycleLR
from tinygrad import nn, dtypes, Tensor, Device, GlobalCounters, TinyJit
from tinygrad.nn.state import get_state_dict, get_parameters
from tinygrad.nn import optim
from tinygrad.helpers import Context, BEAM, WINO, getenv, colored
from tinygrad.helpers import Context, BEAM, WINO, getenv, colored, prod
from tinygrad.features.multi import MultiLazyBuffer
BS, STEPS = getenv("BS", 512), getenv("STEPS", 1000)
@@ -28,45 +28,62 @@ else:
dtypes.default_float = dtypes.float32
np_dtype = np.float32
class BatchNorm(nn.BatchNorm2d):
class UnsyncedBatchNorm:
def __init__(self, sz:int, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1, num_devices=len(GPUS)):
self.eps, self.track_running_stats, self.momentum = eps, track_running_stats, momentum
self.num_devices = num_devices
if affine: self.weight, self.bias = Tensor.ones(sz), Tensor.zeros(sz)
else: self.weight, self.bias = None, None
self.running_mean, self.running_var = Tensor.zeros(num_devices, sz, requires_grad=False), Tensor.ones(num_devices, sz, requires_grad=False)
self.num_batches_tracked = Tensor.zeros(1, requires_grad=False)
def __call__(self, x:Tensor):
if isinstance(x.lazydata, MultiLazyBuffer): assert x.lazydata.axis is None or x.lazydata.axis == 0 and len(x.lazydata.lbs) == self.num_devices
rshape, x = x.shape, x.reshape(self.num_devices, -1, *x.shape[1:])
batch_mean, batch_invstd = self.calc_stats(x)
ret = x.batchnorm(
self.weight.reshape(1, -1).expand((self.num_devices, -1)),
self.bias.reshape(1, -1).expand((self.num_devices, -1)),
batch_mean, batch_invstd, axis=(0, 2))
return ret.reshape(rshape)
def calc_stats(self, x:Tensor):
if Tensor.training:
# This requires two full memory accesses to x
# https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
# There's "online" algorithms that fix this, like https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
batch_mean = x.mean(axis=(1,3,4))
y = (x - batch_mean.reshape(shape=[batch_mean.shape[0], 1, -1, 1, 1]))
batch_var = (y*y).mean(axis=(1,3,4))
batch_invstd = batch_var.add(self.eps).pow(-0.5)
# NOTE: wow, this is done all throughout training in most PyTorch models
if self.track_running_stats:
self.running_mean.assign((1-self.momentum) * self.running_mean + self.momentum * batch_mean.detach())
self.running_var.assign((1-self.momentum) * self.running_var + self.momentum * prod(y.shape[1:])/(prod(y.shape[1:])-y.shape[2]) * batch_var.detach())
self.num_batches_tracked += 1
else:
batch_mean = self.running_mean
# NOTE: this can be precomputed for static inference. we expand it here so it fuses
batch_invstd = self.running_var.reshape(self.running_var.shape[0], 1, -1, 1, 1).expand(x.shape).add(self.eps).rsqrt()
return batch_mean, batch_invstd
class BatchNorm(nn.BatchNorm2d if getenv("SYNCBN") else UnsyncedBatchNorm):
def __init__(self, num_features):
super().__init__(num_features, track_running_stats=False, eps=1e-12, momentum=0.85, affine=True)
self.weight.requires_grad = False
self.bias.requires_grad = True
class UnsyncedBatchNorm:
def __init__(self, num_features, num_devices=len(GPUS)):
self.bns:List[BatchNorm] = []
for _ in range(num_devices):
bn = BatchNorm(num_features)
self.bns.append(bn)
def __call__(self, x:Tensor):
if len(self.bns) == 1: return self.bns[0](x)
bn_ts = []
assert isinstance(x.lazydata, MultiLazyBuffer)
for bound, bn in zip(x.lazydata.bounds, self.bns):
# TODO: __getitem__ does not work
# xi = x[bound]
xi = x.shrink((bound, None, None, None))
bni = bn(xi)
bn_ts.append(bni)
# TODO: what do we want to do for inference? average weight? pick any one?
# a good start would be to check each mean/std are similar
return bn_ts[0].cat(*bn_ts[1:])
class ConvGroup:
def __init__(self, channels_in, channels_out):
self.conv1 = nn.Conv2d(channels_in, channels_out, kernel_size=3, padding=1, bias=False)
self.conv2 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding=1, bias=False)
if getenv("SYNCBN"):
self.norm1 = BatchNorm(channels_out)
self.norm2 = BatchNorm(channels_out)
else:
self.norm1 = UnsyncedBatchNorm(channels_out)
self.norm2 = UnsyncedBatchNorm(channels_out)
self.norm1 = BatchNorm(channels_out)
self.norm2 = BatchNorm(channels_out)
def __call__(self, x):
x = self.conv1(x)
@@ -288,8 +305,11 @@ def train_cifar():
X_test, Y_test = X_test.cast(dtypes.default_float), Y_test.cast(dtypes.default_float)
if len(GPUS) > 1:
for x in get_parameters(model):
x.to_(GPUS)
for k, x in get_state_dict(model).items():
if not getenv('SYNCBN') and ('running_mean' in k or 'running_bias' in k):
x.shard_(GPUS, axis=0)
else:
x.to_(GPUS)
# parse the training params into bias and non-bias
params_dict = get_state_dict(model)

View File

@@ -541,6 +541,30 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
out.mean().backward()
optim.step()
def test_unsynced_backprop_sync_weights(self):
from extra.lr_scheduler import OneCycleLR
from examples.hlb_cifar10 import UnsyncedBatchNorm
from tinygrad.features.multi import MultiLazyBuffer
GPUS = (d1, d2)
with Tensor.train():
conv = nn.Conv2d(3, 16, 3)
bn = UnsyncedBatchNorm(16, num_devices=len(GPUS))
for p in get_parameters([conv, bn]):
if not isinstance(p.lazydata, MultiLazyBuffer):
p.shard_(GPUS)
optim = nn.optim.Adam(get_parameters([conv, bn]))
lr_sched = OneCycleLR(optim, max_lr=0.1, pct_start=0.1, div_factor=100, final_div_factor=0.1, total_steps=10)
lr_sched.step()
fake_image = Tensor.rand((8, 3, 32, 32)).shard(GPUS, axis=0)
out = bn(conv(fake_image))
optim.zero_grad()
out.mean().backward()
optim.step()
@given(strat.sampled_from((False, True)))
def test_batchnorm(self, is_training):
devices = [f"{Device.DEFAULT}:{i}" for i in range(4)]
@@ -564,13 +588,14 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
bn_ts[0].cat(*bn_ts[1:]).numpy()
def test_synced_vs_unsynced_bn(self):
from examples.hlb_cifar10 import BatchNorm, UnsyncedBatchNorm
from examples.hlb_cifar10 import UnsyncedBatchNorm
from tinygrad.nn import BatchNorm2d
devices = [f"{Device.DEFAULT}:{i}" for i in range(4)]
x = Tensor.ones(8, 8, 8, 8).contiguous().realize().shard(devices, axis=0)
with Tensor.train():
synced_bn = BatchNorm(8)
unsynced_bn = UnsyncedBatchNorm(8)
synced_bn = BatchNorm2d(8)
unsynced_bn = UnsyncedBatchNorm(8, num_devices=len(devices))
for p in get_parameters([synced_bn, unsynced_bn]):
p.shard_(devices)

View File

@@ -62,6 +62,24 @@ class TestNN(unittest.TestCase):
def test_batchnorm2d_training(self):
self.test_batchnorm2d(True)
def test_batchnorm_axis(self):
sz = (2, 4, 3, 2, 2)
x = Tensor.randn(sz)
weight = Tensor.randn(2, 3)
bias = Tensor.randn(2, 3)
mean = Tensor.randn(2, 3)
invstd = Tensor.randn(2, 3)
a = (x.batchnorm(weight, bias, mean, invstd, axis=(0, 2))
.permute(1, 0, 2, 3, 4).reshape(4, 6, 2, 2))
b = (x.permute(1, 0, 2, 3, 4).reshape(4, 6, 2, 2)
.batchnorm(weight.flatten(), bias.flatten(), mean.flatten(), invstd.flatten()))
t_x = torch.tensor(x.permute(1, 0, 2, 3, 4).reshape(4, 6, 2, 2).numpy())
t_weight, t_bias = torch.tensor(weight.flatten().numpy()), torch.tensor(bias.flatten().numpy())
t_mean, t_invstd = torch.tensor(mean.flatten().numpy()), torch.tensor(invstd.flatten().numpy())
torch.nn.functional.batch_norm(t_x, t_mean, 1.0 / t_invstd**2, t_weight, t_bias)
np.testing.assert_allclose(a.numpy(), b.numpy())
def test_linear(self):
def _test_linear(x, in_dim, out_dim):
# create in tinygrad

View File

@@ -927,11 +927,12 @@ class Tensor:
y = (self - self.mean(axis, keepdim=True))
return y.mul((y*y).mean(axis, keepdim=True).add(eps).rsqrt())
def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor) -> Tensor:
shape = (1, -1) + (1,) * (self.ndim-2)
def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor, axis:Union[int,Tuple[int,...]]=1) -> Tensor:
axis_ = argfix(axis)
shape = tuple(s if ax in axis_ else 1 for ax, s in enumerate(self.shape))
x = self - mean.reshape(shape)
if weight: x = x * weight.reshape(shape)
ret = x.mul(invstd.reshape(shape) if len(invstd.shape) == 1 else invstd)
ret = x.mul(invstd.reshape(shape) if len(invstd.shape) == len(axis_) else invstd)
return (ret + bias.reshape(shape)) if bias else ret
def dropout(self, p=0.5) -> Tensor: