From 3bd992dc953d2cc1e62000bb4c321ab35ed5a391 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Wed, 9 Apr 2025 15:59:45 +0800 Subject: [PATCH] multi stage graph_rewrite_map (#9803) * multistage graph_rewrite_map * s/merge_map/input_map * build up kernel_map from the tensor_map --- test/unit/test_rewrite_map.py | 16 ++++++++++++++++ tinygrad/engine/grouper.py | 7 ++++--- tinygrad/ops.py | 7 +++++-- 3 files changed, 25 insertions(+), 5 deletions(-) diff --git a/test/unit/test_rewrite_map.py b/test/unit/test_rewrite_map.py index 0625dc07fa..8220351683 100644 --- a/test/unit/test_rewrite_map.py +++ b/test/unit/test_rewrite_map.py @@ -28,6 +28,22 @@ class TestRewriteMap(unittest.TestCase): self.assertIs(sub_map[a+b], e) self.assertIs(sub_map[(a+b)*c], f) + def test_multistage_substitute(self): + a = UOp.variable('a', 0, 10) + b = UOp.variable('b', 0, 10) + c = UOp.variable('c', 0, 10) + d = UOp.variable('d', 0, 10) + sub1 = {a+b:c} + start = (a+b)*c + # stage 1: (a+b)*c -> c*c + sub_map1 = graph_rewrite_map(start, _substitute, sub1, bottom_up=True) + self.assertIs(sub_map1[(a+b)*c], c*c) + # stage 2: c*c -> d + sub2 = {c*c:d} + sub_map2 = graph_rewrite_map(sub_map1[start], _substitute, sub2, input_map=sub_map1, bottom_up=True) + # (a+b)*c -> c*c -> d + self.assertIs(sub_map2[(a+b)*c], d) + def test_add_zero(self): # Build a small graph: add(0, add(const=0, const=5)) zero_node = UOp.const(dtypes.int, 0) diff --git a/tinygrad/engine/grouper.py b/tinygrad/engine/grouper.py index 42bb29df93..0668e9c81d 100644 --- a/tinygrad/engine/grouper.py +++ b/tinygrad/engine/grouper.py @@ -401,14 +401,15 @@ def get_becomes_map(big_sink:UOp) -> tuple[dict[UOp, UOp], dict[Variable, int]]: # group into kernels sink = tensor_map[big_sink] realize_map = group_realizes(sink) - kernel_map = graph_rewrite_map(sink, create_kernels, KernelContext(realize_map, {v:k.metadata for k,v in tensor_map.items()}), bottom_up=True) - sched_sink = kernel_map[sink] + tensor_map = graph_rewrite_map(sink, create_kernels, KernelContext(realize_map, {v:k.metadata for k,v in tensor_map.items()}), bottom_up=True, + input_map=tensor_map) + sched_sink = tensor_map[sink] type_verify(list(sched_sink.toposort), kernel_spec) # map tensors to buffer/const, optionally apply a VIEW on top becomes_map: dict[UOp, UOp] = {} for k,v in tensor_map.items(): - if (kernel:=kernel_map.get(v.base)) is not None and kernel.base.op is Ops.ASSIGN: v = kernel.view(unwrap(v.st)) + if (kernel:=tensor_map.get(v.base)) is not None and kernel.base.op is Ops.ASSIGN: v = kernel.view(unwrap(v.st)) if k is v: continue if k.op is Ops.ASSIGN: becomes_map[k] = k.src[0] diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 96c2ead8a0..f06ef7b110 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -958,9 +958,12 @@ def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=N return rewrite_ctx.bottom_up_rewrite(sink) if bottom_up else rewrite_ctx.top_down_rewrite(sink) @track_matches -def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, track_children=False) -> dict[UOp, UOp]: +def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, track_children=False, input_map=None) -> dict[UOp, UOp]: rewrite_ctx = RewriteContext(pm, ctx, children=sink.get_children_map() if track_children else None) - return {k:(rewrite_ctx.bottom_up_rewrite(k) if bottom_up else rewrite_ctx.top_down_rewrite(k)) for k in list(sink.toposort)[::-1]} + new_map = {k:(rewrite_ctx.bottom_up_rewrite(k) if bottom_up else rewrite_ctx.top_down_rewrite(k)) for k in list(sink.toposort)[::-1]} + if input_map is not None: + for k,v in input_map.items(): new_map[k] = new_map.get(v,v) + return new_map def sint_to_uop(x:sint, dtype:DType=dtypes.int) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x