simpler idx rewrite structure in simplify_valid_image_load (#6704)

express valid into things to check when rewriting idx. it's the same for single clause or a simplex
[run_process_replay]
This commit is contained in:
chenyu
2024-09-24 03:35:39 -04:00
committed by GitHub
parent d3ed50c769
commit a6078c099f

View File

@@ -212,26 +212,26 @@ def simplify_valid_image_load(load:UOp, buf:UOp):
# simplify idx given that valid is True
for uop,v in bounds.items():
# some expr has lower bound > upper bound -> valid is an empty set
# some expr has lower bound > upper bound -> valid is an empty set and we return early
if v[0] is not None and v[1] is not None and v[0] > v[1]:
return UOp(UOps.LOAD, load.dtype, (buf, idx, invalid_val, valid.const_like(False)))
if uop.op is UOps.ALU and uop.arg is BinaryOps.ADD and all(is_irreducible(u) and u.vmin == 0 for u in _get_chain(uop, BinaryOps.ADD)):
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
newidxs: List[List[UOp]] = [[], []]
for variable in _get_chain(uop, BinaryOps.ADD):
new = UOp(UOps.DEFINE_VAR, variable.dtype, (), ("fake", 1, variable.vmax))
newidx = replace_uop(graph_rewrite(replace_uop(idx, variable, new), constant_folder), new, variable)
newidxs[0].append(newidx.src[0])
newidxs[1].append(newidx.src[1])
if len(newidxs[0])==1 or (len(newidxs[0]) > 1 and all_same([i.key for i in newidxs[0]])): idx = idx.replace(src=(newidxs[0][0], idx.src[1]))
if len(newidxs[1])==1 or (len(newidxs[1]) > 1 and all_same([i.key for i in newidxs[1]])): idx = idx.replace(src=(idx.src[0], newidxs[1][0]))
to_check = [(Xi, UOp(UOps.DEFINE_VAR, Xi.dtype, (), ("fake", 1, Xi.vmax))) for Xi in _get_chain(uop, BinaryOps.ADD)]
else:
new = UOp.define_var("fake", uop.dtype, uop.vmin if v[0] is None else v[0], uop.vmax if v[1] is None else v[1])
newidx = replace_uop(graph_rewrite(replace_uop(idx, uop, new), constant_folder), new, uop)
if newidx.key != idx.key: idx = newidx
# try checking the whole clause
to_check = [(uop, UOp.define_var("fake", uop.dtype, uop.vmin if v[0] is None else v[0], uop.vmax if v[1] is None else v[1]))]
newidxs:List[List[UOp]] = [[], []]
for X,newX in to_check:
newidx = replace_uop(graph_rewrite(replace_uop(idx, X, newX), constant_folder), newX, X)
newidxs[0].append(newidx.src[0])
newidxs[1].append(newidx.src[1])
# if every branch in to_check gives the same simplified output, we can rewrite the idx
if len(newidxs[0])==1 or (len(newidxs[0]) > 1 and all_same([i.key for i in newidxs[0]])): idx = idx.replace(src=(newidxs[0][0], idx.src[1]))
if len(newidxs[1])==1 or (len(newidxs[1]) > 1 and all_same([i.key for i in newidxs[1]])): idx = idx.replace(src=(idx.src[0], newidxs[1][0]))
# can drop valid if idx is out of bound when valid is False
drop_stmt = []
@@ -253,7 +253,9 @@ def simplify_valid_image_load(load:UOp, buf:UOp):
for i,b in zip(idx.src, (buf_dtype.shape[1], buf_dtype.shape[0])):
if is_increasing(i):
rw = graph_rewrite(replace_uop(i, X, X.const_like(test_value)), constant_folder)
if rw.vmin >= b or rw.vmax < 0: drop_stmt.append(stmt)
if rw.vmin >= b or rw.vmax < 0:
drop_stmt.append(stmt)
break
if drop_stmt or idx.key != start_idx.key:
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in _get_chain(valid, BinaryOps.AND) if s not in drop_stmt]) else None