mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-08 14:05:09 +08:00
477 lines
18 KiB
Python
477 lines
18 KiB
Python
import unittest, random
|
|
from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes, Variable
|
|
from tinygrad.uop.ops import Ops, UOp
|
|
from tinygrad.helpers import getenv, prod, Context
|
|
from tinygrad.nn.state import get_parameters
|
|
from tinygrad.engine.realize import run_linear, compile_linear
|
|
import numpy as np
|
|
from hypothesis import given, strategies as strat, settings
|
|
from test.helpers import not_support_multi_device, needs_second_gpu, slow, call_is_graph
|
|
|
|
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
|
|
settings.load_profile("my_profile")
|
|
|
|
d0 = f"{Device.DEFAULT}:0"
|
|
d1 = f"{Device.DEFAULT}:1"
|
|
d2 = f"{Device.DEFAULT}:2"
|
|
d3 = f"{Device.DEFAULT}:3"
|
|
d4 = f"{Device.DEFAULT}:4"
|
|
d5 = f"{Device.DEFAULT}:5"
|
|
devices_2 = (d1, d2)
|
|
devices_3 = (d1, d2, d3)
|
|
devices_4 = (d1, d2, d3, d4)
|
|
N = 128
|
|
|
|
# shard_x is "data parallel"
|
|
# shard_w is "model parallel"
|
|
|
|
def _test_allreduce(t:Tensor):
|
|
aa = (t[0:64] + t[64:128] + t[128:192] + t[192:256]).repeat([4,1]).realize()
|
|
ts = t.shard(devices_4, 0).realize()
|
|
b = Tensor(UOp.allreduce(ts.uop, Ops.ADD, ts.device))
|
|
b.realize()
|
|
return aa, b
|
|
|
|
@unittest.skipIf(not_support_multi_device(), "no multi")
|
|
class TestMultiTensor(unittest.TestCase):
|
|
@needs_second_gpu
|
|
def setUp(self): pass
|
|
|
|
def test_to(self):
|
|
X = Tensor.ones(256).contiguous().realize()
|
|
X.to_(devices_2)
|
|
assert X.shape == (256,)
|
|
(X + X).realize()
|
|
|
|
def test_gradient(self):
|
|
X = Tensor.ones(256).contiguous().realize()
|
|
X.to_(devices_2)
|
|
grad = X.sum().gradient(X)[0]
|
|
grad.realize()
|
|
|
|
def test_shard(self):
|
|
X = Tensor.ones(256).contiguous().realize()
|
|
X.shard_(devices_2, 0)
|
|
for lb in X.uop.src:
|
|
assert lb.shape == (128,)
|
|
(X + X).realize()
|
|
|
|
@unittest.expectedFailure # TODO: fix
|
|
def test_shard_empty(self):
|
|
GlobalCounters.reset()
|
|
X = Tensor.empty(256).shard(devices_2, 0).realize()
|
|
assert GlobalCounters.kernel_count == 0
|
|
(X + X).realize()
|
|
|
|
# TODO: fix this to not copy on the src device
|
|
@unittest.expectedFailure
|
|
def test_shard_no_recompile(self):
|
|
X = Tensor.ones(256).contiguous().realize()
|
|
X.shard_(devices_2, 0)
|
|
out = (X + X)
|
|
linear = compile_linear(out.schedule_linear())
|
|
names = [call.src[0].src[0].arg.name for call in linear.src if call.src[0].op is Ops.PROGRAM]
|
|
run_linear(linear)
|
|
self.assertEqual(len(set(names)), 1, "function was relinearized")
|
|
|
|
def test_shard_same_device(self):
|
|
X = Tensor.ones(256).contiguous().realize()
|
|
X.shard_((d1, X.device), 0)
|
|
(X + X).realize()
|
|
|
|
def test_numpy(self):
|
|
X = Tensor.ones(256)
|
|
X.shard_((d1, d2), 0)
|
|
np.testing.assert_allclose(X.numpy(), 1)
|
|
|
|
def test_four_add(self):
|
|
X = Tensor.ones(256, 256).contiguous().realize()
|
|
W = Tensor.ones(256, 256).contiguous().realize()
|
|
X.shard_(devices_4, 1)
|
|
W.shard_(devices_4, None)
|
|
O = X + W
|
|
np.testing.assert_allclose(O.numpy(), 2)
|
|
|
|
def test_elementwise_dtype(self):
|
|
Tensor.manual_seed(0)
|
|
X = Tensor.randn(8, 8).realize()
|
|
W = Tensor.randn(8, 8).realize()
|
|
X.shard_(devices_4, 0)
|
|
W.shard_(devices_4, 0)
|
|
O = X.shrink(((0, 2), None)) * W.shrink(((0, 2), None)) < 2
|
|
np.testing.assert_allclose(O.numpy(), X.numpy()[0:2]*W.numpy()[0:2] < 2)
|
|
|
|
def test_shrink_on_shard_axis(self):
|
|
X = Tensor.arange(4*4).reshape(4,4).clone().realize()
|
|
X_np = X.numpy()
|
|
X.shard_(devices_2, 0)
|
|
# only shrink on the device that owns the shard, this is enabled by the mselect simplifier
|
|
for i in range(2):
|
|
xt = X[i*2:i*2+2].contiguous()
|
|
linear, var_vals = xt.linear_with_vars()
|
|
#kernels = [call for call in linear.src if call.src[0].op is Ops.SINK]
|
|
#self.assertEqual(len(kernels), 1)
|
|
#self.assertEqual(kernels[0].src[1].buffer.device, devices_2[i])
|
|
run_linear(linear, var_vals)
|
|
np.testing.assert_equal(xt.numpy(), X_np[i*2:i*2+2])
|
|
|
|
@given(strat.sampled_from((devices_2, devices_3)),
|
|
strat.sampled_from((Ops.ADD, Ops.MUL, Ops.MAX)),
|
|
strat.sampled_from((None, 0, 1)), strat.sampled_from((None, 0, 1)))
|
|
def test_simple_reduce(self, devices, rop, shard_axis, reduce_axis):
|
|
N = 4 * len(devices)
|
|
X = (Tensor.rand(N*N)-1).reshape(N, N).shard_(devices, shard_axis)
|
|
n = X.numpy()
|
|
f = {Ops.ADD: lambda x: x.sum(reduce_axis), Ops.MUL: lambda x: x.prod(reduce_axis), Ops.MAX: lambda x: x.max(reduce_axis)}[rop]
|
|
fX = f(X)
|
|
fn = f(n)
|
|
np.testing.assert_allclose(fX.numpy(), fn, rtol=1e-6, atol=1e-6)
|
|
|
|
def test_allreduce_naive(self):
|
|
with Context(RING=0):
|
|
a,b = _test_allreduce(Tensor.rand(256, 256))
|
|
np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5)
|
|
|
|
def test_allreduce_ring(self):
|
|
with Context(RING=2):
|
|
a,b = _test_allreduce(Tensor.rand(256, 256))
|
|
np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5)
|
|
|
|
def test_allreduce_all2all(self):
|
|
with Context(ALL2ALL=2):
|
|
a,b = _test_allreduce(Tensor.rand(256, 256))
|
|
np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5)
|
|
|
|
def test_copy_jit(self):
|
|
@TinyJit
|
|
def copy_tensor(x:Tensor): return (x.to(f"{x.device.split(':')[0]}:1") + 1)
|
|
for _ in range(5):
|
|
t = Tensor.rand(256).realize()
|
|
x = copy_tensor(t)
|
|
np.testing.assert_equal((t+1).numpy(), x.numpy())
|
|
|
|
def test_allreduce_naive_jit(self):
|
|
with Context(RING=0):
|
|
jit_allreduce = TinyJit(_test_allreduce)
|
|
for _ in range(5):
|
|
a,b = jit_allreduce(Tensor.rand(256, 256))
|
|
np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5)
|
|
|
|
def test_allreduce_ring_jit(self):
|
|
with Context(RING=2):
|
|
jit_allreduce = TinyJit(_test_allreduce)
|
|
for _ in range(5):
|
|
a,b = jit_allreduce(Tensor.rand(256, 256))
|
|
np.testing.assert_almost_equal(a.numpy(), b.numpy(), decimal=5)
|
|
|
|
def test_multitensor_jit_input(self):
|
|
@TinyJit
|
|
def f(x): return (x+1).contiguous().sum()
|
|
for _ in range(5):
|
|
tt = Tensor.arange(0, 4).clone().realize().shard((d1,d2), 0).realize()
|
|
out = f(tt)
|
|
assert out.item() == 1+2+3+4
|
|
|
|
def test_multitensor_inside_jit(self):
|
|
@TinyJit
|
|
def f(x): return (x.shard((d1,d2), 0)+1).contiguous().sum()
|
|
for _ in range(5):
|
|
tt = Tensor.arange(0, 4).clone().realize()
|
|
out = f(tt)
|
|
assert out.item() == 1+2+3+4
|
|
|
|
def test_fuzz_allreduce(self):
|
|
random.seed(41)
|
|
for it in range(2):
|
|
for n in range(2, 4+1):
|
|
shape = tuple([(n if i == 0 else 1) * random.randint(1, 10) for i in range(random.randint(1, 4))])
|
|
t = Tensor.rand(shape).shard_(tuple([d0, d1, d2, d3][:n]), 0)
|
|
with Context(RING=0):
|
|
a = Tensor(UOp.allreduce(t.uop, Ops.ADD, t.device))
|
|
with Context(RING=2):
|
|
b = Tensor(UOp.allreduce(t.uop, Ops.ADD, t.device))
|
|
diff = a - b
|
|
mean_err = diff.reshape((prod(diff.shape),)).abs().mean().numpy()
|
|
max_err = diff.reshape((prod(diff.shape),)).abs().max().numpy()
|
|
assert mean_err < 1e-6, f"big mean error, iteration {it}_{n}"
|
|
assert max_err < 1e-6, f"big max error, iteration {it}_{n}"
|
|
|
|
def _test_model_train_step(self, m, fake_image, labels):
|
|
from tinygrad.nn.optim import LARS
|
|
optimizer = LARS(get_parameters(m), 0.1)
|
|
|
|
optimizer.zero_grad()
|
|
m.load_from_pretrained()
|
|
output = m(fake_image).sparse_categorical_crossentropy(labels, label_smoothing=0.1)
|
|
output.backward()
|
|
grad = m.conv1.weight.grad.numpy()
|
|
|
|
fake_image_sharded = fake_image.shard(devices_2, axis=0)
|
|
labels_sharded = labels.shard(devices_2, axis=0)
|
|
for p in get_parameters(m): p.shard_(devices_2).realize()
|
|
GlobalCounters.reset()
|
|
optimizer.zero_grad()
|
|
shard_output = m(fake_image_sharded).sparse_categorical_crossentropy(labels_sharded, label_smoothing=0.1)
|
|
shard_output.backward()
|
|
shard_grad = m.conv1.weight.grad.numpy()
|
|
# sometimes there is zeros in these grads... why?
|
|
np.testing.assert_allclose(grad, shard_grad, atol=1e-5, rtol=1e-5)
|
|
|
|
@slow
|
|
def test_data_parallel_resnet_train_step(self):
|
|
from extra.models.resnet import ResNet18
|
|
fake_image = Tensor.rand((2, 3, 224//16, 224//16))
|
|
labels = Tensor.randint(2, low=0, high=1000)
|
|
m = ResNet18()
|
|
self._test_model_train_step(m, fake_image, labels)
|
|
|
|
def test_data_parallel_simple_train_step(self):
|
|
class Model:
|
|
def __init__(self): self.conv1 = nn.Linear(128,128)
|
|
def __call__(self, x): return self.conv1(x)
|
|
def load_from_pretrained(self): pass
|
|
|
|
fake_image = Tensor.rand((128,))
|
|
labels = Tensor.randint(2, low=0, high=127)
|
|
m = Model()
|
|
self._test_model_train_step(m, fake_image, labels)
|
|
|
|
def test_assign_kv_cache_multi(self):
|
|
bsz, max_context = 2, 8
|
|
|
|
class Attn:
|
|
@TinyJit
|
|
def __call__(self, xk:Tensor, start_pos:UOp):
|
|
seqlen = xk.shape[1]
|
|
if not hasattr(self, "cache_k"):
|
|
self.cache_k = Tensor.zeros(bsz, max_context, 1, 1).shard(devices_2).contiguous().realize()
|
|
keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1).contiguous() if start_pos > 0 else xk
|
|
self.cache_k.assign(keys.pad((None,(0,max_context-start_pos-seqlen),None,None)).contiguous()).realize()
|
|
|
|
attn = Attn()
|
|
xk = Tensor.ones(bsz, 3, 1, 1).shard(devices_2).contiguous()
|
|
attn(xk, 0)
|
|
for i in range(3,6):
|
|
# copied from LLaMA
|
|
start_pos = Variable("start_pos", 1, max_context).bind(i)
|
|
xk = Tensor.ones(bsz, 1, 1, 1).shard(devices_2).contiguous()
|
|
attn(xk, start_pos)
|
|
|
|
out = attn.cache_k.flatten().numpy()
|
|
np.testing.assert_allclose(out, [1.,1.,1.,1.,1.,1.,0.,0.,1.,1.,1.,1.,1.,1.,0.,0.])
|
|
|
|
def test_multi_tensor_jit_graph_assign_updates_each_shard(self):
|
|
@TinyJit
|
|
def jf(out: Tensor) -> Tensor:
|
|
tmp = (Tensor.arange(4, dtype=dtypes.float).clone().shard(devices_2, 0) + 1).contiguous().realize()
|
|
out.assign((tmp + 1).contiguous()).realize()
|
|
return out
|
|
|
|
out = Tensor.full((4,), -1.0).shard(devices_2, 0).contiguous().realize()
|
|
expected = np.arange(4, dtype=np.float32) + 2
|
|
for _ in range(5):
|
|
out.assign(Tensor.full((4,), -1.0).shard(devices_2, 0).contiguous()).realize()
|
|
jf(out)
|
|
np.testing.assert_allclose(out.numpy(), expected, atol=1e-4, rtol=1e-5)
|
|
assert jf.captured is not None
|
|
|
|
@unittest.skip("test broken")
|
|
def test_multi_device_jit_graph(self):
|
|
if Device[d0].graph is None or Device[d1].graph is None: raise unittest.SkipTest("only test graphs")
|
|
|
|
@TinyJit
|
|
def jf(a: Tensor, b: Tensor, c: Tensor, d:Tensor):
|
|
# Create 80 entries on device 0: 2 batches.
|
|
for _ in range(40):
|
|
a = ((a + b).realize() + (a * b).realize()).realize()
|
|
# Create 80 entries on device 1: 2 batches.
|
|
for _ in range(40):
|
|
c = ((c + d).realize() + (c * d).realize()).realize()
|
|
# Create a copy from device 0 to 1: 1 entry.
|
|
a = a.to(d1).realize()
|
|
# Creates one last entry on device 1: 1 batch.
|
|
return (a + c).realize()
|
|
|
|
a = Tensor.randn(10, 10, device=d0).realize()
|
|
b = Tensor.randn(10, 10, device=d0).realize()
|
|
c = Tensor.randn(10, 10, device=d1).realize()
|
|
d = Tensor.randn(10, 10, device=d1).realize()
|
|
|
|
ref = jf(a, b, c, d).numpy()
|
|
for _ in range(5):
|
|
o = jf(a, b, c, d).numpy()
|
|
np.testing.assert_allclose(ref, o, atol=1e-4, rtol=1e-5)
|
|
|
|
# Checking that 2 graphs per device, 1 copy and 1 last graph on device 1 are created.
|
|
sis = jf.captured.linear.src
|
|
assert len(sis) == 6
|
|
for si in (sis[0], sis[1], sis[2], sis[3], sis[5]):
|
|
assert call_is_graph(si)
|
|
assert sis[4].src[0].op is Ops.COPY
|
|
|
|
def test_rand_on_multiple_devices(self):
|
|
# different devices generate different rand
|
|
d0_rand = Tensor.rand(256, device=d0).realize()
|
|
d1_rand = Tensor.rand(256, device=d1).realize()
|
|
assert not np.allclose(d0_rand.numpy(), d1_rand.numpy())
|
|
|
|
def test_rand_on_multiple_devices_manual_seed(self):
|
|
Tensor.manual_seed(123)
|
|
d0_rand = Tensor.rand(2, device=d0).tolist()
|
|
d1_rand = Tensor.rand(2, device=d1).tolist()
|
|
|
|
# manual_seed again gives the same values
|
|
Tensor.manual_seed(123)
|
|
d0_rand2 = Tensor.rand(2, device=d0).tolist()
|
|
d1_rand2 = Tensor.rand(2, device=d1).tolist()
|
|
self.assertEqual(d0_rand, d0_rand2)
|
|
self.assertEqual(d1_rand, d1_rand2)
|
|
|
|
# device seed is only determined by init order, so flipping init order flips rands
|
|
Tensor.manual_seed(123)
|
|
d1_rand_flip = Tensor.rand(2, device=d1).tolist()
|
|
d0_rand_flip = Tensor.rand(2, device=d0).tolist()
|
|
self.assertEqual(d0_rand, d1_rand_flip)
|
|
self.assertEqual(d1_rand, d0_rand_flip)
|
|
|
|
def test_const_like_shrink_on_shard_axis(self):
|
|
t = Tensor.ones(16, 16, dtype=dtypes.int).shard(devices_2, axis=0)
|
|
out = t.const_like(2)[:, :8]
|
|
linear, var_vals = out.linear_with_vars()
|
|
self.assertEqual(len(linear.src), 0)
|
|
run_linear(linear, var_vals)
|
|
self.assertEqual(out.tolist(), [[2]*8]*16)
|
|
|
|
@unittest.skipIf(not_support_multi_device(), "no multi")
|
|
class TestHandleData(unittest.TestCase):
|
|
@needs_second_gpu
|
|
def test_copied_to_device(self):
|
|
device = (d0, d1, d2, d3)
|
|
t = Tensor([1, 2, 3, 4]).shard(device).realize()
|
|
not_covered = t.to(d5)
|
|
sched = not_covered.schedule_linear().src
|
|
assert len(sched) == 1
|
|
# setup again because create_schedule has side effect
|
|
t = Tensor([1, 2, 3, 4]).shard(device).realize()
|
|
not_covered = t.to(d5)
|
|
assert not_covered.realize().tolist() == [1, 2, 3, 4]
|
|
|
|
for d in device:
|
|
t = Tensor([1, 2, 3, 4]).shard(device).realize()
|
|
covered = t.to(d)
|
|
sched = covered.schedule_linear().src
|
|
# TODO: this isn't optimized out anymore
|
|
#assert len(sched) == 0
|
|
# setup again because create_schedule has side effect
|
|
t = Tensor([1, 2, 3, 4]).shard(device).realize()
|
|
covered = t.to(d)
|
|
assert covered.realize().tolist() == [1, 2, 3, 4]
|
|
|
|
@unittest.skipIf(not_support_multi_device(), "need multi")
|
|
class TestMultiBufferView(unittest.TestCase):
|
|
@needs_second_gpu
|
|
def setUp(self): pass
|
|
|
|
def _check(self, a_ref:Tensor, a_multi:Tensor, view_fn):
|
|
b_ref = view_fn(a_ref)
|
|
b_multi = view_fn(a_multi).contiguous()
|
|
linear, var_vals = b_multi.linear_with_vars()
|
|
if all(hasattr(Device[d].allocator, "_offset") for d in b_multi.device):
|
|
compiled = [call for call in linear.src if call.src[0].op is Ops.SINK]
|
|
self.assertEqual(len(compiled), 0, f"expected zero compiled kernels, got {len(compiled)}")
|
|
run_linear(linear, var_vals)
|
|
np.testing.assert_equal(b_multi.numpy(), b_ref.numpy())
|
|
|
|
@unittest.skip("flaky on LLVM")
|
|
def test_shrink_non_shard_axis(self):
|
|
ref = Tensor.arange(8*4*10).reshape(8, 4, 10).clone().realize()
|
|
a = Tensor.arange(8*4*10).reshape(8, 4, 10).clone().shard(devices_2, axis=1).realize()
|
|
self._check(ref, a, lambda t: t[3])
|
|
|
|
def test_shrink_2d(self):
|
|
ref = Tensor.arange(6*4).reshape(6, 4).clone().realize()
|
|
a = Tensor.arange(6*4).reshape(6, 4).clone().shard(devices_2, axis=1).realize()
|
|
self._check(ref, a, lambda t: t.shrink(((1, 4), None)))
|
|
|
|
def test_reshape_then_shrink(self):
|
|
ref = Tensor.arange(8*6).reshape(8, 6).clone().realize()
|
|
a = Tensor.arange(8*6).reshape(8, 6).clone().shard(devices_2, axis=1).realize()
|
|
self._check(ref, a, lambda t: t.reshape(4, 2, 6)[1])
|
|
|
|
def test_chained_shrink(self):
|
|
ref = Tensor.arange(10*8).reshape(10, 8).clone().realize()
|
|
a = Tensor.arange(10*8).reshape(10, 8).clone().shard(devices_2, axis=1).realize()
|
|
self._check(ref, a, lambda t: t.shrink(((2, 8), None)).shrink(((1, 4), None)))
|
|
|
|
def test_4_devices(self):
|
|
ref = Tensor.arange(8*12).reshape(8, 12).clone().realize()
|
|
a = Tensor.arange(8*12).reshape(8, 12).clone().shard(devices_4, axis=1).realize()
|
|
out = a[5].contiguous()
|
|
linear, var_vals = out.linear_with_vars()
|
|
if all(hasattr(Device[d].allocator, "_offset") for d in out.device):
|
|
compiled = [call for call in linear.src if call.src[0].op is Ops.SINK]
|
|
self.assertEqual(len(compiled), 0)
|
|
run_linear(linear, var_vals)
|
|
np.testing.assert_equal(out.numpy(), ref[5].numpy())
|
|
|
|
@unittest.skipIf(not_support_multi_device(), "need multi")
|
|
class TestMultiTransformer(unittest.TestCase):
|
|
@needs_second_gpu
|
|
def test_transformer(self):
|
|
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(2))
|
|
|
|
from extra.models.llama import Transformer
|
|
args = {"dim": 32, "n_heads": 1, "n_kv_heads": 1, "n_layers": 2, "norm_eps": 1e-5, "rope_theta": 500000, "vocab_size": 1024,
|
|
"hidden_dim": 32, "max_context": 12}
|
|
real_model = Transformer(**args)
|
|
shard_model = Transformer(**args)
|
|
|
|
# copy state
|
|
nn.state.load_state_dict(shard_model, nn.state.get_state_dict(real_model))
|
|
|
|
# shard
|
|
for k,v in nn.state.get_state_dict(shard_model).items():
|
|
if 'scale' in k: v.shard_(device, axis=None) # from quantized
|
|
elif '.attention.' in k: v.shard_(device, axis=-1)
|
|
elif '.feed_forward.w1.' in k: v.shard_(device, axis=0)
|
|
elif '.feed_forward.w3.' in k: v.shard_(device, axis=0)
|
|
elif '.feed_forward.' in k: v.shard_(device, axis=-1)
|
|
elif 'tok_embeddings.weight' in k: v.shard_(device, axis=0)
|
|
elif 'output.weight' in k: v.shard_(device, axis=0)
|
|
else: v.shard_(device, axis=None)
|
|
|
|
last_tok = 0
|
|
for i in range(10):
|
|
real_tok = real_model(Tensor([[last_tok]], device=Device.DEFAULT), i).item()
|
|
shard_tok = shard_model(Tensor([[last_tok]], device=device), i).item()
|
|
|
|
# test kv cache
|
|
kv1 = real_model.layers[0].attention.cache_kv.numpy()
|
|
kv2 = shard_model.layers[0].attention.cache_kv.numpy()
|
|
#print(np.concatenate([kv1[:, :, :, :, 0:1], kv2[:, :, :, :, 0:1]], axis=4))
|
|
np.testing.assert_allclose(kv1, kv2, atol=1e-5, rtol=1e-5, err_msg=f"issue at token {i}")
|
|
|
|
# test token
|
|
self.assertEqual(real_tok, shard_tok, f"issue at token {i}")
|
|
last_tok = real_tok
|
|
|
|
@unittest.skip("super slow")
|
|
def test_llama1b_full(self):
|
|
from tinygrad.helpers import fetch
|
|
fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir="llama3-1b-instruct")
|
|
model = fetch("https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q6_K.gguf",
|
|
"Llama-3.2-1B-Instruct-Q6_K.gguf", subdir="llama3-1b-instruct")
|
|
|
|
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(2))
|
|
from examples.llama3 import build_transformer
|
|
real_model = build_transformer(model, model_size="1B", device=Device.DEFAULT)
|
|
shard_model = build_transformer(model, model_size="1B", device=device)
|
|
|
|
last_tok = 0
|
|
real_tok = real_model(Tensor([[last_tok]], device=Device.DEFAULT), 0)
|
|
shard_tok = shard_model(Tensor([[last_tok]], device=device), 0)
|
|
self.assertEqual(real_tok.item(), shard_tok.item())
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|