support mselect input to CALL, less kernels in allreduce (#16567)

* support mselect input to CALL, less kernels in allreduce

* resolve mstack
This commit is contained in:
qazal
2026-06-11 17:10:47 +08:00
committed by GitHub
parent 7d4a77dce4
commit a83710396c
5 changed files with 9 additions and 6 deletions

View File

@@ -40,7 +40,7 @@ class TestRingAllReduce(unittest.TestCase):
pairs = [(c.src[1].buffer.device, c.src[2].buffer.device) for c in copies]
self.assertEqual(len(pairs), N*(N-1))
self.assertEqual(len(sinks), N+2)
self.assertEqual(len(sinks), 2)
self.assertTrue(all(dst != src for dst, src in pairs))
def test_correct_ring(self):

View File

@@ -138,6 +138,7 @@ class ExecContext:
def _resolve(b:UOp, inputs:tuple[UOp, ...]) -> UOp:
if b.op in (Ops.SLICE, Ops.MSELECT) and b.src[0].op is Ops.PARAM: return b.replace(src=(inputs[b.src[0].arg.slot], *b.src[1:]))
if b.op is Ops.MSTACK: return b.replace(src=tuple(_resolve(x, inputs) for x in b.src))
return inputs[b.arg.slot] if b.op is Ops.PARAM else b
def resolve_params(call:UOp, inputs:tuple[UOp, ...]) -> list[UOp]: return [_resolve(b, inputs) for b in get_call_arg_uops(call)]

View File

@@ -39,6 +39,7 @@ def create_schedule(sched_sink:UOp) -> UOp:
case Ops.MSELECT | Ops.MSTACK:
for ss in s.src:
if ss.op is Ops.MSELECT: ss = ss.src[0]
ss = _unwrap_src(ss)
if ss.op not in {Ops.BUFFER, Ops.PARAM}:
assert ss.op is Ops.AFTER, f"ss.op is not AFTER, it's {ss.op}"
for t in _split_after(ss)[0]:

View File

@@ -178,6 +178,9 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
(UPat(Ops.COPY, src=(UPat(GroupOp.Movement, name="r"), UPat(name="d")), name="c"),
lambda c,r,d: c.replace(src=(r.contiguous(), d)) if resolve(r.numel() != r.base.numel(), False) else None),
# copying mselect to same device is just mselect (no NOOP kernel)
(UPat(Ops.COPY, src=(UPat(Ops.MSELECT, name="ms"), UPat()), name="copy"), lambda ms,copy: ms if ms.device == copy.device else None),
# copy only to different device
(UPat(Ops.COPY, src=(UPat.var("x"), UPat()), name="copy"), lambda x,copy: x.f(Ops.NOOP) if x.device == copy.device else None),
@@ -487,8 +490,6 @@ def unbind_kernel(ctx:LocalAddBufferContext, b:UOp):
def handle_after(ctx:LocalAddBufferContext, after:UOp):
if isinstance(after.dtype, PtrDType) and after.addrspace == AddrSpace.LOCAL: return None
buf = after.buf_uop
# HACK to put the buffer in the MAP instead of MSTACK/MSELECT
if buf.op in {Ops.MSTACK, Ops.MSELECT}: buf = buf.src[0]
# NOTE: this is bottom up, so we only add it once
if buf not in ctx.map: ctx.map[buf] = after
return buf
@@ -507,7 +508,7 @@ def find_bufs(x:UOp):
to_define_global = PatternMatcher([
(UPat(Ops.STORE, name="x"), find_bufs),
(UPat(Ops.BUFFER, name="buf"), debuf),
(UPat((Ops.BUFFER, Ops.MSTACK, Ops.MSELECT), name="buf"), debuf),
(UPat(Ops.PARAM, name="v"), lambda v:
UOp.variable(v.arg.name, v.arg.vmin_vmax[0], v.arg.vmin_vmax[1], v.dtype)
if v.arg.name is not None and v.arg.vmin_vmax is not None else None),
@@ -516,7 +517,7 @@ to_define_global = PatternMatcher([
(UPat(Ops.INDEX, src=(UPat(Ops.DEFINE_VAR, name="v"),)), lambda v: v),
(UPat(Ops.BIND, name="b"), unbind_kernel),
(UPat((Ops.MSTACK, Ops.MSELECT, Ops.AFTER), name="after"), handle_after),
(UPat(Ops.AFTER, name="after"), handle_after),
# remove device from local BUFFERIZE
(UPat(Ops.STAGE, name="b"), lambda b: b.replace(arg=replace(b.arg, device=None))),

View File

@@ -788,7 +788,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
if self.op is Ops.DEFINE_LOCAL: return AddrSpace.LOCAL
if self.op is Ops.DEFINE_REG: return AddrSpace.REG
if self.op is Ops.LOAD: return AddrSpace.REG # LOAD brings things into registers
if self.op in {Ops.INDEX, Ops.CAST, Ops.AFTER, Ops.REDUCE, Ops.GEP, Ops.STORE}:
if self.op in {Ops.INDEX, Ops.CAST, Ops.AFTER, Ops.REDUCE, Ops.GEP, Ops.STORE, Ops.MSTACK, Ops.MSELECT}:
return self.src[0].addrspace
if self.op in GroupOp.Movement: return self.src[0].addrspace
if self.op in {Ops.STACK, Ops.WMMA} or self.op in GroupOp.Elementwise: