From a1ec32cfd2931f36f7f06a8b84c8e893be5b8dd6 Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Fri, 5 Jun 2026 18:39:41 -0400 Subject: [PATCH] llama: current grad scaling (#16518) --- extra/gemm/cdna_asm_gemm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/extra/gemm/cdna_asm_gemm.py b/extra/gemm/cdna_asm_gemm.py index 0c367353c8..9c75d0753c 100644 --- a/extra/gemm/cdna_asm_gemm.py +++ b/extra/gemm/cdna_asm_gemm.py @@ -2783,7 +2783,9 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool= g_scale = Tensor(inv_scale_u, device=a.device) else: assert grad_amax_state is not None, "fp8 matmul bwd needs either a mailbox entry or a grad_amax_state" - if getenv("FUSED_GRAD_QUANTIZE", 0): + if getenv("CURRENT_GRAD_SCALE", 0): + g_fp8, g_scale, _ = quantize_fp8(g_t, amax_state=None) + elif getenv("FUSED_GRAD_QUANTIZE", 0): g_fp8, g_scale, _, store_effect = quantize_fp8_delayed(g_t, Tensor(grad_amax_state, device=a.device)) assert g_fp8.uop.op is Ops.AFTER, f"expected AFTER, got {g_fp8.uop.op}" g_fp8 = Tensor(g_fp8.uop.replace(src=g_fp8.uop.src + (store_effect,)), device=a.device)