mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
llama: minimize peak init mem (#16440)
This commit is contained in:
@@ -1419,10 +1419,7 @@ def train_llama3():
|
||||
|
||||
for p in optim.params:
|
||||
grad_dtype = dtypes.bfloat16 if p.dtype == FP8_DTYPE else p.dtype
|
||||
if isinstance(p.device, tuple) and p.uop.axis is not None:
|
||||
p.grad = Tensor.zeros(p.shape, dtype=grad_dtype, device=p.device[0]).shard_(p.device, axis=p.uop.axis).contiguous()
|
||||
else:
|
||||
p.grad = Tensor.zeros(p.shape, dtype=grad_dtype, device=p.device).contiguous()
|
||||
p.grad = p.zeros_like(dtype=grad_dtype).contiguous()
|
||||
grads = [p.grad for p in optim.params]
|
||||
|
||||
scheduler = CosineAnnealingLRWithWarmup(optim, opt_base_learning_rate, opt_end_learning_rate, opt_learning_rate_warmup_steps, opt_learning_rate_decay_steps)
|
||||
|
||||
@@ -222,14 +222,19 @@ class FlatTransformer:
|
||||
for v in get_parameters(self): v.shard_(device, axis=None)
|
||||
else:
|
||||
# flat per-layer weights: axis 0 is n_layers, so shard axes are +1 vs per-layer Transformer
|
||||
self.wqkv.shard_(device, axis=1).realize() # (n_layers, out, dim) shard out
|
||||
self.wo.shard_(device, axis=2).realize() # (n_layers, dim, in) shard in
|
||||
def _shard_fp8(name:str, axis:int):
|
||||
getattr(self, name).shard_(device, axis=axis)
|
||||
self._fp8_inv_scale[name] = self._fp8_inv_scale[name].to(device).contiguous().is_param_(False)
|
||||
self._fp8_next_inv_scale[name] = self._fp8_next_inv_scale[name].to(device).contiguous().is_param_(False)
|
||||
Tensor.realize(getattr(self, name), self._fp8_inv_scale[name], self._fp8_next_inv_scale[name])
|
||||
_shard_fp8("wqkv", 1) # (n_layers, out, dim) shard out
|
||||
_shard_fp8("wo", 2) # (n_layers, dim, in) shard in
|
||||
if SPLIT_W13:
|
||||
self.w1.shard_(device, axis=1).realize()
|
||||
self.w3.shard_(device, axis=1).realize()
|
||||
_shard_fp8("w1", 1)
|
||||
_shard_fp8("w3", 1)
|
||||
else:
|
||||
self.w13.shard_(device, axis=1).realize() # (n_layers, hidden*2, dim) shard out
|
||||
self.w2.shard_(device, axis=2).realize() # (n_layers, dim, hidden) shard in
|
||||
_shard_fp8("w13", 1) # (n_layers, hidden*2, dim) shard out
|
||||
_shard_fp8("w2", 2) # (n_layers, dim, hidden) shard in
|
||||
self.attention_norm.shard_(device, axis=None).realize()
|
||||
self.ffn_norm.shard_(device, axis=None).realize()
|
||||
self.norm.weight.shard_(device, axis=None).realize()
|
||||
@@ -240,10 +245,6 @@ class FlatTransformer:
|
||||
for name in amax_dict:
|
||||
for i in range(len(amax_dict[name])):
|
||||
amax_dict[name][i] = amax_dict[name][i].to(device).contiguous().is_param_(False)
|
||||
for name in self._fp8_inv_scale:
|
||||
self._fp8_inv_scale[name] = self._fp8_inv_scale[name].to(device).contiguous().is_param_(False)
|
||||
for name in self._fp8_next_inv_scale:
|
||||
self._fp8_next_inv_scale[name] = self._fp8_next_inv_scale[name].to(device).contiguous().is_param_(False)
|
||||
|
||||
def __call__(self, tokens:Tensor, save:bool=True):
|
||||
h = self.tok_embeddings(tokens)
|
||||
@@ -325,11 +326,7 @@ if __name__ == "__main__":
|
||||
|
||||
# preallocate all the grad buffers and zero them out
|
||||
grad_dtype = lambda x: dtypes.bfloat16 if x.dtype in dtypes.fp8s else x.dtype
|
||||
def _make_grad(x):
|
||||
if isinstance(x.device, tuple) and x.uop.axis is not None:
|
||||
return Tensor.zeros(x.shape, dtype=grad_dtype(x), device=x.device[0]).shard_(x.device, axis=x.uop.axis).contiguous()
|
||||
return Tensor.zeros(x.shape, dtype=grad_dtype(x), device=x.device).contiguous()
|
||||
grads = {x:_make_grad(x) for x in state.values() if x.is_param}
|
||||
grads = {x:x.zeros_like(dtype=grad_dtype(x)).contiguous() for x in state.values() if x.is_param}
|
||||
|
||||
fp8_amax = [t for ts in model._fp8_amax.values() for t in ts]
|
||||
fp8_grad_amax = [t for ts in model._fp8_grad_amax.values() for t in ts]
|
||||
|
||||
Reference in New Issue
Block a user