reorder binops (#9328)

* reorder binops

* test improvements + fix string tests

* ugh, okay this
This commit is contained in:
George Hotz
2025-03-03 14:58:18 +08:00
committed by GitHub
parent 146eb73790
commit 2cc4cb74f0
5 changed files with 15 additions and 11 deletions

View File

@@ -29,7 +29,7 @@ class TestArange(unittest.TestCase):
f1 = self._get_flops(256, opts) + 1
f2 = self._get_flops(2560, opts) + 1
print(f"{f1=}, {f2=}")
assert (f1 < 5000 and f2 < 5000) or (f2 / f1 < 15), f"bad complexity, flops {f2/f1:.1f}X while inputs 10X"
assert (f1 < 6000 and f2 < 6000) or (f2 / f1 < 16), f"bad complexity, flops {f2/f1:.1f}X while inputs 10X"
if limit is not None and not getenv("PTX"):
# PTX counts index ALU in flops
assert f1 <= limit, f"{f1=}, {limit=}"

View File

@@ -102,7 +102,7 @@ class TestLinearizer(unittest.TestCase):
stores = [u for u in lin.uops if u.op is Ops.STORE]
mutable_bufs = dedup(flatten([[x for x in u.src[0].toposort if x.op is Ops.DEFINE_GLOBAL] for u in stores]))
assert len(mutable_bufs) == len(stores) == 2
assert [u.arg for u in mutable_bufs] == [0, 1]
self.assertSetEqual(set([u.arg for u in mutable_bufs]), set([0,1]))
def _test_no_nested_ranges(self, lins, skip=None):
for l in lins:

View File

@@ -260,7 +260,7 @@ class TestImageSimplification(unittest.TestCase):
self.check(load,
"((((idx2*2)+ridx0)<11)&((((idx1*8)+ridx1)<3)!=True))",
"(((idx0+((idx1*512)+(ridx1*64)))+832)%1024)",
"((((idx2*2)+ridx0)+(((idx1+((ridx1+5)//8))+1)//2))+-4)")
"(((((idx1+((ridx1+5)//8))+1)//2)+((idx2*2)+ridx0))+-4)")
def test_simplify1(self):
# idx has the form (A % m, A // m + k) and valid has (c0 < A) and (A < c1)

View File

@@ -197,7 +197,8 @@ class TestSymbolic(unittest.TestCase):
def test_mod_congruence_multiple_vars(self):
self.helper_test_variable((9+9*Variable("x",0,3)+9*Variable("y",0,3))%10, 3, 9, "(((x*-1)+(y*-1))+9)")
self.helper_test_variable((7+9*Variable("x",0,2)+9*Variable("y",0,2)+Variable("z",0,2))%10, 3, 9, "(((z+(x*-1))+(y*-1))+7)")
self.helper_test_variable((7+9*Variable("x",0,2)+9*Variable("y",0,2)+Variable("z",0,2))%10, 3, 9,
("(((z+(x*-1))+(y*-1))+7)", "(((y*-1)+(z+(x*-1)))+7)"))
self.helper_test_variable((10+12*Variable("x",0,2)+Variable("y", 0, 4)%3)%13, 8, 12, "(((x*-1)+(y%3))+10)")
def test_div_congruence(self):
@@ -289,11 +290,11 @@ class TestSymbolic(unittest.TestCase):
def test_lt_sum_factor_rhs_partial(self):
self.helper_test_variable((Variable("a", 0, 6)*6 + Variable("b", 0, 6)*4 + Variable("c", 0, 6)*8) < 4, 0, 1,
"((((a*3)+(b*2))+(c*4))<2)")
("((((a*3)+(b*2))+(c*4))<2)", "(((b*2)+((a*3)+(c*4)))<2)"))
def test_lt_sum_factor_rhs_all(self):
self.helper_test_variable((Variable("a", 0, 6)*6 + Variable("b", 0, 6)*4 + Variable("c", 0, 6)*8) < 2, 0, 1,
"((((a*3)+(b*2))+(c*4))<1)")
("((((a*3)+(b*2))+(c*4))<1)", "(((b*2)+((a*3)+(c*4)))<1)"))
def test_and_fold(self):
self.helper_test_variable(uand([uconst(0), Variable("a", 0, 1)]), 0, 0, "0")
@@ -369,6 +370,7 @@ class TestSymbolic(unittest.TestCase):
alu0 = gidx2*640+gidx1*160+(gidx0//5)*2+lidx0*320+lidx1*10
self.helper_test_variable((alu0+lidx2*2+1)//20, 0, 8192,
("((((((gidx0//5)+lidx2)//5)+lidx1)//2)+(((gidx2*32)+(gidx1*8))+(lidx0*16)))",
"(((lidx1+((lidx2+(gidx0//5))//5))//2)+((gidx2*32)+((gidx1*8)+(lidx0*16))))",
"((((gidx1*8)+(gidx2*32))+(lidx0*16))+((lidx1+((lidx2+(gidx0//5))//5))//2))"))
def test_sum_div_complex2(self):
@@ -390,7 +392,8 @@ class TestSymbolic(unittest.TestCase):
gidx0 = Variable("gidx0", 0, 7)
lidx2 = Variable("lidx2", 0, 12)
lidx3 = Variable("lidx3", 0, 1)
self.helper_test_variable((gidx0+lidx2+lidx3)*4, 0, 80, "(((gidx0*4)+(lidx2*4))+(lidx3*4))")
self.helper_test_variable((gidx0+lidx2+lidx3)*4, 0, 80,
("(((gidx0*4)+(lidx2*4))+(lidx3*4))","((lidx3*4)+((gidx0*4)+(lidx2*4)))"))
@unittest.expectedFailure
def test_variable_divmod(self):
@@ -497,8 +500,8 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable((a*3+b*4<1).ne(True), 0, 1, "(((a+b)<1)!=True)")
self.helper_test_variable((a*(-3)+b*4<1).ne(True), 0, 1, "((((a*-3)+(b*4))<1)!=True)") # negative coeff, should not be simplified
self.helper_test_variable((a*3+d*4<1).ne(True), 0, 1, "((((a*3)+(d*4))<1)!=True)") # var can be negative, should not be simplified
self.helper_test_variable((a+b+c*2<1).ne(True), 0, 1, ("((((a+b)+c)<1)!=True)", "(((c+(a+b))<1)!=True)"))
self.helper_test_variable((a+b*2+c*4<1).ne(True), 0, 1, ("((((a+b)+c)<1)!=True)", "(((c+(a+b))<1)!=True)"))
self.helper_test_variable((a+b+c*2<1).ne(True), 0, 1, ("((((a+b)+c)<1)!=True)", "(((c+(a+b))<1)!=True)", '(((b+(a+c))<1)!=True)'))
self.helper_test_variable((a+b*2+c*4<1).ne(True), 0, 1, ("((((a+b)+c)<1)!=True)", "(((c+(a+b))<1)!=True)", '(((b+(a+c))<1)!=True)'))
def test_where_removal(self):
cond = Variable("a", 0, 3) < 2
@@ -740,6 +743,7 @@ class TestSymbolicRealWorld(unittest.TestCase):
self.assertIn(idx.render(),
("((((((((((lidx5+1)//16)*802816)+(((lidx5+1)%16)*49))+(gidx0*3211264))+(gidx1*784))+(gidx2*8))+(lidx4*100352))+lidx3)+2207744)",
'((lidx3+((((((((lidx5+1)//16)*802816)+(((lidx5+1)%16)*49))+(gidx0*3211264))+(gidx1*784))+(gidx2*8))+(lidx4*100352)))+2207744)',
'((lidx3+((lidx4*100352)+((gidx2*8)+((gidx1*784)+((gidx0*3211264)+((((lidx5+1)//16)*802816)+(((lidx5+1)%16)*49)))))))+2207744)',
))
class TestBounds(unittest.TestCase):

View File

@@ -132,8 +132,8 @@ class Ops(FastEnum):
WMMA = auto()
# BinaryOps
ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702
SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto(); FDIV = auto(); POW = auto() # noqa: E702
MUL = auto(); SHL = auto(); SHR = auto(); IDIV = auto(); ADD = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto() # noqa: E702
XOR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto(); FDIV = auto(); POW = auto() # noqa: E702
# TernaryOps
WHERE = auto(); MULACC = auto() # noqa: E702