add device options for tests in multigpu (#3121)

This commit is contained in:
Yixiang Gao
2024-01-14 15:17:47 -08:00
committed by GitHub
parent 79f4627fbc
commit c13d51da1d

View File

@@ -9,6 +9,8 @@ import numpy as np
d_zero = f"{Device.DEFAULT}:0"
d0, d1 = f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2"
d2, d3 = f"{Device.DEFAULT}:3", f"{Device.DEFAULT}:4"
devices_2 = (d0, d1)
devices_3 = (d0, d1, d2)
N = 128
# shard_x is "data parallel"
@@ -73,39 +75,39 @@ class TestMultiTensor(unittest.TestCase):
def test_simple_reduce_0(self): return self._test_simple_reduce_axis(0)
def test_simple_reduce_1(self): return self._test_simple_reduce_axis(1)
def _test_matmul_shard_axis(self, shard_x, shard_w):
def _test_matmul_shard_axis(self, shard_x, shard_w, device):
X = Tensor.kaiming_uniform(N, N).realize()
W = Tensor.kaiming_uniform(N, N).realize()
Xs = X.shard((d0, d1), shard_x)
Ws = W.shard((d0, d1), shard_w)
Xs = X.shard(device, shard_x)
Ws = W.shard(device, shard_w)
O = (Xs@Ws)
np.testing.assert_allclose(X.numpy() @ W.numpy(), O.to(Device.DEFAULT).numpy(), atol=1e-5)
def _test_double_matmul_shard_axis(self, shard_x, shard_w):
def _test_double_matmul_shard_axis(self, shard_x, shard_w, device):
X = Tensor.kaiming_uniform(N, N).realize()
W1 = Tensor.kaiming_uniform(N, N).realize()
W2 = Tensor.kaiming_uniform(N, N).realize()
Xs = X.shard((d0, d1), shard_x)
W1s = W1.shard((d0, d1), shard_w)
W2s = W2.shard((d0, d1), shard_w)
Xs = X.shard(device, shard_x)
W1s = W1.shard(device, shard_w)
W2s = W2.shard(device, shard_w)
O = (Xs@W1s)@W2s
np.testing.assert_allclose((X.numpy() @ W1.numpy()) @ W2.numpy(), O.to(Device.DEFAULT).numpy(), atol=1e-5)
def test_matmul_shard_none(self): return self._test_matmul_shard_axis(None, None)
def test_matmul_shard_X_0(self): return self._test_matmul_shard_axis(0, None)
def test_matmul_shard_X_1(self): return self._test_matmul_shard_axis(1, None)
def test_matmul_shard_W_0(self): return self._test_matmul_shard_axis(None, 0)
def test_matmul_shard_W_1(self): return self._test_matmul_shard_axis(None, 1)
def test_matmul_shard_none(self): return self._test_matmul_shard_axis(None, None, devices_2)
def test_matmul_shard_X_0(self): return self._test_matmul_shard_axis(0, None, devices_2)
def test_matmul_shard_X_1(self): return self._test_matmul_shard_axis(1, None, devices_2)
def test_matmul_shard_W_0(self): return self._test_matmul_shard_axis(None, 0, devices_2)
def test_matmul_shard_W_1(self): return self._test_matmul_shard_axis(None, 1, devices_2)
def test_matmul_shard_0_0(self): return self._test_matmul_shard_axis(0, 0)
def test_matmul_shard_0_1(self): return self._test_matmul_shard_axis(0, 1)
def test_matmul_shard_1_0(self): return self._test_matmul_shard_axis(1, 0)
def test_matmul_shard_1_1(self): return self._test_matmul_shard_axis(1, 1)
def test_matmul_shard_0_0(self): return self._test_matmul_shard_axis(0, 0, devices_2)
def test_matmul_shard_0_1(self): return self._test_matmul_shard_axis(0, 1, devices_2)
def test_matmul_shard_1_0(self): return self._test_matmul_shard_axis(1, 0, devices_2)
def test_matmul_shard_1_1(self): return self._test_matmul_shard_axis(1, 1, devices_2)
def test_double_matmul_shard_X_0(self): return self._test_double_matmul_shard_axis(0, None)
def test_double_matmul_shard_X_1(self): return self._test_double_matmul_shard_axis(1, None)
def test_double_matmul_shard_W_0(self): return self._test_double_matmul_shard_axis(None, 0)
def test_double_matmul_shard_W_1(self): return self._test_double_matmul_shard_axis(None, 1)
def test_double_matmul_shard_X_0(self): return self._test_double_matmul_shard_axis(0, None, devices_2)
def test_double_matmul_shard_X_1(self): return self._test_double_matmul_shard_axis(1, None, devices_2)
def test_double_matmul_shard_W_0(self): return self._test_double_matmul_shard_axis(None, 0, devices_2)
def test_double_matmul_shard_W_1(self): return self._test_double_matmul_shard_axis(None, 1, devices_2)
def test_conv_data_shard(self):
conv = nn.Conv2d(3, 16, 3, bias=False)
@@ -177,14 +179,8 @@ class TestMultiTensor(unittest.TestCase):
@unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault")
@unittest.skipIf(Device.DEFAULT == "GPU", "GPU requires cl_khr_fp16")
def test_llama_attention(self):
bs = 1
seq_len = 1
dim = 128
n_heads = 4
n_kv_heads = 4
max_context = 32
def _test_llama_attention(self, device):
bs, seq_len, dim, n_heads, n_kv_heads, max_context = 1, 1, 128, 4, 4, 32
freqs_cis = Tensor.rand(1, seq_len, 1, (dim//n_heads)//2, 2).half()
mask = None
start_pos = 0
@@ -198,11 +194,13 @@ class TestMultiTensor(unittest.TestCase):
layer_sharded.wk.weight.assign(layer.wk.weight.shard((d0, d1), axis=0)).realize()
layer_sharded.wv.weight.assign(layer.wv.weight.shard((d0, d1), axis=0)).realize()
layer_sharded.wo.weight.assign(layer.wo.weight.shard((d0, d1), axis=0)).realize()
x_sharded = x.shard((d0, d1), axis=None).realize()
freqs_cis_sharded = freqs_cis.shard((d0, d1), axis=None).realize()
x_sharded = x.shard(devices_2, axis=None).realize()
freqs_cis_sharded = freqs_cis.shard(devices_2, axis=None).realize()
y_sharded = layer_sharded(x_sharded, start_pos, freqs_cis_sharded, mask)
np.testing.assert_allclose(y.numpy(), y_sharded.numpy(), atol=1e-6, rtol=1e-6)
np.testing.assert_allclose(y.numpy(), y_sharded.numpy(), atol=1e-4, rtol=1e-4)
def test_llama_attention(self): return self._test_llama_attention(devices_2)
def test_data_parallel_resnet(self):
import sys, pathlib