diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 405ccd535f..d4c7145cfe 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -1837,20 +1837,34 @@ class TestKernelOpts(unittest.TestCase): [Opt(OptOps.UPCAST, 1, 4)], [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4)], # check upcasts [Opt(OptOps.UNROLL, 0, 2)], # check unroll - [Opt(OptOps.UNROLL, 0, 0)], # check full unroll of reduce with locals - [Opt(OptOps.LOCAL, 0, 4)], # check local [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 2)], # check combo of unroll and local [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 2)], [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4)], - [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.LOCAL, 0, 2)], [Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UPCAST, 0, 4)], # check permutations [Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 0, 4)], [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 1, 4)], [Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)], - [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 0, 4)], # [Opt(OptOps.GROUP, 0, 2)] # doesn't work because group_for_reduce dims become early locals (conflicting with TC) ], apply_tc=True, atol=atol, rtol=rtol) + @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") + @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") + def test_tensor_core_opts_locals(self): + N = 128 + Tensor.manual_seed(1552) + for tc in Device[Device.DEFAULT].renderer.tensor_cores: + # bf16 buffer returns float32 numpy outputs so test would fail. testing opt with half suffices. + if tc.dtype_in == dtypes.bfloat16: continue + a, b = Tensor.rand(N, N, dtype=tc.dtype_in), Tensor.rand(N, N, dtype=tc.dtype_in) + r = a.matmul(b, acc_dtype=tc.dtype_out) + (atol, rtol) = ((0.25, 0.01) if tc.dtype_out == dtypes.half else (3e-2, 1e-3)) if tc.dtype_in == dtypes.half else (1e-4, 1e-4) + helper_linearizer_opt(r, [ + [Opt(OptOps.UNROLL, 0, 0)], # check full unroll of reduce with locals + [Opt(OptOps.LOCAL, 0, 4)], # check local + [Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.LOCAL, 0, 2)], + [Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 2), Opt(OptOps.UPCAST, 0, 4)], + ], apply_tc=True, atol=atol, rtol=rtol) + @unittest.skip("parallel tensor cores") @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") def test_fused_tensor_core_simple(self):