llama: minimize peak init mem (#16440)

This commit is contained in:
wozeparrot
2026-05-29 21:00:37 -04:00
committed by GitHub
parent d943493b79
commit c23652e486
2 changed files with 13 additions and 19 deletions

View File

@@ -1419,10 +1419,7 @@ def train_llama3():
for p in optim.params:
grad_dtype = dtypes.bfloat16 if p.dtype == FP8_DTYPE else p.dtype
if isinstance(p.device, tuple) and p.uop.axis is not None:
p.grad = Tensor.zeros(p.shape, dtype=grad_dtype, device=p.device[0]).shard_(p.device, axis=p.uop.axis).contiguous()
else:
p.grad = Tensor.zeros(p.shape, dtype=grad_dtype, device=p.device).contiguous()
p.grad = p.zeros_like(dtype=grad_dtype).contiguous()
grads = [p.grad for p in optim.params]
scheduler = CosineAnnealingLRWithWarmup(optim, opt_base_learning_rate, opt_end_learning_rate, opt_learning_rate_warmup_steps, opt_learning_rate_decay_steps)

View File

@@ -222,14 +222,19 @@ class FlatTransformer:
for v in get_parameters(self): v.shard_(device, axis=None)
else:
# flat per-layer weights: axis 0 is n_layers, so shard axes are +1 vs per-layer Transformer
self.wqkv.shard_(device, axis=1).realize() # (n_layers, out, dim) shard out
self.wo.shard_(device, axis=2).realize() # (n_layers, dim, in) shard in
def _shard_fp8(name:str, axis:int):
getattr(self, name).shard_(device, axis=axis)
self._fp8_inv_scale[name] = self._fp8_inv_scale[name].to(device).contiguous().is_param_(False)
self._fp8_next_inv_scale[name] = self._fp8_next_inv_scale[name].to(device).contiguous().is_param_(False)
Tensor.realize(getattr(self, name), self._fp8_inv_scale[name], self._fp8_next_inv_scale[name])
_shard_fp8("wqkv", 1) # (n_layers, out, dim) shard out
_shard_fp8("wo", 2) # (n_layers, dim, in) shard in
if SPLIT_W13:
self.w1.shard_(device, axis=1).realize()
self.w3.shard_(device, axis=1).realize()
_shard_fp8("w1", 1)
_shard_fp8("w3", 1)
else:
self.w13.shard_(device, axis=1).realize() # (n_layers, hidden*2, dim) shard out
self.w2.shard_(device, axis=2).realize() # (n_layers, dim, hidden) shard in
_shard_fp8("w13", 1) # (n_layers, hidden*2, dim) shard out
_shard_fp8("w2", 2) # (n_layers, dim, hidden) shard in
self.attention_norm.shard_(device, axis=None).realize()
self.ffn_norm.shard_(device, axis=None).realize()
self.norm.weight.shard_(device, axis=None).realize()
@@ -240,10 +245,6 @@ class FlatTransformer:
for name in amax_dict:
for i in range(len(amax_dict[name])):
amax_dict[name][i] = amax_dict[name][i].to(device).contiguous().is_param_(False)
for name in self._fp8_inv_scale:
self._fp8_inv_scale[name] = self._fp8_inv_scale[name].to(device).contiguous().is_param_(False)
for name in self._fp8_next_inv_scale:
self._fp8_next_inv_scale[name] = self._fp8_next_inv_scale[name].to(device).contiguous().is_param_(False)
def __call__(self, tokens:Tensor, save:bool=True):
h = self.tok_embeddings(tokens)
@@ -325,11 +326,7 @@ if __name__ == "__main__":
# preallocate all the grad buffers and zero them out
grad_dtype = lambda x: dtypes.bfloat16 if x.dtype in dtypes.fp8s else x.dtype
def _make_grad(x):
if isinstance(x.device, tuple) and x.uop.axis is not None:
return Tensor.zeros(x.shape, dtype=grad_dtype(x), device=x.device[0]).shard_(x.device, axis=x.uop.axis).contiguous()
return Tensor.zeros(x.shape, dtype=grad_dtype(x), device=x.device).contiguous()
grads = {x:_make_grad(x) for x in state.values() if x.is_param}
grads = {x:x.zeros_like(dtype=grad_dtype(x)).contiguous() for x in state.values() if x.is_param}
fp8_amax = [t for ts in model._fp8_amax.values() for t in ts]
fp8_grad_amax = [t for ts in model._fp8_grad_amax.values() for t in ts]