use is_dtype_supported to check dtype support in tc tests (#10035)

This commit is contained in:
Ignacio Sica
2025-04-24 14:59:14 -03:00
committed by GitHub
parent 93a1e9eeb9
commit 373ca59b7f

View File

@@ -1062,16 +1062,14 @@ class TestLinearizer(unittest.TestCase):
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
def test_tensor_cores(self):
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
if (getenv("EMULATE_CUDA") or getenv("EMULATE_INTEL") or getenv("EMULATE_METAL") or getenv("EMULATE_AMD_MFMA") or getenv("EMULATE_AMD")) and \
(tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue
if CI and Device.DEFAULT in ("METAL", "AMD") and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue
if not is_dtype_supported(tc.dtype_in) or not is_dtype_supported(tc.dtype_out): continue
# for AMX, tc.dims[2] == 1 so reduceop is None thus tensor_cores are not triggered
helper_tc_allclose(tc.dims[0], tc.dims[1], 2 if AMX else tc.dims[2], tc.dtype_in, tc.dtype_out, axis=0, tc_opt=0)
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
def test_tensor_cores_codegen(self):
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
if CI and Device.DEFAULT == "AMD" and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue
if not is_dtype_supported(tc.dtype_in) or not is_dtype_supported(tc.dtype_out): continue
n, m, k = tc.dims[0], tc.dims[1], 2 if AMX else tc.dims[2]
a, b = Tensor.rand(m, k, dtype=tc.dtype_in), Tensor.rand(k, n, dtype=tc.dtype_in)
r = a.matmul(b, dtype=tc.dtype_out)
@@ -1092,16 +1090,14 @@ class TestLinearizer(unittest.TestCase):
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
def test_tensor_cores_padded(self):
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
if (getenv("EMULATE_CUDA") or getenv("EMULATE_METAL") or getenv("EMULATE_AMD_MFMA") or getenv("EMULATE_AMD")) and \
(tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue
if CI and Device.DEFAULT in ("METAL", "AMD") and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue
if not is_dtype_supported(tc.dtype_in) or not is_dtype_supported(tc.dtype_out): continue
helper_tc_allclose(tc.dims[0]+(pad:=1), tc.dims[1]+pad, tc.dims[2]+pad, tc.dtype_in, tc.dtype_out, tc_opt=2)
@unittest.skipUnless(Device.DEFAULT in {"AMD"}, "Test for AMD device")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
def test_tensor_cores_padded_amd(self):
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
if CI and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue
if not is_dtype_supported(tc.dtype_in) or not is_dtype_supported(tc.dtype_out): continue
helper_tc_allclose(tc.dims[0]+(pad:=3), tc.dims[1]+pad, tc.dims[2]+pad, tc.dtype_in, tc.dtype_out, tc_opt=2)
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
@@ -1129,7 +1125,7 @@ class TestLinearizer(unittest.TestCase):
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
def test_tensor_cores_multi_reduce(self):
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
if tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16: continue
if not is_dtype_supported(tc.dtype_in) or not is_dtype_supported(tc.dtype_out): continue
# this will be a M=G16, N=G32, M=G16, M=G16, K=R16, K=R16, K=R16 with 9 choices of TC MNK axes
golden_result = None
for axis in range(9):