mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
fix multiple REDUCE on same RANGE (#14504)
each RANGE maps to one END, but reduce_to_acc is local and would not know this
This commit is contained in:
@@ -1092,6 +1092,18 @@ class TestSchedule(unittest.TestCase):
|
||||
np.testing.assert_allclose(out[0].numpy(), np.sqrt(np.square(x.numpy() - np_mu).sum(-1)/x.shape[-1]), atol=1e-4, rtol=1e-4)
|
||||
np.testing.assert_allclose(out[1].numpy(), np.sqrt(np.square(y.numpy() - np_mu).sum(-1)/y.shape[-1]), atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_cumsum_parallel_reduce_fused(self):
|
||||
# two-stage cumsum + ops triggers parallel REDUCEs in one kernel that must share an END
|
||||
step, num_steps = 513, 10
|
||||
t = Tensor.arange(step).float().realize()
|
||||
phase = t.cumsum()
|
||||
tiled = phase.repeat((num_steps,)).reshape(num_steps, step)
|
||||
pattern = Tensor([1,0,0,1,0,0,0,0,1,0]).reshape(num_steps, 1)
|
||||
out = (tiled * pattern).flatten()
|
||||
expected = np.tile(np.arange(step).astype(np.float32).cumsum(), num_steps).reshape(num_steps, step)
|
||||
expected = (expected * np.array([1,0,0,1,0,0,0,0,1,0]).reshape(num_steps, 1)).flatten()
|
||||
np.testing.assert_allclose(out.numpy(), expected, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_multimatmul_fusion(self):
|
||||
Tensor.manual_seed(0)
|
||||
a,b = Tensor.randn(4, 64).realize(), Tensor.rand(64,8).realize()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any, cast
|
||||
import functools, operator, itertools
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.dtype import dtypes, ImageDType, DType, AddrSpace, Invalid, PtrDType
|
||||
from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, GroupOp, identity_element
|
||||
from tinygrad.uop.symbolic import uop_given_valid, parse_valid, invalid_gate
|
||||
@@ -299,6 +299,8 @@ pm_render = PatternMatcher([
|
||||
@dataclass
|
||||
class ReduceContext:
|
||||
acc_num: int = 0
|
||||
# track ENDs by range for merging parallel reduces
|
||||
range_to_ends: dict[tuple[UOp, ...], list[UOp]] = field(default_factory=dict)
|
||||
|
||||
def horizontal_reduce(inp:UOp, out_dtype:DType) -> list[UOp]:
|
||||
# if this has a horizontal reduction component, do that first
|
||||
@@ -324,11 +326,19 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
|
||||
ctx.acc_num += 1
|
||||
ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst)
|
||||
if len(reduce_range) == 0: return ret
|
||||
return acc.after(acc.index(UOp.const(dtypes.int, 0)).store(ret).end(*reduce_range)).index(UOp.const(dtypes.int, 0))
|
||||
end = acc.index(UOp.const(dtypes.int, 0)).store(ret).end(*reduce_range)
|
||||
ctx.range_to_ends.setdefault(reduce_range, []).append(end)
|
||||
return acc.after(end).index(UOp.const(dtypes.int, 0))
|
||||
|
||||
def merge_reduce_ends(ctx:ReduceContext, sink:UOp):
|
||||
# merge ENDs that share the same range
|
||||
subs = {e: UOp.group(*(e.src[0] for e in ends)).end(*r) for r, ends in ctx.range_to_ends.items() if len(ends) > 1 for e in ends}
|
||||
return sink.substitute(subs) if subs else None
|
||||
|
||||
pm_reduce = PatternMatcher([
|
||||
# REDUCE -> DEFINE_ACC+ASSIGN
|
||||
# REDUCE -> DEFINE_ACC+ASSIGN, then merge ENDs with same range
|
||||
(UPat(Ops.REDUCE, name="red"), reduce_to_acc),
|
||||
(UPat(Ops.SINK, name="sink"), merge_reduce_ends),
|
||||
# tensor core built in accumulate
|
||||
(UPat(Ops.WMMA, name="wmma") + UPat.var("add"),
|
||||
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
|
||||
|
||||
Reference in New Issue
Block a user