mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-11 23:46:02 +08:00
llama: only save when small (#16208)
This commit is contained in:
@@ -1458,7 +1458,7 @@ def train_llama3():
|
||||
if is_dp: tokens = tokens.to(None).shard(device, 0)
|
||||
if is_mp: tokens = tokens.shard(device)
|
||||
if not is_sharding: tokens = tokens.to(None)
|
||||
logits:Tensor = model(tokens[:, :-1])
|
||||
logits:Tensor = model(tokens[:, :-1], save=bool(SMALL))
|
||||
if getenv("FAST_CE", 0):
|
||||
from extra.llama_kernels.fused_ce import fused_ce_loss
|
||||
loss = fused_ce_loss(logits.cast(dtypes.bfloat16), tokens[:, 1:], label_smoothing=0.0)
|
||||
|
||||
@@ -208,11 +208,12 @@ class FlatTransformer:
|
||||
return out, h, amaxs, saves
|
||||
|
||||
@function(precompile=True, precompile_backward=True)
|
||||
def run_layer(self, x:Tensor, freqs_cis:Tensor, attn_kwargs:dict, ffn_kwargs:dict):
|
||||
def run_layer(self, x:Tensor, freqs_cis:Tensor, attn_kwargs:dict, ffn_kwargs:dict, save:bool=True):
|
||||
attn, attn_amaxs, attn_saves = self.attention(x, freqs_cis, **attn_kwargs)
|
||||
ffn, h, ffn_amaxs, ffn_saves = self.feed_forward(x, attn, **ffn_kwargs)
|
||||
h = h + ffn
|
||||
return (h, *attn_amaxs, *ffn_amaxs, *attn_saves, *ffn_saves)
|
||||
if save: return (h, *attn_amaxs, *ffn_amaxs, *attn_saves, *ffn_saves)
|
||||
else: return (h, *attn_amaxs, *ffn_amaxs)
|
||||
|
||||
def shard(self, device:tuple[str, ...], mp:bool=False):
|
||||
from tinygrad.nn.state import get_parameters
|
||||
@@ -241,7 +242,7 @@ class FlatTransformer:
|
||||
for name in self._fp8_inv_scale:
|
||||
self._fp8_inv_scale[name] = self._fp8_inv_scale[name].to(device).contiguous().requires_grad_(False)
|
||||
|
||||
def __call__(self, tokens:Tensor):
|
||||
def __call__(self, tokens:Tensor, save:bool=True):
|
||||
h = self.tok_embeddings(tokens)
|
||||
freqs_cis = self.freqs_cis.cast(h.dtype)[:, :tokens.shape[1], :, :, :]
|
||||
a, ga, s = self._fp8_amax, self._fp8_grad_amax, self._fp8_inv_scale
|
||||
@@ -256,7 +257,7 @@ class FlatTransformer:
|
||||
s_1=s["w1"][i], s_3=s["w3"][i], grad_amax_xw1=ga["xw1"][i], grad_amax_xw3=ga["xw3"][i])
|
||||
else:
|
||||
ffn_kwargs.update(w13=self.w13[i], amax_x13=a["x13"][i], s_13=s["w13"][i], grad_amax_xw13=ga["xw13"][i])
|
||||
h, *ret = self.run_layer(h, freqs_cis, attn_kwargs, ffn_kwargs)
|
||||
h, *ret = self.run_layer(h, freqs_cis, attn_kwargs, ffn_kwargs, save=save)
|
||||
amax_names = ["xqkv", "xo"] + (["x1", "x3"] if SPLIT_W13 else ["x13"]) + ["x2"]
|
||||
for name, new_val in zip(amax_names, ret[:len(amax_names)]):
|
||||
a[name][i].assign(new_val)
|
||||
|
||||
Reference in New Issue
Block a user