fix allreduce for max (#3175)

* test cases to show allreduce for max is incorrect

* oh fixed
This commit is contained in:
chenyu
2024-01-18 20:25:35 -05:00
committed by GitHub
parent c51c90bcd4
commit ab1b7c4d09
2 changed files with 27 additions and 7 deletions

View File

@@ -73,15 +73,34 @@ class TestMultiTensor(unittest.TestCase):
O = X + W
np.testing.assert_allclose(O.numpy(), 2)
def _test_simple_reduce_axis(self, shard_x):
def _test_sum_axis(self, shard_x):
X = Tensor.ones(256, 256).contiguous().realize()
X.shard_((d0, d1), shard_x)
O = X.sum(axis=0)
np.testing.assert_allclose(O.numpy(), 256)
O = X.sum(axis=1)
np.testing.assert_allclose(O.numpy(), 256)
O = X.sum()
np.testing.assert_allclose(O.numpy(), 256*256)
def test_simple_reduce(self): return self._test_simple_reduce_axis(None)
def test_simple_reduce_0(self): return self._test_simple_reduce_axis(0)
def test_simple_reduce_1(self): return self._test_simple_reduce_axis(1)
def test_sum(self): return self._test_sum_axis(None)
def test_sum_0(self): return self._test_sum_axis(0)
def test_sum_1(self): return self._test_sum_axis(1)
def _test_max_axis(self, shard_x):
X = Tensor.arange(16).reshape(4, 4)
n = X.numpy()
X.shard_((d0, d1), shard_x)
O = X.max(axis=0)
np.testing.assert_allclose(O.numpy(), n.max(0))
O = X.max(axis=1)
np.testing.assert_allclose(O.numpy(), n.max(1))
O = X.max()
np.testing.assert_allclose(O.numpy(), n.max())
def test_max(self): return self._test_max_axis(None)
def test_max_0(self): return self._test_max_axis(0)
def test_max_1(self): return self._test_max_axis(1)
def _test_matmul_shard_axis(self, shard_x, shard_w, device):
X = Tensor.kaiming_uniform(N, N).realize()

View File

@@ -7,9 +7,10 @@ from tinygrad.ops import BinaryOps, LoadOps, UnaryOps, TernaryOps, ReduceOps
from tinygrad.lazy import LazyBuffer, create_schedule
from tinygrad.shape.shapetracker import ShapeTracker, sint
def all_reduce(lbs):
def all_reduce(op:ReduceOps, lbs):
# TODO: replace this with ring reduce
return [functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
bop = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}[op]
return [functools.reduce(lambda x,y: x.e(bop, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
def to_sharded(lbs:List[LazyBuffer], axis:int) -> List[LazyBuffer]:
if DEBUG >= 3 and lbs[0].shape[axis] % len(lbs) != 0: print(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}")
@@ -73,7 +74,7 @@ class MultiLazyBuffer:
def r(self, op:ReduceOps, new_shape:Tuple[sint, ...]) -> MultiLazyBuffer:
if self.axis is not None and new_shape[self.axis] == 1:
# all-reduce on sharded axes
return MultiLazyBuffer(all_reduce([x.r(op, new_shape) for x in self.lbs]), None)
return MultiLazyBuffer(all_reduce(op, [x.r(op, new_shape) for x in self.lbs]), None)
# reduce on non sharded axes, piecewise is fine. if axis is None this is also correct
return MultiLazyBuffer([x.r(op, self._shape_to_single_shard(new_shape, x)) for x in self.lbs], self.axis)