mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 16:37:04 +08:00
less E kernels in all2all (#16546)
This commit is contained in:
@@ -26,7 +26,7 @@ class TestRingAllReduce(unittest.TestCase):
|
||||
copies = [si for si in linear.src if si.src[0].op is Ops.COPY]
|
||||
sinks = [si for si in linear.src if si.src[0].op is Ops.SINK]
|
||||
self.assertEqual(len(copies), 24)
|
||||
self.assertEqual(len(sinks), 30)
|
||||
self.assertEqual(len(sinks), 26)
|
||||
|
||||
def test_correct_ring(self):
|
||||
with Context(RING=2):
|
||||
|
||||
@@ -310,6 +310,8 @@ pm_const_buffer_folding = pm_mops+PatternMatcher([
|
||||
(UPat(Ops.STAGE, name="b"), cleanup_dead_axes),
|
||||
# remove noop buffers. if we look at the next index we can remove even more of these
|
||||
(UPat(Ops.INDEX, name="idx").f(Ops.STAGE, allow_any_len=True, name="b2"), remove_noop_bufferize),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.STAGE),), allow_any_len=True, name="idx").f(Ops.NOOP).f(Ops.STAGE, allow_any_len=True, name="b2"),
|
||||
remove_noop_bufferize),
|
||||
# no buffers for const (ranges don't matter for const - it's the same value everywhere)
|
||||
(UPat(Ops.CONST, name='c').f(Ops.STAGE, allow_any_len=True, name="b"), lambda c,b: b.const_like(c.arg)),
|
||||
# indexing a const is a const
|
||||
|
||||
Reference in New Issue
Block a user