mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
does this pass?
This commit is contained in:
@@ -77,11 +77,13 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
|
||||
else: pm_devectorize = sym+load_store_folding+correct_load_store+load_store_indexing
|
||||
if DEVECTORIZE >= 0: sink = graph_rewrite(sink, pm_devectorize, ctx=ren, name="devectorize")
|
||||
|
||||
# lower the index dtype to a concrete int. this needs to happen while gates are still present
|
||||
sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, name="lower all index dtypes")
|
||||
|
||||
# move the gates from index onto the loads and stores
|
||||
sink = graph_rewrite(sink, pm_move_gates_from_index, name="move gates from index")
|
||||
|
||||
# lower the index dtype to a concrete int
|
||||
sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, name="lower all index dtypes")
|
||||
# a final symbolic
|
||||
sink = graph_rewrite(sink, symbolic, name="post index symbolic")
|
||||
|
||||
# optional pre matcher
|
||||
|
||||
@@ -37,6 +37,7 @@ def _drop_valid_stmts(valid:UOp, idx:UOp, height:int, width:int) -> list[UOp]:
|
||||
return drop_stmt
|
||||
|
||||
def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
|
||||
start_idx = start_idx.simplify() # if you don't do this, uop_given_valid may simplify things and this might inf loop
|
||||
idx = uop_given_valid(valid, start_idx)
|
||||
if not isinstance(buf.dtype, ImageDType): return None if idx is start_idx else buf.index(idx.valid(valid), ptr=True)
|
||||
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
# this transforms Invalid into gated load/stores
|
||||
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat
|
||||
from tinygrad.dtype import Invalid
|
||||
from tinygrad.dtype import Invalid, dtypes
|
||||
|
||||
pm_move_gates_from_index = PatternMatcher([
|
||||
(UPat.var("buf").index(UPat.var("gate").where(UPat.var("idx"), UPat(arg=Invalid))).or_casted(name="cast").load(name="l"),
|
||||
lambda buf,gate,idx,cast,l: buf.index(idx).cast(cast.dtype).load(l.const_like(0), gate, dtype=l.dtype)),
|
||||
lambda buf,gate,idx,cast,l: buf.index(idx, ptr=True).cast(cast.dtype).load(l.const_like(0), gate, dtype=l.dtype)),
|
||||
(UPat.var("buf").index(UPat.var("gate").where(UPat.var("idx"), UPat(arg=Invalid))).or_casted(name="cast").store(UPat.var("data")),
|
||||
lambda buf,gate,idx,cast,data: buf.index(idx).cast(cast.dtype).store(data, gate)),
|
||||
lambda buf,gate,idx,cast,data: buf.index(idx, ptr=True).cast(cast.dtype).store(data, gate)),
|
||||
# remove hanging weakint casts
|
||||
(UPat.var("buf").index(UPat.var("idx", dtypes.ints).cast()), lambda buf,idx: buf.index(idx, ptr=True)),
|
||||
])
|
||||
|
||||
@@ -1542,12 +1542,6 @@ pm_lower_index_dtype = PatternMatcher([
|
||||
(UPat(Ops.DEFINE_VAR, dtype=dtypes.weakint, name="u"), lambda u: u.replace(dtype=dtypes.int).cast(dtypes.weakint)),
|
||||
(UPat(Ops.BIND, src=(UPat.var("var").cast(dtypes.weakint), UPat.cvar("val").cast(dtypes.weakint))),
|
||||
lambda var,val: var.bind(val).cast(dtypes.weakint)),
|
||||
# lower Invalid
|
||||
(UPat.var("buf").index(UPat.var("cond").where(UPat.var("idx"), UPat(Ops.CONST, arg=Invalid))), lambda buf,idx,cond: buf.index(idx, cond, ptr=True)),
|
||||
# remove hanging casts
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast()),), lambda buf,idx: buf.index(idx, ptr=True)),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.ints).cast(), UPat.var("valid"))),
|
||||
lambda buf,idx,valid: buf.index(idx, valid, ptr=True)),
|
||||
(UPat((Ops.SINK, Ops.NOOP, Ops.END), name="n"),
|
||||
lambda n: n.replace(src=tuple(s.src[0] if s.op is Ops.CAST and s.dtype == dtypes.weakint else s for s in n.src))),
|
||||
# vectorized indexes (ie. images) must be int
|
||||
|
||||
Reference in New Issue
Block a user