something

This commit is contained in:
George Hotz
2026-05-08 17:32:30 -07:00
parent e14b2b41c6
commit b910f1d5c0
2 changed files with 6 additions and 17 deletions

View File

@@ -38,17 +38,7 @@ def _drop_valid_stmts(valid:UOp, idx:UOp, height:int, width:int) -> list[UOp]:
def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
idx = uop_given_valid(valid, start_idx)
if not isinstance(buf.dtype, ImageDType): return None if idx is start_idx else buf.index(idx.valid(valid), ptr=True)
# wait for it to be image indexed before running simplification
if start_idx.dtype.count != 2: return None
drop_stmt = _drop_valid_stmts(valid, idx, buf.dtype.shape[0], buf.dtype.shape[1])
if not drop_stmt and idx is start_idx: return None
new_valid = UOp.uprod(*ss) if (ss:=[s for s in valid.split_uop(Ops.AND) if s not in drop_stmt]) else None
x, y = idx.gep(0), idx.gep(1)
return buf.index(x.valid(new_valid) if new_valid is not None else x, y.valid(new_valid) if new_valid is not None else y, ptr=True)
return None if isinstance(buf.dtype, ImageDType) or idx is start_idx else buf.index(idx.valid(valid), ptr=True)
def simplify_valid_image_load(buf:UOp, start_x:UOp, start_y:UOp, valid:UOp) -> UOp|None:
if not isinstance(buf.dtype, ImageDType) or start_x.dtype.scalar() is not dtypes.weakint or \

View File

@@ -22,20 +22,19 @@ def index_and_valid(idx:UOp) -> tuple[UOp, UOp]:
def valid_idx(idx:UOp, valid:UOp) -> UOp:
return idx if valid.op is Ops.CONST and valid.arg is True else valid.where(idx, idx.const_like(Invalid))
def get_image_idx(idx:UOp, width:int) -> UOp:
def get_image_idx(idx:UOp, height:int, width:int) -> UOp:
x, valid = index_and_valid(idx.src[1])
idx_x, idx_y = (x.gep(0), x.gep(1)) if x.dtype.count == 2 else ((x // 4) % width, x // (4*width))
px = x // 4
idx_x, idx_y = (px, px.const_like(0)) if height == 1 else (px % width, px // width)
return idx.replace(src=(idx.src[0], valid_idx(idx_x, valid), valid_idx(idx_y, valid)))
def image_fixup(ls:UOp):
# normal image load/store from split_load_store: casted linear offset -> image x/y coordinates
if ls.src[0].op is Ops.CAST and (cast_idx:=ls.src[0].src[0]).op is Ops.INDEX and isinstance(dt:=cast_idx.src[0].dtype, ImageDType):
assert ls.src[0].dtype.count == 4, "image must be casted to 4"
return ls.replace(src=(cast_idx if len(cast_idx.src) == 3 else get_image_idx(cast_idx, dt.shape[1]),)+ls.src[1:])
return ls.replace(src=(cast_idx if len(cast_idx.src) == 3 else get_image_idx(cast_idx, dt.shape[0], dt.shape[1]),)+ls.src[1:])
if ls.src[0].op is not Ops.INDEX or not isinstance(dt:=ls.src[0].src[0].dtype, ImageDType) or len(ls.src[0].src) == 3: return None
off, _ = index_and_valid(ls.src[0].src[1])
if off.dtype.count == 2: return ls.replace(src=(get_image_idx(ls.src[0], dt.shape[1]),)+ls.src[1:])
# this is an unprocessed image without a cast, we should just make it a buffer
idx = ls.src[0].src[0].replace(dtype=(new_dt:=dtypes.half if dt.itemsize == 2 else dtypes.float).ptr(dt.size)).index(ls.src[0].src[1])
@@ -64,7 +63,7 @@ pm_move_gates_from_index = PatternMatcher([
(UPat.var("gate").where(UPat.var("a"), UPat().load(UPat(), ~UPat.var("gate", dtype=dtypes.bool), name="l").or_casted()), lambda gate,l,a:
l.replace(src=(l.src[0], a.src[0] if a.op is Ops.CAST and a.src[0].dtype == l.dtype else a.cast(l.dtype), l.src[2])).cast(a.dtype)),
# vectorized indexes (ie. images) must be int
# vectorized indexes must be int
(UPat(Ops.INDEX, src=(UPat(), UPat(Ops.STACK, dtypes.long, name="vec")), allow_any_len=True, name="idx"),
lambda idx,vec: idx.replace(src=(idx.src[0], UOp.vectorize(*(u.cast(dtypes.int) for u in vec.src)), *idx.src[2:]))),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("x"), UPat.var("y")), name="idx"), image_coords_to_int),