diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 27e17d42da..076fecbf87 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -87,8 +87,8 @@ class TestMultiTensor(unittest.TestCase): def test_sum_0(self): return self._test_sum_axis(0) def test_sum_1(self): return self._test_sum_axis(1) - def _test_max_axis(self, shard_x): - X = Tensor.arange(16).reshape(4, 4) + def _test_max_axis(self, shard_x, sign=1): + X = Tensor.arange(16).reshape(4, 4) * sign n = X.numpy() X.shard_((d0, d1), shard_x) O = X.max(axis=0) @@ -101,6 +101,9 @@ class TestMultiTensor(unittest.TestCase): def test_max(self): return self._test_max_axis(None) def test_max_0(self): return self._test_max_axis(0) def test_max_1(self): return self._test_max_axis(1) + def test_max_neg(self): return self._test_max_axis(None, sign=-1) + def test_max_0_neg(self): return self._test_max_axis(0, sign=-1) + def test_max_1_neg(self): return self._test_max_axis(1, sign=-1) def _test_matmul_shard_axis(self, shard_x, shard_w, device): X = Tensor.kaiming_uniform(N, N).realize() diff --git a/tinygrad/features/multi.py b/tinygrad/features/multi.py index 2b0cd54548..e5c1a4878d 100644 --- a/tinygrad/features/multi.py +++ b/tinygrad/features/multi.py @@ -8,6 +8,7 @@ from tinygrad.lazy import LazyBuffer, create_schedule from tinygrad.shape.shapetracker import ShapeTracker, sint def all_reduce(op:ReduceOps, lbs): + # TODO: does this work with uneven shards? add tests if so # TODO: replace this with ring reduce bop = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}[op] return [functools.reduce(lambda x,y: x.e(bop, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]