diff --git a/test/unit/test_simplify_valid_idx.py b/test/unit/test_simplify_valid_idx.py index 9ca3c58538..9093f4d037 100644 --- a/test/unit/test_simplify_valid_idx.py +++ b/test/unit/test_simplify_valid_idx.py @@ -124,6 +124,20 @@ class TestValidIdxSimplification(unittest.TestCase): "(((ridx0*2)+(ridx3*-1))+1)", "(ridx2<1)") + def test_valid_becomes_const1(self): + # from DSP mobilenetv2 + ridx0 = Range(0, 30) + ridx1 = Range(1, 7) + ridx2 = Range(2, 2) + alu11 = (ridx1+ridx2) + alu15 = ((alu11+1)//7) + idx = (alu15*-31)+(((((alu11+218)//224)+ridx0)%30)*1568) + valid = (ridx2<1)&(ridx1<6) + load = get_gated_load_uop(valid, idx) + self.check(load, + "(ridx0*1568)", + "((ridx2<1)&(ridx1<6))") + class TestImageSimplification(unittest.TestCase): def check(self, load, svalid, sidx0, sidx1): load = full_graph_rewrite(load.sink()).src[0] diff --git a/tinygrad/codegen/symbolic.py b/tinygrad/codegen/symbolic.py index ab4c83b914..1c2d01e4ba 100644 --- a/tinygrad/codegen/symbolic.py +++ b/tinygrad/codegen/symbolic.py @@ -275,17 +275,20 @@ def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None: # simplify uop given that valid is True for expr,v in bounds.items(): + v0, v1 = (expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1]) # some expr has lower bound > upper bound -> valid is an empty set and we return None - if v[0] is not None and v[1] is not None and v[0] > v[1]: return None - + if v0 > v1: return None + # whole node became a const + if v0 == v1: + uop = uop.substitute({expr:expr.const_like(v0)}).simplify() + continue # every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop candidates = [] - if expr.op is Ops.ADD and v[0] == 1 and all(u.op in GroupOp.Irreducible for u in split_uop(expr, Ops.ADD)): + if expr.op is Ops.ADD and v0 == 1 and all(u.op in GroupOp.Irreducible for u in split_uop(expr, Ops.ADD)): # if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, Ops.ADD)]) # try checking the whole clause - if expr in uop.toposort: - candidates.append([(expr, UOp.variable("fake", expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1], expr.dtype))]) + if expr in uop.toposort: candidates.append([(expr, UOp.variable("fake", v0, v1, expr.dtype))]) for candidate in candidates: # if every branch in candidate gives the same simplified uop, we can rewrite the uop