mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
llama: current grad scaling (#16518)
This commit is contained in:
@@ -2783,7 +2783,9 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool=
|
||||
g_scale = Tensor(inv_scale_u, device=a.device)
|
||||
else:
|
||||
assert grad_amax_state is not None, "fp8 matmul bwd needs either a mailbox entry or a grad_amax_state"
|
||||
if getenv("FUSED_GRAD_QUANTIZE", 0):
|
||||
if getenv("CURRENT_GRAD_SCALE", 0):
|
||||
g_fp8, g_scale, _ = quantize_fp8(g_t, amax_state=None)
|
||||
elif getenv("FUSED_GRAD_QUANTIZE", 0):
|
||||
g_fp8, g_scale, _, store_effect = quantize_fp8_delayed(g_t, Tensor(grad_amax_state, device=a.device))
|
||||
assert g_fp8.uop.op is Ops.AFTER, f"expected AFTER, got {g_fp8.uop.op}"
|
||||
g_fp8 = Tensor(g_fp8.uop.replace(src=g_fp8.uop.src + (store_effect,)), device=a.device)
|
||||
|
||||
Reference in New Issue
Block a user