mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-08 05:54:59 +08:00
llama: immediate scaling on flag (#16494)
This commit is contained in:
@@ -7,6 +7,7 @@ from tinygrad.uop.ops import UOp, Ops
|
||||
STOCHASTIC_ROUND = getenv("STOCHASTIC_ROUND", 0)
|
||||
MASTER_WEIGHTS = getenv("MASTER_WEIGHTS", 0)
|
||||
FP8_AMAX_MARGIN = getenv("FP8_AMAX_MARGIN", 1.1)
|
||||
IMMEDIATE_SCALE = getenv("IMMEDIATE_SCALE", 0)
|
||||
|
||||
def stochastic_round_bf16(x:Tensor) -> Tensor:
|
||||
bits = x.bitcast(dtypes.uint32)
|
||||
@@ -90,6 +91,13 @@ class GradAccClipAdamW(Optimizer):
|
||||
return out.shard_like(t) if offloaded else out
|
||||
if t.dtype in dtypes.fp8s:
|
||||
from examples.mlperf.models.flat_llama import FP8_MAX
|
||||
if IMMEDIATE_SCALE:
|
||||
amax_axis = tuple(range(t._inv_scale.ndim, new_w.ndim))
|
||||
new_inv = ((new_w.float().abs().max(axis=amax_axis).detach() + 1e-8) / FP8_MAX).cast(t._inv_scale.dtype)
|
||||
t._inv_scale.assign(new_inv.shard_like(t._inv_scale) if offloaded else new_inv)
|
||||
scale = new_inv.reciprocal().reshape(*new_inv.shape, *([1]*(new_w.ndim-new_inv.ndim)))
|
||||
ret = (new_w * scale).clamp(-FP8_MAX, FP8_MAX).cast(t.dtype)
|
||||
return ret.shard_like(t) if offloaded else ret
|
||||
# delayed scaling: reuse previous step's inv_scale
|
||||
t._inv_scale.assign(t._next_inv_scale)
|
||||
inv_scale = t._inv_scale.to(new_w.device) if offloaded else t._inv_scale
|
||||
|
||||
Reference in New Issue
Block a user