does this pass?

This commit is contained in:
George Hotz
2026-05-06 09:30:58 -07:00
parent 95b0a651c2
commit 2ccefa11ec
4 changed files with 10 additions and 11 deletions

View File

@@ -77,11 +77,13 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
else: pm_devectorize = sym+load_store_folding+correct_load_store+load_store_indexing
if DEVECTORIZE >= 0: sink = graph_rewrite(sink, pm_devectorize, ctx=ren, name="devectorize")
# lower the index dtype to a concrete int. this needs to happen while gates are still present
sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, name="lower all index dtypes")
# move the gates from index onto the loads and stores
sink = graph_rewrite(sink, pm_move_gates_from_index, name="move gates from index")
# lower the index dtype to a concrete int
sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, name="lower all index dtypes")
# a final symbolic
sink = graph_rewrite(sink, symbolic, name="post index symbolic")
# optional pre matcher

View File

@@ -37,6 +37,7 @@ def _drop_valid_stmts(valid:UOp, idx:UOp, height:int, width:int) -> list[UOp]:
return drop_stmt
def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
start_idx = start_idx.simplify() # if you don't do this, uop_given_valid may simplify things and this might inf loop
idx = uop_given_valid(valid, start_idx)
if not isinstance(buf.dtype, ImageDType): return None if idx is start_idx else buf.index(idx.valid(valid), ptr=True)

View File

@@ -1,11 +1,13 @@
# this transforms Invalid into gated load/stores
from tinygrad.uop.ops import PatternMatcher, UPat
from tinygrad.dtype import Invalid
from tinygrad.dtype import Invalid, dtypes
pm_move_gates_from_index = PatternMatcher([
(UPat.var("buf").index(UPat.var("gate").where(UPat.var("idx"), UPat(arg=Invalid))).or_casted(name="cast").load(name="l"),
lambda buf,gate,idx,cast,l: buf.index(idx).cast(cast.dtype).load(l.const_like(0), gate, dtype=l.dtype)),
lambda buf,gate,idx,cast,l: buf.index(idx, ptr=True).cast(cast.dtype).load(l.const_like(0), gate, dtype=l.dtype)),
(UPat.var("buf").index(UPat.var("gate").where(UPat.var("idx"), UPat(arg=Invalid))).or_casted(name="cast").store(UPat.var("data")),
lambda buf,gate,idx,cast,data: buf.index(idx).cast(cast.dtype).store(data, gate)),
lambda buf,gate,idx,cast,data: buf.index(idx, ptr=True).cast(cast.dtype).store(data, gate)),
# remove hanging weakint casts
(UPat.var("buf").index(UPat.var("idx", dtypes.ints).cast()), lambda buf,idx: buf.index(idx, ptr=True)),
])

View File

@@ -1542,12 +1542,6 @@ pm_lower_index_dtype = PatternMatcher([
(UPat(Ops.DEFINE_VAR, dtype=dtypes.weakint, name="u"), lambda u: u.replace(dtype=dtypes.int).cast(dtypes.weakint)),
(UPat(Ops.BIND, src=(UPat.var("var").cast(dtypes.weakint), UPat.cvar("val").cast(dtypes.weakint))),
lambda var,val: var.bind(val).cast(dtypes.weakint)),
# lower Invalid
(UPat.var("buf").index(UPat.var("cond").where(UPat.var("idx"), UPat(Ops.CONST, arg=Invalid))), lambda buf,idx,cond: buf.index(idx, cond, ptr=True)),
# remove hanging casts
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast()),), lambda buf,idx: buf.index(idx, ptr=True)),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast(), UPat.var("valid"))),
lambda buf,idx,valid: buf.index(idx, valid, ptr=True)),
(UPat((Ops.SINK, Ops.NOOP, Ops.END), name="n"),
lambda n: n.replace(src=tuple(s.src[0] if s.op is Ops.CAST and s.dtype == dtypes.weakint else s for s in n.src))),
# vectorized indexes (ie. images) must be int