mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-11 23:46:02 +08:00
support mselect input to CALL, less kernels in allreduce (#16567)
* support mselect input to CALL, less kernels in allreduce * resolve mstack
This commit is contained in:
@@ -40,7 +40,7 @@ class TestRingAllReduce(unittest.TestCase):
|
||||
pairs = [(c.src[1].buffer.device, c.src[2].buffer.device) for c in copies]
|
||||
|
||||
self.assertEqual(len(pairs), N*(N-1))
|
||||
self.assertEqual(len(sinks), N+2)
|
||||
self.assertEqual(len(sinks), 2)
|
||||
self.assertTrue(all(dst != src for dst, src in pairs))
|
||||
|
||||
def test_correct_ring(self):
|
||||
|
||||
@@ -138,6 +138,7 @@ class ExecContext:
|
||||
|
||||
def _resolve(b:UOp, inputs:tuple[UOp, ...]) -> UOp:
|
||||
if b.op in (Ops.SLICE, Ops.MSELECT) and b.src[0].op is Ops.PARAM: return b.replace(src=(inputs[b.src[0].arg.slot], *b.src[1:]))
|
||||
if b.op is Ops.MSTACK: return b.replace(src=tuple(_resolve(x, inputs) for x in b.src))
|
||||
return inputs[b.arg.slot] if b.op is Ops.PARAM else b
|
||||
def resolve_params(call:UOp, inputs:tuple[UOp, ...]) -> list[UOp]: return [_resolve(b, inputs) for b in get_call_arg_uops(call)]
|
||||
|
||||
|
||||
@@ -39,6 +39,7 @@ def create_schedule(sched_sink:UOp) -> UOp:
|
||||
case Ops.MSELECT | Ops.MSTACK:
|
||||
for ss in s.src:
|
||||
if ss.op is Ops.MSELECT: ss = ss.src[0]
|
||||
ss = _unwrap_src(ss)
|
||||
if ss.op not in {Ops.BUFFER, Ops.PARAM}:
|
||||
assert ss.op is Ops.AFTER, f"ss.op is not AFTER, it's {ss.op}"
|
||||
for t in _split_after(ss)[0]:
|
||||
|
||||
@@ -178,6 +178,9 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
|
||||
(UPat(Ops.COPY, src=(UPat(GroupOp.Movement, name="r"), UPat(name="d")), name="c"),
|
||||
lambda c,r,d: c.replace(src=(r.contiguous(), d)) if resolve(r.numel() != r.base.numel(), False) else None),
|
||||
|
||||
# copying mselect to same device is just mselect (no NOOP kernel)
|
||||
(UPat(Ops.COPY, src=(UPat(Ops.MSELECT, name="ms"), UPat()), name="copy"), lambda ms,copy: ms if ms.device == copy.device else None),
|
||||
|
||||
# copy only to different device
|
||||
(UPat(Ops.COPY, src=(UPat.var("x"), UPat()), name="copy"), lambda x,copy: x.f(Ops.NOOP) if x.device == copy.device else None),
|
||||
|
||||
@@ -487,8 +490,6 @@ def unbind_kernel(ctx:LocalAddBufferContext, b:UOp):
|
||||
def handle_after(ctx:LocalAddBufferContext, after:UOp):
|
||||
if isinstance(after.dtype, PtrDType) and after.addrspace == AddrSpace.LOCAL: return None
|
||||
buf = after.buf_uop
|
||||
# HACK to put the buffer in the MAP instead of MSTACK/MSELECT
|
||||
if buf.op in {Ops.MSTACK, Ops.MSELECT}: buf = buf.src[0]
|
||||
# NOTE: this is bottom up, so we only add it once
|
||||
if buf not in ctx.map: ctx.map[buf] = after
|
||||
return buf
|
||||
@@ -507,7 +508,7 @@ def find_bufs(x:UOp):
|
||||
|
||||
to_define_global = PatternMatcher([
|
||||
(UPat(Ops.STORE, name="x"), find_bufs),
|
||||
(UPat(Ops.BUFFER, name="buf"), debuf),
|
||||
(UPat((Ops.BUFFER, Ops.MSTACK, Ops.MSELECT), name="buf"), debuf),
|
||||
(UPat(Ops.PARAM, name="v"), lambda v:
|
||||
UOp.variable(v.arg.name, v.arg.vmin_vmax[0], v.arg.vmin_vmax[1], v.dtype)
|
||||
if v.arg.name is not None and v.arg.vmin_vmax is not None else None),
|
||||
@@ -516,7 +517,7 @@ to_define_global = PatternMatcher([
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.DEFINE_VAR, name="v"),)), lambda v: v),
|
||||
|
||||
(UPat(Ops.BIND, name="b"), unbind_kernel),
|
||||
(UPat((Ops.MSTACK, Ops.MSELECT, Ops.AFTER), name="after"), handle_after),
|
||||
(UPat(Ops.AFTER, name="after"), handle_after),
|
||||
|
||||
# remove device from local BUFFERIZE
|
||||
(UPat(Ops.STAGE, name="b"), lambda b: b.replace(arg=replace(b.arg, device=None))),
|
||||
|
||||
@@ -788,7 +788,7 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
|
||||
if self.op is Ops.DEFINE_LOCAL: return AddrSpace.LOCAL
|
||||
if self.op is Ops.DEFINE_REG: return AddrSpace.REG
|
||||
if self.op is Ops.LOAD: return AddrSpace.REG # LOAD brings things into registers
|
||||
if self.op in {Ops.INDEX, Ops.CAST, Ops.AFTER, Ops.REDUCE, Ops.GEP, Ops.STORE}:
|
||||
if self.op in {Ops.INDEX, Ops.CAST, Ops.AFTER, Ops.REDUCE, Ops.GEP, Ops.STORE, Ops.MSTACK, Ops.MSELECT}:
|
||||
return self.src[0].addrspace
|
||||
if self.op in GroupOp.Movement: return self.src[0].addrspace
|
||||
if self.op in {Ops.STACK, Ops.WMMA} or self.op in GroupOp.Elementwise:
|
||||
|
||||
Reference in New Issue
Block a user