diff --git a/tinygrad/codegen/late/devectorizer.py b/tinygrad/codegen/late/devectorizer.py index 8c494f4f5e..263a2c071c 100644 --- a/tinygrad/codegen/late/devectorizer.py +++ b/tinygrad/codegen/late/devectorizer.py @@ -38,17 +38,7 @@ def _drop_valid_stmts(valid:UOp, idx:UOp, height:int, width:int) -> list[UOp]: def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None: idx = uop_given_valid(valid, start_idx) - if not isinstance(buf.dtype, ImageDType): return None if idx is start_idx else buf.index(idx.valid(valid), ptr=True) - - # wait for it to be image indexed before running simplification - if start_idx.dtype.count != 2: return None - - drop_stmt = _drop_valid_stmts(valid, idx, buf.dtype.shape[0], buf.dtype.shape[1]) - - if not drop_stmt and idx is start_idx: return None - new_valid = UOp.uprod(*ss) if (ss:=[s for s in valid.split_uop(Ops.AND) if s not in drop_stmt]) else None - x, y = idx.gep(0), idx.gep(1) - return buf.index(x.valid(new_valid) if new_valid is not None else x, y.valid(new_valid) if new_valid is not None else y, ptr=True) + return None if isinstance(buf.dtype, ImageDType) or idx is start_idx else buf.index(idx.valid(valid), ptr=True) def simplify_valid_image_load(buf:UOp, start_x:UOp, start_y:UOp, valid:UOp) -> UOp|None: if not isinstance(buf.dtype, ImageDType) or start_x.dtype.scalar() is not dtypes.weakint or \ diff --git a/tinygrad/codegen/late/gater.py b/tinygrad/codegen/late/gater.py index 2321b973c4..3fbafb74e5 100644 --- a/tinygrad/codegen/late/gater.py +++ b/tinygrad/codegen/late/gater.py @@ -22,20 +22,19 @@ def index_and_valid(idx:UOp) -> tuple[UOp, UOp]: def valid_idx(idx:UOp, valid:UOp) -> UOp: return idx if valid.op is Ops.CONST and valid.arg is True else valid.where(idx, idx.const_like(Invalid)) -def get_image_idx(idx:UOp, width:int) -> UOp: +def get_image_idx(idx:UOp, height:int, width:int) -> UOp: x, valid = index_and_valid(idx.src[1]) - idx_x, idx_y = (x.gep(0), x.gep(1)) if x.dtype.count == 2 else ((x // 4) % width, x // (4*width)) + px = x // 4 + idx_x, idx_y = (px, px.const_like(0)) if height == 1 else (px % width, px // width) return idx.replace(src=(idx.src[0], valid_idx(idx_x, valid), valid_idx(idx_y, valid))) def image_fixup(ls:UOp): # normal image load/store from split_load_store: casted linear offset -> image x/y coordinates if ls.src[0].op is Ops.CAST and (cast_idx:=ls.src[0].src[0]).op is Ops.INDEX and isinstance(dt:=cast_idx.src[0].dtype, ImageDType): assert ls.src[0].dtype.count == 4, "image must be casted to 4" - return ls.replace(src=(cast_idx if len(cast_idx.src) == 3 else get_image_idx(cast_idx, dt.shape[1]),)+ls.src[1:]) + return ls.replace(src=(cast_idx if len(cast_idx.src) == 3 else get_image_idx(cast_idx, dt.shape[0], dt.shape[1]),)+ls.src[1:]) if ls.src[0].op is not Ops.INDEX or not isinstance(dt:=ls.src[0].src[0].dtype, ImageDType) or len(ls.src[0].src) == 3: return None - off, _ = index_and_valid(ls.src[0].src[1]) - if off.dtype.count == 2: return ls.replace(src=(get_image_idx(ls.src[0], dt.shape[1]),)+ls.src[1:]) # this is an unprocessed image without a cast, we should just make it a buffer idx = ls.src[0].src[0].replace(dtype=(new_dt:=dtypes.half if dt.itemsize == 2 else dtypes.float).ptr(dt.size)).index(ls.src[0].src[1]) @@ -64,7 +63,7 @@ pm_move_gates_from_index = PatternMatcher([ (UPat.var("gate").where(UPat.var("a"), UPat().load(UPat(), ~UPat.var("gate", dtype=dtypes.bool), name="l").or_casted()), lambda gate,l,a: l.replace(src=(l.src[0], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype else a.cast(l.dtype), l.src[2])).cast(a.dtype)), - # vectorized indexes (ie. images) must be int + # vectorized indexes must be int (UPat(Ops.INDEX, src=(UPat(), UPat(Ops.STACK, dtypes.long, name="vec")), allow_any_len=True, name="idx"), lambda idx,vec: idx.replace(src=(idx.src[0], UOp.vectorize(*(u.cast(dtypes.int) for u in vec.src)), *idx.src[2:]))), (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("x"), UPat.var("y")), name="idx"), image_coords_to_int),