mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-15 09:33:03 +08:00
renumber ranges (#12182)
* enable rangeify const folding * renumber ranges for kernel deduping
This commit is contained in:
@@ -458,6 +458,7 @@ class LocalAddBufferContext:
|
||||
dg:int = 0
|
||||
map:dict = field(default_factory=dict)
|
||||
vars:dict = field(default_factory=dict)
|
||||
range:int = 0
|
||||
|
||||
def debuf(ctx:LocalAddBufferContext, buf:UOp):
|
||||
ret = UOp(Ops.DEFINE_GLOBAL, buf.dtype.ptr(buf.arg), arg=ctx.dg)
|
||||
@@ -477,6 +478,12 @@ def handle_assign(ctx:LocalAddBufferContext, assign:UOp):
|
||||
ctx.map[buf] = assign
|
||||
return buf
|
||||
|
||||
def renumber_range(ctx:LocalAddBufferContext, r:UOp):
|
||||
if r.tag is not None: return None
|
||||
ret = r.replace(arg=(ctx.range,)+r.arg[1:], tag=())
|
||||
ctx.range += 1
|
||||
return ret
|
||||
|
||||
to_define_global = PatternMatcher([
|
||||
(UPat(Ops.BUFFER, name="buf"), debuf),
|
||||
(UPat(Ops.BIND, name="b"), unbind_kernel),
|
||||
@@ -485,6 +492,9 @@ to_define_global = PatternMatcher([
|
||||
# HACK in case any CONSTs were replaced
|
||||
# this is only needed if you are using symbolic
|
||||
#(UPat(Ops.CONST, name="c"), lambda c: c.replace(src=()) if len(c.src) else None),
|
||||
|
||||
# renumber the ranges starting with 0 so that kernel deduping works
|
||||
(UPat(Ops.RANGE, name="r"), renumber_range),
|
||||
])
|
||||
|
||||
rangeify_codegen = PatternMatcher([
|
||||
|
||||
Reference in New Issue
Block a user