mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
changes for symbolic (#6844)
* changes for symbolic * only for ints * check int first
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import unittest
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import UOp
|
||||
from tinygrad.ops import UOp, resolve
|
||||
|
||||
class TestUOpResolve(unittest.TestCase):
|
||||
def test_simple_int(self):
|
||||
@@ -39,6 +39,14 @@ class TestUOpResolve(unittest.TestCase):
|
||||
u = UOp.const(dtypes.int, 4) > 7
|
||||
self.assertFalse(u)
|
||||
|
||||
def test_ambiguous_less_than(self):
|
||||
u = UOp.define_var("i", dtypes.pyint, 1, 10)
|
||||
self.assertTrue(resolve(u < 4))
|
||||
self.assertFalse(resolve(u < 4, False))
|
||||
self.assertTrue(resolve(u < 11, False))
|
||||
self.assertFalse(resolve(u < -1, False))
|
||||
self.assertFalse(resolve(u < -1, True))
|
||||
|
||||
def test_float_direct(self):
|
||||
u = UOp.const(dtypes.float, 4.5) + 7
|
||||
self.assertEqual(float(u), 11.5)
|
||||
|
||||
@@ -557,7 +557,7 @@ class Kernel:
|
||||
if (s:=self.full_unupcasted_shape[-1]) <= 32 and isinstance(s, int): # NOTE: cannot loop unroll symbolic axis
|
||||
self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0))
|
||||
# if it's small, upcast a second reduce dimension too
|
||||
if self.first_reduce < self.first_upcast and s <= 3 and (s2:=self.full_unupcasted_shape[-1]) <= 3 and isinstance(s2, int):
|
||||
if self.first_reduce < self.first_upcast and s <= 3 and isinstance(s2:=self.full_unupcasted_shape[-1], int) and s2 <= 3:
|
||||
self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0))
|
||||
else:
|
||||
for splits in [4]:
|
||||
|
||||
@@ -452,8 +452,6 @@ sym = simple_pm+PatternMatcher([
|
||||
# mod folding
|
||||
(UPat.var("x") % UPat.cvar("c", vec=False), lambda x,c: newx if 0 < c.arg and (newx:=mod_folding(x,c.arg)) is not None else None),
|
||||
# ** combine terms (opinionated) **
|
||||
(UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1)
|
||||
(UPat.var("x") + UPat.var("x") * UPat.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1)
|
||||
(UPat.var("x") + UPat.var("x"), lambda x: x*2), # (x+x)-> x*2
|
||||
((UPat.var("x") / UPat.var("x2")) / UPat.var("x3"), lambda x,x2,x3: x/(x2*x3)), # (x/x2)/x3 -> x/(x2*x3)
|
||||
(-1 * (UPat.var("x") + UPat.var("y")), lambda x,y: (-x)+(-y)), # -(x+y) -> -x + -y
|
||||
|
||||
@@ -146,6 +146,12 @@ BUFFER_UOPS = {UOps.LOAD, UOps.STORE, UOps.VALID}
|
||||
COMMUTATIVE = {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR}
|
||||
END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.ASSIGN, UOps.ENDRANGE)}
|
||||
|
||||
# With True as the default, this matches the old symbolic behavior
|
||||
# python3 -c 'from tinygrad.shape.symbolic import Variable; print(bool(Variable("a", 1, 10) < 4))' -> True
|
||||
def resolve(x, default:bool=True):
|
||||
try: return bool(x)
|
||||
except ValueError: return default
|
||||
|
||||
class UOp(MathTrait):
|
||||
__slots__ = ["op", "dtype", "src", "arg"]
|
||||
def __init__(self, op: UOps, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None):
|
||||
@@ -175,9 +181,10 @@ class UOp(MathTrait):
|
||||
def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))")
|
||||
def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is UOps.REDUCE_AXIS else self.arg
|
||||
# *** uop evaluation ***
|
||||
def simplify(self): return graph_rewrite(self, simple_pm)
|
||||
def _eval(self, dtype, expected_type) -> ConstType:
|
||||
assert self.dtype in dtype, f"eval with wrong dtype {self}"
|
||||
simple_self = graph_rewrite(self, simple_pm)
|
||||
simple_self = self.simplify()
|
||||
vmin, vmax = simple_self._min_max
|
||||
if vmin != vmax: raise ValueError(f"eval failed to be a single number, range is {vmin} to {vmax} in {simple_self}")
|
||||
assert type(vmin) is expected_type, f"vmin is wrong dtype {vmin} != {expected_type}"
|
||||
@@ -704,14 +711,19 @@ simple_pm = PatternMatcher([
|
||||
(UPat.var("x") // 1, lambda x: x), # x//1 -> x
|
||||
(UPat.var("x") // -1, lambda x: -x), # x//-1 -> -x
|
||||
(UPat.var("x") / UPat.var("x"), lambda x: x.const_like(1)), # x/x -> 1
|
||||
(UPat.var("x") < UPat.var("x"), lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x < x -> False
|
||||
((UPat.var("x") * UPat.var("x2")) / UPat.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x
|
||||
(UPat.var("x", dtype=dtypes.bool) & UPat.cvar("c", vec=False), lambda x,c: x if c.arg else c),
|
||||
(UPat.var("x", dtype=dtypes.bool) | UPat.cvar("c", vec=False), lambda x,c: c if c.arg else x),
|
||||
(UPat.var("x").max(UPat.var("x")), lambda x: x),
|
||||
((UPat.var("x") & UPat.var("x")), lambda x: x),
|
||||
((UPat.var("x") | UPat.var("x")), lambda x: x),
|
||||
(UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x),
|
||||
# ** combine terms **
|
||||
(UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1)
|
||||
(UPat.var("x") + UPat.var("x") * UPat.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1)
|
||||
# ** zero folding **
|
||||
(UPat.var("x") < UPat.var("x"), lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x < x -> False
|
||||
(UPat.var("x", dtype=dtypes.ints) != UPat.var("x"), lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x != x -> False (only ints)
|
||||
# x*0 -> 0 or 0*x -> 0
|
||||
# if x is nan or inf it should render the nan value.
|
||||
# NOTE: this can be wrong for loaded NaN
|
||||
|
||||
Reference in New Issue
Block a user