From b4d267dfd4187801535d1cf2646954fa83734a3d Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Thu, 14 May 2026 20:46:29 -0400 Subject: [PATCH] llama: only save when small (#16208) --- examples/mlperf/model_train.py | 2 +- examples/mlperf/models/flat_llama.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 8420a1727a..e944279155 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -1458,7 +1458,7 @@ def train_llama3(): if is_dp: tokens = tokens.to(None).shard(device, 0) if is_mp: tokens = tokens.shard(device) if not is_sharding: tokens = tokens.to(None) - logits:Tensor = model(tokens[:, :-1]) + logits:Tensor = model(tokens[:, :-1], save=bool(SMALL)) if getenv("FAST_CE", 0): from extra.llama_kernels.fused_ce import fused_ce_loss loss = fused_ce_loss(logits.cast(dtypes.bfloat16), tokens[:, 1:], label_smoothing=0.0) diff --git a/examples/mlperf/models/flat_llama.py b/examples/mlperf/models/flat_llama.py index 6a41c4aa78..50fa509ccb 100644 --- a/examples/mlperf/models/flat_llama.py +++ b/examples/mlperf/models/flat_llama.py @@ -208,11 +208,12 @@ class FlatTransformer: return out, h, amaxs, saves @function(precompile=True, precompile_backward=True) - def run_layer(self, x:Tensor, freqs_cis:Tensor, attn_kwargs:dict, ffn_kwargs:dict): + def run_layer(self, x:Tensor, freqs_cis:Tensor, attn_kwargs:dict, ffn_kwargs:dict, save:bool=True): attn, attn_amaxs, attn_saves = self.attention(x, freqs_cis, **attn_kwargs) ffn, h, ffn_amaxs, ffn_saves = self.feed_forward(x, attn, **ffn_kwargs) h = h + ffn - return (h, *attn_amaxs, *ffn_amaxs, *attn_saves, *ffn_saves) + if save: return (h, *attn_amaxs, *ffn_amaxs, *attn_saves, *ffn_saves) + else: return (h, *attn_amaxs, *ffn_amaxs) def shard(self, device:tuple[str, ...], mp:bool=False): from tinygrad.nn.state import get_parameters @@ -241,7 +242,7 @@ class FlatTransformer: for name in self._fp8_inv_scale: self._fp8_inv_scale[name] = self._fp8_inv_scale[name].to(device).contiguous().requires_grad_(False) - def __call__(self, tokens:Tensor): + def __call__(self, tokens:Tensor, save:bool=True): h = self.tok_embeddings(tokens) freqs_cis = self.freqs_cis.cast(h.dtype)[:, :tokens.shape[1], :, :, :] a, ga, s = self._fp8_amax, self._fp8_grad_amax, self._fp8_inv_scale @@ -256,7 +257,7 @@ class FlatTransformer: s_1=s["w1"][i], s_3=s["w3"][i], grad_amax_xw1=ga["xw1"][i], grad_amax_xw3=ga["xw3"][i]) else: ffn_kwargs.update(w13=self.w13[i], amax_x13=a["x13"][i], s_13=s["w13"][i], grad_amax_xw13=ga["xw13"][i]) - h, *ret = self.run_layer(h, freqs_cis, attn_kwargs, ffn_kwargs) + h, *ret = self.run_layer(h, freqs_cis, attn_kwargs, ffn_kwargs, save=save) amax_names = ["xqkv", "xo"] + (["x1", "x3"] if SPLIT_W13 else ["x13"]) + ["x2"] for name, new_val in zip(amax_names, ret[:len(amax_names)]): a[name][i].assign(new_val)