mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-11 23:46:02 +08:00
reorder binops (#9328)
* reorder binops * test improvements + fix string tests * ugh, okay this
This commit is contained in:
@@ -29,7 +29,7 @@ class TestArange(unittest.TestCase):
|
||||
f1 = self._get_flops(256, opts) + 1
|
||||
f2 = self._get_flops(2560, opts) + 1
|
||||
print(f"{f1=}, {f2=}")
|
||||
assert (f1 < 5000 and f2 < 5000) or (f2 / f1 < 15), f"bad complexity, flops {f2/f1:.1f}X while inputs 10X"
|
||||
assert (f1 < 6000 and f2 < 6000) or (f2 / f1 < 16), f"bad complexity, flops {f2/f1:.1f}X while inputs 10X"
|
||||
if limit is not None and not getenv("PTX"):
|
||||
# PTX counts index ALU in flops
|
||||
assert f1 <= limit, f"{f1=}, {limit=}"
|
||||
|
||||
@@ -102,7 +102,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
stores = [u for u in lin.uops if u.op is Ops.STORE]
|
||||
mutable_bufs = dedup(flatten([[x for x in u.src[0].toposort if x.op is Ops.DEFINE_GLOBAL] for u in stores]))
|
||||
assert len(mutable_bufs) == len(stores) == 2
|
||||
assert [u.arg for u in mutable_bufs] == [0, 1]
|
||||
self.assertSetEqual(set([u.arg for u in mutable_bufs]), set([0,1]))
|
||||
|
||||
def _test_no_nested_ranges(self, lins, skip=None):
|
||||
for l in lins:
|
||||
|
||||
@@ -260,7 +260,7 @@ class TestImageSimplification(unittest.TestCase):
|
||||
self.check(load,
|
||||
"((((idx2*2)+ridx0)<11)&((((idx1*8)+ridx1)<3)!=True))",
|
||||
"(((idx0+((idx1*512)+(ridx1*64)))+832)%1024)",
|
||||
"((((idx2*2)+ridx0)+(((idx1+((ridx1+5)//8))+1)//2))+-4)")
|
||||
"(((((idx1+((ridx1+5)//8))+1)//2)+((idx2*2)+ridx0))+-4)")
|
||||
|
||||
def test_simplify1(self):
|
||||
# idx has the form (A % m, A // m + k) and valid has (c0 < A) and (A < c1)
|
||||
|
||||
@@ -197,7 +197,8 @@ class TestSymbolic(unittest.TestCase):
|
||||
|
||||
def test_mod_congruence_multiple_vars(self):
|
||||
self.helper_test_variable((9+9*Variable("x",0,3)+9*Variable("y",0,3))%10, 3, 9, "(((x*-1)+(y*-1))+9)")
|
||||
self.helper_test_variable((7+9*Variable("x",0,2)+9*Variable("y",0,2)+Variable("z",0,2))%10, 3, 9, "(((z+(x*-1))+(y*-1))+7)")
|
||||
self.helper_test_variable((7+9*Variable("x",0,2)+9*Variable("y",0,2)+Variable("z",0,2))%10, 3, 9,
|
||||
("(((z+(x*-1))+(y*-1))+7)", "(((y*-1)+(z+(x*-1)))+7)"))
|
||||
self.helper_test_variable((10+12*Variable("x",0,2)+Variable("y", 0, 4)%3)%13, 8, 12, "(((x*-1)+(y%3))+10)")
|
||||
|
||||
def test_div_congruence(self):
|
||||
@@ -289,11 +290,11 @@ class TestSymbolic(unittest.TestCase):
|
||||
|
||||
def test_lt_sum_factor_rhs_partial(self):
|
||||
self.helper_test_variable((Variable("a", 0, 6)*6 + Variable("b", 0, 6)*4 + Variable("c", 0, 6)*8) < 4, 0, 1,
|
||||
"((((a*3)+(b*2))+(c*4))<2)")
|
||||
("((((a*3)+(b*2))+(c*4))<2)", "(((b*2)+((a*3)+(c*4)))<2)"))
|
||||
|
||||
def test_lt_sum_factor_rhs_all(self):
|
||||
self.helper_test_variable((Variable("a", 0, 6)*6 + Variable("b", 0, 6)*4 + Variable("c", 0, 6)*8) < 2, 0, 1,
|
||||
"((((a*3)+(b*2))+(c*4))<1)")
|
||||
("((((a*3)+(b*2))+(c*4))<1)", "(((b*2)+((a*3)+(c*4)))<1)"))
|
||||
|
||||
def test_and_fold(self):
|
||||
self.helper_test_variable(uand([uconst(0), Variable("a", 0, 1)]), 0, 0, "0")
|
||||
@@ -369,6 +370,7 @@ class TestSymbolic(unittest.TestCase):
|
||||
alu0 = gidx2*640+gidx1*160+(gidx0//5)*2+lidx0*320+lidx1*10
|
||||
self.helper_test_variable((alu0+lidx2*2+1)//20, 0, 8192,
|
||||
("((((((gidx0//5)+lidx2)//5)+lidx1)//2)+(((gidx2*32)+(gidx1*8))+(lidx0*16)))",
|
||||
"(((lidx1+((lidx2+(gidx0//5))//5))//2)+((gidx2*32)+((gidx1*8)+(lidx0*16))))",
|
||||
"((((gidx1*8)+(gidx2*32))+(lidx0*16))+((lidx1+((lidx2+(gidx0//5))//5))//2))"))
|
||||
|
||||
def test_sum_div_complex2(self):
|
||||
@@ -390,7 +392,8 @@ class TestSymbolic(unittest.TestCase):
|
||||
gidx0 = Variable("gidx0", 0, 7)
|
||||
lidx2 = Variable("lidx2", 0, 12)
|
||||
lidx3 = Variable("lidx3", 0, 1)
|
||||
self.helper_test_variable((gidx0+lidx2+lidx3)*4, 0, 80, "(((gidx0*4)+(lidx2*4))+(lidx3*4))")
|
||||
self.helper_test_variable((gidx0+lidx2+lidx3)*4, 0, 80,
|
||||
("(((gidx0*4)+(lidx2*4))+(lidx3*4))","((lidx3*4)+((gidx0*4)+(lidx2*4)))"))
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_variable_divmod(self):
|
||||
@@ -497,8 +500,8 @@ class TestSymbolic(unittest.TestCase):
|
||||
self.helper_test_variable((a*3+b*4<1).ne(True), 0, 1, "(((a+b)<1)!=True)")
|
||||
self.helper_test_variable((a*(-3)+b*4<1).ne(True), 0, 1, "((((a*-3)+(b*4))<1)!=True)") # negative coeff, should not be simplified
|
||||
self.helper_test_variable((a*3+d*4<1).ne(True), 0, 1, "((((a*3)+(d*4))<1)!=True)") # var can be negative, should not be simplified
|
||||
self.helper_test_variable((a+b+c*2<1).ne(True), 0, 1, ("((((a+b)+c)<1)!=True)", "(((c+(a+b))<1)!=True)"))
|
||||
self.helper_test_variable((a+b*2+c*4<1).ne(True), 0, 1, ("((((a+b)+c)<1)!=True)", "(((c+(a+b))<1)!=True)"))
|
||||
self.helper_test_variable((a+b+c*2<1).ne(True), 0, 1, ("((((a+b)+c)<1)!=True)", "(((c+(a+b))<1)!=True)", '(((b+(a+c))<1)!=True)'))
|
||||
self.helper_test_variable((a+b*2+c*4<1).ne(True), 0, 1, ("((((a+b)+c)<1)!=True)", "(((c+(a+b))<1)!=True)", '(((b+(a+c))<1)!=True)'))
|
||||
|
||||
def test_where_removal(self):
|
||||
cond = Variable("a", 0, 3) < 2
|
||||
@@ -740,6 +743,7 @@ class TestSymbolicRealWorld(unittest.TestCase):
|
||||
self.assertIn(idx.render(),
|
||||
("((((((((((lidx5+1)//16)*802816)+(((lidx5+1)%16)*49))+(gidx0*3211264))+(gidx1*784))+(gidx2*8))+(lidx4*100352))+lidx3)+2207744)",
|
||||
'((lidx3+((((((((lidx5+1)//16)*802816)+(((lidx5+1)%16)*49))+(gidx0*3211264))+(gidx1*784))+(gidx2*8))+(lidx4*100352)))+2207744)',
|
||||
'((lidx3+((lidx4*100352)+((gidx2*8)+((gidx1*784)+((gidx0*3211264)+((((lidx5+1)//16)*802816)+(((lidx5+1)%16)*49)))))))+2207744)',
|
||||
))
|
||||
|
||||
class TestBounds(unittest.TestCase):
|
||||
|
||||
@@ -132,8 +132,8 @@ class Ops(FastEnum):
|
||||
WMMA = auto()
|
||||
|
||||
# BinaryOps
|
||||
ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702
|
||||
SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto(); FDIV = auto(); POW = auto() # noqa: E702
|
||||
MUL = auto(); SHL = auto(); SHR = auto(); IDIV = auto(); ADD = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto() # noqa: E702
|
||||
XOR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto(); FDIV = auto(); POW = auto() # noqa: E702
|
||||
|
||||
# TernaryOps
|
||||
WHERE = auto(); MULACC = auto() # noqa: E702
|
||||
|
||||
Reference in New Issue
Block a user