From f11f63007d09c69ca5ecfeda7d9652da972b95f2 Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Thu, 4 Jun 2026 13:30:00 -0400 Subject: [PATCH] llama: immediate scaling on flag (#16494) --- examples/mlperf/optim.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/examples/mlperf/optim.py b/examples/mlperf/optim.py index 053bc5d3e6..fb92ecd10a 100644 --- a/examples/mlperf/optim.py +++ b/examples/mlperf/optim.py @@ -7,6 +7,7 @@ from tinygrad.uop.ops import UOp, Ops STOCHASTIC_ROUND = getenv("STOCHASTIC_ROUND", 0) MASTER_WEIGHTS = getenv("MASTER_WEIGHTS", 0) FP8_AMAX_MARGIN = getenv("FP8_AMAX_MARGIN", 1.1) +IMMEDIATE_SCALE = getenv("IMMEDIATE_SCALE", 0) def stochastic_round_bf16(x:Tensor) -> Tensor: bits = x.bitcast(dtypes.uint32) @@ -90,6 +91,13 @@ class GradAccClipAdamW(Optimizer): return out.shard_like(t) if offloaded else out if t.dtype in dtypes.fp8s: from examples.mlperf.models.flat_llama import FP8_MAX + if IMMEDIATE_SCALE: + amax_axis = tuple(range(t._inv_scale.ndim, new_w.ndim)) + new_inv = ((new_w.float().abs().max(axis=amax_axis).detach() + 1e-8) / FP8_MAX).cast(t._inv_scale.dtype) + t._inv_scale.assign(new_inv.shard_like(t._inv_scale) if offloaded else new_inv) + scale = new_inv.reciprocal().reshape(*new_inv.shape, *([1]*(new_w.ndim-new_inv.ndim))) + ret = (new_w * scale).clamp(-FP8_MAX, FP8_MAX).cast(t.dtype) + return ret.shard_like(t) if offloaded else ret # delayed scaling: reuse previous step's inv_scale t._inv_scale.assign(t._next_inv_scale) inv_scale = t._inv_scale.to(new_w.device) if offloaded else t._inv_scale