changes for symbolic (#6844)

* changes for symbolic

* only for ints

* check int first
This commit is contained in:
George Hotz
2024-10-02 12:57:16 +08:00
committed by GitHub
parent 1735f8ef1c
commit be12409b51
4 changed files with 24 additions and 6 deletions

View File

@@ -1,6 +1,6 @@
import unittest
from tinygrad.dtype import dtypes
from tinygrad.ops import UOp
from tinygrad.ops import UOp, resolve
class TestUOpResolve(unittest.TestCase):
def test_simple_int(self):
@@ -39,6 +39,14 @@ class TestUOpResolve(unittest.TestCase):
u = UOp.const(dtypes.int, 4) > 7
self.assertFalse(u)
def test_ambiguous_less_than(self):
u = UOp.define_var("i", dtypes.pyint, 1, 10)
self.assertTrue(resolve(u < 4))
self.assertFalse(resolve(u < 4, False))
self.assertTrue(resolve(u < 11, False))
self.assertFalse(resolve(u < -1, False))
self.assertFalse(resolve(u < -1, True))
def test_float_direct(self):
u = UOp.const(dtypes.float, 4.5) + 7
self.assertEqual(float(u), 11.5)

View File

@@ -557,7 +557,7 @@ class Kernel:
if (s:=self.full_unupcasted_shape[-1]) <= 32 and isinstance(s, int): # NOTE: cannot loop unroll symbolic axis
self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0))
# if it's small, upcast a second reduce dimension too
if self.first_reduce < self.first_upcast and s <= 3 and (s2:=self.full_unupcasted_shape[-1]) <= 3 and isinstance(s2, int):
if self.first_reduce < self.first_upcast and s <= 3 and isinstance(s2:=self.full_unupcasted_shape[-1], int) and s2 <= 3:
self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0))
else:
for splits in [4]:

View File

@@ -452,8 +452,6 @@ sym = simple_pm+PatternMatcher([
# mod folding
(UPat.var("x") % UPat.cvar("c", vec=False), lambda x,c: newx if 0 < c.arg and (newx:=mod_folding(x,c.arg)) is not None else None),
# ** combine terms (opinionated) **
(UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1)
(UPat.var("x") + UPat.var("x") * UPat.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1)
(UPat.var("x") + UPat.var("x"), lambda x: x*2), # (x+x)-> x*2
((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3)), # (x/x2)/x3 -> x/(x2*x3)
(-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y

View File

@@ -146,6 +146,12 @@ BUFFER_UOPS = {UOps.LOAD, UOps.STORE, UOps.VALID}
COMMUTATIVE = {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR}
END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.ASSIGN, UOps.ENDRANGE)}
# With True as the default, this matches the old symbolic behavior
# python3 -c 'from tinygrad.shape.symbolic import Variable; print(bool(Variable("a", 1, 10) < 4))' -> True
def resolve(x, default:bool=True):
try: return bool(x)
except ValueError: return default
class UOp(MathTrait):
__slots__ = ["op", "dtype", "src", "arg"]
def __init__(self, op: UOps, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None):
@@ -175,9 +181,10 @@ class UOp(MathTrait):
def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))")
def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is UOps.REDUCE_AXIS else self.arg
# *** uop evaluation ***
def simplify(self): return graph_rewrite(self, simple_pm)
def _eval(self, dtype, expected_type) -> ConstType:
assert self.dtype in dtype, f"eval with wrong dtype {self}"
simple_self = graph_rewrite(self, simple_pm)
simple_self = self.simplify()
vmin, vmax = simple_self._min_max
if vmin != vmax: raise ValueError(f"eval failed to be a single number, range is {vmin} to {vmax} in {simple_self}")
assert type(vmin) is expected_type, f"vmin is wrong dtype {vmin} != {expected_type}"
@@ -704,14 +711,19 @@ simple_pm = PatternMatcher([
(UPat.var("x") // 1, lambda x: x), # x//1 -> x
(UPat.var("x") // -1, lambda x: -x), # x//-1 -> -x
(UPat.var("x") / UPat.var("x"), lambda x: x.const_like(1)), # x/x -> 1
(UPat.var("x") < UPat.var("x"), lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x < x -> False
((UPat.var("x") * UPat.var("x2")) / UPat.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x
(UPat.var("x", dtype=dtypes.bool) & UPat.cvar("c", vec=False), lambda x,c: x if c.arg else c),
(UPat.var("x", dtype=dtypes.bool) | UPat.cvar("c", vec=False), lambda x,c: c if c.arg else x),
(UPat.var("x").max(UPat.var("x")), lambda x: x),
((UPat.var("x") & UPat.var("x")), lambda x: x),
((UPat.var("x") | UPat.var("x")), lambda x: x),
(UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x),
# ** combine terms **
(UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1)
(UPat.var("x") + UPat.var("x") * UPat.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1)
# ** zero folding **
(UPat.var("x") < UPat.var("x"), lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x < x -> False
(UPat.var("x", dtype=dtypes.ints) != UPat.var("x"), lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x != x -> False (only ints)
# x*0 -> 0 or 0*x -> 0
# if x is nan or inf it should render the nan value.
# NOTE: this can be wrong for loaded NaN