diff --git a/test/unit/test_uop_resolve.py b/test/unit/test_uop_resolve.py index d117e56ecd..dbb0692c35 100644 --- a/test/unit/test_uop_resolve.py +++ b/test/unit/test_uop_resolve.py @@ -1,6 +1,6 @@ import unittest from tinygrad.dtype import dtypes -from tinygrad.ops import UOp +from tinygrad.ops import UOp, resolve class TestUOpResolve(unittest.TestCase): def test_simple_int(self): @@ -39,6 +39,14 @@ class TestUOpResolve(unittest.TestCase): u = UOp.const(dtypes.int, 4) > 7 self.assertFalse(u) + def test_ambiguous_less_than(self): + u = UOp.define_var("i", dtypes.pyint, 1, 10) + self.assertTrue(resolve(u < 4)) + self.assertFalse(resolve(u < 4, False)) + self.assertTrue(resolve(u < 11, False)) + self.assertFalse(resolve(u < -1, False)) + self.assertFalse(resolve(u < -1, True)) + def test_float_direct(self): u = UOp.const(dtypes.float, 4.5) + 7 self.assertEqual(float(u), 11.5) diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index aee854e85a..16ba72358b 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -557,7 +557,7 @@ class Kernel: if (s:=self.full_unupcasted_shape[-1]) <= 32 and isinstance(s, int): # NOTE: cannot loop unroll symbolic axis self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0)) # if it's small, upcast a second reduce dimension too - if self.first_reduce < self.first_upcast and s <= 3 and (s2:=self.full_unupcasted_shape[-1]) <= 3 and isinstance(s2, int): + if self.first_reduce < self.first_upcast and s <= 3 and isinstance(s2:=self.full_unupcasted_shape[-1], int) and s2 <= 3: self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0)) else: for splits in [4]: diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 5f077c5ccb..d9adb6d3f0 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -452,8 +452,6 @@ sym = simple_pm+PatternMatcher([ # mod folding (UPat.var("x") % UPat.cvar("c", vec=False), lambda x,c: newx if 0 < c.arg and (newx:=mod_folding(x,c.arg)) is not None else None), # ** combine terms (opinionated) ** - (UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1) - (UPat.var("x") + UPat.var("x") * UPat.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1) (UPat.var("x") + UPat.var("x"), lambda x: x*2), # (x+x)-> x*2 ((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3)), # (x/x2)/x3 -> x/(x2*x3) (-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 354f78dd75..d02690c284 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -146,6 +146,12 @@ BUFFER_UOPS = {UOps.LOAD, UOps.STORE, UOps.VALID} COMMUTATIVE = {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR} END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.ASSIGN, UOps.ENDRANGE)} +# With True as the default, this matches the old symbolic behavior +# python3 -c 'from tinygrad.shape.symbolic import Variable; print(bool(Variable("a", 1, 10) < 4))' -> True +def resolve(x, default:bool=True): + try: return bool(x) + except ValueError: return default + class UOp(MathTrait): __slots__ = ["op", "dtype", "src", "arg"] def __init__(self, op: UOps, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None): @@ -175,9 +181,10 @@ class UOp(MathTrait): def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))") def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is UOps.REDUCE_AXIS else self.arg # *** uop evaluation *** + def simplify(self): return graph_rewrite(self, simple_pm) def _eval(self, dtype, expected_type) -> ConstType: assert self.dtype in dtype, f"eval with wrong dtype {self}" - simple_self = graph_rewrite(self, simple_pm) + simple_self = self.simplify() vmin, vmax = simple_self._min_max if vmin != vmax: raise ValueError(f"eval failed to be a single number, range is {vmin} to {vmax} in {simple_self}") assert type(vmin) is expected_type, f"vmin is wrong dtype {vmin} != {expected_type}" @@ -704,14 +711,19 @@ simple_pm = PatternMatcher([ (UPat.var("x") // 1, lambda x: x), # x//1 -> x (UPat.var("x") // -1, lambda x: -x), # x//-1 -> -x (UPat.var("x") / UPat.var("x"), lambda x: x.const_like(1)), # x/x -> 1 - (UPat.var("x") < UPat.var("x"), lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x < x -> False ((UPat.var("x") * UPat.var("x2")) / UPat.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x (UPat.var("x", dtype=dtypes.bool) & UPat.cvar("c", vec=False), lambda x,c: x if c.arg else c), (UPat.var("x", dtype=dtypes.bool) | UPat.cvar("c", vec=False), lambda x,c: c if c.arg else x), + (UPat.var("x").max(UPat.var("x")), lambda x: x), ((UPat.var("x") & UPat.var("x")), lambda x: x), ((UPat.var("x") | UPat.var("x")), lambda x: x), (UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x), + # ** combine terms ** + (UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1) + (UPat.var("x") + UPat.var("x") * UPat.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1) # ** zero folding ** + (UPat.var("x") < UPat.var("x"), lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x < x -> False + (UPat.var("x", dtype=dtypes.ints) != UPat.var("x"), lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x != x -> False (only ints) # x*0 -> 0 or 0*x -> 0 # if x is nan or inf it should render the nan value. # NOTE: this can be wrong for loaded NaN