diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index eb4452799b..8174a527d2 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -458,6 +458,7 @@ class LocalAddBufferContext: dg:int = 0 map:dict = field(default_factory=dict) vars:dict = field(default_factory=dict) + range:int = 0 def debuf(ctx:LocalAddBufferContext, buf:UOp): ret = UOp(Ops.DEFINE_GLOBAL, buf.dtype.ptr(buf.arg), arg=ctx.dg) @@ -477,6 +478,12 @@ def handle_assign(ctx:LocalAddBufferContext, assign:UOp): ctx.map[buf] = assign return buf +def renumber_range(ctx:LocalAddBufferContext, r:UOp): + if r.tag is not None: return None + ret = r.replace(arg=(ctx.range,)+r.arg[1:], tag=()) + ctx.range += 1 + return ret + to_define_global = PatternMatcher([ (UPat(Ops.BUFFER, name="buf"), debuf), (UPat(Ops.BIND, name="b"), unbind_kernel), @@ -485,6 +492,9 @@ to_define_global = PatternMatcher([ # HACK in case any CONSTs were replaced # this is only needed if you are using symbolic #(UPat(Ops.CONST, name="c"), lambda c: c.replace(src=()) if len(c.src) else None), + + # renumber the ranges starting with 0 so that kernel deduping works + (UPat(Ops.RANGE, name="r"), renumber_range), ]) rangeify_codegen = PatternMatcher([