diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 8faf467dbb..e0dd59f734 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -902,7 +902,8 @@ class TestLinearizer(unittest.TestCase): lin = helper_linearizer_opt(out, wanna_output=[a.numpy().sum()*a.numpy().sum()])[0] # RANGE -> LOAD -> ASSIGN -> ALU end = max(i for i,u in enumerate(lin.uops) if u.op is Ops.ENDRANGE) - assert lin.uops[end+1].op in GroupOp.ALU + # the INDEX can be first + assert lin.uops[end+1].op in GroupOp.ALU or lin.uops[end+2].op in GroupOp.ALU def test_range_outer_op_after_phi_nested_range(self): a = Tensor.randn(2, ).realize() @@ -910,7 +911,8 @@ class TestLinearizer(unittest.TestCase): lin = helper_linearizer_opt(out, wanna_output=[(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3))).sum()*2])[0] # RANGE -> LOAD -> ASSIGN -> ALU end = max(i for i,u in enumerate(lin.uops) if u.op is Ops.ENDRANGE) - assert lin.uops[end+1].op in GroupOp.ALU + # the INDEX can be first + assert lin.uops[end+1].op in GroupOp.ALU or lin.uops[end+2].op in GroupOp.ALU def test_load_dedup(self): # for different leaves in the AST, the same loads may occur. diff --git a/test/test_uops.py b/test/test_uops.py index 53060f44a1..85896bf82b 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -336,8 +336,9 @@ class TestAssembly(unittest.TestCase): a2 = UOp(Ops.MUL, dtypes.int, (l1, c2)) uops = to_uops_list([a1,a2], opts=Device[Device.DEFAULT].renderer) Device[Device.DEFAULT].renderer.render("test", uops) - self.assertEqual(uops[-1].op, Ops.SHL) - self.assertEqual(uops[-2].op, Ops.MUL) + ops = [x.op for x in uops] + self.assertIn(Ops.SHL, ops) + self.assertIn(Ops.MUL, ops) def test_bitshift_right(self): g1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0) @@ -348,8 +349,9 @@ class TestAssembly(unittest.TestCase): a2 = UOp(Ops.IDIV, dtypes.int, (l1, c2)) uops = to_uops_list([a1,a2], opts=Device[Device.DEFAULT].renderer) Device[Device.DEFAULT].renderer.render("test", uops) - self.assertEqual(uops[-1].op, Ops.SHR) - self.assertEqual(uops[-2].op, Ops.IDIV) + ops = [x.op for x in uops] + self.assertIn(Ops.SHR, ops) + self.assertIn(Ops.IDIV, ops) class TestUOpMethod(unittest.TestCase): @unittest.skip("uops lt no longer ordered") diff --git a/tinygrad/ops.py b/tinygrad/ops.py index e091371b28..f19b50536f 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -107,6 +107,9 @@ class Ops(FastEnum): EMPTY = auto() BUFFER_VIEW = auto() + # blocks in linearizer + BLOCK = auto(); BLOCKSTART = auto(); BLOCKFORK = auto(); BLOCKEND = auto() # noqa: E702 + EXPAND = auto() CONTRACT = auto() VIEW = auto() diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 6b4bd8b7a4..2d93b4ced4 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -11,7 +11,8 @@ from tinygrad.codegen.kernel import Kernel uops_colors = {Ops.LOAD: "#ffc0c0", Ops.PRELOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_ACC: "#f0ffe0", Ops.REDUCE_AXIS: "#FF6B6B", Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#e0ffc0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff", - Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.BUFFER: "#B0BDFF",} + Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", **{x:"#ffffc0" for x in GroupOp.ALU}, + Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF",} # ** API spec