llama: use mlperf model (#15257)

This commit is contained in:
wozeparrot
2026-03-13 23:08:32 +08:00
committed by GitHub
parent 4b59083d7c
commit a191ac0566
2 changed files with 84 additions and 4 deletions

View File

@@ -1282,7 +1282,7 @@ def train_bert():
previous_step = i
def train_llama3():
from extra.models.llama import Transformer
from examples.mlperf.models.llama import Transformer
from examples.llama3 import MODEL_PARAMS
from examples.mlperf.lr_schedulers import CosineAnnealingLRWithWarmup
from examples.mlperf.optim import GradAccClipAdamW
@@ -1343,7 +1343,7 @@ def train_llama3():
if (MP := getenv("MP", 1)) > 1: model_params['vocab_size'] = round_up(model_params['vocab_size'], 256 * MP)
vocab_mask:Tensor = Tensor.arange(model_params['vocab_size']).reshape(1, 1, -1) >= real_vocab_size
model = Transformer(**model_params, max_context=SEQLEN, jit=False, disable_kv_cache=True)
model = Transformer(**model_params, max_context=SEQLEN)
params = get_parameters(model)
# weights are all bfloat16 for now
@@ -1417,7 +1417,7 @@ def train_llama3():
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(MP))
tokens = tokens.shard(device)
if DP == 1 and MP == 1: tokens = tokens.to(None)
logits:Tensor = model(tokens[:, :-1], start_pos=0, temperature=math.nan)
logits:Tensor = model(tokens[:, :-1])
loss = vocab_mask.where(-1e9, logits).sparse_categorical_crossentropy(tokens[:, 1:])
loss.backward()
assert all(p.grad is g for p,g in zip(optim.params, grads))
@@ -1449,7 +1449,7 @@ def train_llama3():
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(MP))
tokens = tokens.shard(device)
if DP == 1 and MP == 1: tokens = tokens.to(None)
logits:Tensor = model(tokens[:, :-1], start_pos=0, temperature=math.nan)
logits:Tensor = model(tokens[:, :-1])
loss = vocab_mask.where(-1e9, logits).sparse_categorical_crossentropy(tokens[:, 1:])
return loss.flatten().float().to("CPU")

View File

@@ -0,0 +1,80 @@
from tinygrad import Tensor, nn
from tinygrad.helpers import getenv
from extra.models.llama import apply_rotary_emb, precompute_freqs_cis
class Attention:
def __init__(self, dim:int, n_heads:int, n_kv_heads:int|None=None, linear=nn.Linear):
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1]
self.head_dim = dim // n_heads
self.n_rep = self.n_heads // self.n_kv_heads
if getenv("WQKV"):
self.wqkv = linear(dim, self.n_heads * self.head_dim + self.n_kv_heads * self.head_dim * 2, bias=False)
else:
self.wq = linear(dim, self.n_heads * self.head_dim, bias=False)
self.wk = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = linear(self.n_heads * self.head_dim, dim, bias=False)
def __call__(self, x:Tensor, freqs_cis:Tensor) -> Tensor:
if getenv("WQKV"):
xqkv = self.wqkv(x)
xqkv = xqkv.reshape(xqkv.shape[0], xqkv.shape[1], self.n_kv_heads, self.n_rep + 2, self.head_dim)
xq = xqkv[:, :, :, :self.n_rep].reshape(xqkv.shape[0], xqkv.shape[1], -1)
xk = xqkv[:, :, :, self.n_rep:self.n_rep+1].reshape(xqkv.shape[0], xqkv.shape[1], -1)
xv = xqkv[:, :, :, self.n_rep+1:self.n_rep+2].reshape(xqkv.shape[0], xqkv.shape[1], -1)
else:
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim)
xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim)
xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
bsz, seqlen, _, _ = xq.shape
xq, xk, xv = xq.transpose(1, 2), xk.transpose(1, 2), xv.transpose(1, 2)
attn = xq.scaled_dot_product_attention(xk, xv, is_causal=True, enable_gqa=True).transpose(1, 2)
attn = attn.reshape(bsz, seqlen, -1)
return self.wo(attn)
class FeedForward:
def __init__(self, dim:int, hidden_dim:int, linear=nn.Linear):
self.w1 = linear(dim, hidden_dim, bias=False)
self.w2 = linear(hidden_dim, dim, bias=False)
self.w3 = linear(dim, hidden_dim, bias=False) # the gate in Gated Linear Unit
def __call__(self, x:Tensor) -> Tensor:
w1 = self.w1(x).silu()
w3 = self.w3(x)
return self.w2(w1 * w3)
class TransformerBlock:
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int|None, norm_eps:float, linear=nn.Linear):
self.attention = Attention(dim, n_heads, n_kv_heads, linear)
self.feed_forward = FeedForward(dim, hidden_dim, linear)
self.attention_norm = nn.RMSNorm(dim, norm_eps)
self.ffn_norm = nn.RMSNorm(dim, norm_eps)
def __call__(self, x:Tensor, freqs_cis:Tensor):
h = x + self.attention(self.attention_norm(x), freqs_cis)
return h + self.feed_forward(self.ffn_norm(h))
class Transformer:
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size:int, n_kv_heads:int|None=None,
rope_theta:int=10000, max_context:int=1024, linear=nn.Linear, embedding=nn.Embedding):
self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, linear) for _ in range(n_layers)]
self.norm = nn.RMSNorm(dim, norm_eps)
self.tok_embeddings = embedding(vocab_size, dim)
self.output = nn.Linear(dim, vocab_size, bias=False) if embedding == nn.Embedding else linear(dim, vocab_size, bias=False)
self.freqs_cis = precompute_freqs_cis(dim // n_heads, max_context * 2, rope_theta).contiguous().requires_grad_(False)
def __call__(self, tokens:Tensor):
h = self.tok_embeddings(tokens)
freqs_cis = self.freqs_cis.cast(h.dtype)[:, :tokens.shape[1], :, :, :]
for layer in self.layers: h = layer(h, freqs_cis)
logits = self.output(self.norm(h))
return logits