From c23652e4865648c8b2b76ff63e16d52feb6609de Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Fri, 29 May 2026 21:00:37 -0400 Subject: [PATCH] llama: minimize peak init mem (#16440) --- examples/mlperf/model_train.py | 5 +---- examples/mlperf/models/flat_llama.py | 27 ++++++++++++--------------- 2 files changed, 13 insertions(+), 19 deletions(-) diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index fbe0e7ffb0..f294bb3071 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -1419,10 +1419,7 @@ def train_llama3(): for p in optim.params: grad_dtype = dtypes.bfloat16 if p.dtype == FP8_DTYPE else p.dtype - if isinstance(p.device, tuple) and p.uop.axis is not None: - p.grad = Tensor.zeros(p.shape, dtype=grad_dtype, device=p.device[0]).shard_(p.device, axis=p.uop.axis).contiguous() - else: - p.grad = Tensor.zeros(p.shape, dtype=grad_dtype, device=p.device).contiguous() + p.grad = p.zeros_like(dtype=grad_dtype).contiguous() grads = [p.grad for p in optim.params] scheduler = CosineAnnealingLRWithWarmup(optim, opt_base_learning_rate, opt_end_learning_rate, opt_learning_rate_warmup_steps, opt_learning_rate_decay_steps) diff --git a/examples/mlperf/models/flat_llama.py b/examples/mlperf/models/flat_llama.py index d8f032c6a6..12240d068c 100644 --- a/examples/mlperf/models/flat_llama.py +++ b/examples/mlperf/models/flat_llama.py @@ -222,14 +222,19 @@ class FlatTransformer: for v in get_parameters(self): v.shard_(device, axis=None) else: # flat per-layer weights: axis 0 is n_layers, so shard axes are +1 vs per-layer Transformer - self.wqkv.shard_(device, axis=1).realize() # (n_layers, out, dim) shard out - self.wo.shard_(device, axis=2).realize() # (n_layers, dim, in) shard in + def _shard_fp8(name:str, axis:int): + getattr(self, name).shard_(device, axis=axis) + self._fp8_inv_scale[name] = self._fp8_inv_scale[name].to(device).contiguous().is_param_(False) + self._fp8_next_inv_scale[name] = self._fp8_next_inv_scale[name].to(device).contiguous().is_param_(False) + Tensor.realize(getattr(self, name), self._fp8_inv_scale[name], self._fp8_next_inv_scale[name]) + _shard_fp8("wqkv", 1) # (n_layers, out, dim) shard out + _shard_fp8("wo", 2) # (n_layers, dim, in) shard in if SPLIT_W13: - self.w1.shard_(device, axis=1).realize() - self.w3.shard_(device, axis=1).realize() + _shard_fp8("w1", 1) + _shard_fp8("w3", 1) else: - self.w13.shard_(device, axis=1).realize() # (n_layers, hidden*2, dim) shard out - self.w2.shard_(device, axis=2).realize() # (n_layers, dim, hidden) shard in + _shard_fp8("w13", 1) # (n_layers, hidden*2, dim) shard out + _shard_fp8("w2", 2) # (n_layers, dim, hidden) shard in self.attention_norm.shard_(device, axis=None).realize() self.ffn_norm.shard_(device, axis=None).realize() self.norm.weight.shard_(device, axis=None).realize() @@ -240,10 +245,6 @@ class FlatTransformer: for name in amax_dict: for i in range(len(amax_dict[name])): amax_dict[name][i] = amax_dict[name][i].to(device).contiguous().is_param_(False) - for name in self._fp8_inv_scale: - self._fp8_inv_scale[name] = self._fp8_inv_scale[name].to(device).contiguous().is_param_(False) - for name in self._fp8_next_inv_scale: - self._fp8_next_inv_scale[name] = self._fp8_next_inv_scale[name].to(device).contiguous().is_param_(False) def __call__(self, tokens:Tensor, save:bool=True): h = self.tok_embeddings(tokens) @@ -325,11 +326,7 @@ if __name__ == "__main__": # preallocate all the grad buffers and zero them out grad_dtype = lambda x: dtypes.bfloat16 if x.dtype in dtypes.fp8s else x.dtype - def _make_grad(x): - if isinstance(x.device, tuple) and x.uop.axis is not None: - return Tensor.zeros(x.shape, dtype=grad_dtype(x), device=x.device[0]).shard_(x.device, axis=x.uop.axis).contiguous() - return Tensor.zeros(x.shape, dtype=grad_dtype(x), device=x.device).contiguous() - grads = {x:_make_grad(x) for x in state.values() if x.is_param} + grads = {x:x.zeros_like(dtype=grad_dtype(x)).contiguous() for x in state.values() if x.is_param} fp8_amax = [t for ts in model._fp8_amax.values() for t in ts] fp8_grad_amax = [t for ts in model._fp8_grad_amax.values() for t in ts]