From a6078c099fb64a444d5ced5efce05ab9573d7534 Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 24 Sep 2024 03:35:39 -0400 Subject: [PATCH] simpler idx rewrite structure in simplify_valid_image_load (#6704) express valid into things to check when rewriting idx. it's the same for single clause or a simplex [run_process_replay] --- tinygrad/codegen/uopgraph.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 7fc6a9afe7..2c9351ed65 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -212,26 +212,26 @@ def simplify_valid_image_load(load:UOp, buf:UOp): # simplify idx given that valid is True for uop,v in bounds.items(): - # some expr has lower bound > upper bound -> valid is an empty set + # some expr has lower bound > upper bound -> valid is an empty set and we return early if v[0] is not None and v[1] is not None and v[0] > v[1]: return UOp(UOps.LOAD, load.dtype, (buf, idx, invalid_val, valid.const_like(False))) if uop.op is UOps.ALU and uop.arg is BinaryOps.ADD and all(is_irreducible(u) and u.vmin == 0 for u in _get_chain(uop, BinaryOps.ADD)): # if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output - newidxs: List[List[UOp]] = [[], []] - for variable in _get_chain(uop, BinaryOps.ADD): - new = UOp(UOps.DEFINE_VAR, variable.dtype, (), ("fake", 1, variable.vmax)) - newidx = replace_uop(graph_rewrite(replace_uop(idx, variable, new), constant_folder), new, variable) - newidxs[0].append(newidx.src[0]) - newidxs[1].append(newidx.src[1]) - - if len(newidxs[0])==1 or (len(newidxs[0]) > 1 and all_same([i.key for i in newidxs[0]])): idx = idx.replace(src=(newidxs[0][0], idx.src[1])) - if len(newidxs[1])==1 or (len(newidxs[1]) > 1 and all_same([i.key for i in newidxs[1]])): idx = idx.replace(src=(idx.src[0], newidxs[1][0])) - + to_check = [(Xi, UOp(UOps.DEFINE_VAR, Xi.dtype, (), ("fake", 1, Xi.vmax))) for Xi in _get_chain(uop, BinaryOps.ADD)] else: - new = UOp.define_var("fake", uop.dtype, uop.vmin if v[0] is None else v[0], uop.vmax if v[1] is None else v[1]) - newidx = replace_uop(graph_rewrite(replace_uop(idx, uop, new), constant_folder), new, uop) - if newidx.key != idx.key: idx = newidx + # try checking the whole clause + to_check = [(uop, UOp.define_var("fake", uop.dtype, uop.vmin if v[0] is None else v[0], uop.vmax if v[1] is None else v[1]))] + + newidxs:List[List[UOp]] = [[], []] + for X,newX in to_check: + newidx = replace_uop(graph_rewrite(replace_uop(idx, X, newX), constant_folder), newX, X) + newidxs[0].append(newidx.src[0]) + newidxs[1].append(newidx.src[1]) + + # if every branch in to_check gives the same simplified output, we can rewrite the idx + if len(newidxs[0])==1 or (len(newidxs[0]) > 1 and all_same([i.key for i in newidxs[0]])): idx = idx.replace(src=(newidxs[0][0], idx.src[1])) + if len(newidxs[1])==1 or (len(newidxs[1]) > 1 and all_same([i.key for i in newidxs[1]])): idx = idx.replace(src=(idx.src[0], newidxs[1][0])) # can drop valid if idx is out of bound when valid is False drop_stmt = [] @@ -253,7 +253,9 @@ def simplify_valid_image_load(load:UOp, buf:UOp): for i,b in zip(idx.src, (buf_dtype.shape[1], buf_dtype.shape[0])): if is_increasing(i): rw = graph_rewrite(replace_uop(i, X, X.const_like(test_value)), constant_folder) - if rw.vmin >= b or rw.vmax < 0: drop_stmt.append(stmt) + if rw.vmin >= b or rw.vmax < 0: + drop_stmt.append(stmt) + break if drop_stmt or idx.key != start_idx.key: new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in _get_chain(valid, BinaryOps.AND) if s not in drop_stmt]) else None