diff --git a/test/unit/test_shapetracker.py b/test/unit/test_shapetracker.py index 9ebdd63b89..f2ec339483 100644 --- a/test/unit/test_shapetracker.py +++ b/test/unit/test_shapetracker.py @@ -154,7 +154,7 @@ class TestRealStrides(unittest.TestCase): View.create((1, 3, 22, 21), (0, 192, 16, 1), 0, ((0, 1), (0, 3), (0, 12), (0, 16))), View.create((3, 11, 7, 2, 3), (462, 21, 1, 231, 7), 0, None), )) - self.assertEqual(st.real_strides(), (132, 12, None, None, None)) + self.assertEqual(st.real_strides(), (132, None, None, None, None)) class TestRealSimplifies(unittest.TestCase): def tearDown(self): diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index 474dd83ebd..cf5c26854e 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -271,91 +271,6 @@ gep_pushing = PatternMatcher([ (UPat(Ops.WMMA, name="wmma").f(Ops.GEP, name="gep"), gep_through_wmma), ]) -# ******** we take a small aside to "simplify_valid" to rewrite "and" clauses (valids) ******** - -def parse_valid(valid:UOp) -> tuple[UOp, bool, int]: - # if it's X <= c, returns X, True, c - # if it's X >= c, returns X, False, c - - # (X < c).ne(True) -> X >= c - if valid.op is Ops.CMPNE and valid.src[1].op is Ops.CONST and valid.src[1].arg == 1 and \ - (s0:=valid.src[0]).op is Ops.CMPLT and dtypes.is_int(s0.src[0].dtype): return s0.src[0], False, int(s0.src[1].vmin) - # X < c -> X <= c-1 - if valid.op is Ops.CMPLT and dtypes.is_int(valid.src[0].dtype): return valid.src[0], True, int((valid.src[1]).vmax)-1 - raise ValueError(f"not able to parse {valid=}") - -def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None: - # return None if valid is always False, otherwise the simplified uop (might be the same as input) - - # first, parse valid into {expr: (lower_bound, upper_bound)} - bounds:defaultdict[UOp, list[ConstType|None]] = defaultdict(lambda: [None, None]) - for stmt in valid.split_uop(Ops.AND): - try: expr, is_upper, c = parse_valid(stmt) - except ValueError: return uop # give up if we cannot parse the valid - bounds[expr][int(is_upper)] = c - - # don't simplify any other gates, can lead to OOB, we substitute them back later - uop = uop.substitute((load_subs:={u: UOp(Ops.NOOP, arg=u) for u in uop.toposort() if u.op is Ops.INDEX})) - - # simplify uop given that valid is True - for expr,v in bounds.items(): - v0, v1 = (expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1]) - expr = expr.substitute(load_subs) # make sure expr appears in same form in the uop - # some expr has lower bound > upper bound -> valid is an empty set and we return None - if v0 > v1: return None - # whole node became a const - if v0 == v1: - uop = uop.substitute({expr:expr.const_like(v0)}).simplify() - continue - # every candidate is a set of constrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop - candidates = [] - if expr.op is Ops.ADD and v0 == 1 and all(u.op in GroupOp.Irreducible for u in expr.split_uop(Ops.ADD)): - # if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output - candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in expr.split_uop(Ops.ADD)]) - # try checking the whole clause - if expr in uop.toposort(): candidates.append([(expr, UOp.variable("fake", v0, v1, expr.dtype))]) - - for candidate in candidates: - # if every branch in candidate gives the same simplified uop, we can rewrite the uop - newuops = [uop.substitute({X:newX}).simplify().substitute({newX:X}).simplify() for X,newX in candidate] - if uop.op is Ops.VECTORIZE and len(uop.src) == 2: - if all_same([uops.src[0] for uops in newuops]): uop = uop.replace(src=(newuops[0].src[0], uop.src[1])) - if all_same([uops.src[1] for uops in newuops]): uop = uop.replace(src=(uop.src[0], newuops[0].src[1])) - elif all_same(newuops): uop = newuops[0] - - # put the loads back in - uop = uop.substitute({v:k for k,v in load_subs.items()}) - return uop - -def _valid_priority(v: UOp, valids:list[UOp]): - # we want valid that's in other valids' parents to be first, so it's more likely the other valids get simplified - try: return sum(-1 if parse_valid(v)[0] in other.toposort() else 0 for other in valids) - except ValueError: return 0 - -def simplify_valid(valid:UOp) -> UOp|None: - ret:list[UOp] = [] - something_changed = False - valids = list(valid.split_uop(Ops.AND)) - for stmt in sorted(valids, key=lambda v: _valid_priority(v, valids)): - # TODO: root cause this and test_simplify_valid_from_div - if stmt.op is Ops.CAST: return None - ret.append(newstmt if ret and (newstmt:=uop_given_valid(functools.reduce(operator.and_, ret), stmt)) is not None else stmt) - if ret[-1] is not stmt: something_changed = True - return functools.reduce(operator.and_, ret) if something_changed else None - -# ******** phase 3 is the complete symbolic, and deals with very complex things like loop rewriting and threefry transform ******** - -def reduce_mul_chain(r:UOp): - if r.arg not in {Ops.ADD, Ops.MAX}: return None - if r.dtype != r.src[0].dtype: return None - inside, outside = [], [] - for m in r.src[0].split_uop(Ops.MUL): - m_parents = m.toposort() - if all(r not in m_parents for r in r.src[1:]) and (r.arg != Ops.MAX or m.vmin >= 0): outside.append(m) - else: inside.append(m) - if len(outside) == 0: return None - return r.replace(src=(prod(inside) if len(inside) else r.src[0].const_like(1),)+r.src[1:])*prod(outside) - commutative = PatternMatcher([ # ** COMMUTATIVE flipping (only for index) ** # NOTE: this can break merging vector math by only flipping some of them @@ -367,8 +282,6 @@ symbolic = symbolic_simple+commutative+PatternMatcher([ (UPat.var("x") | (UPat.var("x") & UPat.var()), lambda x: x), # x|(x&y) -> x # TODO: make a more general or folder like simplify_valid (UPat.var("x", dtype=dtypes.bool) | UPat.var("x").logical_not(), lambda x: x.const_like(True)), # x|!x -> True - # simplify valid - (UPat(Ops.AND, name="valid"), simplify_valid), # ** combine terms ** (UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1) ((UPat.var("y") + UPat.var("x") * UPat.cvar("c0")) + UPat.var("x") * UPat.cvar("c1"), lambda x,y,c0,c1: y+x*(c0+c1)), @@ -462,10 +375,97 @@ symbolic_flat = symbolic+PatternMatcher([ ((UPat.var("x", dtypes.index) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c), ]) +# ******** we take a small aside to "simplify_valid" to rewrite valids ******** + +def parse_valid(valid:UOp) -> tuple[UOp, bool, int]: + # if it's X <= c, returns X, True, c + # if it's X >= c, returns X, False, c + + # (X < c).ne(True) -> X >= c + if valid.op is Ops.CMPNE and valid.src[1].op is Ops.CONST and valid.src[1].arg == 1 and \ + (s0:=valid.src[0]).op is Ops.CMPLT and dtypes.is_int(s0.src[0].dtype): return s0.src[0], False, int(s0.src[1].vmin) + # X < c -> X <= c-1 + if valid.op is Ops.CMPLT and dtypes.is_int(valid.src[0].dtype): return valid.src[0], True, int((valid.src[1]).vmax)-1 + raise ValueError(f"not able to parse {valid=}") + +def uop_given_valid(valid:UOp, uop:UOp) -> UOp|None: + # return None if valid is always False, otherwise the simplified uop (might be the same as input) + + # first, parse valid into {expr: (lower_bound, upper_bound)} + bounds:defaultdict[UOp, list[ConstType|None]] = defaultdict(lambda: [None, None]) + for stmt in valid.split_uop(Ops.AND): + try: expr, is_upper, c = parse_valid(stmt) + except ValueError: return uop # give up if we cannot parse the valid + bounds[expr][int(is_upper)] = c + + # don't simplify any other gates, can lead to OOB, we substitute them back later + uop = uop.substitute((load_subs:={u: UOp(Ops.NOOP, arg=u) for u in uop.toposort() if u.op is Ops.INDEX})) + + # simplify uop given that valid is True + for expr,v in bounds.items(): + v0, v1 = (expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1]) + expr = expr.substitute(load_subs) # make sure expr appears in same form in the uop + # some expr has lower bound > upper bound -> valid is an empty set and we return None + if v0 > v1: return None + # whole node became a const + if v0 == v1: + uop = uop.substitute({expr:expr.const_like(v0)}).simplify() + continue + # every candidate is a set of constrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop + candidates = [] + if expr.op is Ops.ADD and v0 == 1 and all(u.op in GroupOp.Irreducible for u in expr.split_uop(Ops.ADD)): + # if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output + candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in expr.split_uop(Ops.ADD)]) + # try checking the whole clause + if expr in uop.toposort(): candidates.append([(expr, UOp.variable("fake", v0, v1, expr.dtype))]) + + for candidate in candidates: + # if every branch in candidate gives the same simplified uop, we can rewrite the uop + newuops = [uop.substitute({X:newX}).simplify().substitute({newX:X}).simplify() for X,newX in candidate] + if uop.op is Ops.VECTORIZE and len(uop.src) == 2: + if all_same([uops.src[0] for uops in newuops]): uop = uop.replace(src=(newuops[0].src[0], uop.src[1])) + if all_same([uops.src[1] for uops in newuops]): uop = uop.replace(src=(uop.src[0], newuops[0].src[1])) + elif all_same(newuops): uop = newuops[0] + + # put the loads back in + uop = uop.substitute({v:k for k,v in load_subs.items()}) + return uop + +def _valid_priority(v: UOp, valids:list[UOp]): + # we want valid that's in other valids' parents to be first, so it's more likely the other valids get simplified + try: return sum(-1 if parse_valid(v)[0] in other.toposort() else 0 for other in valids) + except ValueError: return 0 + +def simplify_valid(valid:UOp) -> UOp|None: + ret:list[UOp] = [] + something_changed = False + valids = list(valid.split_uop(Ops.AND)) + for stmt in sorted(valids, key=lambda v: _valid_priority(v, valids)): + # TODO: root cause this and test_simplify_valid_from_div + if stmt.op is Ops.CAST: return None + ret.append(newstmt if ret and (newstmt:=uop_given_valid(functools.reduce(operator.and_, ret), stmt)) is not None else stmt) + if ret[-1] is not stmt: something_changed = True + return functools.reduce(operator.and_, ret) if something_changed else None + +# ******** phase 3 is the complete symbolic, and deals with very complex things like loop rewriting and threefry transform ******** + +def reduce_mul_chain(r:UOp): + if r.arg not in {Ops.ADD, Ops.MAX}: return None + if r.dtype != r.src[0].dtype: return None + inside, outside = [], [] + for m in r.src[0].split_uop(Ops.MUL): + m_parents = m.toposort() + if all(r not in m_parents for r in r.src[1:]) and (r.arg != Ops.MAX or m.vmin >= 0): outside.append(m) + else: inside.append(m) + if len(outside) == 0: return None + return r.replace(src=(prod(inside) if len(inside) else r.src[0].const_like(1),)+r.src[1:])*prod(outside) + # this is symbolic 2.0 REMOVE_FROM_SINK = {Ops.SINK, Ops.UNROLL, Ops.PTRCAT, Ops.CAT, Ops.NOOP} REMOVE_FROM_BARRIER = {Ops.VECTORIZE, Ops.SINK, Ops.CAT, Ops.PTRCAT, Ops.NOOP} sym = symbolic_flat+PatternMatcher([ + # simplify valid + (UPat(Ops.AND, name="valid"), simplify_valid), # LOAD/STORE -> NOOP (UPat.var('x').store(UPat.var('x').load(), allow_any_len=True), lambda x: None if x.dtype.addrspace != AddrSpace.REG else x.src[0].src[0]), (UPat(Ops.LOAD, src=(UPat.cvar('c'))), lambda c: c),