simplify bool.cast() != const (#15874)

This commit is contained in:
chenyu
2026-04-22 17:08:09 -04:00
committed by GitHub
parent 2041945f4b
commit b9e2bc619e
2 changed files with 16 additions and 0 deletions

View File

@@ -851,6 +851,19 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable((a+b+c*2<1).ne(True), 0, 1, "((((a+b)+c)<1)!=True)")
self.helper_test_variable((a+b*2+c*4<1).ne(True), 0, 1, "((((a+b)+c)<1)!=True)")
def test_cast_bool_to_int_ne_const(self):
cond = Variable("a", 0, 3) < 2
# CAST(bool -> int) != 0 -> cond
self.helper_test_variable(cond.cast(dtypes.int).ne(0), 0, 1, "(a<2)")
# CAST(bool -> int) != 1 -> !cond
self.helper_test_variable(cond.cast(dtypes.int).ne(1), 0, 1, "((a<2)!=True)")
# CAST(bool -> int) != c (c not in {0,1}) -> always True (CAST is 0 or 1)
self.helper_test_variable(cond.cast(dtypes.int).ne(2), 1, 1, "True")
self.helper_test_variable(cond.cast(dtypes.int).ne(-1), 1, 1, "True")
# CAST(bool -> weakint) folds too
self.helper_test_variable(cond.cast(dtypes.weakint).ne(0), 0, 1, "(a<2)")
self.helper_test_variable(cond.cast(dtypes.weakint).ne(1), 0, 1, "((a<2)!=True)")
def test_where_removal(self):
cond = Variable("a", 0, 3) < 2
u1, u0 = cond.const_like(True), cond.const_like(False)

View File

@@ -91,6 +91,9 @@ symbolic_simple = propagate_invalid + PatternMatcher([
(UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x),
(UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, True), UPat.const(dtypes.bool, False)), lambda x: x),
(UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, False), UPat.const(dtypes.bool, True)), lambda x: x.logical_not()),
# CAST(bool -> int) != const — CAST(True)=1, CAST(False)=0, so fold based on const value
(UPat.var("x", dtype=dtypes.bool).cast(dtypes.ints+(dtypes.weakint,)) != UPat.cvar("c", vec=False),
lambda x,c: x if c.arg == 0 else x.logical_not() if c.arg == 1 else x.const_like(True)),
(UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.weakint)).trunc(), lambda x: x),
# ** zero folding **
(UPat.var("x") < UPat.var("x"), lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x < x -> False