Files
tinygrad/test/unit/test_allreduce.py
2026-04-13 20:24:12 -07:00

53 lines
2.2 KiB
Python

import unittest
from tinygrad import Tensor, dtypes
from tinygrad.helpers import Context
from tinygrad.uop.ops import Ops
class TestRingAllReduce(unittest.TestCase):
def test_schedule_ring(self):
with Context(RING=2):
N = 4
ds = tuple(f"CPU:{i}" for i in range(N))
t = Tensor.empty(N, N*100).shard(ds, axis=0).realize()
schedules = t.sum(0).schedule_with_vars()[0]
copies = [si for si in schedules if si.ast.op is Ops.COPY]
pairs = [(c.bufs[0].device, c.bufs[1].device) for c in copies]
# N*(N-1) scatter reduce, and N*(N-1) allgather
self.assertEqual(len(pairs), N*(N-1)*2)
# copy topology forms a ring
self.assertEqual(len(set(pairs)), N)
def test_correct_ring(self):
with Context(RING=2):
N = 4
ds = tuple(f"CPU:{i}" for i in range(N))
t = Tensor.ones(N, N*100).contiguous().shard(ds, axis=0).realize()
out = t.sum(0)
self.assertListEqual(out.tolist(), [4]*N*100)
class TestAllreduceCast(unittest.TestCase):
def _get_copy_dtypes(self, dtype, allreduce_cast):
ds = tuple(f"CPU:{i}" for i in range(2))
with Context(ALLREDUCE_CAST=allreduce_cast, RING=0, SCACHE=0):
t = Tensor.empty(4, 4, dtype=dtype).shard(ds, axis=0)
schedules = t.sum(0).schedule_with_vars()[0]
return {si.bufs[0].dtype.scalar() for si in schedules if si.ast.op is Ops.COPY}
def test_allreduce_cast_bf16(self):
# with ALLREDUCE_CAST, allreduce copies stay in bfloat16 instead of promoting to float32
self.assertNotIn(dtypes.float, self._get_copy_dtypes(dtypes.bfloat16, allreduce_cast=1))
self.assertIn(dtypes.float, self._get_copy_dtypes(dtypes.bfloat16, allreduce_cast=0))
def test_allreduce_cast_half(self):
self.assertNotIn(dtypes.float, self._get_copy_dtypes(dtypes.half, allreduce_cast=1))
self.assertIn(dtypes.float, self._get_copy_dtypes(dtypes.half, allreduce_cast=0))
def test_allreduce_cast_float32_noop(self):
# float32 should not be affected by ALLREDUCE_CAST (no promotion happens)
dtypes_on = self._get_copy_dtypes(dtypes.float, allreduce_cast=1)
dtypes_off = self._get_copy_dtypes(dtypes.float, allreduce_cast=0)
self.assertEqual(dtypes_on, dtypes_off)
if __name__ == '__main__':
unittest.main()