llama: current grad scaling (#16518)

This commit is contained in:
wozeparrot
2026-06-05 18:39:41 -04:00
committed by GitHub
parent 8c0ba1da5c
commit a1ec32cfd2

View File

@@ -2783,7 +2783,9 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool=
g_scale = Tensor(inv_scale_u, device=a.device)
else:
assert grad_amax_state is not None, "fp8 matmul bwd needs either a mailbox entry or a grad_amax_state"
if getenv("FUSED_GRAD_QUANTIZE", 0):
if getenv("CURRENT_GRAD_SCALE", 0):
g_fp8, g_scale, _ = quantize_fp8(g_t, amax_state=None)
elif getenv("FUSED_GRAD_QUANTIZE", 0):
g_fp8, g_scale, _, store_effect = quantize_fp8_delayed(g_t, Tensor(grad_amax_state, device=a.device))
assert g_fp8.uop.op is Ops.AFTER, f"expected AFTER, got {g_fp8.uop.op}"
g_fp8 = Tensor(g_fp8.uop.replace(src=g_fp8.uop.src + (store_effect,)), device=a.device)