From c13d51da1d4e2a439ff5efd015df1db0e02c551f Mon Sep 17 00:00:00 2001 From: Yixiang Gao Date: Sun, 14 Jan 2024 15:17:47 -0800 Subject: [PATCH] add device options for tests in multigpu (#3121) --- test/test_multitensor.py | 60 +++++++++++++++++++--------------------- 1 file changed, 29 insertions(+), 31 deletions(-) diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 5130dbdbe3..e5d7ac57b3 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -9,6 +9,8 @@ import numpy as np d_zero = f"{Device.DEFAULT}:0" d0, d1 = f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2" d2, d3 = f"{Device.DEFAULT}:3", f"{Device.DEFAULT}:4" +devices_2 = (d0, d1) +devices_3 = (d0, d1, d2) N = 128 # shard_x is "data parallel" @@ -73,39 +75,39 @@ class TestMultiTensor(unittest.TestCase): def test_simple_reduce_0(self): return self._test_simple_reduce_axis(0) def test_simple_reduce_1(self): return self._test_simple_reduce_axis(1) - def _test_matmul_shard_axis(self, shard_x, shard_w): + def _test_matmul_shard_axis(self, shard_x, shard_w, device): X = Tensor.kaiming_uniform(N, N).realize() W = Tensor.kaiming_uniform(N, N).realize() - Xs = X.shard((d0, d1), shard_x) - Ws = W.shard((d0, d1), shard_w) + Xs = X.shard(device, shard_x) + Ws = W.shard(device, shard_w) O = (Xs@Ws) np.testing.assert_allclose(X.numpy() @ W.numpy(), O.to(Device.DEFAULT).numpy(), atol=1e-5) - def _test_double_matmul_shard_axis(self, shard_x, shard_w): + def _test_double_matmul_shard_axis(self, shard_x, shard_w, device): X = Tensor.kaiming_uniform(N, N).realize() W1 = Tensor.kaiming_uniform(N, N).realize() W2 = Tensor.kaiming_uniform(N, N).realize() - Xs = X.shard((d0, d1), shard_x) - W1s = W1.shard((d0, d1), shard_w) - W2s = W2.shard((d0, d1), shard_w) + Xs = X.shard(device, shard_x) + W1s = W1.shard(device, shard_w) + W2s = W2.shard(device, shard_w) O = (Xs@W1s)@W2s np.testing.assert_allclose((X.numpy() @ W1.numpy()) @ W2.numpy(), O.to(Device.DEFAULT).numpy(), atol=1e-5) - def test_matmul_shard_none(self): return self._test_matmul_shard_axis(None, None) - def test_matmul_shard_X_0(self): return self._test_matmul_shard_axis(0, None) - def test_matmul_shard_X_1(self): return self._test_matmul_shard_axis(1, None) - def test_matmul_shard_W_0(self): return self._test_matmul_shard_axis(None, 0) - def test_matmul_shard_W_1(self): return self._test_matmul_shard_axis(None, 1) + def test_matmul_shard_none(self): return self._test_matmul_shard_axis(None, None, devices_2) + def test_matmul_shard_X_0(self): return self._test_matmul_shard_axis(0, None, devices_2) + def test_matmul_shard_X_1(self): return self._test_matmul_shard_axis(1, None, devices_2) + def test_matmul_shard_W_0(self): return self._test_matmul_shard_axis(None, 0, devices_2) + def test_matmul_shard_W_1(self): return self._test_matmul_shard_axis(None, 1, devices_2) - def test_matmul_shard_0_0(self): return self._test_matmul_shard_axis(0, 0) - def test_matmul_shard_0_1(self): return self._test_matmul_shard_axis(0, 1) - def test_matmul_shard_1_0(self): return self._test_matmul_shard_axis(1, 0) - def test_matmul_shard_1_1(self): return self._test_matmul_shard_axis(1, 1) + def test_matmul_shard_0_0(self): return self._test_matmul_shard_axis(0, 0, devices_2) + def test_matmul_shard_0_1(self): return self._test_matmul_shard_axis(0, 1, devices_2) + def test_matmul_shard_1_0(self): return self._test_matmul_shard_axis(1, 0, devices_2) + def test_matmul_shard_1_1(self): return self._test_matmul_shard_axis(1, 1, devices_2) - def test_double_matmul_shard_X_0(self): return self._test_double_matmul_shard_axis(0, None) - def test_double_matmul_shard_X_1(self): return self._test_double_matmul_shard_axis(1, None) - def test_double_matmul_shard_W_0(self): return self._test_double_matmul_shard_axis(None, 0) - def test_double_matmul_shard_W_1(self): return self._test_double_matmul_shard_axis(None, 1) + def test_double_matmul_shard_X_0(self): return self._test_double_matmul_shard_axis(0, None, devices_2) + def test_double_matmul_shard_X_1(self): return self._test_double_matmul_shard_axis(1, None, devices_2) + def test_double_matmul_shard_W_0(self): return self._test_double_matmul_shard_axis(None, 0, devices_2) + def test_double_matmul_shard_W_1(self): return self._test_double_matmul_shard_axis(None, 1, devices_2) def test_conv_data_shard(self): conv = nn.Conv2d(3, 16, 3, bias=False) @@ -177,14 +179,8 @@ class TestMultiTensor(unittest.TestCase): @unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault") @unittest.skipIf(Device.DEFAULT == "GPU", "GPU requires cl_khr_fp16") - def test_llama_attention(self): - bs = 1 - seq_len = 1 - dim = 128 - n_heads = 4 - n_kv_heads = 4 - max_context = 32 - + def _test_llama_attention(self, device): + bs, seq_len, dim, n_heads, n_kv_heads, max_context = 1, 1, 128, 4, 4, 32 freqs_cis = Tensor.rand(1, seq_len, 1, (dim//n_heads)//2, 2).half() mask = None start_pos = 0 @@ -198,11 +194,13 @@ class TestMultiTensor(unittest.TestCase): layer_sharded.wk.weight.assign(layer.wk.weight.shard((d0, d1), axis=0)).realize() layer_sharded.wv.weight.assign(layer.wv.weight.shard((d0, d1), axis=0)).realize() layer_sharded.wo.weight.assign(layer.wo.weight.shard((d0, d1), axis=0)).realize() - x_sharded = x.shard((d0, d1), axis=None).realize() - freqs_cis_sharded = freqs_cis.shard((d0, d1), axis=None).realize() + x_sharded = x.shard(devices_2, axis=None).realize() + freqs_cis_sharded = freqs_cis.shard(devices_2, axis=None).realize() y_sharded = layer_sharded(x_sharded, start_pos, freqs_cis_sharded, mask) - np.testing.assert_allclose(y.numpy(), y_sharded.numpy(), atol=1e-6, rtol=1e-6) + np.testing.assert_allclose(y.numpy(), y_sharded.numpy(), atol=1e-4, rtol=1e-4) + + def test_llama_attention(self): return self._test_llama_attention(devices_2) def test_data_parallel_resnet(self): import sys, pathlib