From a191ac0566e5e46eddd0fe83bf954b73d60464ff Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Fri, 13 Mar 2026 23:08:32 +0800 Subject: [PATCH] llama: use mlperf model (#15257) --- examples/mlperf/model_train.py | 8 ++-- examples/mlperf/models/llama.py | 80 +++++++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+), 4 deletions(-) create mode 100644 examples/mlperf/models/llama.py diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 4822361cff..3817ff8c64 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -1282,7 +1282,7 @@ def train_bert(): previous_step = i def train_llama3(): - from extra.models.llama import Transformer + from examples.mlperf.models.llama import Transformer from examples.llama3 import MODEL_PARAMS from examples.mlperf.lr_schedulers import CosineAnnealingLRWithWarmup from examples.mlperf.optim import GradAccClipAdamW @@ -1343,7 +1343,7 @@ def train_llama3(): if (MP := getenv("MP", 1)) > 1: model_params['vocab_size'] = round_up(model_params['vocab_size'], 256 * MP) vocab_mask:Tensor = Tensor.arange(model_params['vocab_size']).reshape(1, 1, -1) >= real_vocab_size - model = Transformer(**model_params, max_context=SEQLEN, jit=False, disable_kv_cache=True) + model = Transformer(**model_params, max_context=SEQLEN) params = get_parameters(model) # weights are all bfloat16 for now @@ -1417,7 +1417,7 @@ def train_llama3(): device = tuple(f"{Device.DEFAULT}:{i}" for i in range(MP)) tokens = tokens.shard(device) if DP == 1 and MP == 1: tokens = tokens.to(None) - logits:Tensor = model(tokens[:, :-1], start_pos=0, temperature=math.nan) + logits:Tensor = model(tokens[:, :-1]) loss = vocab_mask.where(-1e9, logits).sparse_categorical_crossentropy(tokens[:, 1:]) loss.backward() assert all(p.grad is g for p,g in zip(optim.params, grads)) @@ -1449,7 +1449,7 @@ def train_llama3(): device = tuple(f"{Device.DEFAULT}:{i}" for i in range(MP)) tokens = tokens.shard(device) if DP == 1 and MP == 1: tokens = tokens.to(None) - logits:Tensor = model(tokens[:, :-1], start_pos=0, temperature=math.nan) + logits:Tensor = model(tokens[:, :-1]) loss = vocab_mask.where(-1e9, logits).sparse_categorical_crossentropy(tokens[:, 1:]) return loss.flatten().float().to("CPU") diff --git a/examples/mlperf/models/llama.py b/examples/mlperf/models/llama.py new file mode 100644 index 0000000000..0ae17fd93c --- /dev/null +++ b/examples/mlperf/models/llama.py @@ -0,0 +1,80 @@ +from tinygrad import Tensor, nn +from tinygrad.helpers import getenv +from extra.models.llama import apply_rotary_emb, precompute_freqs_cis + +class Attention: + def __init__(self, dim:int, n_heads:int, n_kv_heads:int|None=None, linear=nn.Linear): + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads # n_kv_heads != n_heads implies MQA [arxiv/2307.09288, A.2.1] + self.head_dim = dim // n_heads + self.n_rep = self.n_heads // self.n_kv_heads + + if getenv("WQKV"): + self.wqkv = linear(dim, self.n_heads * self.head_dim + self.n_kv_heads * self.head_dim * 2, bias=False) + else: + self.wq = linear(dim, self.n_heads * self.head_dim, bias=False) + self.wk = linear(dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = linear(dim, self.n_kv_heads * self.head_dim, bias=False) + + self.wo = linear(self.n_heads * self.head_dim, dim, bias=False) + + def __call__(self, x:Tensor, freqs_cis:Tensor) -> Tensor: + if getenv("WQKV"): + xqkv = self.wqkv(x) + xqkv = xqkv.reshape(xqkv.shape[0], xqkv.shape[1], self.n_kv_heads, self.n_rep + 2, self.head_dim) + xq = xqkv[:, :, :, :self.n_rep].reshape(xqkv.shape[0], xqkv.shape[1], -1) + xk = xqkv[:, :, :, self.n_rep:self.n_rep+1].reshape(xqkv.shape[0], xqkv.shape[1], -1) + xv = xqkv[:, :, :, self.n_rep+1:self.n_rep+2].reshape(xqkv.shape[0], xqkv.shape[1], -1) + else: + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + xq = xq.reshape(xq.shape[0], xq.shape[1], self.n_heads, self.head_dim) + xk = xk.reshape(xk.shape[0], xk.shape[1], self.n_kv_heads, self.head_dim) + xv = xv.reshape(xv.shape[0], xv.shape[1], self.n_kv_heads, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, freqs_cis) + bsz, seqlen, _, _ = xq.shape + + xq, xk, xv = xq.transpose(1, 2), xk.transpose(1, 2), xv.transpose(1, 2) + attn = xq.scaled_dot_product_attention(xk, xv, is_causal=True, enable_gqa=True).transpose(1, 2) + + attn = attn.reshape(bsz, seqlen, -1) + return self.wo(attn) + +class FeedForward: + def __init__(self, dim:int, hidden_dim:int, linear=nn.Linear): + self.w1 = linear(dim, hidden_dim, bias=False) + self.w2 = linear(hidden_dim, dim, bias=False) + self.w3 = linear(dim, hidden_dim, bias=False) # the gate in Gated Linear Unit + + def __call__(self, x:Tensor) -> Tensor: + w1 = self.w1(x).silu() + w3 = self.w3(x) + return self.w2(w1 * w3) + +class TransformerBlock: + def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int|None, norm_eps:float, linear=nn.Linear): + self.attention = Attention(dim, n_heads, n_kv_heads, linear) + self.feed_forward = FeedForward(dim, hidden_dim, linear) + self.attention_norm = nn.RMSNorm(dim, norm_eps) + self.ffn_norm = nn.RMSNorm(dim, norm_eps) + + def __call__(self, x:Tensor, freqs_cis:Tensor): + h = x + self.attention(self.attention_norm(x), freqs_cis) + return h + self.feed_forward(self.ffn_norm(h)) + +class Transformer: + def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_layers:int, norm_eps:float, vocab_size:int, n_kv_heads:int|None=None, + rope_theta:int=10000, max_context:int=1024, linear=nn.Linear, embedding=nn.Embedding): + self.layers = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, linear) for _ in range(n_layers)] + self.norm = nn.RMSNorm(dim, norm_eps) + self.tok_embeddings = embedding(vocab_size, dim) + self.output = nn.Linear(dim, vocab_size, bias=False) if embedding == nn.Embedding else linear(dim, vocab_size, bias=False) + self.freqs_cis = precompute_freqs_cis(dim // n_heads, max_context * 2, rope_theta).contiguous().requires_grad_(False) + + def __call__(self, tokens:Tensor): + h = self.tok_embeddings(tokens) + freqs_cis = self.freqs_cis.cast(h.dtype)[:, :tokens.shape[1], :, :, :] + for layer in self.layers: h = layer(h, freqs_cis) + logits = self.output(self.norm(h)) + return logits