mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
fix allreduce for max (#3175)
* test cases to show allreduce for max is incorrect * oh fixed
This commit is contained in:
@@ -73,15 +73,34 @@ class TestMultiTensor(unittest.TestCase):
|
||||
O = X + W
|
||||
np.testing.assert_allclose(O.numpy(), 2)
|
||||
|
||||
def _test_simple_reduce_axis(self, shard_x):
|
||||
def _test_sum_axis(self, shard_x):
|
||||
X = Tensor.ones(256, 256).contiguous().realize()
|
||||
X.shard_((d0, d1), shard_x)
|
||||
O = X.sum(axis=0)
|
||||
np.testing.assert_allclose(O.numpy(), 256)
|
||||
O = X.sum(axis=1)
|
||||
np.testing.assert_allclose(O.numpy(), 256)
|
||||
O = X.sum()
|
||||
np.testing.assert_allclose(O.numpy(), 256*256)
|
||||
|
||||
def test_simple_reduce(self): return self._test_simple_reduce_axis(None)
|
||||
def test_simple_reduce_0(self): return self._test_simple_reduce_axis(0)
|
||||
def test_simple_reduce_1(self): return self._test_simple_reduce_axis(1)
|
||||
def test_sum(self): return self._test_sum_axis(None)
|
||||
def test_sum_0(self): return self._test_sum_axis(0)
|
||||
def test_sum_1(self): return self._test_sum_axis(1)
|
||||
|
||||
def _test_max_axis(self, shard_x):
|
||||
X = Tensor.arange(16).reshape(4, 4)
|
||||
n = X.numpy()
|
||||
X.shard_((d0, d1), shard_x)
|
||||
O = X.max(axis=0)
|
||||
np.testing.assert_allclose(O.numpy(), n.max(0))
|
||||
O = X.max(axis=1)
|
||||
np.testing.assert_allclose(O.numpy(), n.max(1))
|
||||
O = X.max()
|
||||
np.testing.assert_allclose(O.numpy(), n.max())
|
||||
|
||||
def test_max(self): return self._test_max_axis(None)
|
||||
def test_max_0(self): return self._test_max_axis(0)
|
||||
def test_max_1(self): return self._test_max_axis(1)
|
||||
|
||||
def _test_matmul_shard_axis(self, shard_x, shard_w, device):
|
||||
X = Tensor.kaiming_uniform(N, N).realize()
|
||||
|
||||
@@ -7,9 +7,10 @@ from tinygrad.ops import BinaryOps, LoadOps, UnaryOps, TernaryOps, ReduceOps
|
||||
from tinygrad.lazy import LazyBuffer, create_schedule
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, sint
|
||||
|
||||
def all_reduce(lbs):
|
||||
def all_reduce(op:ReduceOps, lbs):
|
||||
# TODO: replace this with ring reduce
|
||||
return [functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
|
||||
bop = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX}[op]
|
||||
return [functools.reduce(lambda x,y: x.e(bop, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
|
||||
|
||||
def to_sharded(lbs:List[LazyBuffer], axis:int) -> List[LazyBuffer]:
|
||||
if DEBUG >= 3 and lbs[0].shape[axis] % len(lbs) != 0: print(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}")
|
||||
@@ -73,7 +74,7 @@ class MultiLazyBuffer:
|
||||
def r(self, op:ReduceOps, new_shape:Tuple[sint, ...]) -> MultiLazyBuffer:
|
||||
if self.axis is not None and new_shape[self.axis] == 1:
|
||||
# all-reduce on sharded axes
|
||||
return MultiLazyBuffer(all_reduce([x.r(op, new_shape) for x in self.lbs]), None)
|
||||
return MultiLazyBuffer(all_reduce(op, [x.r(op, new_shape) for x in self.lbs]), None)
|
||||
# reduce on non sharded axes, piecewise is fine. if axis is None this is also correct
|
||||
return MultiLazyBuffer([x.r(op, self._shape_to_single_shard(new_shape, x)) for x in self.lbs], self.axis)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user