mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
UnsyncedBatchNorm with synced trainable weights for hlb cifar (#3472)
* UnsyncedBatchNorm with synced trainable weights for hlb cifar * multitensor reshape tests * test mlb assign change axis * E501 * argfix axis * don't import batchnorm from hlb_cifar in test_multitensor * pass num_devices to UnsyncedBatchNorm in test, allow UnsyncedBatchNorm to be used with LB * add backprop test for UnsyncedBatchNorm * break out MLB assign and reshape changes * manually shard running mean and running var * don't shard unless syncbn=0 * replace nn.BatchNorm2d with UnsyncedBatchNorm * don't increment num_batches_tracked if not tracking running stats * update tests * oops * Revert "oops" This reverts commit5e8a67a535. * Revert "update tests" This reverts commit7ebf65d89a. * Revert "don't increment num_batches_tracked if not tracking running stats" This reverts commit78de0ea9ee. * Revert "replace nn.BatchNorm2d with UnsyncedBatchNorm" This reverts commitd03da53da7. * don't increment num_batched_tracked if not tracking running stats * oops * test_batchnorm_axis * compare against torch * types --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -11,7 +11,7 @@ from extra.lr_scheduler import OneCycleLR
|
||||
from tinygrad import nn, dtypes, Tensor, Device, GlobalCounters, TinyJit
|
||||
from tinygrad.nn.state import get_state_dict, get_parameters
|
||||
from tinygrad.nn import optim
|
||||
from tinygrad.helpers import Context, BEAM, WINO, getenv, colored
|
||||
from tinygrad.helpers import Context, BEAM, WINO, getenv, colored, prod
|
||||
from tinygrad.features.multi import MultiLazyBuffer
|
||||
|
||||
BS, STEPS = getenv("BS", 512), getenv("STEPS", 1000)
|
||||
@@ -28,45 +28,62 @@ else:
|
||||
dtypes.default_float = dtypes.float32
|
||||
np_dtype = np.float32
|
||||
|
||||
class BatchNorm(nn.BatchNorm2d):
|
||||
class UnsyncedBatchNorm:
|
||||
def __init__(self, sz:int, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1, num_devices=len(GPUS)):
|
||||
self.eps, self.track_running_stats, self.momentum = eps, track_running_stats, momentum
|
||||
self.num_devices = num_devices
|
||||
|
||||
if affine: self.weight, self.bias = Tensor.ones(sz), Tensor.zeros(sz)
|
||||
else: self.weight, self.bias = None, None
|
||||
|
||||
self.running_mean, self.running_var = Tensor.zeros(num_devices, sz, requires_grad=False), Tensor.ones(num_devices, sz, requires_grad=False)
|
||||
self.num_batches_tracked = Tensor.zeros(1, requires_grad=False)
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
if isinstance(x.lazydata, MultiLazyBuffer): assert x.lazydata.axis is None or x.lazydata.axis == 0 and len(x.lazydata.lbs) == self.num_devices
|
||||
|
||||
rshape, x = x.shape, x.reshape(self.num_devices, -1, *x.shape[1:])
|
||||
batch_mean, batch_invstd = self.calc_stats(x)
|
||||
ret = x.batchnorm(
|
||||
self.weight.reshape(1, -1).expand((self.num_devices, -1)),
|
||||
self.bias.reshape(1, -1).expand((self.num_devices, -1)),
|
||||
batch_mean, batch_invstd, axis=(0, 2))
|
||||
return ret.reshape(rshape)
|
||||
|
||||
def calc_stats(self, x:Tensor):
|
||||
if Tensor.training:
|
||||
# This requires two full memory accesses to x
|
||||
# https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
|
||||
# There's "online" algorithms that fix this, like https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
|
||||
batch_mean = x.mean(axis=(1,3,4))
|
||||
y = (x - batch_mean.reshape(shape=[batch_mean.shape[0], 1, -1, 1, 1]))
|
||||
batch_var = (y*y).mean(axis=(1,3,4))
|
||||
batch_invstd = batch_var.add(self.eps).pow(-0.5)
|
||||
|
||||
# NOTE: wow, this is done all throughout training in most PyTorch models
|
||||
if self.track_running_stats:
|
||||
self.running_mean.assign((1-self.momentum) * self.running_mean + self.momentum * batch_mean.detach())
|
||||
self.running_var.assign((1-self.momentum) * self.running_var + self.momentum * prod(y.shape[1:])/(prod(y.shape[1:])-y.shape[2]) * batch_var.detach())
|
||||
self.num_batches_tracked += 1
|
||||
else:
|
||||
batch_mean = self.running_mean
|
||||
# NOTE: this can be precomputed for static inference. we expand it here so it fuses
|
||||
batch_invstd = self.running_var.reshape(self.running_var.shape[0], 1, -1, 1, 1).expand(x.shape).add(self.eps).rsqrt()
|
||||
return batch_mean, batch_invstd
|
||||
|
||||
class BatchNorm(nn.BatchNorm2d if getenv("SYNCBN") else UnsyncedBatchNorm):
|
||||
def __init__(self, num_features):
|
||||
super().__init__(num_features, track_running_stats=False, eps=1e-12, momentum=0.85, affine=True)
|
||||
self.weight.requires_grad = False
|
||||
self.bias.requires_grad = True
|
||||
|
||||
class UnsyncedBatchNorm:
|
||||
def __init__(self, num_features, num_devices=len(GPUS)):
|
||||
self.bns:List[BatchNorm] = []
|
||||
for _ in range(num_devices):
|
||||
bn = BatchNorm(num_features)
|
||||
self.bns.append(bn)
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
if len(self.bns) == 1: return self.bns[0](x)
|
||||
|
||||
bn_ts = []
|
||||
assert isinstance(x.lazydata, MultiLazyBuffer)
|
||||
for bound, bn in zip(x.lazydata.bounds, self.bns):
|
||||
# TODO: __getitem__ does not work
|
||||
# xi = x[bound]
|
||||
xi = x.shrink((bound, None, None, None))
|
||||
bni = bn(xi)
|
||||
bn_ts.append(bni)
|
||||
# TODO: what do we want to do for inference? average weight? pick any one?
|
||||
# a good start would be to check each mean/std are similar
|
||||
return bn_ts[0].cat(*bn_ts[1:])
|
||||
|
||||
class ConvGroup:
|
||||
def __init__(self, channels_in, channels_out):
|
||||
self.conv1 = nn.Conv2d(channels_in, channels_out, kernel_size=3, padding=1, bias=False)
|
||||
self.conv2 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding=1, bias=False)
|
||||
|
||||
if getenv("SYNCBN"):
|
||||
self.norm1 = BatchNorm(channels_out)
|
||||
self.norm2 = BatchNorm(channels_out)
|
||||
else:
|
||||
self.norm1 = UnsyncedBatchNorm(channels_out)
|
||||
self.norm2 = UnsyncedBatchNorm(channels_out)
|
||||
self.norm1 = BatchNorm(channels_out)
|
||||
self.norm2 = BatchNorm(channels_out)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.conv1(x)
|
||||
@@ -288,8 +305,11 @@ def train_cifar():
|
||||
X_test, Y_test = X_test.cast(dtypes.default_float), Y_test.cast(dtypes.default_float)
|
||||
|
||||
if len(GPUS) > 1:
|
||||
for x in get_parameters(model):
|
||||
x.to_(GPUS)
|
||||
for k, x in get_state_dict(model).items():
|
||||
if not getenv('SYNCBN') and ('running_mean' in k or 'running_bias' in k):
|
||||
x.shard_(GPUS, axis=0)
|
||||
else:
|
||||
x.to_(GPUS)
|
||||
|
||||
# parse the training params into bias and non-bias
|
||||
params_dict = get_state_dict(model)
|
||||
|
||||
@@ -541,6 +541,30 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
|
||||
out.mean().backward()
|
||||
optim.step()
|
||||
|
||||
def test_unsynced_backprop_sync_weights(self):
|
||||
from extra.lr_scheduler import OneCycleLR
|
||||
from examples.hlb_cifar10 import UnsyncedBatchNorm
|
||||
from tinygrad.features.multi import MultiLazyBuffer
|
||||
GPUS = (d1, d2)
|
||||
|
||||
with Tensor.train():
|
||||
conv = nn.Conv2d(3, 16, 3)
|
||||
bn = UnsyncedBatchNorm(16, num_devices=len(GPUS))
|
||||
|
||||
for p in get_parameters([conv, bn]):
|
||||
if not isinstance(p.lazydata, MultiLazyBuffer):
|
||||
p.shard_(GPUS)
|
||||
optim = nn.optim.Adam(get_parameters([conv, bn]))
|
||||
lr_sched = OneCycleLR(optim, max_lr=0.1, pct_start=0.1, div_factor=100, final_div_factor=0.1, total_steps=10)
|
||||
lr_sched.step()
|
||||
|
||||
fake_image = Tensor.rand((8, 3, 32, 32)).shard(GPUS, axis=0)
|
||||
|
||||
out = bn(conv(fake_image))
|
||||
optim.zero_grad()
|
||||
out.mean().backward()
|
||||
optim.step()
|
||||
|
||||
@given(strat.sampled_from((False, True)))
|
||||
def test_batchnorm(self, is_training):
|
||||
devices = [f"{Device.DEFAULT}:{i}" for i in range(4)]
|
||||
@@ -564,13 +588,14 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
|
||||
bn_ts[0].cat(*bn_ts[1:]).numpy()
|
||||
|
||||
def test_synced_vs_unsynced_bn(self):
|
||||
from examples.hlb_cifar10 import BatchNorm, UnsyncedBatchNorm
|
||||
from examples.hlb_cifar10 import UnsyncedBatchNorm
|
||||
from tinygrad.nn import BatchNorm2d
|
||||
devices = [f"{Device.DEFAULT}:{i}" for i in range(4)]
|
||||
x = Tensor.ones(8, 8, 8, 8).contiguous().realize().shard(devices, axis=0)
|
||||
|
||||
with Tensor.train():
|
||||
synced_bn = BatchNorm(8)
|
||||
unsynced_bn = UnsyncedBatchNorm(8)
|
||||
synced_bn = BatchNorm2d(8)
|
||||
unsynced_bn = UnsyncedBatchNorm(8, num_devices=len(devices))
|
||||
|
||||
for p in get_parameters([synced_bn, unsynced_bn]):
|
||||
p.shard_(devices)
|
||||
|
||||
@@ -62,6 +62,24 @@ class TestNN(unittest.TestCase):
|
||||
def test_batchnorm2d_training(self):
|
||||
self.test_batchnorm2d(True)
|
||||
|
||||
def test_batchnorm_axis(self):
|
||||
sz = (2, 4, 3, 2, 2)
|
||||
x = Tensor.randn(sz)
|
||||
weight = Tensor.randn(2, 3)
|
||||
bias = Tensor.randn(2, 3)
|
||||
mean = Tensor.randn(2, 3)
|
||||
invstd = Tensor.randn(2, 3)
|
||||
a = (x.batchnorm(weight, bias, mean, invstd, axis=(0, 2))
|
||||
.permute(1, 0, 2, 3, 4).reshape(4, 6, 2, 2))
|
||||
b = (x.permute(1, 0, 2, 3, 4).reshape(4, 6, 2, 2)
|
||||
.batchnorm(weight.flatten(), bias.flatten(), mean.flatten(), invstd.flatten()))
|
||||
t_x = torch.tensor(x.permute(1, 0, 2, 3, 4).reshape(4, 6, 2, 2).numpy())
|
||||
t_weight, t_bias = torch.tensor(weight.flatten().numpy()), torch.tensor(bias.flatten().numpy())
|
||||
t_mean, t_invstd = torch.tensor(mean.flatten().numpy()), torch.tensor(invstd.flatten().numpy())
|
||||
torch.nn.functional.batch_norm(t_x, t_mean, 1.0 / t_invstd**2, t_weight, t_bias)
|
||||
|
||||
np.testing.assert_allclose(a.numpy(), b.numpy())
|
||||
|
||||
def test_linear(self):
|
||||
def _test_linear(x, in_dim, out_dim):
|
||||
# create in tinygrad
|
||||
|
||||
@@ -927,11 +927,12 @@ class Tensor:
|
||||
y = (self - self.mean(axis, keepdim=True))
|
||||
return y.mul((y*y).mean(axis, keepdim=True).add(eps).rsqrt())
|
||||
|
||||
def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor) -> Tensor:
|
||||
shape = (1, -1) + (1,) * (self.ndim-2)
|
||||
def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor, axis:Union[int,Tuple[int,...]]=1) -> Tensor:
|
||||
axis_ = argfix(axis)
|
||||
shape = tuple(s if ax in axis_ else 1 for ax, s in enumerate(self.shape))
|
||||
x = self - mean.reshape(shape)
|
||||
if weight: x = x * weight.reshape(shape)
|
||||
ret = x.mul(invstd.reshape(shape) if len(invstd.shape) == 1 else invstd)
|
||||
ret = x.mul(invstd.reshape(shape) if len(invstd.shape) == len(axis_) else invstd)
|
||||
return (ret + bias.reshape(shape)) if bias else ret
|
||||
|
||||
def dropout(self, p=0.5) -> Tensor:
|
||||
|
||||
Reference in New Issue
Block a user