mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
@@ -66,7 +66,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
b_bufs = [u.uop for u in lin.uops.uops[-2].vin[1].vin]
|
||||
|
||||
assert a_bufs == [UOps.LOAD, UOps.CONST]
|
||||
assert b_bufs == [UOps.CONST, UOps.CONST]
|
||||
assert b_bufs == [] # [UOps.CONST, UOps.CONST] will be folded
|
||||
|
||||
def test_upcast_cse(self):
|
||||
# when upcasting, within a subtree, there may be common expressions.
|
||||
@@ -126,7 +126,6 @@ class TestLinearizer(unittest.TestCase):
|
||||
num_ops = len([uop for uop in k.uops if uop.uop == UOps.ALU])
|
||||
assert num_ops == 0, "more alu uops than needed"
|
||||
|
||||
@unittest.skip("constant folding not supported yet")
|
||||
def test_constant_fold(self):
|
||||
a, b = Tensor(2), Tensor(3)
|
||||
r = a * b
|
||||
@@ -216,7 +215,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
c0 = UOp(UOps.CONST, dtypes.float, vin=(), arg=0.0)
|
||||
c1 = UOp(UOps.CONST, dtypes.float, vin=(), arg=1.0)
|
||||
assert helper_test_simplify(UOps.ALU, dtypes.float, vin=(UOp(UOps.CONST, dtypes.bool, vin=(), arg=True), c0, c1),
|
||||
arg=TernaryOps.WHERE).uop == UOps.CONST
|
||||
arg=TernaryOps.WHERE).arg == c0.arg
|
||||
|
||||
def helper_realized_ast(r:Tensor):
|
||||
s = create_schedule([r.lazydata])
|
||||
|
||||
@@ -146,6 +146,16 @@ class TestExecALU(TestUOps):
|
||||
self.assertEqual(exec_alu(BinaryOps.DIV, dtypes.float32, (7.0, 3.0)), 2+(1.0/3.0))
|
||||
self.assertEqual(exec_alu(BinaryOps.DIV, dtypes.float32, (7.0, -3.0)), -2-(1.0/3.0))
|
||||
|
||||
def test_bool_neg(self):
|
||||
self.assertEqual(exec_alu(UnaryOps.NEG, dtypes.bool, (False,)), True)
|
||||
self.assertEqual(exec_alu(UnaryOps.NEG, dtypes.bool, (True,)), False)
|
||||
|
||||
def test_bool_cmplt(self):
|
||||
self.assertEqual(exec_alu(BinaryOps.CMPLT, dtypes.bool, (False, False)), False)
|
||||
self.assertEqual(exec_alu(BinaryOps.CMPLT, dtypes.bool, (False, True)), True)
|
||||
self.assertEqual(exec_alu(BinaryOps.CMPLT, dtypes.bool, (True, False)), False)
|
||||
self.assertEqual(exec_alu(BinaryOps.CMPLT, dtypes.bool, (True, True)), False)
|
||||
|
||||
def test_overflow(self):
|
||||
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.uint8, (250, 250)), 244)
|
||||
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.uint8, (256, 0)), 0)
|
||||
|
||||
@@ -34,14 +34,16 @@ def hook_overflow(dv, fxn):
|
||||
python_alu = {
|
||||
UnaryOps.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan,
|
||||
UnaryOps.EXP2: hook_overflow(math.inf, lambda x: math.exp(x*math.log(2))),
|
||||
UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, UnaryOps.SIN: math.sin, UnaryOps.NEG: operator.neg,
|
||||
UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, UnaryOps.SIN: math.sin,
|
||||
UnaryOps.NEG: lambda x: (not x) if isinstance(x, bool) else -x,
|
||||
BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.XOR: operator.xor,
|
||||
BinaryOps.MAX: max, BinaryOps.CMPEQ: operator.eq, BinaryOps.CMPLT: operator.lt, BinaryOps.MOD: operator.mod,
|
||||
BinaryOps.DIV: lambda x,y: int(x/y) if isinstance(x, int) else (x/y if y != 0 else math.nan),
|
||||
TernaryOps.WHERE: lambda x,y,z: y if x else z}
|
||||
|
||||
truncate: Dict[DType, Callable] = {
|
||||
**{dt:lambda x: x for dt in dtypes.fields().values() if dt == dtypes.bool or dtypes.is_float(dt)},
|
||||
dtypes.bool: lambda x: bool(x),
|
||||
**{dt:lambda x: x for dt in dtypes.fields().values() if dtypes.is_float(dt)},
|
||||
**{dt:functools.partial(lambda vv,x: x&vv, (1 << (dt.itemsize*8))-1) for dt in dtypes.fields().values() if dtypes.is_unsigned(dt)},
|
||||
**{dt:functools.partial(lambda vv,aa,x: ((x+aa)&vv)-aa, (1 << (dt.itemsize*8))-1, 1 << (dt.itemsize*8-1)) \
|
||||
for dt in dtypes.fields().values() if dtypes.is_int(dt) and not dtypes.is_unsigned(dt)}}
|
||||
@@ -89,12 +91,10 @@ class UOpGraph:
|
||||
if arg is BinaryOps.ADD and vin[1].uop is UOps.ALU and vin[1].arg is UnaryOps.NEG:
|
||||
return self.add(UOps.ALU, dtype, (vin[0], vin[1].vin[0]), BinaryOps.SUB, cachable, insert_before)
|
||||
# constant folding
|
||||
if arg is UnaryOps.NEG and vin[0].uop is UOps.CONST:
|
||||
return self.add(UOps.CONST, dtype, arg=-vin[0].arg if dtype != dtypes.bool else not vin[0].arg, insert_before=insert_before)
|
||||
if arg is TernaryOps.WHERE and vin[1] == vin[2]: return vin[1] # a conditional with the same results either way is a noop
|
||||
if arg is TernaryOps.WHERE and vin[0].uop is UOps.CONST: return vin[1] if vin[0].arg else vin[2]
|
||||
if arg is BinaryOps.MUL and vin[0].uop is UOps.CONST and vin[1].uop is UOps.CONST and dtype is not None and dtypes.is_float(dtype):
|
||||
return self.add(UOps.CONST, dtype, arg=vin[0].arg * vin[1].arg, insert_before=insert_before)
|
||||
if all(x.uop is UOps.CONST for x in vin):
|
||||
return self.add(UOps.CONST, dtype, arg=exec_alu(arg, dtype, [x.arg for x in vin]), insert_before=insert_before)
|
||||
# zero folding
|
||||
for x in [0,1]:
|
||||
if arg is BinaryOps.ADD and vin[x].uop is UOps.CONST and vin[x].arg == 0.0: return vin[1-x]
|
||||
|
||||
Reference in New Issue
Block a user