Files
tinygrad/test/unit/test_allreduce.py
qazal a83710396c support mselect input to CALL, less kernels in allreduce (#16567)
* support mselect input to CALL, less kernels in allreduce

* resolve mstack
2026-06-11 18:10:47 +09:00

79 lines
3.2 KiB
Python

import unittest
from tinygrad import Tensor, dtypes
from tinygrad.helpers import Context
from tinygrad.uop.ops import Ops
class TestRingAllReduce(unittest.TestCase):
def test_schedule_ring(self):
with Context(RING=2):
N = 4
ds = tuple(f"CPU:{i}" for i in range(N))
t = Tensor.empty(N, N*100).shard(ds, axis=0).realize()
linear = t.sum(0).linear_with_vars()[0]
copies = [si for si in linear.src if si.src[0].op is Ops.COPY]
pairs = [(c.src[1].buffer.device, c.src[2].buffer.device) for c in copies]
# N*(N-1) scatter reduce, and N*(N-1) allgather
self.assertEqual(len(pairs), N*(N-1)*2)
# copy topology forms a ring
self.assertEqual(len(set(pairs)), N)
def test_schedule_all2all(self):
with Context(ALL2ALL=2):
N = 4
ds = tuple(f"CPU:{i}" for i in range(N))
t = Tensor.empty(N, N*100).shard(ds, axis=0).realize()
linear = t.sum(0).mul(2.0).contiguous().linear_with_vars()[0]
copies = [si for si in linear.src if si.src[0].op is Ops.COPY]
sinks = [si for si in linear.src if si.src[0].op is Ops.SINK]
self.assertEqual(len(copies), 24)
self.assertEqual(len(sinks), 26)
@Context(RING=0, ALL2ALL=0)
def test_schedule_naive(self):
N = 4
ds = tuple(f"NULL:{i}" for i in range(N))
t = Tensor.empty(N, 4096).shard(ds, axis=0).realize()
linear = t.sum(0).linear_with_vars()[0]
copies = [si for si in linear.src if si.src[0].op is Ops.COPY]
sinks = [si for si in linear.src if si.src[0].op is Ops.SINK]
pairs = [(c.src[1].buffer.device, c.src[2].buffer.device) for c in copies]
self.assertEqual(len(pairs), N*(N-1))
self.assertEqual(len(sinks), 2)
self.assertTrue(all(dst != src for dst, src in pairs))
def test_correct_ring(self):
with Context(RING=2):
N = 4
ds = tuple(f"CPU:{i}" for i in range(N))
t = Tensor.ones(N, N*100).contiguous().shard(ds, axis=0).realize()
out = t.sum(0)
self.assertListEqual(out.tolist(), [4]*N*100)
class TestAllreduceCast(unittest.TestCase):
def _get_copy_dtypes(self, dtype, allreduce_cast):
ds = tuple(f"CPU:{i}" for i in range(2))
with Context(ALLREDUCE_CAST=allreduce_cast, RING=0, SCACHE=0):
t = Tensor.empty(4, 4, dtype=dtype).shard(ds, axis=0)
linear = t.sum(0).linear_with_vars()[0]
return {si.src[1].buffer.dtype.scalar() for si in linear.src if si.src[0].op is Ops.COPY}
def test_allreduce_cast_bf16(self):
# with ALLREDUCE_CAST, allreduce copies stay in bfloat16 instead of promoting to float32
self.assertNotIn(dtypes.float, self._get_copy_dtypes(dtypes.bfloat16, allreduce_cast=1))
self.assertIn(dtypes.float, self._get_copy_dtypes(dtypes.bfloat16, allreduce_cast=0))
def test_allreduce_cast_half(self):
self.assertNotIn(dtypes.float, self._get_copy_dtypes(dtypes.half, allreduce_cast=1))
self.assertIn(dtypes.float, self._get_copy_dtypes(dtypes.half, allreduce_cast=0))
def test_allreduce_cast_float32_noop(self):
# float32 should not be affected by ALLREDUCE_CAST (no promotion happens)
dtypes_on = self._get_copy_dtypes(dtypes.float, allreduce_cast=1)
dtypes_off = self._get_copy_dtypes(dtypes.float, allreduce_cast=0)
self.assertEqual(dtypes_on, dtypes_off)
if __name__ == '__main__':
unittest.main()