llama: immediate scaling on flag (#16494)

This commit is contained in:
wozeparrot
2026-06-04 13:30:00 -04:00
committed by GitHub
parent 4fb8ce1831
commit f11f63007d

View File

@@ -7,6 +7,7 @@ from tinygrad.uop.ops import UOp, Ops
STOCHASTIC_ROUND = getenv("STOCHASTIC_ROUND", 0)
MASTER_WEIGHTS = getenv("MASTER_WEIGHTS", 0)
FP8_AMAX_MARGIN = getenv("FP8_AMAX_MARGIN", 1.1)
IMMEDIATE_SCALE = getenv("IMMEDIATE_SCALE", 0)
def stochastic_round_bf16(x:Tensor) -> Tensor:
bits = x.bitcast(dtypes.uint32)
@@ -90,6 +91,13 @@ class GradAccClipAdamW(Optimizer):
return out.shard_like(t) if offloaded else out
if t.dtype in dtypes.fp8s:
from examples.mlperf.models.flat_llama import FP8_MAX
if IMMEDIATE_SCALE:
amax_axis = tuple(range(t._inv_scale.ndim, new_w.ndim))
new_inv = ((new_w.float().abs().max(axis=amax_axis).detach() + 1e-8) / FP8_MAX).cast(t._inv_scale.dtype)
t._inv_scale.assign(new_inv.shard_like(t._inv_scale) if offloaded else new_inv)
scale = new_inv.reciprocal().reshape(*new_inv.shape, *([1]*(new_w.ndim-new_inv.ndim)))
ret = (new_w * scale).clamp(-FP8_MAX, FP8_MAX).cast(t.dtype)
return ret.shard_like(t) if offloaded else ret
# delayed scaling: reuse previous step's inv_scale
t._inv_scale.assign(t._next_inv_scale)
inv_scale = t._inv_scale.to(new_w.device) if offloaded else t._inv_scale