diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index 6813359208..3c555bbe64 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -118,9 +118,6 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp: sink = graph_rewrite(sink, pm_dtype_decomps, ctx=(set(), ren), name="decomp dtypes") sink = graph_rewrite(sink, pm_transcendental, name="transcendental") - # move gates from unrenderable INVALID where - sink = graph_rewrite(sink, pm_move_gates_from_index, name="move gates from index") - # GEP/STACK stuff sink = graph_rewrite(sink, pm_render, name="pm_render gep/stack") @@ -130,6 +127,9 @@ def full_rewrite_to_sink(ast:UOp, ren:Renderer, optimize:bool=True) -> UOp: name_to_slot = {nm:num_params+i for i,nm in enumerate(sorted([x.arg[0] for x in sink.toposort() if x.op is Ops.DEFINE_VAR]))} sink = graph_rewrite(sink, pm_remove_vec_dtypes, ctx=name_to_slot, name="transform to new style") + # move gates from unrenderable INVALID where + sink = graph_rewrite(sink, pm_move_gates_from_index, name="move gates from index") + # final rules for the renderer (without sym) extra_matcher = ren.extra_matcher if ren.extra_matcher is not None else PatternMatcher([]) pm_final_rewrite = pm_decomp+extra_matcher+pm_split_ends diff --git a/tinygrad/codegen/late/gater.py b/tinygrad/codegen/late/gater.py index 2f2195bd96..cd88f5b608 100644 --- a/tinygrad/codegen/late/gater.py +++ b/tinygrad/codegen/late/gater.py @@ -3,19 +3,19 @@ from tinygrad.uop.ops import PatternMatcher, UPat, Ops from tinygrad.dtype import Invalid, dtypes pm_move_gates_from_index = PatternMatcher([ - # here we create the alt value for load to be 0s and remove the where Invalid - (UPat.var("buf").index(UPat.var("gate").where(UPat.var("idx"), UPat(arg=Invalid))).or_casted(name="cast").load(name="l"), - lambda buf,gate,idx,cast,l: buf.index(idx, ptr=True).cast(cast.dtype).load(l.const_like(0), gate, dtype=l.dtype)), - (UPat.var("buf").index(UPat.var("gate").where(UPat.var("idx"), UPat(arg=Invalid))).or_casted(name="cast").store(UPat.var("data")), - lambda buf,gate,idx,cast,data: buf.index(idx, ptr=True).cast(cast.dtype).store(data, gate)), + # for image idx (must be first) + (UPat.var("buf").index(UPat.var("gate").where(UPat.var("idx_y"), UPat(arg=Invalid)), + UPat.var("gate").where(UPat.var("idx_x"), UPat(arg=Invalid))).load(name="l"), + lambda buf,gate,idx_y,idx_x,l: buf.index(idx_y, idx_x).load(l.vconst_like(0), gate)), + (UPat.var("buf").index(UPat.var("gate").where(UPat.var("idx_y"), UPat(arg=Invalid)), + UPat.var("gate").where(UPat.var("idx_x"), UPat(arg=Invalid))).store(UPat.var("data")), + lambda buf,gate,idx_y,idx_x,data: buf.index(idx_y, idx_x).store(data, gate)), - # for image idx - (UPat.var("buf").index(UPat.var("gate").where(UPat.var("idx_y"), UPat(arg=Invalid)), - UPat.var("gate").where(UPat.var("idx_x"), UPat(arg=Invalid))).or_casted(name="cast").load(name="l"), - lambda buf,gate,idx_y,idx_x,cast,l: buf.index(idx_y, idx_x, ptr=True).cast(cast.dtype).load(l.const_like(0), gate, dtype=l.dtype)), - (UPat.var("buf").index(UPat.var("gate").where(UPat.var("idx_y"), UPat(arg=Invalid)), - UPat.var("gate").where(UPat.var("idx_x"), UPat(arg=Invalid))).or_casted(name="cast").store(UPat.var("data")), - lambda buf,gate,idx_y,idx_x,cast,data: buf.index(idx_y, idx_x, ptr=True).cast(cast.dtype).store(data, gate)), + # here we create the alt value for load to be 0s and remove the where Invalid + (UPat((Ops.INDEX, Ops.SHRINK), src=(UPat(), UPat.var("gate").where(UPat.var("idx"), UPat(arg=Invalid)),), name="mop", allow_any_len=True) \ + .load(name="l"), lambda mop,gate,idx,l: mop.replace(src=(mop.src[0],idx)+mop.src[2:]).load(l.vconst_like(0), gate)), + (UPat((Ops.INDEX, Ops.SHRINK), src=(UPat(), UPat.var("gate").where(UPat.var("idx"), UPat(arg=Invalid)),), name="mop", allow_any_len=True) \ + .store(UPat.var("data")), lambda mop,gate,idx,data: mop.replace(src=(mop.src[0],idx)+mop.src[2:]).store(data, gate)), # Where after gated load becomes alt value (UPat.var("gate").where(UPat().load(UPat(), UPat.var("gate", dtype=dtypes.bool), name="l").or_casted(), UPat.var("a")), lambda gate,l,a: diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index bb7a0f49a6..49381dacb7 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -497,6 +497,12 @@ class UOp(RandMixin, metaclass=UOpMetaClass): def _wrap_uop(self, u:UOp) -> UOp: return u def const_like(self, b:ConstLike, dtype:DType|None=None): return UOp.const(dtype or self.dtype.base, b, shape=self._shape) + def vconst_like(self, b:ConstLike, dtype:DType|None=None): + # for use after movement ops have been removed + ret = UOp.const(dtype or self.dtype.base, b) + if self.shape == (): return ret + if len(self.shape) == 1: return UOp(Ops.STACK, ret.dtype, (ret,)*self.max_numel()) + raise RuntimeError(f"vconst_like only works on 0 or 1D shapes, not {self.shape}") def ufix(self, x): if isinstance(x, UOp): return x # float self keeps its dtype for any scalar, int self only for int/Invalid scalars