Files
tinygrad/test/unit/test_multitensor.py
2026-06-05 18:38:46 -04:00

962 lines
38 KiB
Python

import unittest, numpy as np
from tinygrad import Tensor, Variable, Context, Device, TinyJit, GlobalCounters, dtypes, UOp, nn, getenv
from tinygrad.nn.state import get_parameters, get_state_dict
from tinygrad.uop.ops import Ops
from test.helpers import not_support_multi_device, needs_second_gpu, slow
from hypothesis import given, strategies as strat, settings
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
settings.load_profile("my_profile")
d0 = f"{Device.DEFAULT}:0"
d1 = f"{Device.DEFAULT}:1"
d2 = f"{Device.DEFAULT}:2"
d3 = f"{Device.DEFAULT}:3"
d4 = f"{Device.DEFAULT}:4"
d5 = f"{Device.DEFAULT}:5"
devices_2 = (d1, d2)
devices_3 = (d1, d2, d3)
devices_4 = (d1, d2, d3, d4)
N = 128
@unittest.skipIf(not_support_multi_device(), "no multi")
class TestMultiTensor(unittest.TestCase):
@needs_second_gpu
def setUp(self): pass
def test_arange_shrink(self):
x = Tensor.arange(4)
self.assertEqual(x.shard(devices_2, 0).realize().shrink(((2, 4),)).tolist(), [2, 3])
self.assertEqual(x.shard(devices_2, 0).realize().shrink(((0, 2),)).tolist(), [0, 1])
def test_shard_like(self):
X = Tensor.ones(256).shard(devices_2, 0)
Y = Tensor.zeros(256).shard_like(X)
self.assertEqual(Y.device, X.device)
self.assertEqual(Y.uop.axis, 0)
# also test with axis=None
X2 = Tensor.ones(256).shard(devices_2, axis=None)
Y2 = Tensor.zeros(256).shard_like(X2)
self.assertEqual(Y2.device, X2.device)
self.assertEqual(Y2.uop.axis, None)
# test with single device
X3 = Tensor.ones(256)
Y3 = Tensor.zeros(256).shard_like(X3)
self.assertEqual(Y3.device, X3.device)
# cannot shard_like multi unless it's a no-op
X4 = Tensor.ones(256).shard(devices_2, 0)
Y4 = Tensor.ones(256).shard(devices_2, 0).shard_like(X4)
self.assertEqual(Y4.device, X4.device)
self.assertEqual(Y4.uop.axis, 0)
with self.assertRaises(RuntimeError):
Tensor.ones(256).shard(devices_2, None).shard_like(X4)
def _test_shard_op(self, op, out, n=4):
t = Tensor.ones(n).contiguous().realize().shard(devices_2, 0)
r = op(t).realize()
#assert t.uop.is_realized, "shard didn't realize"
self.assertEqual(r.tolist(), out)
def test_shard_reshape(self): self._test_shard_op(lambda t:t.reshape(2, 2), [[1.,1.],[1.,1.]])
def test_shard_elementwise(self): self._test_shard_op(lambda t:(t+t).reshape(2, 2), [[2.,2.],[2.,2.]])
def test_alu_deviceless_const(self):
s = Tensor([1.0, 2, 3, 4]).shard((f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1"), axis=0)
np.testing.assert_equal((s + Tensor(UOp.const(dtypes.float, 1.0))).numpy(), [2, 3, 4, 5])
np.testing.assert_equal((s + Tensor(UOp.const(dtypes.float, 1.0)).reshape((1,)).expand((4,))).numpy(), [2, 3, 4, 5])
def test_shard_reduce(self):
self._test_shard_op(lambda t:t.reshape(2, 3).sum(axis=1), [3.,3.], n=6)
self._test_shard_op(lambda t:t.reshape(2, 3).sum(axis=0), [2.,2.,2.], n=6)
def test_shard_not_multiple(self):
X = Tensor.ones(256).contiguous().realize()
with self.assertRaises(RuntimeError):
X.shard_(devices_3, 0)
def test_tensor_from_multi(self):
X = Tensor([1, 2], dtype=dtypes.int).shard_(devices_2, 0)
Y = Tensor(X.uop)
self.assertEqual(Y.device, devices_2)
np.testing.assert_equal(X.numpy(), Y.numpy())
with self.assertRaises(AssertionError):
_ = Tensor(X.uop, dtype=dtypes.float)
def test_sharded_arange(self):
sharded_arange = Tensor.arange(1000).clone().shard(devices_2, 0)
sharded_arange.realize()
np.testing.assert_equal(sharded_arange.numpy(), np.arange(1000))
def test_shard_plus_one_sum(self):
X = Tensor.ones(256).contiguous().realize()
X.shard_((d1, d2), 0)
(X + 1).sum().realize()
def test_shard_plus_one_sum_d0(self):
X = Tensor.ones(256).contiguous().realize()
X.shard_((d0, d2), 0)
(X + 1).sum().realize()
def _test_simple_add_axis(self, shard_x, shard_w):
X = Tensor.ones(256).contiguous().realize()
W = Tensor.ones(256).contiguous().realize()
X.shard_((d1, d2), shard_x)
W.shard_((d1, d2), shard_w)
O = X + W
np.testing.assert_allclose(O.numpy(), 2)
def test_simple_add(self): return self._test_simple_add_axis(None, None)
def test_simple_add_X(self): return self._test_simple_add_axis(0, None)
def test_simple_add_W(self): return self._test_simple_add_axis(None, 0)
def test_simple_add_XW(self): return self._test_simple_add_axis(0, 0)
@given(strat.sampled_from((0, 1, None)), strat.sampled_from((0, 2)))
def test_allreduce_shard_ring_sum(self, axis, use_ring):
t = Tensor([1, 2, 3, 4]).reshape(2, 2)
with Context(RING=use_ring):
np.testing.assert_equal(t.shard(devices_2, axis=axis).sum().item(), 10)
def test_multiple_to_single_device(self):
kernel_counts = {}
for ring in (0, 2):
GlobalCounters.reset()
with Context(RING=ring, SCACHE=0):
t = Tensor.arange(32).clone().shard(devices_4, 0).to(Device.DEFAULT)
t.realize()
kernel_counts[ring] = GlobalCounters.kernel_count
self.assertEqual(t.device, Device.DEFAULT)
np.testing.assert_equal(t.numpy(), np.arange(32))
self.assertEqual(kernel_counts[0], kernel_counts[2])
def test_to_single_device_gather_memory(self):
nrows, ncols = 64, 1024
nbytes = nrows*ncols*4
for devs in (devices_2, devices_4):
ndev = len(devs)
for axis in (0, 1):
sh = Tensor.arange(nrows*ncols).reshape(nrows, ncols).clone().shard(devs, axis).realize()
kernels, mem = {}, {}
for ring in (0, 2):
GlobalCounters.reset()
with Context(RING=ring, SCACHE=0):
t = sh.to(Device.DEFAULT)
t.realize()
kernels[ring], mem[ring] = GlobalCounters.kernel_count, GlobalCounters.global_mem
self.assertEqual(t.device, Device.DEFAULT)
np.testing.assert_equal(t.numpy(), np.arange(nrows*ncols).reshape(nrows, ncols))
self.assertEqual(kernels[0], kernels[2])
self.assertEqual(mem[0], mem[2])
self.assertLess(kernels[0], 2*ndev)
self.assertLessEqual(mem[0], 4*nbytes)
def test_multitensor_jit_input_reduce_shard_axis(self):
@TinyJit
def f(x): return x.sum(0).realize()
for _ in range(5):
tt = Tensor.ones(2, 64).contiguous().realize().shard((d1,d2), 0).realize()
out = f(tt)
np.testing.assert_allclose(out.numpy(), np.full(64, 2.0))
def _test_matmul_shard_axis(self, shard_x, shard_w, device):
X = Tensor.kaiming_uniform(N, N).realize()
W = Tensor.kaiming_uniform(N, N).realize()
Xs = X.shard(device, shard_x)
Ws = W.shard(device, shard_w)
O = (Xs@Ws)
with np.errstate(all='ignore'):
np.testing.assert_allclose(X.numpy() @ W.numpy(), O.to(Device.DEFAULT).numpy(), atol=1e-5)
def _test_double_matmul_shard_axis(self, shard_x, shard_w, device):
X = Tensor.kaiming_uniform(N, N).realize()
W1 = Tensor.kaiming_uniform(N, N).realize()
W2 = Tensor.kaiming_uniform(N, N).realize()
Xs = X.shard(device, shard_x)
W1s = W1.shard(device, shard_w)
W2s = W2.shard(device, shard_w)
O = (Xs@W1s)@W2s
with np.errstate(all='ignore'):
np.testing.assert_allclose((X.numpy() @ W1.numpy()) @ W2.numpy(), O.to(Device.DEFAULT).numpy(), atol=1e-5)
def test_matmul_shard_none(self): return self._test_matmul_shard_axis(None, None, devices_2)
def test_matmul_shard_X_0(self): return self._test_matmul_shard_axis(0, None, devices_2)
def test_matmul_shard_X_1(self): return self._test_matmul_shard_axis(1, None, devices_2)
def test_matmul_shard_W_0(self): return self._test_matmul_shard_axis(None, 0, devices_2)
def test_matmul_shard_W_1(self): return self._test_matmul_shard_axis(None, 1, devices_2)
def test_matmul_shard_0_0(self): return self._test_matmul_shard_axis(0, 0, devices_2)
def test_matmul_shard_0_1(self): return self._test_matmul_shard_axis(0, 1, devices_2)
def test_matmul_shard_1_0(self): return self._test_matmul_shard_axis(1, 0, devices_2)
def test_matmul_shard_1_1(self): return self._test_matmul_shard_axis(1, 1, devices_2)
def test_double_matmul_shard_X_0(self): return self._test_double_matmul_shard_axis(0, None, devices_2)
def test_double_matmul_shard_X_1(self): return self._test_double_matmul_shard_axis(1, None, devices_2)
def test_double_matmul_shard_W_0(self): return self._test_double_matmul_shard_axis(None, 0, devices_2)
def test_double_matmul_shard_W_1(self): return self._test_double_matmul_shard_axis(None, 1, devices_2)
def test_conv_data_shard(self):
conv = nn.Conv2d(3, 16, 3, bias=False)
for p in get_parameters(conv): p.shard_(devices_2)
fake_image = Tensor.rand((2, 3, 32, 32)).shard(devices_2, axis=0)
out = conv(fake_image)
out.numpy()
def test_conv_bias_data_shard(self):
conv = nn.Conv2d(3, 16, 3)
for p in get_parameters(conv): p.shard_(devices_2)
fake_image = Tensor.rand((2, 3, 32, 32)).shard(devices_2, axis=0)
out = conv(fake_image)
out.numpy()
def test_backprop_conv(self):
with Tensor.train():
conv = nn.Conv2d(3, 16, 3)
for p in get_parameters(conv): p.shard_(devices_2)
optim = nn.optim.Adam(get_parameters(conv))
fake_image = Tensor.rand((2, 3, 32, 32)).shard(devices_2, axis=0)
out = conv(fake_image)
optim.zero_grad()
out.mean().backward()
#for p in get_parameters(conv): p.grad.realize()
optim.step()
out.numpy()
def test_backprop_conv_wino(self):
with Context(WINO=1): self.test_backprop_conv()
def test_backward_sum(self):
x = Tensor([[1.,2,3,4], [5,6,7,8]]).shard(devices_2, axis=0)
w = Tensor([1.,2,3,4]).shard(devices_2)
out = x * w
out.mean().backward()
tst = w.grad.numpy()
np.testing.assert_allclose(tst, [0.75, 1., 1.25, 1.5])
def test_lr_scheduler_OneCycleLR(self):
from extra.lr_scheduler import OneCycleLR
conv = nn.Conv2d(3, 16, 3)
for p in get_parameters(conv): p.shard_(devices_2)
optim = nn.optim.SGD(get_parameters(conv))
lr_sched = OneCycleLR(optim, max_lr=0.1, pct_start=0.1, div_factor=100, final_div_factor=0.1, total_steps=10)
lr_sched.step()
def test_embedding(self):
B, T, embed_size, vocab_size = 4, 10, 20, 28
layer = nn.Embedding(vocab_size, embed_size)
x = Tensor(np.random.randint(0, vocab_size, (B, T), dtype=np.int32))
z = layer(x)
layer_sharded = nn.Embedding(vocab_size, embed_size)
layer_sharded.weight.replace(layer.weight.shard(devices_2, axis=1)).realize()
x_sharded = x.shard(devices_2, axis=None)
z_shard = layer_sharded(x_sharded)
np.testing.assert_allclose(z.numpy(), z_shard.numpy(), atol=1e-6, rtol=1e-6)
def test_embedding_backward(self, shard_weight_axis=None):
B, T, embed_size, vocab_size = 4, 10, 20, 28
layer = nn.Embedding(vocab_size, embed_size)
x = Tensor(np.random.randint(0, vocab_size, (B, T), dtype=np.int32))
z = layer(x)
z.sum().backward()
grad = layer.weight.grad.numpy()
layer_sharded = nn.Embedding(vocab_size, embed_size)
layer_sharded.weight.replace(layer.weight.shard(devices_2, axis=shard_weight_axis)).realize()
x_sharded = x.shard(devices_2, axis=None)
z_shard = layer_sharded(x_sharded)
z_shard.sum().backward()
grad_shard = layer_sharded.weight.grad.numpy()
np.testing.assert_allclose(grad, grad_shard, atol=1e-6, rtol=1e-6)
def test_embedding_backward_shard_weight(self): self.test_embedding_backward(shard_weight_axis=1)
def test_rmsnorm(self):
B, T, embed_size = 4, 10, 20
norm = nn.RMSNorm(embed_size)
x = Tensor.rand((B, T, embed_size)).contiguous().realize()
y = norm(x)
# for norm layers, the correct way to shard weights is duplication
norm_sharded = nn.RMSNorm(embed_size)
norm_sharded.weight.shard_(devices_2, axis=None).realize()
# if x is being sharded, then all-reduce is involved
x_sharded = x.shard(devices_2, axis=2).realize()
y_shard = norm_sharded(x_sharded).realize()
np.testing.assert_allclose(y.numpy(), y_shard.numpy(), atol=1e-6, rtol=1e-6)
# if x is being duplicated, then the operations remain inside each GPU
# which is the common case
x_sharded = x.shard(devices_2, axis=None).realize()
y_shard = norm_sharded(x_sharded).realize()
np.testing.assert_allclose(y.numpy(), y_shard.numpy(), atol=1e-6, rtol=1e-6)
def test_sdpa_causal_shard_batch(self):
B, H, T, D = 4, 2, 10, 16
q = Tensor.rand(B, H, T, D)
k = Tensor.rand(B, H, T, D)
v = Tensor.rand(B, H, T, D)
q_shard = q.shard(devices_2, axis=0)
k_shard = k.shard(devices_2, axis=0)
v_shard = v.shard(devices_2, axis=0)
Tensor.realize(q, k, v, q_shard, k_shard, v_shard)
y = Tensor.scaled_dot_product_attention(q, k, v, is_causal=True).realize()
y_shard = Tensor.scaled_dot_product_attention(q_shard, k_shard, v_shard, is_causal=True).realize()
np.testing.assert_allclose(y_shard.numpy(), y.numpy(), atol=1e-6, rtol=1e-6)
# NOTE: this is failing on LLVM CI, no idea why. Works locally.
@slow
def test_data_parallel_resnet(self):
from extra.models.resnet import ResNet18
fake_image = Tensor.rand((2, 3, 224//16, 224//16))
fake_image_sharded = fake_image.shard(devices_2, axis=0)
m = ResNet18()
m.load_from_pretrained()
real_output = m(fake_image).log_softmax().numpy()
for p in get_parameters(m): p.shard_(devices_2).realize()
GlobalCounters.reset()
shard_output = m(fake_image_sharded).log_softmax().realize()
shard_output_np = shard_output.numpy()
np.testing.assert_allclose(real_output, shard_output_np, atol=1e-6, rtol=1e-6)
def test_multi_tensor_jit_param(self):
@TinyJit
def jf(a, b) -> Tensor:
return (a + b).realize()
for _ in range(5):
a = Tensor.ones(256).contiguous().realize()
b = Tensor.ones(256).contiguous().realize()
a.shard_(devices_2)
b.shard_(devices_2)
c = jf(a, b)
np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
assert jf.captured is not None
def test_multi_tensor_jit_body(self):
@TinyJit
def jf() -> Tensor:
a = Tensor.ones(256).contiguous().realize()
b = Tensor.ones(256).contiguous().realize()
a.shard_(devices_2)
b.shard_(devices_2)
return (a + b).realize()
for _ in range(5):
r = jf()
np.testing.assert_allclose(r.numpy(), np.ones(256)+np.ones(256), atol=1e-4, rtol=1e-5)
assert jf.captured is not None
def test_multitensor_jit_in_list(self):
# test MULTI tensor inside a list container - exercises the container unpacking + MULTI unpacking
@TinyJit
def f(a, arr): return (a + arr[0]).realize()
for i in range(5):
a = Tensor.full((4,), i).contiguous().realize().shard(devices_2, 0).realize()
b = Tensor.ones(4).contiguous().realize().shard(devices_2, 0).realize()
out = f(a, [b])
np.testing.assert_allclose(out.numpy(), np.full(4, i) + np.ones(4), atol=1e-4, rtol=1e-5)
def test_multitensor_jit_multiple_inputs(self):
# test multiple MULTI tensors as inputs - each gets unpacked to component UOps
@TinyJit
def f(a, b, c): return (a + b + c).realize()
for i in range(5):
a = Tensor.full((4,), i).contiguous().realize().shard(devices_2, 0).realize()
b = Tensor.full((4,), i*2).contiguous().realize().shard(devices_2, 0).realize()
c = Tensor.ones(4).contiguous().realize().shard(devices_2, 0).realize()
out = f(a, b, c)
np.testing.assert_allclose(out.numpy(), np.full(4, i) + np.full(4, i*2) + np.ones(4), atol=1e-4, rtol=1e-5)
def test_multitensor_jit_different_sharding(self):
# test MULTI tensors with different sharding - one sharded on axis 0, one broadcast (axis=None)
@TinyJit
def f(a, b): return (a + b).realize()
for i in range(5):
a = Tensor.full((4, 4), i).contiguous().realize().shard(devices_2, 0).realize()
b = Tensor.full((4, 4), i*2).contiguous().realize().shard(devices_2, None).realize()
out = f(a, b)
np.testing.assert_allclose(out.numpy(), np.full((4, 4), i) + np.full((4, 4), i*2), atol=1e-4, rtol=1e-5)
def test_bn_ast_on_devices(self):
t = Tensor.empty((16, 64, 112, 112)).shard(devices_4, axis=0)
bn = nn.BatchNorm2d(64)
for p in get_parameters(bn): p.shard_(devices_4).realize()
out = bn(t)
scheds = [call for call in out.schedule_linear().src if call.src[0].op is not Ops.COPY and set(call.device) <= set(devices_4)]
self.assertEqual(set(scheds[0].device), set(devices_4), "should have ast on each shard device")
self.assertEqual(len(set(s.src[0] for s in scheds)), 1)
def test_flip(self):
rng = Tensor.rand((10, 10, 10))
t0 = rng.shard(devices_2, axis=1)
out = t0.flip(0) + 1
self.assertTrue((rng.flip(0)+1).allclose(out.to(rng.device)).item())
@unittest.skip("flaky")
def test_reshape_on_axis(self):
t0 = Tensor.rand((26, 15, 7)).shard(devices_3, axis=1)
# test split and rejoin to the right
t1 = t0.reshape((26, 3, 5, 7))
t2 = t0.reshape((26, 3, 35))
t3 = t1.reshape((26, 15, 7))
t4 = t2.reshape((26, 105,))
for t in [t0, t1, t2, t3, t4]:
assert t.uop.axis == 1
np.testing.assert_allclose(t.numpy().flatten(), t0.numpy().flatten())
# test shape-one axis
t5 = t4.reshape((26, 1, 105))
assert t5.uop.axis == 2
np.testing.assert_allclose(t.numpy().flatten(), t5.numpy().flatten())
# test split and rejoin to the right and reshape to the left
t5 = t0.reshape((2, 13, 3, 5, 7))
t6 = t0.reshape((13, 2, 3, 7, 5))
t7 = t0.reshape((1, 13, 2, 3, 1, 7, 5))
assert t5.uop.axis == 2
assert t6.uop.axis == 2
assert t7.uop.axis == 3
np.testing.assert_allclose(t5.numpy().flatten(), t0.numpy().flatten())
np.testing.assert_allclose(t6.numpy().flatten(), t0.numpy().flatten())
np.testing.assert_allclose(t7.numpy().flatten(), t0.numpy().flatten())
# test no left join
with self.assertRaises((AssertionError, ValueError)):
t0.reshape((26*15,7)).contiguous().schedule_linear()
# it doesn't work like this anymore
# NOTE: this never failed in assign_multi, it failed tensor spec because MULTI was never pushed in the graph
@unittest.skip("this test is broken")
def test_mlb_assign_change_axis(self):
t_none = Tensor.zeros((16, 16)).shard(devices_2).contiguous().realize()
t_zero = Tensor.ones((16, 16)).shard(devices_2, axis=0)
with self.assertRaises(RuntimeError):
# don't allow assigns that change axes
t_none.assign(t_zero)
t_none.schedule_linear()
def test_init_rand_with_multiple_devices_fail(self):
# init rand with multi device is not allowed
with self.assertRaises(ValueError):
Tensor.rand(256, device=devices_2)
def test_rand_like_on_shard(self, axis=None):
t = Tensor.empty((16, 16)).shard(devices_2, axis=axis)
t2 = Tensor.rand_like(t)
self.assertEqual(t.shape, t2.shape)
self.assertEqual(t.device, t2.device)
self.assertEqual(t.dtype, t2.dtype)
self.assertEqual(t.uop.axis, t2.uop.axis)
t2.realize()
def test_rand_like_on_shard_axis(self): self.test_rand_like_on_shard(0)
def test_rand_like_from_alu(self):
a = Tensor.ones(4, 4).shard(devices_4, axis=0)
aa = a + a
self.assertEqual(aa.device, devices_4)
self.assertEqual(aa.uop.axis, 0)
raa = aa.rand_like()
self.assertEqual(raa.device, devices_4)
self.assertEqual(raa.uop.axis, 0)
b = Tensor.empty(4, 4).shard(devices_4, axis=None)
ab = a + b
self.assertEqual(ab.device, devices_4)
self.assertEqual(ab.uop.axis, 0)
rab = ab.rand_like()
self.assertEqual(rab.device, devices_4)
self.assertEqual(rab.uop.axis, 0)
def test_rand_like_none_shard(self):
t = Tensor.empty((16, 16)).shard(devices_2)
t2 = Tensor.rand_like(t)
self.assertEqual(t.shape, t2.shape)
self.assertEqual(t.device, t2.device)
self.assertEqual(t.dtype, t2.dtype)
self.assertEqual(t.uop.axis, t2.uop.axis)
def test_rand_like_arg_dtype(self):
t = Tensor.empty((16, 16), dtype=dtypes.int32).shard(devices_2, axis=1)
t2 = Tensor.rand_like(t, dtype=dtypes.float32)
self.assertEqual(t.dtype, dtypes.int32)
self.assertEqual(t2.dtype, dtypes.float32)
def test_rand_like_arg_device(self):
# axis=None
t = Tensor.empty((16, 16)).shard((d1, d2), axis=None)
with self.assertRaises(RuntimeError):
Tensor.rand_like(t, device=(d3, d4))
# axis=1
t = Tensor.empty((16, 16)).shard((d1, d2), axis=1)
with self.assertRaises(RuntimeError):
Tensor.rand_like(t, device=(d3, d4))
def test_full_like_on_shard(self, axis=None):
t = Tensor.empty((16, 16)).shard(devices_2, axis=axis)
t2 = Tensor.full_like(t, 1.0)
self.assertEqual(t.shape, t2.shape)
self.assertEqual(t.device, t2.device)
self.assertEqual(t.dtype, t2.dtype)
self.assertEqual(t.uop.axis, t2.uop.axis)
t2.realize()
def test_full_like_on_shard_axis(self): self.test_full_like_on_shard(0)
def test_dropout_on_shard(self):
with Tensor.train():
X = Tensor.ones(256).to(devices_2)
output = X.dropout(0.5).numpy()
unique, counts = np.unique(output, return_counts=True)
assert set(unique) == {0, 2}, unique
assert 96 < counts[0] < 160, counts[0]
def test_dropout_on_shard_axis(self):
with Tensor.train():
X = Tensor.ones(512).shard(devices_2, axis=0)
output = X.dropout(0.5).numpy()
unique, counts = np.unique(output, return_counts=True)
assert set(unique) == {0, 2}, unique
assert 192 < counts[0] < 320, counts[0]
@unittest.skip("TODO: this requires forced_realize to be deleted.")
def test_shard_memory(self):
devices = (d0, d1, d2, d3)
t = Tensor.zeros(16, 16).contiguous()
t.shard_(devices, axis=0).realize()
assert all([lb is lb.base and lb.realized.base.size == 4 * 16 for lb in t.uop.src])
def test_clone(self):
for axis in (None, 0):
t = Tensor.arange(16).reshape(4, 4).clone().shard(devices_2, axis=axis).contiguous().realize()
t_clone = t.clone().realize()
self.assertEqual(t_clone.device, t.device)
self.assertEqual(t_clone.uop.axis, axis)
self.assertEqual(t_clone.tolist(), t.tolist())
t_clone += 1
self.assertNotEqual(t_clone.tolist(), t.tolist())
@unittest.skip("RANGEIFY doesn't support multi const folding")
def test_multi_const_folding(self):
with Context(TRACK_MATCH_STATS=0):
a = Tensor.arange(3).clone().realize()
zeros = Tensor.zeros(3).realize()
b = a.to(devices_2)*zeros.to(devices_2)
sched = b.schedule_linear().src
self.assertEqual(len(sched), 0)
self.assertListEqual(b.tolist(), [0, 0, 0])
@unittest.skipIf(not_support_multi_device(), "no multi")
class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
@needs_second_gpu
def setUp(self): pass
# shrink a multitensor on sharded axis
def test_shrink_bad_args(self):
t = Tensor.arange(64).reshape(8, 8).clone().realize()
t.shard_([f"{Device.DEFAULT}:{i}" for i in range(4)], axis=0)
with self.assertRaises(AssertionError):
# sharded axis shrink on non-device boundry is not allowed
a = t.shrink(((0, 3), (0, 8))).contiguous()
a.schedule_linear()
a = t.shrink(((0, 2), (2, 4)))
assert a.shape == (2, 2)
ref = Tensor.arange(64).reshape(8, 8).shrink(((0, 2), (2, 4)))
np.testing.assert_equal(a.numpy(), ref.numpy())
a = t.shrink(((0, 2), (0, 8))).contiguous()
a.schedule_linear()
assert a.shape == (2, 8)
p = a.pad(((0, 6), (0, 0))).contiguous()
p.schedule_linear()
assert p.shape == (8, 8)
@given(strat.sampled_from([dtypes.float, dtypes.int, dtypes.int64, dtypes.int16]))
def test_ops(self, dtype):
if dtype not in Device[Device.DEFAULT].renderer.supported_dtypes(): return
t = Tensor.arange(64).reshape(8, 8).clone().realize()
t.shard_([f"{Device.DEFAULT}:{i}" for i in range(4)], axis=0)
for i in range(4):
print(f"{i=}")
a = t.shrink(((0+2*i,2+2*i),None))
b = Tensor(t.numpy()[0+2*i:2+2*i])
assert a.shape == b.shape == (2, 8)
np.testing.assert_allclose(a.numpy(), b.numpy())
# cast
np.testing.assert_allclose(a.float().numpy(), b.float().numpy())
# elementwise
np.testing.assert_allclose(a.exp().numpy(), b.exp().numpy(), rtol=1e-7, atol=1e-3)
np.testing.assert_allclose(a.reciprocal().numpy(), b.reciprocal().numpy(), rtol=1e-7, atol=1e-3)
np.testing.assert_allclose(a.pow(-0.5).numpy(), b.pow(-0.5).numpy(), rtol=1e-7, atol=1e-3)
np.testing.assert_allclose((a+a).numpy(), (b+b).numpy(), rtol=1e-7, atol=1e-3)
np.testing.assert_equal((a+1).numpy(), (b+1).numpy())
np.testing.assert_equal((1+a).numpy(), (1+b).numpy())
np.testing.assert_allclose((a.where(a+a, a)).numpy(), (b.where(b+b, b)).numpy(), rtol=1e-7, atol=1e-3)
np.testing.assert_allclose((a.where(1, 0)).numpy(), (b.where(1, 0)).numpy(), rtol=1e-7, atol=1e-3)
# reduce
np.testing.assert_allclose(a.max().numpy(), b.max().numpy(), rtol=1e-7, atol=1e-3)
np.testing.assert_allclose(a.sum().numpy(), b.sum().numpy(), rtol=1e-7, atol=1e-3)
np.testing.assert_allclose(a.mean().numpy(), b.mean().numpy(), rtol=1e-7, atol=1e-3)
np.testing.assert_allclose(a.max(0).numpy(), b.max(0).numpy(), rtol=1e-7, atol=1e-3)
np.testing.assert_allclose(a.sum(0).numpy(), b.sum(0).numpy(), rtol=1e-7, atol=1e-3)
np.testing.assert_allclose(a.mean(0).numpy(), b.mean(0).numpy(), rtol=1e-7, atol=1e-3)
np.testing.assert_allclose(a.max(1).numpy(), b.max(1).numpy(), rtol=1e-7, atol=1e-3)
np.testing.assert_allclose(a.sum(1).numpy(), b.sum(1).numpy(), rtol=1e-7, atol=1e-3)
np.testing.assert_allclose(a.mean(1).numpy(), b.mean(1).numpy(), rtol=1e-7, atol=1e-3)
# pad it back
np.testing.assert_allclose(a.pad(((2*i, 2*(4-i-1)), None)).numpy(), b.pad(((2*i, 2*(4-i-1)), None)).numpy(), rtol=1e-7, atol=1e-3)
# other movement
np.testing.assert_allclose(a.pad((None, (1, 1))).numpy(), b.pad((None, (1, 1))).numpy(), rtol=1e-7, atol=1e-3)
np.testing.assert_allclose(a.shrink((None, (1, 3))).numpy(), b.shrink((None, (1, 3))).numpy(), rtol=1e-7, atol=1e-3)
np.testing.assert_allclose(a.permute((1, 0)).numpy(), b.permute((1, 0)).numpy(), rtol=1e-7, atol=1e-3)
np.testing.assert_allclose(a.reshape((2, 2, 4)).numpy(), b.reshape((2, 2, 4)).numpy(), rtol=1e-7, atol=1e-3)
np.testing.assert_allclose(a.reshape((2, 1, 8)).expand((2, 5, 8)).numpy(), b.reshape((2, 1, 8)).expand((2, 5, 8)).numpy(), rtol=1e-7, atol=1e-3)
np.testing.assert_allclose(a.flip(-1).numpy(), b.flip(-1).numpy(), rtol=1e-7, atol=1e-3)
def test_add_two_partitions(self):
t = Tensor.arange(64).reshape(8, 8).clone().realize()
t.shard_([f"{Device.DEFAULT}:{i}" for i in range(4)], axis=0)
a = t.shrink(((2, 4), None))
b = t.shrink(((6, 8), None))
na = t.numpy()[2:4]
nb = t.numpy()[6:8]
np.testing.assert_equal(a.numpy(), na)
np.testing.assert_equal(b.numpy(), nb)
np.testing.assert_equal((a+b).numpy(), na+nb)
c = a.pad(((2, 4), None)) + b.pad(((6, 0), None))
c.realize()
expected = np.concatenate([np.zeros_like(t.numpy()[0:2]), na, np.zeros_like(t.numpy()[4:6]), nb])
np.testing.assert_equal(c.numpy(), expected)
def test_add_different_tensors(self):
devices = [f"{Device.DEFAULT}:{i}" for i in range(4)]
x = Tensor.arange(64).reshape(8, 8).clone().realize().shard(devices, axis=0)
to_add = []
for i in range(len(devices)):
to_add.append((Tensor.ones(2, 8) * i).shard(devices))
added:list[Tensor] = []
for bound, a in zip(x.uop.bounds, to_add):
added.append(x[bound[0]:bound[1]] + a)
output = added[0].cat(*added[1:])
expected = np.arange(64).reshape((8,8)) + np.array([[0,0,1,1,2,2,3,3] for _ in range(8)]).T
np.testing.assert_allclose(output.numpy(), expected)
@unittest.skipIf(not_support_multi_device(), "no multi")
class TestBatchNorm(unittest.TestCase):
@needs_second_gpu
def setUp(self): pass
def test_unsynced_backprop_conv_bn(self):
with Tensor.train():
from extra.lr_scheduler import OneCycleLR
convs = [nn.Conv2d(3, 16, 3), nn.Conv2d(3, 16, 3)]
bns = [nn.BatchNorm2d(16), nn.BatchNorm2d(16)]
for p in get_parameters(convs + bns):
p.shard_((d1, d2))
optim = nn.optim.Adam(get_parameters(convs + bns))
lr_sched = OneCycleLR(optim, max_lr=0.1, pct_start=0.1, div_factor=100, final_div_factor=0.1, total_steps=10)
lr_sched.step()
fake_image = Tensor.rand((8, 3, 32, 32)).shard((d1, d2), axis=0)
f1 = fake_image.shrink(((0, 4), None, None, None))
f2 = fake_image.shrink(((4, 8), None, None, None))
out1 = bns[0](convs[0](f1))
out2 = bns[1](convs[1](f2))
out = out1.cat(out2)
optim.zero_grad()
out.mean().backward()
optim.step()
out.numpy()
def test_unsynced_backprop_standalone_bn(self):
from extra.lr_scheduler import OneCycleLR
GPUS = (d1, d2)
class BatchNorm:
def __init__(self, num_features):
self.bns:list[nn.BatchNorm2d] = []
for _ in GPUS:
bn = nn.BatchNorm2d(num_features, track_running_stats=False, eps=1e-12, momentum=0.85, affine=True)
self.bns.append(bn)
def __call__(self, x:Tensor):
bn_ts = []
each = x.shape[0]//len(self.bns)
for i, bn in enumerate(self.bns):
xi = x.shrink(((each*(i), each*(i+1)), None, None, None))
bni = bn(xi)
bn_ts.append(bni)
return bn_ts[0].cat(*bn_ts[1:])
with Tensor.train():
conv = nn.Conv2d(3, 16, 3)
bn = BatchNorm(16)
for p in get_parameters([conv, bn]):
p.shard_(GPUS)
optim = nn.optim.Adam(get_parameters([conv, bn]))
lr_sched = OneCycleLR(optim, max_lr=0.1, pct_start=0.1, div_factor=100, final_div_factor=0.1, total_steps=10)
lr_sched.step()
fake_image = Tensor.rand((8, 3, 32, 32)).shard(GPUS, axis=0)
out = bn(conv(fake_image))
optim.zero_grad()
out.mean().backward()
optim.step()
def test_unsynced_backprop_sync_weights(self):
from extra.lr_scheduler import OneCycleLR
from examples.hlb_cifar10 import UnsyncedBatchNorm
GPUS = (d1, d2)
with Tensor.train():
conv = nn.Conv2d(3, 16, 3)
bn = UnsyncedBatchNorm(16, num_devices=len(GPUS))
for k, p in get_state_dict([conv, bn]).items():
if 'running_mean' in k or 'running_var' in k:
p.shard_(GPUS, axis=0)
else:
p.to_(GPUS)
optim = nn.optim.Adam(get_parameters([conv, bn]))
lr_sched = OneCycleLR(optim, max_lr=0.1, pct_start=0.1, div_factor=100, final_div_factor=0.1, total_steps=10)
lr_sched.step()
fake_image = Tensor.rand((8, 3, 32, 32)).shard(GPUS, axis=0)
out = bn(conv(fake_image))
optim.zero_grad()
out.mean().backward()
optim.step()
@given(strat.sampled_from((False, True)))
def test_batchnorm(self, is_training):
devices = [f"{Device.DEFAULT}:{i}" for i in range(4)]
x = Tensor.arange(4096).reshape(8, 8, 8, 8).clone().realize().shard(devices, axis=0)
with Tensor.train(is_training):
bns = []
for _ in range(len(devices)):
bn = nn.BatchNorm2d(8)
for p in get_parameters(bn):
p.shard_(devices)
bns.append(bn)
bn_ts = []
for bound, bn in zip(x.uop.bounds, bns):
bni = bn(x[bound[0]:bound[1]])
bn_ts.append(bni)
bn_ts[0].cat(*bn_ts[1:]).numpy()
def test_synced_vs_unsynced_bn(self):
from examples.hlb_cifar10 import UnsyncedBatchNorm
from tinygrad.nn import BatchNorm2d
devices = [f"{Device.DEFAULT}:{i}" for i in range(4)]
x = Tensor.ones(8, 8, 8, 8).contiguous().realize().shard(devices, axis=0)
with Tensor.train():
synced_bn = BatchNorm2d(8)
unsynced_bn = UnsyncedBatchNorm(8, num_devices=len(devices))
for p in get_parameters(synced_bn):
p.shard_(devices)
for k, p in get_state_dict(unsynced_bn).items():
if 'running_mean' in k or 'running_var' in k:
p.shard_(devices, axis=0)
else:
p.to_(devices)
synced_out = synced_bn(x)
synced_si = list(synced_out.schedule_linear().src)
unsynced_out = unsynced_bn(x)
unsynced_si = list(unsynced_out.schedule_linear().src)
# TODO: test synced / unsynced batchnorm cross device kernel and copies
assert synced_si
assert unsynced_si
@unittest.skipIf(not_support_multi_device(), "need multi")
class TestMultiFromUnrenderable(unittest.TestCase):
@needs_second_gpu
def test_from_npy(self):
t = Tensor(np.arange(100, dtype=np.uint32))
ll = t.shard((d0, d1), axis=0) + 1
np.testing.assert_equal(ll.numpy(), np.arange(100)+1)
@unittest.skipIf(not_support_multi_device(), "need multi")
class TestMultiAssign(unittest.TestCase):
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(2))
@needs_second_gpu
def setUp(self): pass
def test_multi_assign_realized(self):
out = Tensor.zeros(4).shard(self.device, 0).contiguous().realize()
ones = Tensor.ones(4).shard(self.device, 0).contiguous().realize()
out.assign(ones).realize()
self.assertListEqual(out.tolist(), [1,1,1,1])
def test_multi_assign_unrealized(self):
out = Tensor.zeros(4).contiguous().realize().shard(self.device, 0)
ones = Tensor.ones(4).shard(self.device, 0).contiguous().realize()
out.assign(ones).realize()
self.assertListEqual(out.tolist(), [1,1,1,1])
def test_multi_assign_both_unrealized(self):
out = Tensor.zeros(4).contiguous().realize().shard(self.device, 0)
ones = Tensor.ones(4).contiguous().realize().shard(self.device, 0)
out.assign(ones).realize()
self.assertListEqual(out.tolist(), [1,1,1,1])
def test_multi_assign_scalar(self):
out = Tensor.ones(4).shard(self.device, 0).contiguous().realize()
out.assign(0).realize()
self.assertListEqual(out.tolist(), [0,0,0,0])
def test_multi_assign_const_like(self):
out = Tensor.ones(4).shard(self.device, 0).contiguous().realize()
out.assign(out.const_like(7)).realize()
self.assertListEqual(out.tolist(), [7,7,7,7])
def test_multi_assign_piece(self):
out = Tensor.zeros(4,4).shard(self.device, 0).contiguous().realize()
ones = Tensor.ones(4,1).shard(self.device, 0).contiguous().realize()
out[:, 2:3].assign(ones).realize()
self.assertListEqual(out.tolist(), [[0,0,1,0], [0,0,1,0], [0,0,1,0], [0,0,1,0]])
def test_multi_assign_piece_noncontig(self):
out = Tensor.zeros(4,4).contiguous().realize().shard(self.device, 0).realize()
ones = Tensor.ones(4,1).shard(self.device, 0).contiguous().realize()
out[:, 2:3].assign(ones).realize()
self.assertListEqual(out.tolist(), [[0,0,1,0], [0,0,1,0], [0,0,1,0], [0,0,1,0]])
@unittest.expectedFailure
def test_multi_assign_piece_unrealized(self):
out = Tensor.zeros(4,4).contiguous().realize().shard(self.device, 0)
ones = Tensor.ones(4,1).shard(self.device, 0).contiguous().realize()
out[:, 2:3].assign(ones).realize()
self.assertListEqual(out.tolist(), [[0,0,1,0], [0,0,1,0], [0,0,1,0], [0,0,1,0]])
def test_multi_assign_var_offset(self):
out = Tensor.zeros(4,4).contiguous().realize().shard(self.device, 0).realize()
ones = Tensor.ones(4,1).shard(self.device, 0).contiguous().realize()
vi = Variable("i", 0, 3).bind(2)
out[:, vi:vi+1].assign(ones).realize()
self.assertListEqual(out.tolist(), [[0,0,1,0], [0,0,1,0], [0,0,1,0], [0,0,1,0]])
def test_multi_assign_var_offset_jit_none(self): self.test_multi_assign_var_offset_jit(None)
def test_multi_assign_var_offset_jit(self, shard_axis=0):
out = Tensor.zeros(4,6).contiguous().realize().shard(self.device, shard_axis).realize()
ones = Tensor.ones(4,1).shard(self.device, shard_axis).contiguous().realize()
@TinyJit
def f(out:Tensor, vi):
out[:, vi:vi+1].assign(ones).realize()
ones.assign(ones+1).realize()
vi = Variable("i", 0, 5)
for i in range(1,5):
GlobalCounters.reset()
f(out, vi.bind(i))
self.assertListEqual(out.tolist(), [[0,1,2,3,4,0]]*4)
@unittest.skipIf(not_support_multi_device(), "need multi")
class TestMultiSetitem(unittest.TestCase):
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(4))
@needs_second_gpu
def setUp(self): pass
def _t(self, axis): return Tensor.arange(16).clone().realize().shard(self.device, axis=axis)
def test_setitem_scalar_axis0(self):
t = self._t(0)
t[1] = 99
self.assertListEqual(t.tolist(), [0,99,2,3,4,5,6,7,8,9,10,11,12,13,14,15])
def test_setitem_scalar_axis_none(self):
t = self._t(None)
t[1] = 99
self.assertListEqual(t.tolist(), [0,99,2,3,4,5,6,7,8,9,10,11,12,13,14,15])
def test_setitem_slice_cross_shard(self):
t = self._t(0)
t[2:6] = 99
self.assertListEqual(t.tolist(), [0,1,99,99,99,99,6,7,8,9,10,11,12,13,14,15])
def test_setitem_full_slice(self):
t = self._t(0)
t[:] = 42
self.assertListEqual(t.tolist(), [42]*16)
def test_setitem_stride(self):
t = self._t(0)
t[::4] = 0
self.assertListEqual(t.tolist(), [0,1,2,3,0,5,6,7,0,9,10,11,0,13,14,15])
def test_setitem_single_shard(self):
t = self._t(0)
t[13] = 99
self.assertListEqual(t.tolist(), [0,1,2,3,4,5,6,7,8,9,10,11,12,99,14,15])
def test_setitem_tensor_value_replicated(self):
t = self._t(0)
t[2:6] = Tensor([90, 91, 92, 93]).shard(self.device)
self.assertListEqual(t.tolist(), [0,1,90,91,92,93,6,7,8,9,10,11,12,13,14,15])
def test_setitem_tensor_value_sharded_aligned(self):
t = self._t(0)
t[::4] = Tensor([90, 91, 92, 93]).shard(self.device, axis=0)
self.assertListEqual(t.tolist(), [90,1,2,3,91,5,6,7,92,9,10,11,93,13,14,15])
def helper_test_shard_op(shps, fxn, atol=1e-6, rtol=1e-3):
for shp in shps:
single_in = Tensor.randn(shp)
multi_in = single_in.shard(devices_2, axis=0)
single_out = fxn(single_in).numpy()
multi_out = fxn(multi_in).numpy()
try:
assert single_out.shape == multi_out.shape, f"shape mismatch: single={single_out.shape} | multi={multi_out.shape}"
assert single_out.dtype == multi_out.dtype, f"dtype mismatch: single={single_out.dtype} | multi={multi_out.dtype}"
np.testing.assert_allclose(single_out, multi_out, atol=atol, rtol=rtol)
except Exception as e:
raise Exception(f"Failed shape {single_out.shape}: {e}")
@unittest.skipIf(not_support_multi_device(), "no multi")
class TestTensorOps(unittest.TestCase):
@needs_second_gpu
def test_interpolate(self):
helper_test_shard_op([(4,16,16),(4,24,24)], lambda x: Tensor.interpolate(x, (19,19)))
@needs_second_gpu
def test_bitcast(self):
helper_test_shard_op([(256,), (256,)], lambda x: x.bitcast(dtypes.int))
if __name__ == '__main__':
unittest.main()