From eaceafecae9cf699cb09e0e8581cce7b96d197bd Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 28 Apr 2025 20:45:37 -0400 Subject: [PATCH] do fusion locally (#10095) * do fusion locally * oops, that's the right way * explicit delete closure --- test/test_softmax_fusion.py | 9 +++++++-- tinygrad/engine/grouper.py | 12 +++++++++++- tinygrad/viz/serve.py | 2 +- 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/test/test_softmax_fusion.py b/test/test_softmax_fusion.py index 193244278a..fd0dedc068 100644 --- a/test/test_softmax_fusion.py +++ b/test/test_softmax_fusion.py @@ -30,10 +30,10 @@ def single_kernel_softmax(x_in:Tensor, axis=-1, dtype:DTypeLike|None=None) -> Te def run_one_schedule_item(out): lower_schedule_item(get_single_element(out.schedule())).run() class TestFuse(unittest.TestCase): - def _test_fuse(self, fxn, *args, atol=1e-7, **kwargs): + def _test_fuse(self, fxn, *args, atol=1e-7, allow_multiple=False, **kwargs): GlobalCounters.reset() out_single = fxn(*args, **kwargs).fuse() - run_one_schedule_item(out_single) + if not allow_multiple: run_one_schedule_item(out_single) np_single = out_single.numpy() GlobalCounters.reset() np_multi = fxn(*args, **kwargs).numpy() @@ -51,6 +51,11 @@ class TestFuse(unittest.TestCase): a = Tensor.rand(50,50).realize() self._test_fuse(lambda a: a.softmax(axis=-1), a) + def test_fuse_gemm_softmax(self): + a = Tensor.rand(50,50).realize() + b = Tensor.rand(50,50).realize() + self._test_fuse(lambda a,b: ((a@b).relu()+a).contiguous().softmax(axis=-1), a,b, allow_multiple=True) + @unittest.skipUnless(is_dtype_supported(dtypes.float16, Device.DEFAULT), f"no float16 on {Device.DEFAULT}") def test_fuse_softmax_dtype(self): a = Tensor.rand(50,50).realize() diff --git a/tinygrad/engine/grouper.py b/tinygrad/engine/grouper.py index 4b8253cbfa..86f06fdadf 100644 --- a/tinygrad/engine/grouper.py +++ b/tinygrad/engine/grouper.py @@ -412,6 +412,16 @@ pm_fuse = PatternMatcher([ (UPat(Ops.FUSE, name="x"), lambda x: x.src[0].replace(src=tuple(y.fuse() for y in x.src[0].src))), ]) +def do_fusion(x:UOp): + found_contiguous = {} + def gate_contiguous(x): + if is_contiguous:=(x.op is Ops.CONTIGUOUS): found_contiguous[x] = x.replace(src=(UOp(Ops.VIEW, arg=x.st),)) + return not is_contiguous + x.toposort(gate=gate_contiguous) + del gate_contiguous + return graph_rewrite(x.substitute(found_contiguous), pm_fuse, name="local fusion").substitute({v:k for k,v in found_contiguous.items()}) +do_fuse = PatternMatcher([(UPat(Ops.FUSE, name="x"), do_fusion),]) + PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {} if CAPTURE_PROCESS_REPLAY: import atexit @@ -427,7 +437,7 @@ def get_name(becomes_map:dict[UOp, UOp]) -> str: @track_rewrites(name_fxn=get_name) def get_becomes_map(big_sink:UOp) -> dict[UOp, UOp]: # merge_views + simplify - tensor_map = graph_rewrite_map(big_sink, merge_views+sym+replace_contiguous+pm_fuse, ctx={}) + tensor_map = graph_rewrite_map(big_sink, do_fuse+merge_views+sym+replace_contiguous, ctx={}) # display the cleaned up tensor graph if getenv("VIZ"): graph_rewrite(tensor_map[big_sink], PatternMatcher([]), name="View Tensor Graph") diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 2408401178..0444424c86 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -14,7 +14,7 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff", Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55", Ops.IGNORE: "#00C000", **{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BUFFER_VIEW: "#E5EAFF", - Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0"} + Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.FUSE: "#FFa500"} # VIZ API