From fefb0ebc2abbb162be9447224880e7ef977f79ec Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Thu, 2 Apr 2026 13:02:38 +0300 Subject: [PATCH] gemm/asm: fp8 cleanups (#15580) * normal gemm here * s/dtypes.fp8e4m3/FP8_DTYPE * gemm_bw * device UOp stays NULL --- extra/gemm/cdna_asm_gemm.py | 21 +++++++++++---------- test/backend/test_asm_gemm.py | 3 ++- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/extra/gemm/cdna_asm_gemm.py b/extra/gemm/cdna_asm_gemm.py index a110f16eba..e59ab4fda3 100644 --- a/extra/gemm/cdna_asm_gemm.py +++ b/extra/gemm/cdna_asm_gemm.py @@ -6,6 +6,7 @@ from tinygrad.renderer import Estimates from tinygrad.helpers import getenv, all_same, DEBUG from tinygrad.runtime.support.compiler_amd import HIPCCCompiler from tinygrad.runtime.autogen.amd.cdna.ins import * +from examples.mlperf.models.flat_llama import FP8_DTYPE # ** CDNA4 assembly gemm @@ -2655,7 +2656,7 @@ atexit.register(_asm_gemm_report) def can_use_asm_gemm(a:Tensor, b:Tensor) -> bool: if a.dtype != b.dtype: return todo(f"dtypes must match {a.dtype} != {b.dtype}") - if a.dtype not in {dtypes.bfloat16, dtypes.float16, dtypes.fp8e4m3}: return todo(f"only bfloat16/float16/fp8, got {a.dtype}") + if a.dtype not in {dtypes.bfloat16, dtypes.float16, FP8_DTYPE}: return todo(f"only bfloat16/float16/fp8, got {a.dtype}") batch, M, K = (1, *a.shape) if a.ndim == 2 else a.shape N = b.shape[1] if isinstance(a.device, tuple): @@ -2695,13 +2696,14 @@ def custom_uop_gemm(C:UOp, A:UOp, B:UOp) -> UOp: # ** backward gemm, might use the asm gemm -def custom_gemm_bw(gradient:UOp, kernel:UOp, rhs_transposed=False): +def custom_gemm_bw(gradient:UOp, kernel:UOp): out, a, b = kernel.src[1:] assert all_same([gradient.device, a.device, b.device, out.device]) a_t, b_t, g_t = Tensor(a, device=a.device), Tensor(b, device=a.device), Tensor(gradient, device=a.device) # TODO: this needs to be cleaned up and done properly, the batch dim of grad and a multi need to align g_t = g_t[:a.shape[0]] - if rhs_transposed: + if a.dtype.base == FP8_DTYPE: + # fp8 gemm computes a@b.T grad_a = (g_t @ b_t).uop grad_b = (g_t.permute(2, 0, 1).reshape(g_t.shape[2], -1) @ a_t.reshape(-1, a_t.shape[-1])).uop else: @@ -2720,7 +2722,7 @@ def asm_gemm(a:Tensor, b:Tensor) -> Tensor: a = a.reshape(a.shape[0]*a.shape[1], a.shape[2]) squeeze = a.ndim == 2 if squeeze: a = a.unsqueeze(0) - out_dtype = dtypes.bfloat16 if a.dtype == dtypes.fp8e4m3 else a.dtype + out_dtype = dtypes.bfloat16 if a.dtype == FP8_DTYPE else a.dtype batch, M, K = a.shape N = b.shape[1] @@ -2740,13 +2742,12 @@ def asm_gemm(a:Tensor, b:Tensor) -> Tensor: else: out = Tensor.invalid(batch, M, N, dtype=out_dtype, device=a.device) - renderer = Device[a.device[0] if is_multi else a.device].renderer - dname, arch = renderer.device, getattr(renderer, "arch", "") + renderer = Device[dname:=(a.device[0] if is_multi else a.device)].renderer + dname, arch = dname.split(":")[0], getattr(renderer, "arch", "") if arch.startswith("gfx950") and getenv("USE_ASM", 1): - # the FP8 gemm computes a @ b.T - if a.dtype == dtypes.fp8e4m3: - out = Tensor.custom_kernel(out, a, b.T, fxn=functools.partial(custom_hk_fp8_gemm, dname=dname), - grad_fxn=functools.partial(custom_gemm_bw, rhs_transposed=True))[0] + # fp8 gemm computes a@b.T + if a.dtype == FP8_DTYPE: + out = Tensor.custom_kernel(out, a, b.T, fxn=functools.partial(custom_hk_fp8_gemm, dname=dname), grad_fxn=custom_gemm_bw)[0] else: out = Tensor.custom_kernel(out, a, b, fxn=functools.partial(custom_asm_gemm, dname=dname), grad_fxn=custom_gemm_bw)[0] else: diff --git a/test/backend/test_asm_gemm.py b/test/backend/test_asm_gemm.py index d42c4702cd..8c39188ff7 100644 --- a/test/backend/test_asm_gemm.py +++ b/test/backend/test_asm_gemm.py @@ -4,6 +4,7 @@ from tinygrad.device import is_dtype_supported from tinygrad.helpers import getenv, system from extra.gemm.cdna_asm_gemm import asm_gemm from test.helpers import needs_second_gpu +from examples.mlperf.models.flat_llama import FP8_DTYPE # On non CDNA4 it will only validate the Tensor.custom_kernel integration # Use DEV=NULL EMULATE=AMD_CDNA4 to also test the assembly @@ -188,7 +189,7 @@ def has_hipcc(): return True @unittest.skipUnless(has_hipcc(), "FP8 gemm requires hipcc to compile") -class TestGemmLlamaFP8(TestGemmLlama): dtype = dtypes.fp8e4m3 +class TestGemmLlamaFP8(TestGemmLlama): dtype = FP8_DTYPE class TestMagicGu(unittest.TestCase): def test_magicgu_matches_old(self):