mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
simplify bool.cast() != const (#15874)
This commit is contained in:
@@ -851,6 +851,19 @@ class TestSymbolic(unittest.TestCase):
|
||||
self.helper_test_variable((a+b+c*2<1).ne(True), 0, 1, "((((a+b)+c)<1)!=True)")
|
||||
self.helper_test_variable((a+b*2+c*4<1).ne(True), 0, 1, "((((a+b)+c)<1)!=True)")
|
||||
|
||||
def test_cast_bool_to_int_ne_const(self):
|
||||
cond = Variable("a", 0, 3) < 2
|
||||
# CAST(bool -> int) != 0 -> cond
|
||||
self.helper_test_variable(cond.cast(dtypes.int).ne(0), 0, 1, "(a<2)")
|
||||
# CAST(bool -> int) != 1 -> !cond
|
||||
self.helper_test_variable(cond.cast(dtypes.int).ne(1), 0, 1, "((a<2)!=True)")
|
||||
# CAST(bool -> int) != c (c not in {0,1}) -> always True (CAST is 0 or 1)
|
||||
self.helper_test_variable(cond.cast(dtypes.int).ne(2), 1, 1, "True")
|
||||
self.helper_test_variable(cond.cast(dtypes.int).ne(-1), 1, 1, "True")
|
||||
# CAST(bool -> weakint) folds too
|
||||
self.helper_test_variable(cond.cast(dtypes.weakint).ne(0), 0, 1, "(a<2)")
|
||||
self.helper_test_variable(cond.cast(dtypes.weakint).ne(1), 0, 1, "((a<2)!=True)")
|
||||
|
||||
def test_where_removal(self):
|
||||
cond = Variable("a", 0, 3) < 2
|
||||
u1, u0 = cond.const_like(True), cond.const_like(False)
|
||||
|
||||
@@ -91,6 +91,9 @@ symbolic_simple = propagate_invalid + PatternMatcher([
|
||||
(UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x),
|
||||
(UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, True), UPat.const(dtypes.bool, False)), lambda x: x),
|
||||
(UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, False), UPat.const(dtypes.bool, True)), lambda x: x.logical_not()),
|
||||
# CAST(bool -> int) != const — CAST(True)=1, CAST(False)=0, so fold based on const value
|
||||
(UPat.var("x", dtype=dtypes.bool).cast(dtypes.ints+(dtypes.weakint,)) != UPat.cvar("c", vec=False),
|
||||
lambda x,c: x if c.arg == 0 else x.logical_not() if c.arg == 1 else x.const_like(True)),
|
||||
(UPat.var("x", dtype=dtypes.ints+(dtypes.bool, dtypes.weakint)).trunc(), lambda x: x),
|
||||
# ** zero folding **
|
||||
(UPat.var("x") < UPat.var("x"), lambda x: x.const_like(False).cast(dtypes.bool.vec(x.dtype.count))), # x < x -> False
|
||||
|
||||
Reference in New Issue
Block a user