mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
add device options for tests in multigpu (#3121)
This commit is contained in:
@@ -9,6 +9,8 @@ import numpy as np
|
||||
d_zero = f"{Device.DEFAULT}:0"
|
||||
d0, d1 = f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2"
|
||||
d2, d3 = f"{Device.DEFAULT}:3", f"{Device.DEFAULT}:4"
|
||||
devices_2 = (d0, d1)
|
||||
devices_3 = (d0, d1, d2)
|
||||
N = 128
|
||||
|
||||
# shard_x is "data parallel"
|
||||
@@ -73,39 +75,39 @@ class TestMultiTensor(unittest.TestCase):
|
||||
def test_simple_reduce_0(self): return self._test_simple_reduce_axis(0)
|
||||
def test_simple_reduce_1(self): return self._test_simple_reduce_axis(1)
|
||||
|
||||
def _test_matmul_shard_axis(self, shard_x, shard_w):
|
||||
def _test_matmul_shard_axis(self, shard_x, shard_w, device):
|
||||
X = Tensor.kaiming_uniform(N, N).realize()
|
||||
W = Tensor.kaiming_uniform(N, N).realize()
|
||||
Xs = X.shard((d0, d1), shard_x)
|
||||
Ws = W.shard((d0, d1), shard_w)
|
||||
Xs = X.shard(device, shard_x)
|
||||
Ws = W.shard(device, shard_w)
|
||||
O = (Xs@Ws)
|
||||
np.testing.assert_allclose(X.numpy() @ W.numpy(), O.to(Device.DEFAULT).numpy(), atol=1e-5)
|
||||
|
||||
def _test_double_matmul_shard_axis(self, shard_x, shard_w):
|
||||
def _test_double_matmul_shard_axis(self, shard_x, shard_w, device):
|
||||
X = Tensor.kaiming_uniform(N, N).realize()
|
||||
W1 = Tensor.kaiming_uniform(N, N).realize()
|
||||
W2 = Tensor.kaiming_uniform(N, N).realize()
|
||||
Xs = X.shard((d0, d1), shard_x)
|
||||
W1s = W1.shard((d0, d1), shard_w)
|
||||
W2s = W2.shard((d0, d1), shard_w)
|
||||
Xs = X.shard(device, shard_x)
|
||||
W1s = W1.shard(device, shard_w)
|
||||
W2s = W2.shard(device, shard_w)
|
||||
O = (Xs@W1s)@W2s
|
||||
np.testing.assert_allclose((X.numpy() @ W1.numpy()) @ W2.numpy(), O.to(Device.DEFAULT).numpy(), atol=1e-5)
|
||||
|
||||
def test_matmul_shard_none(self): return self._test_matmul_shard_axis(None, None)
|
||||
def test_matmul_shard_X_0(self): return self._test_matmul_shard_axis(0, None)
|
||||
def test_matmul_shard_X_1(self): return self._test_matmul_shard_axis(1, None)
|
||||
def test_matmul_shard_W_0(self): return self._test_matmul_shard_axis(None, 0)
|
||||
def test_matmul_shard_W_1(self): return self._test_matmul_shard_axis(None, 1)
|
||||
def test_matmul_shard_none(self): return self._test_matmul_shard_axis(None, None, devices_2)
|
||||
def test_matmul_shard_X_0(self): return self._test_matmul_shard_axis(0, None, devices_2)
|
||||
def test_matmul_shard_X_1(self): return self._test_matmul_shard_axis(1, None, devices_2)
|
||||
def test_matmul_shard_W_0(self): return self._test_matmul_shard_axis(None, 0, devices_2)
|
||||
def test_matmul_shard_W_1(self): return self._test_matmul_shard_axis(None, 1, devices_2)
|
||||
|
||||
def test_matmul_shard_0_0(self): return self._test_matmul_shard_axis(0, 0)
|
||||
def test_matmul_shard_0_1(self): return self._test_matmul_shard_axis(0, 1)
|
||||
def test_matmul_shard_1_0(self): return self._test_matmul_shard_axis(1, 0)
|
||||
def test_matmul_shard_1_1(self): return self._test_matmul_shard_axis(1, 1)
|
||||
def test_matmul_shard_0_0(self): return self._test_matmul_shard_axis(0, 0, devices_2)
|
||||
def test_matmul_shard_0_1(self): return self._test_matmul_shard_axis(0, 1, devices_2)
|
||||
def test_matmul_shard_1_0(self): return self._test_matmul_shard_axis(1, 0, devices_2)
|
||||
def test_matmul_shard_1_1(self): return self._test_matmul_shard_axis(1, 1, devices_2)
|
||||
|
||||
def test_double_matmul_shard_X_0(self): return self._test_double_matmul_shard_axis(0, None)
|
||||
def test_double_matmul_shard_X_1(self): return self._test_double_matmul_shard_axis(1, None)
|
||||
def test_double_matmul_shard_W_0(self): return self._test_double_matmul_shard_axis(None, 0)
|
||||
def test_double_matmul_shard_W_1(self): return self._test_double_matmul_shard_axis(None, 1)
|
||||
def test_double_matmul_shard_X_0(self): return self._test_double_matmul_shard_axis(0, None, devices_2)
|
||||
def test_double_matmul_shard_X_1(self): return self._test_double_matmul_shard_axis(1, None, devices_2)
|
||||
def test_double_matmul_shard_W_0(self): return self._test_double_matmul_shard_axis(None, 0, devices_2)
|
||||
def test_double_matmul_shard_W_1(self): return self._test_double_matmul_shard_axis(None, 1, devices_2)
|
||||
|
||||
def test_conv_data_shard(self):
|
||||
conv = nn.Conv2d(3, 16, 3, bias=False)
|
||||
@@ -177,14 +179,8 @@ class TestMultiTensor(unittest.TestCase):
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault")
|
||||
@unittest.skipIf(Device.DEFAULT == "GPU", "GPU requires cl_khr_fp16")
|
||||
def test_llama_attention(self):
|
||||
bs = 1
|
||||
seq_len = 1
|
||||
dim = 128
|
||||
n_heads = 4
|
||||
n_kv_heads = 4
|
||||
max_context = 32
|
||||
|
||||
def _test_llama_attention(self, device):
|
||||
bs, seq_len, dim, n_heads, n_kv_heads, max_context = 1, 1, 128, 4, 4, 32
|
||||
freqs_cis = Tensor.rand(1, seq_len, 1, (dim//n_heads)//2, 2).half()
|
||||
mask = None
|
||||
start_pos = 0
|
||||
@@ -198,11 +194,13 @@ class TestMultiTensor(unittest.TestCase):
|
||||
layer_sharded.wk.weight.assign(layer.wk.weight.shard((d0, d1), axis=0)).realize()
|
||||
layer_sharded.wv.weight.assign(layer.wv.weight.shard((d0, d1), axis=0)).realize()
|
||||
layer_sharded.wo.weight.assign(layer.wo.weight.shard((d0, d1), axis=0)).realize()
|
||||
x_sharded = x.shard((d0, d1), axis=None).realize()
|
||||
freqs_cis_sharded = freqs_cis.shard((d0, d1), axis=None).realize()
|
||||
x_sharded = x.shard(devices_2, axis=None).realize()
|
||||
freqs_cis_sharded = freqs_cis.shard(devices_2, axis=None).realize()
|
||||
y_sharded = layer_sharded(x_sharded, start_pos, freqs_cis_sharded, mask)
|
||||
|
||||
np.testing.assert_allclose(y.numpy(), y_sharded.numpy(), atol=1e-6, rtol=1e-6)
|
||||
np.testing.assert_allclose(y.numpy(), y_sharded.numpy(), atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_llama_attention(self): return self._test_llama_attention(devices_2)
|
||||
|
||||
def test_data_parallel_resnet(self):
|
||||
import sys, pathlib
|
||||
|
||||
Reference in New Issue
Block a user