From aab50d1bca9992a9b6a5c5316ca00a1329f7f145 Mon Sep 17 00:00:00 2001 From: b1tg <33436708+b1tg@users.noreply.github.com> Date: Fri, 24 Apr 2026 12:32:10 +0800 Subject: [PATCH] llm: dedup MLA cache_v (#15887) --- tinygrad/llm/model.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tinygrad/llm/model.py b/tinygrad/llm/model.py index 7037bff09b..e989c9c578 100644 --- a/tinygrad/llm/model.py +++ b/tinygrad/llm/model.py @@ -219,9 +219,8 @@ class MLATransformerBlock(FFNBlock): self.freqs_cis[start_pos:start_pos+T]) k_store = c_kv.reshape(B, 1, T, self.config.kv_lora_rank).cat(k_rope.reshape(B, 1, T, self.config.rope_dim), dim=-1) - v_store = c_kv.reshape(B, 1, T, self.config.kv_lora_rank) k = Tensor(self.cache_k.uop.after(self.cache_k[:, :, start_pos:start_pos+T, :].uop.store(k_store.uop)))[:, :, 0:start_pos+T, :] - v = Tensor(self.cache_v.uop.after(self.cache_v[:, :, start_pos:start_pos+T, :].uop.store(v_store.uop)))[:, :, 0:start_pos+T, :] + v = k[..., :self.config.kv_lora_rank] mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1) if resolve(T != 1) else None attn = q @ k.transpose(-1, -2) * (1.0 / self.config.head_dim ** 0.5) @@ -233,7 +232,6 @@ class MLATransformerBlock(FFNBlock): def _init_state(self, x:Tensor): if not hasattr(self, "cache_k"): self.cache_k = Tensor.empty(x.shape[0], 1, self.config.max_context, self.config.kv_lora_rank + self.config.rope_dim, device=x.device) - self.cache_v = Tensor.empty(x.shape[0], 1, self.config.max_context, self.config.kv_lora_rank, device=x.device) self.freqs_cis = precompute_freqs_cis(self.config.rope_dim, self.config.max_context, self.config.rope_theta) class GatedDeltaNetBlock(FFNBlock):