diff --git a/extra/gemm/cdna_asm_gemm.py b/extra/gemm/cdna_asm_gemm.py index 0c367353c8..9c75d0753c 100644 --- a/extra/gemm/cdna_asm_gemm.py +++ b/extra/gemm/cdna_asm_gemm.py @@ -2783,7 +2783,9 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool= g_scale = Tensor(inv_scale_u, device=a.device) else: assert grad_amax_state is not None, "fp8 matmul bwd needs either a mailbox entry or a grad_amax_state" - if getenv("FUSED_GRAD_QUANTIZE", 0): + if getenv("CURRENT_GRAD_SCALE", 0): + g_fp8, g_scale, _ = quantize_fp8(g_t, amax_state=None) + elif getenv("FUSED_GRAD_QUANTIZE", 0): g_fp8, g_scale, _, store_effect = quantize_fp8_delayed(g_t, Tensor(grad_amax_state, device=a.device)) assert g_fp8.uop.op is Ops.AFTER, f"expected AFTER, got {g_fp8.uop.op}" g_fp8 = Tensor(g_fp8.uop.replace(src=g_fp8.uop.src + (store_effect,)), device=a.device)