diff --git a/test/test_schedule.py b/test/test_schedule.py index fb4ed96e13..19e26b9295 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -1092,6 +1092,18 @@ class TestSchedule(unittest.TestCase): np.testing.assert_allclose(out[0].numpy(), np.sqrt(np.square(x.numpy() - np_mu).sum(-1)/x.shape[-1]), atol=1e-4, rtol=1e-4) np.testing.assert_allclose(out[1].numpy(), np.sqrt(np.square(y.numpy() - np_mu).sum(-1)/y.shape[-1]), atol=1e-4, rtol=1e-4) + def test_cumsum_parallel_reduce_fused(self): + # two-stage cumsum + ops triggers parallel REDUCEs in one kernel that must share an END + step, num_steps = 513, 10 + t = Tensor.arange(step).float().realize() + phase = t.cumsum() + tiled = phase.repeat((num_steps,)).reshape(num_steps, step) + pattern = Tensor([1,0,0,1,0,0,0,0,1,0]).reshape(num_steps, 1) + out = (tiled * pattern).flatten() + expected = np.tile(np.arange(step).astype(np.float32).cumsum(), num_steps).reshape(num_steps, step) + expected = (expected * np.array([1,0,0,1,0,0,0,0,1,0]).reshape(num_steps, 1)).flatten() + np.testing.assert_allclose(out.numpy(), expected, atol=1e-4, rtol=1e-4) + def test_multimatmul_fusion(self): Tensor.manual_seed(0) a,b = Tensor.randn(4, 64).realize(), Tensor.rand(64,8).realize() diff --git a/tinygrad/codegen/late/devectorizer.py b/tinygrad/codegen/late/devectorizer.py index dbf4fd3ae6..f4223be7f8 100644 --- a/tinygrad/codegen/late/devectorizer.py +++ b/tinygrad/codegen/late/devectorizer.py @@ -1,7 +1,7 @@ from typing import Any, cast import functools, operator, itertools from collections import defaultdict -from dataclasses import dataclass +from dataclasses import dataclass, field from tinygrad.dtype import dtypes, ImageDType, DType, AddrSpace, Invalid, PtrDType from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, GroupOp, identity_element from tinygrad.uop.symbolic import uop_given_valid, parse_valid, invalid_gate @@ -299,6 +299,8 @@ pm_render = PatternMatcher([ @dataclass class ReduceContext: acc_num: int = 0 + # track ENDs by range for merging parallel reduces + range_to_ends: dict[tuple[UOp, ...], list[UOp]] = field(default_factory=dict) def horizontal_reduce(inp:UOp, out_dtype:DType) -> list[UOp]: # if this has a horizontal reduction component, do that first @@ -324,11 +326,19 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp): ctx.acc_num += 1 ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst) if len(reduce_range) == 0: return ret - return acc.after(acc.index(UOp.const(dtypes.int, 0)).store(ret).end(*reduce_range)).index(UOp.const(dtypes.int, 0)) + end = acc.index(UOp.const(dtypes.int, 0)).store(ret).end(*reduce_range) + ctx.range_to_ends.setdefault(reduce_range, []).append(end) + return acc.after(end).index(UOp.const(dtypes.int, 0)) + +def merge_reduce_ends(ctx:ReduceContext, sink:UOp): + # merge ENDs that share the same range + subs = {e: UOp.group(*(e.src[0] for e in ends)).end(*r) for r, ends in ctx.range_to_ends.items() if len(ends) > 1 for e in ends} + return sink.substitute(subs) if subs else None pm_reduce = PatternMatcher([ - # REDUCE -> DEFINE_ACC+ASSIGN + # REDUCE -> DEFINE_ACC+ASSIGN, then merge ENDs with same range (UPat(Ops.REDUCE, name="red"), reduce_to_acc), + (UPat(Ops.SINK, name="sink"), merge_reduce_ends), # tensor core built in accumulate (UPat(Ops.WMMA, name="wmma") + UPat.var("add"), lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),