mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
llm: dedup MLA cache_v (#15887)
This commit is contained in:
@@ -219,9 +219,8 @@ class MLATransformerBlock(FFNBlock):
|
||||
self.freqs_cis[start_pos:start_pos+T])
|
||||
|
||||
k_store = c_kv.reshape(B, 1, T, self.config.kv_lora_rank).cat(k_rope.reshape(B, 1, T, self.config.rope_dim), dim=-1)
|
||||
v_store = c_kv.reshape(B, 1, T, self.config.kv_lora_rank)
|
||||
k = Tensor(self.cache_k.uop.after(self.cache_k[:, :, start_pos:start_pos+T, :].uop.store(k_store.uop)))[:, :, 0:start_pos+T, :]
|
||||
v = Tensor(self.cache_v.uop.after(self.cache_v[:, :, start_pos:start_pos+T, :].uop.store(v_store.uop)))[:, :, 0:start_pos+T, :]
|
||||
v = k[..., :self.config.kv_lora_rank]
|
||||
|
||||
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1) if resolve(T != 1) else None
|
||||
attn = q @ k.transpose(-1, -2) * (1.0 / self.config.head_dim ** 0.5)
|
||||
@@ -233,7 +232,6 @@ class MLATransformerBlock(FFNBlock):
|
||||
def _init_state(self, x:Tensor):
|
||||
if not hasattr(self, "cache_k"):
|
||||
self.cache_k = Tensor.empty(x.shape[0], 1, self.config.max_context, self.config.kv_lora_rank + self.config.rope_dim, device=x.device)
|
||||
self.cache_v = Tensor.empty(x.shape[0], 1, self.config.max_context, self.config.kv_lora_rank, device=x.device)
|
||||
self.freqs_cis = precompute_freqs_cis(self.config.rope_dim, self.config.max_context, self.config.rope_theta)
|
||||
|
||||
class GatedDeltaNetBlock(FFNBlock):
|
||||
|
||||
Reference in New Issue
Block a user