From 2ccefa11ec686f248039cf489d949ba2a7b17773 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Wed, 6 May 2026 09:30:58 -0700 Subject: [PATCH] does this pass? --- tinygrad/codegen/__init__.py | 6 ++++-- tinygrad/codegen/late/devectorizer.py | 1 + tinygrad/codegen/late/gater.py | 8 +++++--- tinygrad/uop/ops.py | 6 ------ 4 files changed, 10 insertions(+), 11 deletions(-) diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index 2930b5f538..d878f27f20 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -77,11 +77,13 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp: else: pm_devectorize = sym+load_store_folding+correct_load_store+load_store_indexing if DEVECTORIZE >= 0: sink = graph_rewrite(sink, pm_devectorize, ctx=ren, name="devectorize") + # lower the index dtype to a concrete int. this needs to happen while gates are still present + sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, name="lower all index dtypes") + # move the gates from index onto the loads and stores sink = graph_rewrite(sink, pm_move_gates_from_index, name="move gates from index") - # lower the index dtype to a concrete int - sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, name="lower all index dtypes") + # a final symbolic sink = graph_rewrite(sink, symbolic, name="post index symbolic") # optional pre matcher diff --git a/tinygrad/codegen/late/devectorizer.py b/tinygrad/codegen/late/devectorizer.py index f536c97de5..7329ef7d49 100644 --- a/tinygrad/codegen/late/devectorizer.py +++ b/tinygrad/codegen/late/devectorizer.py @@ -37,6 +37,7 @@ def _drop_valid_stmts(valid:UOp, idx:UOp, height:int, width:int) -> list[UOp]: return drop_stmt def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None: + start_idx = start_idx.simplify() # if you don't do this, uop_given_valid may simplify things and this might inf loop idx = uop_given_valid(valid, start_idx) if not isinstance(buf.dtype, ImageDType): return None if idx is start_idx else buf.index(idx.valid(valid), ptr=True) diff --git a/tinygrad/codegen/late/gater.py b/tinygrad/codegen/late/gater.py index 0b0b7c70e2..e78e51c096 100644 --- a/tinygrad/codegen/late/gater.py +++ b/tinygrad/codegen/late/gater.py @@ -1,11 +1,13 @@ # this transforms Invalid into gated load/stores from tinygrad.uop.ops import PatternMatcher, UPat -from tinygrad.dtype import Invalid +from tinygrad.dtype import Invalid, dtypes pm_move_gates_from_index = PatternMatcher([ (UPat.var("buf").index(UPat.var("gate").where(UPat.var("idx"), UPat(arg=Invalid))).or_casted(name="cast").load(name="l"), - lambda buf,gate,idx,cast,l: buf.index(idx).cast(cast.dtype).load(l.const_like(0), gate, dtype=l.dtype)), + lambda buf,gate,idx,cast,l: buf.index(idx, ptr=True).cast(cast.dtype).load(l.const_like(0), gate, dtype=l.dtype)), (UPat.var("buf").index(UPat.var("gate").where(UPat.var("idx"), UPat(arg=Invalid))).or_casted(name="cast").store(UPat.var("data")), - lambda buf,gate,idx,cast,data: buf.index(idx).cast(cast.dtype).store(data, gate)), + lambda buf,gate,idx,cast,data: buf.index(idx, ptr=True).cast(cast.dtype).store(data, gate)), + # remove hanging weakint casts + (UPat.var("buf").index(UPat.var("idx", dtypes.ints).cast()), lambda buf,idx: buf.index(idx, ptr=True)), ]) diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 6e3275f054..51b55e9ceb 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -1542,12 +1542,6 @@ pm_lower_index_dtype = PatternMatcher([ (UPat(Ops.DEFINE_VAR, dtype=dtypes.weakint, name="u"), lambda u: u.replace(dtype=dtypes.int).cast(dtypes.weakint)), (UPat(Ops.BIND, src=(UPat.var("var").cast(dtypes.weakint), UPat.cvar("val").cast(dtypes.weakint))), lambda var,val: var.bind(val).cast(dtypes.weakint)), - # lower Invalid - (UPat.var("buf").index(UPat.var("cond").where(UPat.var("idx"), UPat(Ops.CONST, arg=Invalid))), lambda buf,idx,cond: buf.index(idx, cond, ptr=True)), - # remove hanging casts - (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast()),), lambda buf,idx: buf.index(idx, ptr=True)), - (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast(), UPat.var("valid"))), - lambda buf,idx,valid: buf.index(idx, valid, ptr=True)), (UPat((Ops.SINK, Ops.NOOP, Ops.END), name="n"), lambda n: n.replace(src=tuple(s.src[0] if s.op is Ops.CAST and s.dtype == dtypes.weakint else s for s in n.src))), # vectorized indexes (ie. images) must be int