llm: dedup MLA cache_v (#15887)

This commit is contained in:
b1tg
2026-04-24 12:32:10 +08:00
committed by GitHub
parent f379b5a40a
commit aab50d1bca

View File

@@ -219,9 +219,8 @@ class MLATransformerBlock(FFNBlock):
self.freqs_cis[start_pos:start_pos+T])
k_store = c_kv.reshape(B, 1, T, self.config.kv_lora_rank).cat(k_rope.reshape(B, 1, T, self.config.rope_dim), dim=-1)
v_store = c_kv.reshape(B, 1, T, self.config.kv_lora_rank)
k = Tensor(self.cache_k.uop.after(self.cache_k[:, :, start_pos:start_pos+T, :].uop.store(k_store.uop)))[:, :, 0:start_pos+T, :]
v = Tensor(self.cache_v.uop.after(self.cache_v[:, :, start_pos:start_pos+T, :].uop.store(v_store.uop)))[:, :, 0:start_pos+T, :]
v = k[..., :self.config.kv_lora_rank]
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1) if resolve(T != 1) else None
attn = q @ k.transpose(-1, -2) * (1.0 / self.config.head_dim ** 0.5)
@@ -233,7 +232,6 @@ class MLATransformerBlock(FFNBlock):
def _init_state(self, x:Tensor):
if not hasattr(self, "cache_k"):
self.cache_k = Tensor.empty(x.shape[0], 1, self.config.max_context, self.config.kv_lora_rank + self.config.rope_dim, device=x.device)
self.cache_v = Tensor.empty(x.shape[0], 1, self.config.max_context, self.config.kv_lora_rank, device=x.device)
self.freqs_cis = precompute_freqs_cis(self.config.rope_dim, self.config.max_context, self.config.rope_theta)
class GatedDeltaNetBlock(FFNBlock):