From ab1b7c4d09d2e1d68f3289658d73ff2bf83f3389 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 18 Jan 2024 20:25:35 -0500 Subject: [PATCH] fix allreduce for max (#3175) * test cases to show allreduce for max is incorrect * oh fixed --- test/test_multitensor.py | 27 +++++++++++++++++++++++---- tinygrad/features/multi.py | 7 ++++--- 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/test/test_multitensor.py b/test/test_multitensor.py index c82e39c587..27e17d42da 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -73,15 +73,34 @@ class TestMultiTensor(unittest.TestCase): O = X + W np.testing.assert_allclose(O.numpy(), 2) - def _test_simple_reduce_axis(self, shard_x): + def _test_sum_axis(self, shard_x): X = Tensor.ones(256, 256).contiguous().realize() X.shard_((d0, d1), shard_x) + O = X.sum(axis=0) + np.testing.assert_allclose(O.numpy(), 256) O = X.sum(axis=1) np.testing.assert_allclose(O.numpy(), 256) + O = X.sum() + np.testing.assert_allclose(O.numpy(), 256*256) - def test_simple_reduce(self): return self._test_simple_reduce_axis(None) - def test_simple_reduce_0(self): return self._test_simple_reduce_axis(0) - def test_simple_reduce_1(self): return self._test_simple_reduce_axis(1) + def test_sum(self): return self._test_sum_axis(None) + def test_sum_0(self): return self._test_sum_axis(0) + def test_sum_1(self): return self._test_sum_axis(1) + + def _test_max_axis(self, shard_x): + X = Tensor.arange(16).reshape(4, 4) + n = X.numpy() + X.shard_((d0, d1), shard_x) + O = X.max(axis=0) + np.testing.assert_allclose(O.numpy(), n.max(0)) + O = X.max(axis=1) + np.testing.assert_allclose(O.numpy(), n.max(1)) + O = X.max() + np.testing.assert_allclose(O.numpy(), n.max()) + + def test_max(self): return self._test_max_axis(None) + def test_max_0(self): return self._test_max_axis(0) + def test_max_1(self): return self._test_max_axis(1) def _test_matmul_shard_axis(self, shard_x, shard_w, device): X = Tensor.kaiming_uniform(N, N).realize() diff --git a/tinygrad/features/multi.py b/tinygrad/features/multi.py index 705fcc7b31..2b0cd54548 100644 --- a/tinygrad/features/multi.py +++ b/tinygrad/features/multi.py @@ -7,9 +7,10 @@ from tinygrad.ops import BinaryOps, LoadOps, UnaryOps, TernaryOps, ReduceOps from tinygrad.lazy import LazyBuffer, create_schedule from tinygrad.shape.shapetracker import ShapeTracker, sint -def all_reduce(lbs): +def all_reduce(op:ReduceOps, lbs): # TODO: replace this with ring reduce - return [functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs] + bop = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}[op] + return [functools.reduce(lambda x,y: x.e(bop, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs] def to_sharded(lbs:List[LazyBuffer], axis:int) -> List[LazyBuffer]: if DEBUG >= 3 and lbs[0].shape[axis] % len(lbs) != 0: print(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}") @@ -73,7 +74,7 @@ class MultiLazyBuffer: def r(self, op:ReduceOps, new_shape:Tuple[sint, ...]) -> MultiLazyBuffer: if self.axis is not None and new_shape[self.axis] == 1: # all-reduce on sharded axes - return MultiLazyBuffer(all_reduce([x.r(op, new_shape) for x in self.lbs]), None) + return MultiLazyBuffer(all_reduce(op, [x.r(op, new_shape) for x in self.lbs]), None) # reduce on non sharded axes, piecewise is fine. if axis is None this is also correct return MultiLazyBuffer([x.r(op, self._shape_to_single_shard(new_shape, x)) for x in self.lbs], self.axis)