move new style transform up more (#16593)

* move new style transform up more

* pm_move_gates_from_index works on new style
This commit is contained in:
George Hotz
2026-06-12 17:20:12 -07:00
committed by GitHub
parent a35964493e
commit 96b86aad7b
3 changed files with 21 additions and 15 deletions

View File

@@ -118,9 +118,6 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
sink = graph_rewrite(sink, pm_dtype_decomps, ctx=(set(), ren), name="decomp dtypes")
sink = graph_rewrite(sink, pm_transcendental, name="transcendental")
# move gates from unrenderable INVALID where
sink = graph_rewrite(sink, pm_move_gates_from_index, name="move gates from index")
# GEP/STACK stuff
sink = graph_rewrite(sink, pm_render, name="pm_render gep/stack")
@@ -130,6 +127,9 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp:
name_to_slot = {nm:num_params+i for i,nm in enumerate(sorted([x.arg[0] for x in sink.toposort() if x.op is Ops.DEFINE_VAR]))}
sink = graph_rewrite(sink, pm_remove_vec_dtypes, ctx=name_to_slot, name="transform to new style")
# move gates from unrenderable INVALID where
sink = graph_rewrite(sink, pm_move_gates_from_index, name="move gates from index")
# final rules for the renderer (without sym)
extra_matcher = ren.extra_matcher if ren.extra_matcher is not None else PatternMatcher([])
pm_final_rewrite = pm_decomp+extra_matcher+pm_split_ends

View File

@@ -3,19 +3,19 @@ from tinygrad.uop.ops import PatternMatcher, UPat, Ops
from tinygrad.dtype import Invalid, dtypes
pm_move_gates_from_index = PatternMatcher([
# here we create the alt value for load to be 0s and remove the where Invalid
(UPat.var("buf").index(UPat.var("gate").where(UPat.var("idx"), UPat(arg=Invalid))).or_casted(name="cast").load(name="l"),
lambda buf,gate,idx,cast,l: buf.index(idx, ptr=True).cast(cast.dtype).load(l.const_like(0), gate, dtype=l.dtype)),
(UPat.var("buf").index(UPat.var("gate").where(UPat.var("idx"), UPat(arg=Invalid))).or_casted(name="cast").store(UPat.var("data")),
lambda buf,gate,idx,cast,data: buf.index(idx, ptr=True).cast(cast.dtype).store(data, gate)),
# for image idx (must be first)
(UPat.var("buf").index(UPat.var("gate").where(UPat.var("idx_y"), UPat(arg=Invalid)),
UPat.var("gate").where(UPat.var("idx_x"), UPat(arg=Invalid))).load(name="l"),
lambda buf,gate,idx_y,idx_x,l: buf.index(idx_y, idx_x).load(l.vconst_like(0), gate)),
(UPat.var("buf").index(UPat.var("gate").where(UPat.var("idx_y"), UPat(arg=Invalid)),
UPat.var("gate").where(UPat.var("idx_x"), UPat(arg=Invalid))).store(UPat.var("data")),
lambda buf,gate,idx_y,idx_x,data: buf.index(idx_y, idx_x).store(data, gate)),
# for image idx
(UPat.var("buf").index(UPat.var("gate").where(UPat.var("idx_y"), UPat(arg=Invalid)),
UPat.var("gate").where(UPat.var("idx_x"), UPat(arg=Invalid))).or_casted(name="cast").load(name="l"),
lambda buf,gate,idx_y,idx_x,cast,l: buf.index(idx_y, idx_x, ptr=True).cast(cast.dtype).load(l.const_like(0), gate, dtype=l.dtype)),
(UPat.var("buf").index(UPat.var("gate").where(UPat.var("idx_y"), UPat(arg=Invalid)),
UPat.var("gate").where(UPat.var("idx_x"), UPat(arg=Invalid))).or_casted(name="cast").store(UPat.var("data")),
lambda buf,gate,idx_y,idx_x,cast,data: buf.index(idx_y, idx_x, ptr=True).cast(cast.dtype).store(data, gate)),
# here we create the alt value for load to be 0s and remove the where Invalid
(UPat((Ops.INDEX, Ops.SHRINK), src=(UPat(), UPat.var("gate").where(UPat.var("idx"), UPat(arg=Invalid)),), name="mop", allow_any_len=True) \
.load(name="l"), lambda mop,gate,idx,l: mop.replace(src=(mop.src[0],idx)+mop.src[2:]).load(l.vconst_like(0), gate)),
(UPat((Ops.INDEX, Ops.SHRINK), src=(UPat(), UPat.var("gate").where(UPat.var("idx"), UPat(arg=Invalid)),), name="mop", allow_any_len=True) \
.store(UPat.var("data")), lambda mop,gate,idx,data: mop.replace(src=(mop.src[0],idx)+mop.src[2:]).store(data, gate)),
# Where after gated load becomes alt value
(UPat.var("gate").where(UPat().load(UPat(), UPat.var("gate", dtype=dtypes.bool), name="l").or_casted(), UPat.var("a")), lambda gate,l,a:

View File

@@ -497,6 +497,12 @@ class UOp(RandMixin, metaclass=UOpMetaClass):
def _wrap_uop(self, u:UOp) -> UOp: return u
def const_like(self, b:ConstLike, dtype:DType|None=None):
return UOp.const(dtype or self.dtype.base, b, shape=self._shape)
def vconst_like(self, b:ConstLike, dtype:DType|None=None):
# for use after movement ops have been removed
ret = UOp.const(dtype or self.dtype.base, b)
if self.shape == (): return ret
if len(self.shape) == 1: return UOp(Ops.STACK, ret.dtype, (ret,)*self.max_numel())
raise RuntimeError(f"vconst_like only works on 0 or 1D shapes, not {self.shape}")
def ufix(self, x):
if isinstance(x, UOp): return x
# float self keeps its dtype for any scalar, int self only for int/Invalid scalars