This commit is contained in:
George Hotz
2026-03-04 13:01:35 +08:00
parent 6b82b51759
commit ccb5dcf3b8

View File

@@ -136,7 +136,7 @@ def soft_allreduce(c:UOp, a:UOp):
to = c.src[1].param_like(0)
src = c.src[2].param_like(1)
red = UOp(Ops.ALLREDUCE, dtype=a.arg, src=(src, a.src[1]), arg=a.arg)
return handle_allreduce(src, red).assign(to).sink().call(*c.src[1:])
return to.assign(handle_allreduce(src, red)).sink().call(*c.src[1:])
pm_schedule = PatternMatcher([
(UPat(Ops.SINK, name="function"), lower_sink_to_linear),