less E kernels in all2all (#16546)

This commit is contained in:
qazal
2026-06-09 12:51:57 +08:00
committed by GitHub
parent b8931440ae
commit fa400f9790
2 changed files with 3 additions and 1 deletions

View File

@@ -26,7 +26,7 @@ class TestRingAllReduce(unittest.TestCase):
copies = [si for si in linear.src if si.src[0].op is Ops.COPY]
sinks = [si for si in linear.src if si.src[0].op is Ops.SINK]
self.assertEqual(len(copies), 24)
self.assertEqual(len(sinks), 30)
self.assertEqual(len(sinks), 26)
def test_correct_ring(self):
with Context(RING=2):

View File

@@ -310,6 +310,8 @@ pm_const_buffer_folding = pm_mops+PatternMatcher([
(UPat(Ops.STAGE, name="b"), cleanup_dead_axes),
# remove noop buffers. if we look at the next index we can remove even more of these
(UPat(Ops.INDEX, name="idx").f(Ops.STAGE, allow_any_len=True, name="b2"), remove_noop_bufferize),
(UPat(Ops.INDEX, src=(UPat(Ops.STAGE),), allow_any_len=True, name="idx").f(Ops.NOOP).f(Ops.STAGE, allow_any_len=True, name="b2"),
remove_noop_bufferize),
# no buffers for const (ranges don't matter for const - it's the same value everywhere)
(UPat(Ops.CONST, name='c').f(Ops.STAGE, allow_any_len=True, name="b"), lambda c,b: b.const_like(c.arg)),
# indexing a const is a const