diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 32d2bf329d..ed01648d92 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -136,7 +136,7 @@ def soft_allreduce(c:UOp, a:UOp): to = c.src[1].param_like(0) src = c.src[2].param_like(1) red = UOp(Ops.ALLREDUCE, dtype=a.arg, src=(src, a.src[1]), arg=a.arg) - return handle_allreduce(src, red).assign(to).sink().call(*c.src[1:]) + return to.assign(handle_allreduce(src, red)).sink().call(*c.src[1:]) pm_schedule = PatternMatcher([ (UPat(Ops.SINK, name="function"), lower_sink_to_linear),