From b9e2bc619e9b5cb26bdf370abe7aedca2e06cfc1 Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 22 Apr 2026 17:08:09 -0400 Subject: [PATCH] simplify bool.cast() != const (#15874) --- test/null/test_uop_symbolic.py | 13 +++++++++++++ tinygrad/uop/symbolic.py | 3 +++ 2 files changed, 16 insertions(+) diff --git a/test/null/test_uop_symbolic.py b/test/null/test_uop_symbolic.py index 51c4ebe714..19f25ce7cf 100644 --- a/test/null/test_uop_symbolic.py +++ b/test/null/test_uop_symbolic.py @@ -851,6 +851,19 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable((a+b+c*2<1).ne(True), 0, 1, "((((a+b)+c)<1)!=True)") self.helper_test_variable((a+b*2+c*4<1).ne(True), 0, 1, "((((a+b)+c)<1)!=True)") + def test_cast_bool_to_int_ne_const(self): + cond = Variable("a", 0, 3) < 2 + # CAST(bool -> int) != 0 -> cond + self.helper_test_variable(cond.cast(dtypes.int).ne(0), 0, 1, "(a<2)") + # CAST(bool -> int) != 1 -> !cond + self.helper_test_variable(cond.cast(dtypes.int).ne(1), 0, 1, "((a<2)!=True)") + # CAST(bool -> int) != c (c not in {0,1}) -> always True (CAST is 0 or 1) + self.helper_test_variable(cond.cast(dtypes.int).ne(2), 1, 1, "True") + self.helper_test_variable(cond.cast(dtypes.int).ne(-1), 1, 1, "True") + # CAST(bool -> weakint) folds too + self.helper_test_variable(cond.cast(dtypes.weakint).ne(0), 0, 1, "(a<2)") + self.helper_test_variable(cond.cast(dtypes.weakint).ne(1), 0, 1, "((a<2)!=True)") + def test_where_removal(self): cond = Variable("a", 0, 3) < 2 u1, u0 = cond.const_like(True), cond.const_like(False) diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index 4f8aeb358f..ad422a4115 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -91,6 +91,9 @@ symbolic_simple = propagate_invalid + PatternMatcher([ (UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x), (UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, True), UPat.const(dtypes.bool, False)), lambda x: x), (UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, False), UPat.const(dtypes.bool, True)), lambda x: x.logical_not()), + # CAST(bool -> int) != const — CAST(True)=1, CAST(False)=0, so fold based on const value + (UPat.var("x", dtype=dtypes.bool).cast(dtypes.ints+(dtypes.weakint,)) != UPat.cvar("c", vec=False), + lambda x,c: x if c.arg == 0 else x.logical_not() if c.arg == 1 else x.const_like(True)), (UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.weakint)).trunc(), lambda x: x), # ** zero folding ** (UPat.var("x") < UPat.var("x"), lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x < x -> False