Files
tinygrad/test/unit/test_llm_mla.py
George Hotz a9b6cfece0 refactor llm into files (#15780)
* refactor llm into files

* chat.html

* tokenizer cleanup

* cleanup

* tests
2026-04-17 12:33:11 +08:00

72 lines
3.4 KiB
Python

import unittest
import numpy as np
from tinygrad import Tensor
from tinygrad.llm.model import Transformer, TransformerConfig, apply_rope, MLATransformerBlock, precompute_freqs_cis
class TestMLA(unittest.TestCase):
def _make_config(self, **kwargs):
return TransformerConfig(**{
"num_blocks": 1, "dim": 64, "hidden_dim": 128, "n_heads": 4, "n_kv_heads": 1,
"norm_eps": 1e-5, "vocab_size": 100, "head_dim": 16, "rope_theta": 10000.0, "rope_dim": 8, "max_context": 32,
"kv_lora_rank": 16, "v_head_dim": 8,
} | kwargs)
def test_mla_attention_matches_naive(self):
config = self._make_config(max_context=16)
block = MLATransformerBlock(config)
c = config
B, T = 1, 4
q_nope_head_dim = c.head_dim - c.rope_dim
x = Tensor.randn(B, T, c.dim)
x_norm = block.attn_norm(x)
# --- Our absorbed implementation ---
q = block.attn_q(x_norm).reshape(B, T, c.n_heads, c.head_dim).transpose(1, 2)
q_nope, q_rope = q[..., :q_nope_head_dim], q[..., q_nope_head_dim:]
freqs = precompute_freqs_cis(c.rope_dim, 16, c.rope_theta)
q_rope = apply_rope(q_rope, freqs[0:T])
kv_a = block.attn_kv_a_mqa(x_norm)
c_kv = block.attn_kv_a_norm(kv_a[..., :c.kv_lora_rank])
k_rope = kv_a[..., c.kv_lora_rank:].reshape(B, T, 1, c.rope_dim).transpose(1, 2)
k_rope = apply_rope(k_rope, freqs[0:T])
# --- Naive (non-absorbed): expand K and V, do standard attention ---
k_nope_naive = c_kv.unsqueeze(1) @ block.attn_k_b["weight"] # (B, H, T, nope)
k_naive = k_nope_naive.cat(k_rope.expand(-1, c.n_heads, -1, -1), dim=-1) # (B, H, T, nope+rope)
v_naive = c_kv.unsqueeze(1) @ block.attn_v_b["weight"].transpose(-1, -2) # (B, H, T, v_dim)
q_naive = q_nope.cat(q_rope, dim=-1)
scale = 1.0 / c.head_dim ** 0.5
scores_naive = (q_naive @ k_naive.transpose(-1, -2)) * scale
# causal mask
mask = Tensor.full((1, 1, T, T), float("-inf")).triu(1)
attn_naive = (scores_naive + mask).softmax(-1) @ v_naive # (B, H, T, v_dim)
out_naive = block.attn_output(attn_naive.transpose(1, 2).reshape(B, T, -1))
# --- Absorbed: q_nope @ wk_b^T, then dot with compressed kv ---
q_nope_abs = q_nope @ block.attn_k_b["weight"].transpose(-1, -2) # (B, H, T, lora)
q_abs = q_nope_abs.cat(q_rope, dim=-1) # (B, H, T, lora+rope)
k_abs = c_kv.reshape(B, 1, T, c.kv_lora_rank).cat(k_rope.reshape(B, 1, T, c.rope_dim), dim=-1)
scores_abs = (q_abs @ k_abs.transpose(-1, -2)) * scale
attn_abs = (scores_abs + mask).softmax(-1)
# attn @ v_compressed @ wv_b
v_compressed = c_kv.reshape(B, 1, T, c.kv_lora_rank)
attn_abs_out = (attn_abs @ v_compressed) @ block.attn_v_b["weight"].transpose(-1, -2)
out_abs = block.attn_output(attn_abs_out.transpose(1, 2).reshape(B, T, -1))
# Compare
naive_np = out_naive.realize().numpy()
abs_np = out_abs.realize().numpy()
np.testing.assert_allclose(naive_np, abs_np, atol=1e-4, rtol=1e-4,
err_msg="Absorbed MLA should match naive MLA")
def test_shared_expert_gate_optional(self):
from tinygrad import nn
model = Transformer(self._make_config(num_experts=4, num_experts_per_tok=2, shared_expert_dim=32, shared_expert_gate=False))
self.assertNotIn('blk.0.ffn_gate_inp_shexp.weight', nn.state.get_state_dict(model))
out = model.blk[0]._feed_forward(Tensor.randn(1, 4, model.blk[0].config.dim))
self.assertEqual(out.shape, (1, 4, model.blk[0].config.dim))