add test cases for negative entry max allreduce (#3177)

This commit is contained in:
chenyu
2024-01-18 22:26:51 -05:00
committed by GitHub
parent ab1b7c4d09
commit c4faedebf3
2 changed files with 6 additions and 2 deletions

View File

@@ -87,8 +87,8 @@ class TestMultiTensor(unittest.TestCase):
def test_sum_0(self): return self._test_sum_axis(0)
def test_sum_1(self): return self._test_sum_axis(1)
def _test_max_axis(self, shard_x):
X = Tensor.arange(16).reshape(4, 4)
def _test_max_axis(self, shard_x, sign=1):
X = Tensor.arange(16).reshape(4, 4) * sign
n = X.numpy()
X.shard_((d0, d1), shard_x)
O = X.max(axis=0)
@@ -101,6 +101,9 @@ class TestMultiTensor(unittest.TestCase):
def test_max(self): return self._test_max_axis(None)
def test_max_0(self): return self._test_max_axis(0)
def test_max_1(self): return self._test_max_axis(1)
def test_max_neg(self): return self._test_max_axis(None, sign=-1)
def test_max_0_neg(self): return self._test_max_axis(0, sign=-1)
def test_max_1_neg(self): return self._test_max_axis(1, sign=-1)
def _test_matmul_shard_axis(self, shard_x, shard_w, device):
X = Tensor.kaiming_uniform(N, N).realize()

View File

@@ -8,6 +8,7 @@ from tinygrad.lazy import LazyBuffer, create_schedule
from tinygrad.shape.shapetracker import ShapeTracker, sint
def all_reduce(op:ReduceOps, lbs):
# TODO: does this work with uneven shards? add tests if so
# TODO: replace this with ring reduce
bop = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}[op]
return [functools.reduce(lambda x,y: x.e(bop, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]