Files
tinygrad/test/backend/test_multitensor.py
2026-06-05 18:38:46 -04:00

477 lines
18 KiB
Python

import unittest, random
from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes, Variable
from tinygrad.uop.ops import Ops, UOp
from tinygrad.helpers import getenv, prod, Context
from tinygrad.nn.state import get_parameters
from tinygrad.engine.realize import run_linear, compile_linear
import numpy as np
from hypothesis import given, strategies as strat, settings
from test.helpers import not_support_multi_device, needs_second_gpu, slow, call_is_graph
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
settings.load_profile("my_profile")
d0 = f"{Device.DEFAULT}:0"
d1 = f"{Device.DEFAULT}:1"
d2 = f"{Device.DEFAULT}:2"
d3 = f"{Device.DEFAULT}:3"
d4 = f"{Device.DEFAULT}:4"
d5 = f"{Device.DEFAULT}:5"
devices_2 = (d1, d2)
devices_3 = (d1, d2, d3)
devices_4 = (d1, d2, d3, d4)
N = 128
# shard_x is "data parallel"
# shard_w is "model parallel"
def _test_allreduce(t:Tensor):
aa = (t[0:64] + t[64:128] + t[128:192] + t[192:256]).repeat([4,1]).realize()
ts = t.shard(devices_4, 0).realize()
b = Tensor(UOp.allreduce(ts.uop, Ops.ADD, ts.device))
b.realize()
return aa, b
@unittest.skipIf(not_support_multi_device(), "no multi")
class TestMultiTensor(unittest.TestCase):
@needs_second_gpu
def setUp(self): pass
def test_to(self):
X = Tensor.ones(256).contiguous().realize()
X.to_(devices_2)
assert X.shape == (256,)
(X + X).realize()
def test_gradient(self):
X = Tensor.ones(256).contiguous().realize()
X.to_(devices_2)
grad = X.sum().gradient(X)[0]
grad.realize()
def test_shard(self):
X = Tensor.ones(256).contiguous().realize()
X.shard_(devices_2, 0)
for lb in X.uop.src:
assert lb.shape == (128,)
(X + X).realize()
@unittest.expectedFailure # TODO: fix
def test_shard_empty(self):
GlobalCounters.reset()
X = Tensor.empty(256).shard(devices_2, 0).realize()
assert GlobalCounters.kernel_count == 0
(X + X).realize()
# TODO: fix this to not copy on the src device
@unittest.expectedFailure
def test_shard_no_recompile(self):
X = Tensor.ones(256).contiguous().realize()
X.shard_(devices_2, 0)
out = (X + X)
linear = compile_linear(out.schedule_linear())
names = [call.src[0].src[0].arg.name for call in linear.src if call.src[0].op is Ops.PROGRAM]
run_linear(linear)
self.assertEqual(len(set(names)), 1, "function was relinearized")
def test_shard_same_device(self):
X = Tensor.ones(256).contiguous().realize()
X.shard_((d1, X.device), 0)
(X + X).realize()
def test_numpy(self):
X = Tensor.ones(256)
X.shard_((d1, d2), 0)
np.testing.assert_allclose(X.numpy(), 1)
def test_four_add(self):
X = Tensor.ones(256, 256).contiguous().realize()
W = Tensor.ones(256, 256).contiguous().realize()
X.shard_(devices_4, 1)
W.shard_(devices_4, None)
O = X + W
np.testing.assert_allclose(O.numpy(), 2)
def test_elementwise_dtype(self):
Tensor.manual_seed(0)
X = Tensor.randn(8, 8).realize()
W = Tensor.randn(8, 8).realize()
X.shard_(devices_4, 0)
W.shard_(devices_4, 0)
O = X.shrink(((0, 2), None)) * W.shrink(((0, 2), None)) < 2
np.testing.assert_allclose(O.numpy(), X.numpy()[0:2]*W.numpy()[0:2] < 2)
def test_shrink_on_shard_axis(self):
X = Tensor.arange(4*4).reshape(4,4).clone().realize()
X_np = X.numpy()
X.shard_(devices_2, 0)
# only shrink on the device that owns the shard, this is enabled by the mselect simplifier
for i in range(2):
xt = X[i*2:i*2+2].contiguous()
linear, var_vals = xt.linear_with_vars()
#kernels = [call for call in linear.src if call.src[0].op is Ops.SINK]
#self.assertEqual(len(kernels), 1)
#self.assertEqual(kernels[0].src[1].buffer.device, devices_2[i])
run_linear(linear, var_vals)
np.testing.assert_equal(xt.numpy(), X_np[i*2:i*2+2])
@given(strat.sampled_from((devices_2, devices_3)),
strat.sampled_from((Ops.ADD, Ops.MUL, Ops.MAX)),
strat.sampled_from((None, 0, 1)), strat.sampled_from((None, 0, 1)))
def test_simple_reduce(self, devices, rop, shard_axis, reduce_axis):
N = 4 * len(devices)
X = (Tensor.rand(N*N)-1).reshape(N, N).shard_(devices, shard_axis)
n = X.numpy()
f = {Ops.ADD: lambda x: x.sum(reduce_axis), Ops.MUL: lambda x: x.prod(reduce_axis), Ops.MAX: lambda x: x.max(reduce_axis)}[rop]
fX = f(X)
fn = f(n)
np.testing.assert_allclose(fX.numpy(), fn, rtol=1e-6, atol=1e-6)
def test_allreduce_naive(self):
with Context(RING=0):
a,b = _test_allreduce(Tensor.rand(256, 256))
np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5)
def test_allreduce_ring(self):
with Context(RING=2):
a,b = _test_allreduce(Tensor.rand(256, 256))
np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5)
def test_allreduce_all2all(self):
with Context(ALL2ALL=2):
a,b = _test_allreduce(Tensor.rand(256, 256))
np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5)
def test_copy_jit(self):
@TinyJit
def copy_tensor(x:Tensor): return (x.to(f"{x.device.split(':')[0]}:1") + 1)
for _ in range(5):
t = Tensor.rand(256).realize()
x = copy_tensor(t)
np.testing.assert_equal((t+1).numpy(), x.numpy())
def test_allreduce_naive_jit(self):
with Context(RING=0):
jit_allreduce = TinyJit(_test_allreduce)
for _ in range(5):
a,b = jit_allreduce(Tensor.rand(256, 256))
np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5)
def test_allreduce_ring_jit(self):
with Context(RING=2):
jit_allreduce = TinyJit(_test_allreduce)
for _ in range(5):
a,b = jit_allreduce(Tensor.rand(256, 256))
np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5)
def test_multitensor_jit_input(self):
@TinyJit
def f(x): return (x+1).contiguous().sum()
for _ in range(5):
tt = Tensor.arange(0, 4).clone().realize().shard((d1,d2), 0).realize()
out = f(tt)
assert out.item() == 1+2+3+4
def test_multitensor_inside_jit(self):
@TinyJit
def f(x): return (x.shard((d1,d2), 0)+1).contiguous().sum()
for _ in range(5):
tt = Tensor.arange(0, 4).clone().realize()
out = f(tt)
assert out.item() == 1+2+3+4
def test_fuzz_allreduce(self):
random.seed(41)
for it in range(2):
for n in range(2, 4+1):
shape = tuple([(n if i == 0 else 1) * random.randint(1, 10) for i in range(random.randint(1, 4))])
t = Tensor.rand(shape).shard_(tuple([d0, d1, d2, d3][:n]), 0)
with Context(RING=0):
a = Tensor(UOp.allreduce(t.uop, Ops.ADD, t.device))
with Context(RING=2):
b = Tensor(UOp.allreduce(t.uop, Ops.ADD, t.device))
diff = a - b
mean_err = diff.reshape((prod(diff.shape),)).abs().mean().numpy()
max_err = diff.reshape((prod(diff.shape),)).abs().max().numpy()
assert mean_err < 1e-6, f"big mean error, iteration {it}_{n}"
assert max_err < 1e-6, f"big max error, iteration {it}_{n}"
def _test_model_train_step(self, m, fake_image, labels):
from tinygrad.nn.optim import LARS
optimizer = LARS(get_parameters(m), 0.1)
optimizer.zero_grad()
m.load_from_pretrained()
output = m(fake_image).sparse_categorical_crossentropy(labels, label_smoothing=0.1)
output.backward()
grad = m.conv1.weight.grad.numpy()
fake_image_sharded = fake_image.shard(devices_2, axis=0)
labels_sharded = labels.shard(devices_2, axis=0)
for p in get_parameters(m): p.shard_(devices_2).realize()
GlobalCounters.reset()
optimizer.zero_grad()
shard_output = m(fake_image_sharded).sparse_categorical_crossentropy(labels_sharded, label_smoothing=0.1)
shard_output.backward()
shard_grad = m.conv1.weight.grad.numpy()
# sometimes there is zeros in these grads... why?
np.testing.assert_allclose(grad, shard_grad, atol=1e-5, rtol=1e-5)
@slow
def test_data_parallel_resnet_train_step(self):
from extra.models.resnet import ResNet18
fake_image = Tensor.rand((2, 3, 224//16, 224//16))
labels = Tensor.randint(2, low=0, high=1000)
m = ResNet18()
self._test_model_train_step(m, fake_image, labels)
def test_data_parallel_simple_train_step(self):
class Model:
def __init__(self): self.conv1 = nn.Linear(128,128)
def __call__(self, x): return self.conv1(x)
def load_from_pretrained(self): pass
fake_image = Tensor.rand((128,))
labels = Tensor.randint(2, low=0, high=127)
m = Model()
self._test_model_train_step(m, fake_image, labels)
def test_assign_kv_cache_multi(self):
bsz, max_context = 2, 8
class Attn:
@TinyJit
def __call__(self, xk:Tensor, start_pos:UOp):
seqlen = xk.shape[1]
if not hasattr(self, "cache_k"):
self.cache_k = Tensor.zeros(bsz, max_context, 1, 1).shard(devices_2).contiguous().realize()
keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1).contiguous() if start_pos > 0 else xk
self.cache_k.assign(keys.pad((None,(0,max_context-start_pos-seqlen),None,None)).contiguous()).realize()
attn = Attn()
xk = Tensor.ones(bsz, 3, 1, 1).shard(devices_2).contiguous()
attn(xk, 0)
for i in range(3,6):
# copied from LLaMA
start_pos = Variable("start_pos", 1, max_context).bind(i)
xk = Tensor.ones(bsz, 1, 1, 1).shard(devices_2).contiguous()
attn(xk, start_pos)
out = attn.cache_k.flatten().numpy()
np.testing.assert_allclose(out, [1.,1.,1.,1.,1.,1.,0.,0.,1.,1.,1.,1.,1.,1.,0.,0.])
def test_multi_tensor_jit_graph_assign_updates_each_shard(self):
@TinyJit
def jf(out: Tensor) -> Tensor:
tmp = (Tensor.arange(4, dtype=dtypes.float).clone().shard(devices_2, 0) + 1).contiguous().realize()
out.assign((tmp + 1).contiguous()).realize()
return out
out = Tensor.full((4,), -1.0).shard(devices_2, 0).contiguous().realize()
expected = np.arange(4, dtype=np.float32) + 2
for _ in range(5):
out.assign(Tensor.full((4,), -1.0).shard(devices_2, 0).contiguous()).realize()
jf(out)
np.testing.assert_allclose(out.numpy(), expected, atol=1e-4, rtol=1e-5)
assert jf.captured is not None
@unittest.skip("test broken")
def test_multi_device_jit_graph(self):
if Device[d0].graph is None or Device[d1].graph is None: raise unittest.SkipTest("only test graphs")
@TinyJit
def jf(a: Tensor, b: Tensor, c: Tensor, d:Tensor):
# Create 80 entries on device 0: 2 batches.
for _ in range(40):
a = ((a + b).realize() + (a * b).realize()).realize()
# Create 80 entries on device 1: 2 batches.
for _ in range(40):
c = ((c + d).realize() + (c * d).realize()).realize()
# Create a copy from device 0 to 1: 1 entry.
a = a.to(d1).realize()
# Creates one last entry on device 1: 1 batch.
return (a + c).realize()
a = Tensor.randn(10, 10, device=d0).realize()
b = Tensor.randn(10, 10, device=d0).realize()
c = Tensor.randn(10, 10, device=d1).realize()
d = Tensor.randn(10, 10, device=d1).realize()
ref = jf(a, b, c, d).numpy()
for _ in range(5):
o = jf(a, b, c, d).numpy()
np.testing.assert_allclose(ref, o, atol=1e-4, rtol=1e-5)
# Checking that 2 graphs per device, 1 copy and 1 last graph on device 1 are created.
sis = jf.captured.linear.src
assert len(sis) == 6
for si in (sis[0], sis[1], sis[2], sis[3], sis[5]):
assert call_is_graph(si)
assert sis[4].src[0].op is Ops.COPY
def test_rand_on_multiple_devices(self):
# different devices generate different rand
d0_rand = Tensor.rand(256, device=d0).realize()
d1_rand = Tensor.rand(256, device=d1).realize()
assert not np.allclose(d0_rand.numpy(), d1_rand.numpy())
def test_rand_on_multiple_devices_manual_seed(self):
Tensor.manual_seed(123)
d0_rand = Tensor.rand(2, device=d0).tolist()
d1_rand = Tensor.rand(2, device=d1).tolist()
# manual_seed again gives the same values
Tensor.manual_seed(123)
d0_rand2 = Tensor.rand(2, device=d0).tolist()
d1_rand2 = Tensor.rand(2, device=d1).tolist()
self.assertEqual(d0_rand, d0_rand2)
self.assertEqual(d1_rand, d1_rand2)
# device seed is only determined by init order, so flipping init order flips rands
Tensor.manual_seed(123)
d1_rand_flip = Tensor.rand(2, device=d1).tolist()
d0_rand_flip = Tensor.rand(2, device=d0).tolist()
self.assertEqual(d0_rand, d1_rand_flip)
self.assertEqual(d1_rand, d0_rand_flip)
def test_const_like_shrink_on_shard_axis(self):
t = Tensor.ones(16, 16, dtype=dtypes.int).shard(devices_2, axis=0)
out = t.const_like(2)[:, :8]
linear, var_vals = out.linear_with_vars()
self.assertEqual(len(linear.src), 0)
run_linear(linear, var_vals)
self.assertEqual(out.tolist(), [[2]*8]*16)
@unittest.skipIf(not_support_multi_device(), "no multi")
class TestHandleData(unittest.TestCase):
@needs_second_gpu
def test_copied_to_device(self):
device = (d0, d1, d2, d3)
t = Tensor([1, 2, 3, 4]).shard(device).realize()
not_covered = t.to(d5)
sched = not_covered.schedule_linear().src
assert len(sched) == 1
# setup again because create_schedule has side effect
t = Tensor([1, 2, 3, 4]).shard(device).realize()
not_covered = t.to(d5)
assert not_covered.realize().tolist() == [1, 2, 3, 4]
for d in device:
t = Tensor([1, 2, 3, 4]).shard(device).realize()
covered = t.to(d)
sched = covered.schedule_linear().src
# TODO: this isn't optimized out anymore
#assert len(sched) == 0
# setup again because create_schedule has side effect
t = Tensor([1, 2, 3, 4]).shard(device).realize()
covered = t.to(d)
assert covered.realize().tolist() == [1, 2, 3, 4]
@unittest.skipIf(not_support_multi_device(), "need multi")
class TestMultiBufferView(unittest.TestCase):
@needs_second_gpu
def setUp(self): pass
def _check(self, a_ref:Tensor, a_multi:Tensor, view_fn):
b_ref = view_fn(a_ref)
b_multi = view_fn(a_multi).contiguous()
linear, var_vals = b_multi.linear_with_vars()
if all(hasattr(Device[d].allocator, "_offset") for d in b_multi.device):
compiled = [call for call in linear.src if call.src[0].op is Ops.SINK]
self.assertEqual(len(compiled), 0, f"expected zero compiled kernels, got {len(compiled)}")
run_linear(linear, var_vals)
np.testing.assert_equal(b_multi.numpy(), b_ref.numpy())
@unittest.skip("flaky on LLVM")
def test_shrink_non_shard_axis(self):
ref = Tensor.arange(8*4*10).reshape(8, 4, 10).clone().realize()
a = Tensor.arange(8*4*10).reshape(8, 4, 10).clone().shard(devices_2, axis=1).realize()
self._check(ref, a, lambda t: t[3])
def test_shrink_2d(self):
ref = Tensor.arange(6*4).reshape(6, 4).clone().realize()
a = Tensor.arange(6*4).reshape(6, 4).clone().shard(devices_2, axis=1).realize()
self._check(ref, a, lambda t: t.shrink(((1, 4), None)))
def test_reshape_then_shrink(self):
ref = Tensor.arange(8*6).reshape(8, 6).clone().realize()
a = Tensor.arange(8*6).reshape(8, 6).clone().shard(devices_2, axis=1).realize()
self._check(ref, a, lambda t: t.reshape(4, 2, 6)[1])
def test_chained_shrink(self):
ref = Tensor.arange(10*8).reshape(10, 8).clone().realize()
a = Tensor.arange(10*8).reshape(10, 8).clone().shard(devices_2, axis=1).realize()
self._check(ref, a, lambda t: t.shrink(((2, 8), None)).shrink(((1, 4), None)))
def test_4_devices(self):
ref = Tensor.arange(8*12).reshape(8, 12).clone().realize()
a = Tensor.arange(8*12).reshape(8, 12).clone().shard(devices_4, axis=1).realize()
out = a[5].contiguous()
linear, var_vals = out.linear_with_vars()
if all(hasattr(Device[d].allocator, "_offset") for d in out.device):
compiled = [call for call in linear.src if call.src[0].op is Ops.SINK]
self.assertEqual(len(compiled), 0)
run_linear(linear, var_vals)
np.testing.assert_equal(out.numpy(), ref[5].numpy())
@unittest.skipIf(not_support_multi_device(), "need multi")
class TestMultiTransformer(unittest.TestCase):
@needs_second_gpu
def test_transformer(self):
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(2))
from extra.models.llama import Transformer
args = {"dim": 32, "n_heads": 1, "n_kv_heads": 1, "n_layers": 2, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 1024,
"hidden_dim": 32, "max_context": 12}
real_model = Transformer(**args)
shard_model = Transformer(**args)
# copy state
nn.state.load_state_dict(shard_model, nn.state.get_state_dict(real_model))
# shard
for k,v in nn.state.get_state_dict(shard_model).items():
if 'scale' in k: v.shard_(device, axis=None) # from quantized
elif '.attention.' in k: v.shard_(device, axis=-1)
elif '.feed_forward.w1.' in k: v.shard_(device, axis=0)
elif '.feed_forward.w3.' in k: v.shard_(device, axis=0)
elif '.feed_forward.' in k: v.shard_(device, axis=-1)
elif 'tok_embeddings.weight' in k: v.shard_(device, axis=0)
elif 'output.weight' in k: v.shard_(device, axis=0)
else: v.shard_(device, axis=None)
last_tok = 0
for i in range(10):
real_tok = real_model(Tensor([[last_tok]], device=Device.DEFAULT), i).item()
shard_tok = shard_model(Tensor([[last_tok]], device=device), i).item()
# test kv cache
kv1 = real_model.layers[0].attention.cache_kv.numpy()
kv2 = shard_model.layers[0].attention.cache_kv.numpy()
#print(np.concatenate([kv1[:, :, :, :, 0:1], kv2[:, :, :, :, 0:1]], axis=4))
np.testing.assert_allclose(kv1, kv2, atol=1e-5, rtol=1e-5, err_msg=f"issue at token {i}")
# test token
self.assertEqual(real_tok, shard_tok, f"issue at token {i}")
last_tok = real_tok
@unittest.skip("super slow")
def test_llama1b_full(self):
from tinygrad.helpers import fetch
fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir="llama3-1b-instruct")
model = fetch("https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q6_K.gguf",
"Llama-3.2-1B-Instruct-Q6_K.gguf", subdir="llama3-1b-instruct")
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(2))
from examples.llama3 import build_transformer
real_model = build_transformer(model, model_size="1B", device=Device.DEFAULT)
shard_model = build_transformer(model, model_size="1B", device=device)
last_tok = 0
real_tok = real_model(Tensor([[last_tok]], device=Device.DEFAULT), 0)
shard_tok = shard_model(Tensor([[last_tok]], device=device), 0)
self.assertEqual(real_tok.item(), shard_tok.item())
if __name__ == '__main__':
unittest.main()