mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
gemm/asm: fp8 cleanups (#15580)
* normal gemm here * s/dtypes.fp8e4m3/FP8_DTYPE * gemm_bw * device UOp stays NULL
This commit is contained in:
@@ -6,6 +6,7 @@ from tinygrad.renderer import Estimates
|
||||
from tinygrad.helpers import getenv, all_same, DEBUG
|
||||
from tinygrad.runtime.support.compiler_amd import HIPCCCompiler
|
||||
from tinygrad.runtime.autogen.amd.cdna.ins import *
|
||||
from examples.mlperf.models.flat_llama import FP8_DTYPE
|
||||
|
||||
# ** CDNA4 assembly gemm
|
||||
|
||||
@@ -2655,7 +2656,7 @@ atexit.register(_asm_gemm_report)
|
||||
|
||||
def can_use_asm_gemm(a:Tensor, b:Tensor) -> bool:
|
||||
if a.dtype != b.dtype: return todo(f"dtypes must match {a.dtype} != {b.dtype}")
|
||||
if a.dtype not in {dtypes.bfloat16, dtypes.float16, dtypes.fp8e4m3}: return todo(f"only bfloat16/float16/fp8, got {a.dtype}")
|
||||
if a.dtype not in {dtypes.bfloat16, dtypes.float16, FP8_DTYPE}: return todo(f"only bfloat16/float16/fp8, got {a.dtype}")
|
||||
batch, M, K = (1, *a.shape) if a.ndim == 2 else a.shape
|
||||
N = b.shape[1]
|
||||
if isinstance(a.device, tuple):
|
||||
@@ -2695,13 +2696,14 @@ def custom_uop_gemm(C:UOp, A:UOp, B:UOp) -> UOp:
|
||||
|
||||
# ** backward gemm, might use the asm gemm
|
||||
|
||||
def custom_gemm_bw(gradient:UOp, kernel:UOp, rhs_transposed=False):
|
||||
def custom_gemm_bw(gradient:UOp, kernel:UOp):
|
||||
out, a, b = kernel.src[1:]
|
||||
assert all_same([gradient.device, a.device, b.device, out.device])
|
||||
a_t, b_t, g_t = Tensor(a, device=a.device), Tensor(b, device=a.device), Tensor(gradient, device=a.device)
|
||||
# TODO: this needs to be cleaned up and done properly, the batch dim of grad and a multi need to align
|
||||
g_t = g_t[:a.shape[0]]
|
||||
if rhs_transposed:
|
||||
if a.dtype.base == FP8_DTYPE:
|
||||
# fp8 gemm computes a@b.T
|
||||
grad_a = (g_t @ b_t).uop
|
||||
grad_b = (g_t.permute(2, 0, 1).reshape(g_t.shape[2], -1) @ a_t.reshape(-1, a_t.shape[-1])).uop
|
||||
else:
|
||||
@@ -2720,7 +2722,7 @@ def asm_gemm(a:Tensor, b:Tensor) -> Tensor:
|
||||
a = a.reshape(a.shape[0]*a.shape[1], a.shape[2])
|
||||
squeeze = a.ndim == 2
|
||||
if squeeze: a = a.unsqueeze(0)
|
||||
out_dtype = dtypes.bfloat16 if a.dtype == dtypes.fp8e4m3 else a.dtype
|
||||
out_dtype = dtypes.bfloat16 if a.dtype == FP8_DTYPE else a.dtype
|
||||
|
||||
batch, M, K = a.shape
|
||||
N = b.shape[1]
|
||||
@@ -2740,13 +2742,12 @@ def asm_gemm(a:Tensor, b:Tensor) -> Tensor:
|
||||
else:
|
||||
out = Tensor.invalid(batch, M, N, dtype=out_dtype, device=a.device)
|
||||
|
||||
renderer = Device[a.device[0] if is_multi else a.device].renderer
|
||||
dname, arch = renderer.device, getattr(renderer, "arch", "")
|
||||
renderer = Device[dname:=(a.device[0] if is_multi else a.device)].renderer
|
||||
dname, arch = dname.split(":")[0], getattr(renderer, "arch", "")
|
||||
if arch.startswith("gfx950") and getenv("USE_ASM", 1):
|
||||
# the FP8 gemm computes a @ b.T
|
||||
if a.dtype == dtypes.fp8e4m3:
|
||||
out = Tensor.custom_kernel(out, a, b.T, fxn=functools.partial(custom_hk_fp8_gemm, dname=dname),
|
||||
grad_fxn=functools.partial(custom_gemm_bw, rhs_transposed=True))[0]
|
||||
# fp8 gemm computes a@b.T
|
||||
if a.dtype == FP8_DTYPE:
|
||||
out = Tensor.custom_kernel(out, a, b.T, fxn=functools.partial(custom_hk_fp8_gemm, dname=dname), grad_fxn=custom_gemm_bw)[0]
|
||||
else:
|
||||
out = Tensor.custom_kernel(out, a, b, fxn=functools.partial(custom_asm_gemm, dname=dname), grad_fxn=custom_gemm_bw)[0]
|
||||
else:
|
||||
|
||||
@@ -4,6 +4,7 @@ from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.helpers import getenv, system
|
||||
from extra.gemm.cdna_asm_gemm import asm_gemm
|
||||
from test.helpers import needs_second_gpu
|
||||
from examples.mlperf.models.flat_llama import FP8_DTYPE
|
||||
|
||||
# On non CDNA4 it will only validate the Tensor.custom_kernel integration
|
||||
# Use DEV=NULL EMULATE=AMD_CDNA4 to also test the assembly
|
||||
@@ -188,7 +189,7 @@ def has_hipcc():
|
||||
return True
|
||||
|
||||
@unittest.skipUnless(has_hipcc(), "FP8 gemm requires hipcc to compile")
|
||||
class TestGemmLlamaFP8(TestGemmLlama): dtype = dtypes.fp8e4m3
|
||||
class TestGemmLlamaFP8(TestGemmLlama): dtype = FP8_DTYPE
|
||||
|
||||
class TestMagicGu(unittest.TestCase):
|
||||
def test_magicgu_matches_old(self):
|
||||
|
||||
Reference in New Issue
Block a user