diff --git a/tinygrad/llm/model.py b/tinygrad/llm/model.py index 7037bff09b..e989c9c578 100644 --- a/tinygrad/llm/model.py +++ b/tinygrad/llm/model.py @@ -219,9 +219,8 @@ class MLATransformerBlock(FFNBlock): self.freqs_cis[start_pos:start_pos+T]) k_store = c_kv.reshape(B, 1, T, self.config.kv_lora_rank).cat(k_rope.reshape(B, 1, T, self.config.rope_dim), dim=-1) - v_store = c_kv.reshape(B, 1, T, self.config.kv_lora_rank) k = Tensor(self.cache_k.uop.after(self.cache_k[:, :, start_pos:start_pos+T, :].uop.store(k_store.uop)))[:, :, 0:start_pos+T, :] - v = Tensor(self.cache_v.uop.after(self.cache_v[:, :, start_pos:start_pos+T, :].uop.store(v_store.uop)))[:, :, 0:start_pos+T, :] + v = k[..., :self.config.kv_lora_rank] mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1) if resolve(T != 1) else None attn = q @ k.transpose(-1, -2) * (1.0 / self.config.head_dim ** 0.5) @@ -233,7 +232,6 @@ class MLATransformerBlock(FFNBlock): def _init_state(self, x:Tensor): if not hasattr(self, "cache_k"): self.cache_k = Tensor.empty(x.shape[0], 1, self.config.max_context, self.config.kv_lora_rank + self.config.rope_dim, device=x.device) - self.cache_v = Tensor.empty(x.shape[0], 1, self.config.max_context, self.config.kv_lora_rank, device=x.device) self.freqs_cis = precompute_freqs_cis(self.config.rope_dim, self.config.max_context, self.config.rope_theta) class GatedDeltaNetBlock(FFNBlock):