gemm/asm: fp8 cleanups (#15580)

* normal gemm here

* s/dtypes.fp8e4m3/FP8_DTYPE

* gemm_bw

* device UOp stays NULL
This commit is contained in:
qazal
2026-04-02 13:02:38 +03:00
committed by GitHub
parent 61bc91aa8c
commit fefb0ebc2a
2 changed files with 13 additions and 11 deletions

View File

@@ -6,6 +6,7 @@ from tinygrad.renderer import Estimates
from tinygrad.helpers import getenv, all_same, DEBUG
from tinygrad.runtime.support.compiler_amd import HIPCCCompiler
from tinygrad.runtime.autogen.amd.cdna.ins import *
from examples.mlperf.models.flat_llama import FP8_DTYPE
# ** CDNA4 assembly gemm
@@ -2655,7 +2656,7 @@ atexit.register(_asm_gemm_report)
def can_use_asm_gemm(a:Tensor, b:Tensor) -> bool:
if a.dtype != b.dtype: return todo(f"dtypes must match {a.dtype} != {b.dtype}")
if a.dtype not in {dtypes.bfloat16, dtypes.float16, dtypes.fp8e4m3}: return todo(f"only bfloat16/float16/fp8, got {a.dtype}")
if a.dtype not in {dtypes.bfloat16, dtypes.float16, FP8_DTYPE}: return todo(f"only bfloat16/float16/fp8, got {a.dtype}")
batch, M, K = (1, *a.shape) if a.ndim == 2 else a.shape
N = b.shape[1]
if isinstance(a.device, tuple):
@@ -2695,13 +2696,14 @@ def custom_uop_gemm(C:UOp, A:UOp, B:UOp) -> UOp:
# ** backward gemm, might use the asm gemm
def custom_gemm_bw(gradient:UOp, kernel:UOp, rhs_transposed=False):
def custom_gemm_bw(gradient:UOp, kernel:UOp):
out, a, b = kernel.src[1:]
assert all_same([gradient.device, a.device, b.device, out.device])
a_t, b_t, g_t = Tensor(a, device=a.device), Tensor(b, device=a.device), Tensor(gradient, device=a.device)
# TODO: this needs to be cleaned up and done properly, the batch dim of grad and a multi need to align
g_t = g_t[:a.shape[0]]
if rhs_transposed:
if a.dtype.base == FP8_DTYPE:
# fp8 gemm computes a@b.T
grad_a = (g_t @ b_t).uop
grad_b = (g_t.permute(2, 0, 1).reshape(g_t.shape[2], -1) @ a_t.reshape(-1, a_t.shape[-1])).uop
else:
@@ -2720,7 +2722,7 @@ def asm_gemm(a:Tensor, b:Tensor) -> Tensor:
a = a.reshape(a.shape[0]*a.shape[1], a.shape[2])
squeeze = a.ndim == 2
if squeeze: a = a.unsqueeze(0)
out_dtype = dtypes.bfloat16 if a.dtype == dtypes.fp8e4m3 else a.dtype
out_dtype = dtypes.bfloat16 if a.dtype == FP8_DTYPE else a.dtype
batch, M, K = a.shape
N = b.shape[1]
@@ -2740,13 +2742,12 @@ def asm_gemm(a:Tensor, b:Tensor) -> Tensor:
else:
out = Tensor.invalid(batch, M, N, dtype=out_dtype, device=a.device)
renderer = Device[a.device[0] if is_multi else a.device].renderer
dname, arch = renderer.device, getattr(renderer, "arch", "")
renderer = Device[dname:=(a.device[0] if is_multi else a.device)].renderer
dname, arch = dname.split(":")[0], getattr(renderer, "arch", "")
if arch.startswith("gfx950") and getenv("USE_ASM", 1):
# the FP8 gemm computes a @ b.T
if a.dtype == dtypes.fp8e4m3:
out = Tensor.custom_kernel(out, a, b.T, fxn=functools.partial(custom_hk_fp8_gemm, dname=dname),
grad_fxn=functools.partial(custom_gemm_bw, rhs_transposed=True))[0]
# fp8 gemm computes a@b.T
if a.dtype == FP8_DTYPE:
out = Tensor.custom_kernel(out, a, b.T, fxn=functools.partial(custom_hk_fp8_gemm, dname=dname), grad_fxn=custom_gemm_bw)[0]
else:
out = Tensor.custom_kernel(out, a, b, fxn=functools.partial(custom_asm_gemm, dname=dname), grad_fxn=custom_gemm_bw)[0]
else:

View File

@@ -4,6 +4,7 @@ from tinygrad.device import is_dtype_supported
from tinygrad.helpers import getenv, system
from extra.gemm.cdna_asm_gemm import asm_gemm
from test.helpers import needs_second_gpu
from examples.mlperf.models.flat_llama import FP8_DTYPE
# On non CDNA4 it will only validate the Tensor.custom_kernel integration
# Use DEV=NULL EMULATE=AMD_CDNA4 to also test the assembly
@@ -188,7 +189,7 @@ def has_hipcc():
return True
@unittest.skipUnless(has_hipcc(), "FP8 gemm requires hipcc to compile")
class TestGemmLlamaFP8(TestGemmLlama): dtype = dtypes.fp8e4m3
class TestGemmLlamaFP8(TestGemmLlama): dtype = FP8_DTYPE
class TestMagicGu(unittest.TestCase):
def test_magicgu_matches_old(self):