renumber ranges (#12182)

* enable rangeify const folding

* renumber ranges for kernel deduping
This commit is contained in:
George Hotz
2025-09-15 13:03:39 +08:00
committed by GitHub
parent e1fef895b1
commit ae0edc8a67

View File

@@ -458,6 +458,7 @@ class LocalAddBufferContext:
dg:int = 0
map:dict = field(default_factory=dict)
vars:dict = field(default_factory=dict)
range:int = 0
def debuf(ctx:LocalAddBufferContext, buf:UOp):
ret = UOp(Ops.DEFINE_GLOBAL, buf.dtype.ptr(buf.arg), arg=ctx.dg)
@@ -477,6 +478,12 @@ def handle_assign(ctx:LocalAddBufferContext, assign:UOp):
ctx.map[buf] = assign
return buf
def renumber_range(ctx:LocalAddBufferContext, r:UOp):
if r.tag is not None: return None
ret = r.replace(arg=(ctx.range,)+r.arg[1:], tag=())
ctx.range += 1
return ret
to_define_global = PatternMatcher([
(UPat(Ops.BUFFER, name="buf"), debuf),
(UPat(Ops.BIND, name="b"), unbind_kernel),
@@ -485,6 +492,9 @@ to_define_global = PatternMatcher([
# HACK in case any CONSTs were replaced
# this is only needed if you are using symbolic
#(UPat(Ops.CONST, name="c"), lambda c: c.replace(src=()) if len(c.src) else None),
# renumber the ranges starting with 0 so that kernel deduping works
(UPat(Ops.RANGE, name="r"), renumber_range),
])
rangeify_codegen = PatternMatcher([