fix multiple REDUCE on same RANGE (#14504)

each RANGE maps to one END, but reduce_to_acc is local and would not know this
This commit is contained in:
chenyu
2026-02-02 20:42:09 -05:00
committed by GitHub
parent 93c41a78fa
commit 4f2e7aed24
2 changed files with 25 additions and 3 deletions

View File

@@ -1092,6 +1092,18 @@ class TestSchedule(unittest.TestCase):
np.testing.assert_allclose(out[0].numpy(), np.sqrt(np.square(x.numpy() - np_mu).sum(-1)/x.shape[-1]), atol=1e-4, rtol=1e-4)
np.testing.assert_allclose(out[1].numpy(), np.sqrt(np.square(y.numpy() - np_mu).sum(-1)/y.shape[-1]), atol=1e-4, rtol=1e-4)
def test_cumsum_parallel_reduce_fused(self):
# two-stage cumsum + ops triggers parallel REDUCEs in one kernel that must share an END
step, num_steps = 513, 10
t = Tensor.arange(step).float().realize()
phase = t.cumsum()
tiled = phase.repeat((num_steps,)).reshape(num_steps, step)
pattern = Tensor([1,0,0,1,0,0,0,0,1,0]).reshape(num_steps, 1)
out = (tiled * pattern).flatten()
expected = np.tile(np.arange(step).astype(np.float32).cumsum(), num_steps).reshape(num_steps, step)
expected = (expected * np.array([1,0,0,1,0,0,0,0,1,0]).reshape(num_steps, 1)).flatten()
np.testing.assert_allclose(out.numpy(), expected, atol=1e-4, rtol=1e-4)
def test_multimatmul_fusion(self):
Tensor.manual_seed(0)
a,b = Tensor.randn(4, 64).realize(), Tensor.rand(64,8).realize()

View File

@@ -1,7 +1,7 @@
from typing import Any, cast
import functools, operator, itertools
from collections import defaultdict
from dataclasses import dataclass
from dataclasses import dataclass, field
from tinygrad.dtype import dtypes, ImageDType, DType, AddrSpace, Invalid, PtrDType
from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, GroupOp, identity_element
from tinygrad.uop.symbolic import uop_given_valid, parse_valid, invalid_gate
@@ -299,6 +299,8 @@ pm_render = PatternMatcher([
@dataclass
class ReduceContext:
acc_num: int = 0
# track ENDs by range for merging parallel reduces
range_to_ends: dict[tuple[UOp, ...], list[UOp]] = field(default_factory=dict)
def horizontal_reduce(inp:UOp, out_dtype:DType) -> list[UOp]:
# if this has a horizontal reduction component, do that first
@@ -324,11 +326,19 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
ctx.acc_num += 1
ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst)
if len(reduce_range) == 0: return ret
return acc.after(acc.index(UOp.const(dtypes.int, 0)).store(ret).end(*reduce_range)).index(UOp.const(dtypes.int, 0))
end = acc.index(UOp.const(dtypes.int, 0)).store(ret).end(*reduce_range)
ctx.range_to_ends.setdefault(reduce_range, []).append(end)
return acc.after(end).index(UOp.const(dtypes.int, 0))
def merge_reduce_ends(ctx:ReduceContext, sink:UOp):
# merge ENDs that share the same range
subs = {e: UOp.group(*(e.src[0] for e in ends)).end(*r) for r, ends in ctx.range_to_ends.items() if len(ends) > 1 for e in ends}
return sink.substitute(subs) if subs else None
pm_reduce = PatternMatcher([
# REDUCE -> DEFINE_ACC+ASSIGN
# REDUCE -> DEFINE_ACC+ASSIGN, then merge ENDs with same range
(UPat(Ops.REDUCE, name="red"), reduce_to_acc),
(UPat(Ops.SINK, name="sink"), merge_reduce_ends),
# tensor core built in accumulate
(UPat(Ops.WMMA, name="wmma") + UPat.var("add"),
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),