diff --git a/tinygrad/codegen/simplify.py b/tinygrad/codegen/simplify.py index 1433da7f37..a131023c6c 100644 --- a/tinygrad/codegen/simplify.py +++ b/tinygrad/codegen/simplify.py @@ -99,14 +99,10 @@ pm_reduce_collapse = PatternMatcher([ # MUL casted bool ((UPat.var("x") * UPat.var("gate", dtype=dtypes.bool).cast().or_broadcasted(name="b")), lambda x,gate,b=None: gate.broadcast(x.dtype.count).where(x, 0) if b is not None else gate.where(x, 0)), - # WHERE on LOAD (works on max too) - (UPat.var("gate").where(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load(), 0).reduce(arg=Ops.ADD, allow_any_len=True), - lambda buf,idx,gate: buf.index(idx.valid(gate)).load()), - (UPat.var("gate").where(0, UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load()).reduce(arg=Ops.ADD, allow_any_len=True), - lambda buf,idx,gate: buf.index(idx.valid(gate.logical_not())).load()), - # INDEX on RANGE / gated RANGE - (UPat.var("buf").index(UPat.var("idx").eq(UPat(Ops.RANGE, name="r").or_casted()).where(UPat.var("expr"), invalid_pat)), - lambda buf,r,idx,expr,i: buf.index(expr.substitute({r:idx.cast(r.dtype)}).valid((idx.cast(r.dtype) >= 0) & (idx.cast(r.dtype) < r.src[0])))), + # reduce on gated load becomes can substitute the range and remove the reduce + (UPat.var("buf").index(UPat.var("idx").eq(UPat(Ops.RANGE, name="r").or_casted()).where(UPat.var("expr"), invalid_pat)).load() + .reduce(arg=Ops.ADD, allow_any_len=True), lambda buf,r,idx,expr,i: + buf.index(expr.substitute({r:idx.cast(r.dtype)}).valid((idx.cast(r.dtype) >= 0) & (idx.cast(r.dtype) < r.src[0]))).load()), # AND on WHERE ((UPat.any(UPat(Ops.DEFINE_VAR, name="x"), UPat(Ops.DEFINE_VAR).gep(name="x")) & UPat.var("y")) \ .where(UPat.cvar("c"), 0).reduce(arg=Ops.ADD, allow_any_len=True, name="r"), diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 079a30b7c7..01a83de380 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -344,7 +344,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): return ret.reshape(tuple([x if i not in axis else 1 for i,x in enumerate(self.shape)])) @staticmethod def invalid(count=1): return UOp(Ops.CONST, dtypes.index.vec(count), src=(), arg=Invalid) - def valid(self, cond): return cond.where(self, UOp.invalid(self.dtype.count)) + def valid(self, cond): return self if cond.op is Ops.WHERE and cond.arg else cond.where(self, UOp.invalid(self.dtype.count)) def get_idx(self) -> UOp: assert self.dtype.scalar() is dtypes.index, "Can only call get_idx on index dtype" return self.src[1] if self.op is Ops.WHERE and self.src[2].arg is Invalid else self diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index 5111d61ab6..7b94be3fcc 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -473,6 +473,17 @@ def drop_and_clauses(cond:UOp, x:UOp, i:UOp) -> UOp|None: if not (dropped_clauses:=[c for c in cond.split_uop(Ops.AND) if not any(r in x.ranges for r in c.ranges)]): return None return functools.reduce(operator.and_, [c for c in cond.split_uop(Ops.AND) if c not in dropped_clauses], UOp.const(dtypes.bool, True)).where(x, i) pm_drop_and_clauses = PatternMatcher([(UPat.var("cond").where(UPat.var("x", dtype=dtypes.index), invalid_pat), drop_and_clauses)]) +def where_on_load(l, c1, buf, x): + c2 = x.get_valid() + duplicate_clauses = [c for c in c1.split_uop(Ops.AND) if c in c2.split_uop(Ops.AND)] + # we move the condition from the where to the load _as long as_ the condtition doesn't have some range that would place it inside of a new range + # also no data dependent loads! + moved_clauses = [c for c in c1.split_uop(Ops.AND) if c not in duplicate_clauses and all(r in x.ranges for r in c.ranges) + and not c.op_in_backward_slice_with_self(Ops.LOAD)] + if not (removed:=moved_clauses+duplicate_clauses): return None + # aditionally we can drop the clause on the where if it already exists in the load + remaining_clause = functools.reduce(operator.and_, [c for c in c1.split_uop(Ops.AND) if c not in removed], UOp.const(dtypes.bool, True)) + return remaining_clause.where(UOp.load(buf.index(x.get_idx().valid(functools.reduce(operator.and_, moved_clauses, c2)), *l.src[1:])), 0) pm_simplify_valid = PatternMatcher([ # simplify valid @@ -518,8 +529,9 @@ sym = symbolic_flat+pm_simplify_valid+PatternMatcher([ (UPat((Ops.LOAD, Ops.STORE), src=(UPat().index(UPat.const(dtypes.index, Invalid)).or_casted(),), allow_any_len=True, name="x"), lambda x: UOp(Ops.NOOP) if x.op is Ops.STORE else x.const_like(0)), # invalid store does nothing. invalid load produces 0 # # Where after gated load becomes alt value, TODO: this is sort of duplicated with rules in devectorizer - (UPat.var("c1").where(UPat(Ops.LOAD, src=(UPat().index(UPat.var("c2").where(UPat(), invalid_pat)).or_casted(),), name="l"), 0), - lambda c1,c2,l,i: l.replace(src=(l.src[0],)+l.src[1:]) if all(c in list(c2.split_uop(Ops.AND)) for c in c1.split_uop(Ops.AND)) else None), + (UPat.var("c1").where(UPat(Ops.LOAD, src=(UPat.var("buf").index(UPat.var("x")),), name="l"), 0), where_on_load), + (UPat.var("c1").where(0, UPat(Ops.LOAD, src=(UPat.var("buf").index(UPat.var("x")),), name="l")), + lambda l,c1,buf,x: where_on_load(l,c1.logical_not(),buf,x)), # remove VECTORIZE from SINK/BARRIER. TODO: SINK/BARRIER are really the same thing at GLOBAL/LOCAL levels (UPat(Ops.BARRIER, name="root"), lambda root: UOp(Ops.BARRIER, root.dtype, tuple(flatten(x.src if x.op in REMOVE_FROM_BARRIER else (x,) for x in root.src)), root.arg)