* qwen3.5

* faster

* or

* rm zero hack

* less float

* T=1

* clean

* clean

* 4b

* rope_dim

* Revert "jit: captures linears, not execitems (#15399)"

This reverts commit 9656d97d97.

* DeltaNetBlock

* pairwise_topk

* clean

* Reapply "jit: captures linears, not execitems (#15399)"

This reverts commit cf3deff53d.

* clean topk, _swiglu

* common

* FFNBlock

* clean

* half

* no mix

* qwen3.5 test

* fix ssm cache invalidation

* TransformerConfig

* SSMConfig

* clean

* reset_state

* llm: reuse server conversation tokens to avoid BPE roundtrip cache miss

* import error

* prefill

* none check

* put it back

* clean pairwise_topk

* symbolic: fold BIND(CONST, CONST) to CONST

* clean

* simpler pm

* _cached_msg_count

* stream decoder; ssm checkpoints

* rm checkpoint

* attn_output_gate

* conflict, attn_output_gate

* clean, less has_ssm, assert

* chunked prefill

* _reset_cache

* _reusable_prefix_len

* revert loop

---------

Co-authored-by: b1tg <b1tg@users.noreply.github.com>
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
b1tg
2026-04-13 15:35:24 +08:00
committed by GitHub
parent 2ada38f777
commit 2b5ba0095d
2 changed files with 100 additions and 10 deletions

View File

@@ -508,6 +508,8 @@ jobs:
run: echo "What's a male chicken called? Answer with only one word." | MAX_BUFFER_SIZE=0 python3 -m tinygrad.apps.llm --model llama3.2:1b | tee /dev/stderr | grep -i rooster
- name: Test 1B LLM (llama q4)
run: echo "What's a male chicken called? Answer with only one word." | MAX_BUFFER_SIZE=0 python3 -m tinygrad.apps.llm --model llama3.2:1b-q4 | tee /dev/stderr | grep -i rooster
- name: Test 1B LLM (qwen3.5)
run: echo "What's a male chicken called? Answer with only one word." | MAX_BUFFER_SIZE=0 python3 -m tinygrad.apps.llm --model qwen3.5:0.8b | tee /dev/stderr | grep -i rooster
- name: Test 1B LLM (qwen)
# NOTE: qwen is dumb and only knows about female chickens
run: echo "What's a female chicken called? Answer with only one word." | MAX_BUFFER_SIZE=0 python3 -m tinygrad.apps.llm --model qwen3:0.6b | tee /dev/stderr | grep -i hen

View File

@@ -98,6 +98,14 @@ def pairwise_topk(x: Tensor, k: int) -> tuple[Tensor, Tensor]:
sel = Tensor.zeros_like(x).scatter(-1, cmp.sum(axis=-1).cast('int32'), vals)[:,:,n-k:].cast('int32')
return x.gather(-1, sel), sel
@dataclass(frozen=True)
class SSMConfig:
conv_kernel: int
state_size: int
group_count: int
time_step_rank: int
inner_size: int
@dataclass(frozen=True)
class TransformerConfig:
num_blocks: int
@@ -118,6 +126,9 @@ class TransformerConfig:
norm_topk_prob: bool = False
kv_lora_rank: int = 0
shared_expert_dim: int = 0
full_attention_interval: int = 0
attn_output_gate: bool = False
ssm: SSMConfig|None = None
shared_expert_gate: bool = True
leading_dense_blocks: int = 0
dense_hidden_dim: int = 0
@@ -171,6 +182,10 @@ class FFNBlock:
# TODO: remove the need for this contiguous
return self.ffn_down(self.ffn_gate(x).silu().contiguous() * self.ffn_up(x))
# given the token-prefix match, return how much cached state this block can still reuse
def _reusable_prefix_len(self, prefix_len:int, cached_len:int) -> int: return prefix_len
# return writes that reset this block's state after a cache mismatch
def _state_reset_ops(self) -> list[Tensor]: return []
def _init_state(self, x:Tensor): raise NotImplementedError
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor: raise NotImplementedError
@@ -189,12 +204,12 @@ class TransformerBlock(FFNBlock):
assert config.v_head_dim == config.head_dim, "TransformerBlock requires v_head_dim == head_dim"
# --- attention projections (all linear, bias-free) ------------------
q_proj_out = config.head_dim * config.n_heads
q_proj_out = config.head_dim * config.n_heads * (2 if config.attn_output_gate else 1)
kv_proj_out = config.head_dim * config.n_kv_heads
self.attn_q = nn.Linear(config.dim, q_proj_out, bias=False)
self.attn_k = nn.Linear(config.dim, kv_proj_out, bias=False)
self.attn_v = nn.Linear(config.dim, kv_proj_out, bias=False)
self.attn_output = nn.Linear(q_proj_out, config.dim, bias=False)
self.attn_output = nn.Linear(config.head_dim * config.n_heads, config.dim, bias=False)
if config.qk_norm: self.attn_q_norm, self.attn_k_norm = nn.RMSNorm(config.qk_norm, config.norm_eps), nn.RMSNorm(config.qk_norm, config.norm_eps)
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor:
@@ -202,6 +217,9 @@ class TransformerBlock(FFNBlock):
if self.config.qk_norm and self.config.qk_norm != self.config.head_dim: q, k = self.attn_q_norm(q), self.attn_k_norm(k)
B, T, _ = x.shape
if self.config.attn_output_gate:
qg = q.reshape(B, T, self.config.n_heads, 2, self.config.head_dim)
q, gate = qg[:, :, :, 0, :], qg[:, :, :, 1, :].reshape(B, T, self.config.n_heads * self.config.head_dim)
q = q.reshape(B, T, self.config.n_heads, self.config.head_dim).transpose(1, 2) # (B,H,T,Hd)
k = k.reshape(B, T, self.config.n_kv_heads, self.config.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
v = v.reshape(B, T, self.config.n_kv_heads, self.config.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
@@ -224,7 +242,7 @@ class TransformerBlock(FFNBlock):
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1) if resolve(T != 1) else None
attn = q.scaled_dot_product_attention(k, v, attn_mask=mask, enable_gqa=True) # (B,H,T,Hd)
attn = attn.transpose(1, 2).reshape(B, T, -1) # back to (B,T,D)
return self.attn_output(attn)
return self.attn_output(attn if not self.config.attn_output_gate else (attn * gate.sigmoid()))
def _init_state(self, x:Tensor):
if not hasattr(self, "cache_kv"):
@@ -274,15 +292,71 @@ class MLATransformerBlock(FFNBlock):
self.cache_v = Tensor.empty(x.shape[0], 1, self.config.max_context, self.config.kv_lora_rank, device=x.device)
self.freqs_cis = precompute_freqs_cis(self.config.rope_dim, self.config.max_context, self.config.rope_theta)
class GatedDeltaNetBlock(FFNBlock):
def __init__(self, config:TransformerConfig, ssm:SSMConfig):
super().__init__(config)
self.head_k_dim, self.num_k_heads, self.num_v_heads = ssm.state_size, ssm.group_count, ssm.time_step_rank
self.head_v_dim, self.ssm_conv_kernel = ssm.inner_size // ssm.time_step_rank, ssm.conv_kernel
self.conv_channels, self.q_dim = ssm.inner_size + 2*ssm.group_count*ssm.state_size, ssm.state_size*ssm.group_count
self.attn_qkv, self.attn_gate = nn.Linear(config.dim, self.conv_channels, bias=False), nn.Linear(config.dim, ssm.inner_size, bias=False)
self.ssm_alpha, self.ssm_beta = nn.Linear(config.dim, self.num_v_heads, bias=False), nn.Linear(config.dim, self.num_v_heads, bias=False)
self.ssm_conv1d = {"weight": Tensor.zeros(self.conv_channels, self.ssm_conv_kernel)}
self.ssm_dt = {"bias": Tensor.zeros(self.num_v_heads)}
self.ssm_a = Tensor.zeros(self.num_v_heads)
self.ssm_norm, self.ssm_out = nn.RMSNorm(self.head_v_dim, config.norm_eps), nn.Linear(ssm.inner_size, config.dim, bias=False)
def _attention(self, x:Tensor, start_pos:int|UOp) -> Tensor:
B, T, _ = x.shape
assert T == 1, "GatedDeltaNetBlock currently only supports T=1"
x = x.half()
out_gate = self.attn_gate(x).reshape(B, 1, self.num_v_heads, self.head_v_dim)
beta = self.ssm_beta(x).sigmoid().reshape(B, self.num_v_heads, 1, 1)
alpha = ((self.ssm_alpha(x).float() + self.ssm_dt["bias"]).softplus() * self.ssm_a).reshape(B, self.num_v_heads, 1, 1).exp()
conv_flat = (self.ssm_conv_kernel - 1) * self.conv_channels
ssm_flat = self.num_v_heads * self.head_v_dim * self.head_v_dim
conv_state = self.delta_cache[:, :conv_flat].reshape(B, self.ssm_conv_kernel - 1, self.conv_channels)
recurrent_state = self.delta_cache[:, conv_flat:conv_flat + ssm_flat].reshape(B, self.num_v_heads, self.head_v_dim, self.head_v_dim)
conv_window = conv_state.cat(self.attn_qkv(x), dim=1)
conv_out = (conv_window * self.ssm_conv1d["weight"].T.unsqueeze(0)).sum(1).silu()
q, k, v = conv_out.split([self.q_dim, self.q_dim, self.conv_channels - 2*self.q_dim], dim=-1)
q, k = q.reshape(B, self.num_k_heads, self.head_k_dim).normalize(dim=-1), k.reshape(B, self.num_k_heads, self.head_k_dim).normalize(dim=-1)
v = v.reshape(B, self.num_v_heads, self.head_v_dim)
if self.num_v_heads != self.num_k_heads:
k_repeat = self.num_v_heads // self.num_k_heads
q = q.unsqueeze(1).expand(B, k_repeat, self.num_k_heads, self.head_k_dim).reshape(B, self.num_v_heads, self.head_k_dim)
k = k.unsqueeze(1).expand(B, k_repeat, self.num_k_heads, self.head_k_dim).reshape(B, self.num_v_heads, self.head_k_dim)
q, k, v = (q * self.head_k_dim**-0.5).unsqueeze(-1), k.unsqueeze(-1), v.unsqueeze(-1)
recurrent_state = recurrent_state * alpha
recurrent_state = recurrent_state + ((v - recurrent_state@k) * beta)@k.transpose(-1, -2)
new_cache = conv_window[:, 1:, :].reshape(B, -1).cat(recurrent_state.reshape(B, -1), dim=-1).contiguous()
assigned = self.delta_cache.uop.after(self.delta_cache.uop.store(new_cache.cast(self.delta_cache.dtype).uop))
cache_tensor = Tensor(assigned, device=self.delta_cache.device)
final_state = cache_tensor[:, conv_flat:conv_flat + ssm_flat].reshape(B, self.num_v_heads, self.head_v_dim, self.head_v_dim)
core_attn_out = self.ssm_norm((final_state@q).squeeze(-1).reshape(B, 1, self.num_v_heads, self.head_v_dim))
return self.ssm_out((core_attn_out * out_gate.silu()).reshape(B, 1, -1).cast(x.dtype))
# recurrent state can't be partially reused after divergence, force a full rebuild
def _state_reset_ops(self): return [self.delta_cache.assign(Tensor.zeros_like(self.delta_cache))] if hasattr(self, "delta_cache") else []
def _reusable_prefix_len(self, prefix_len:int, cached_len:int) -> int: return 0 if prefix_len != cached_len else prefix_len
def _init_state(self, x):
if not hasattr(self, "delta_cache"):
conv_flat = (self.ssm_conv_kernel - 1) * self.conv_channels
ssm_flat = self.num_v_heads * self.head_v_dim * self.head_v_dim
self.delta_cache = Tensor.zeros(x.shape[0], conv_flat + ssm_flat, device=x.device).clone()
class Transformer:
def __init__(self, config:TransformerConfig):
dense_config = replace(config, num_experts=0, num_experts_per_tok=0, shared_expert_dim=0, hidden_dim=config.dense_hidden_dim or config.hidden_dim)
if config.ssm: config = replace(config, qk_norm=config.head_dim)
block_cls = MLATransformerBlock if config.kv_lora_rank > 0 else TransformerBlock
self.blk = [block_cls(dense_config if i < config.leading_dense_blocks else config) for i in range(config.num_blocks)]
self.blk:list[FFNBlock] = [GatedDeltaNetBlock(config, config.ssm) if config.ssm and (i+1) % config.full_attention_interval != 0 else
block_cls(dense_config if i < config.leading_dense_blocks else config) for i in range(config.num_blocks)]
self.token_embd = nn.Embedding(config.vocab_size, config.dim)
self.output_norm = nn.RMSNorm(config.dim, config.norm_eps)
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
self.max_context = config.max_context
self.has_recurrent_block = any(isinstance(b, GatedDeltaNetBlock) for b in self.blk)
self._cached_tokens: list[int] = []
# we specialize the JIT for prefill and rollout
self.prefill_jit = TinyJit(self.forward)
@@ -296,7 +370,7 @@ class Transformer:
return (logits / temperature.maximum(1e-12) - (Tensor.rand_like(logits).maximum(1e-12).log().neg()).log()).argmax(-1, keepdim=True)
def __call__(self, tokens:Tensor, start_pos:int|UOp, temperature:Tensor) -> Tensor:
return (self.prefill_jit if resolve(tokens.shape[1] != 1) else self.rollout_jit)(tokens, start_pos, temperature)
return (self.prefill_jit if resolve(tokens.shape[1] != 1) else self.rollout_jit)(tokens.contiguous(), start_pos, temperature)
@staticmethod
def from_gguf(gguf:Tensor, max_context:int|None=None, realize=bool(getenv("REALIZE", 0))) -> tuple[Transformer, dict]:
@@ -313,6 +387,11 @@ class Transformer:
max_context = min(max_context, kv[f'{arch}.context_length']) if max_context is not None else kv[f'{arch}.context_length']
n_heads, n_kv_heads = kv[f'{arch}.attention.head_count'], kv[f'{arch}.attention.head_count_kv']
ssm = None
if arch in ('qwen35', 'qwen35moe'):
ssm = SSMConfig(**{k: kv[f'{arch}.ssm.{k}'] for k in ('conv_kernel','state_size','group_count','time_step_rank','inner_size')})
state_dict = {k.replace('post_attention_norm', 'ffn_norm'):v for k,v in state_dict.items()}
kv_lora_rank = kv.get(f'{arch}.attention.kv_lora_rank', 0)
head_dim = kv.get(f'{arch}.attention.key_length_mla', kv.get(f'{arch}.attention.key_length', kv[f'{arch}.embedding_length'] // n_heads))
rope_dim = kv.get(f'{arch}.rope.dimension_count', head_dim)
@@ -330,7 +409,7 @@ class Transformer:
state_dict[name] = state_dict[name][:kv_lora_rank].cat(state_dict[name][kv_lora_rank:].rearrange("(h two) d -> (two h) d", two=2), dim=0)
config = TransformerConfig(
num_blocks=kv[f'{arch}.block_count'], dim=kv[f'{arch}.embedding_length'],
hidden_dim=kv.get(f'{arch}.expert_feed_forward_length', kv[f'{arch}.feed_forward_length']),
hidden_dim=kv.get(f'{arch}.expert_feed_forward_length', kv.get(f'{arch}.feed_forward_length', 0)),
n_heads=n_heads, n_kv_heads=n_kv_heads, norm_eps=kv[f'{arch}.attention.layer_norm_rms_epsilon'],
vocab_size=len(kv['tokenizer.ggml.tokens']),
head_dim=head_dim,
@@ -348,7 +427,8 @@ class Transformer:
kv.get(f'{arch}.expert_shared_count', 0) * kv.get(f'{arch}.expert_feed_forward_length', 0)),
shared_expert_gate=f"blk.{kv.get(f'{arch}.leading_dense_block_count', 0)}.ffn_gate_inp_shexp.weight" in state_dict,
dense_hidden_dim=kv.get(f'{arch}.feed_forward_length', 0) if kv.get(f'{arch}.leading_dense_block_count', 0) else 0,
routed_scaling_factor=kv.get(f'{arch}.expert_weights_scale', 1.0))
routed_scaling_factor=kv.get(f'{arch}.expert_weights_scale', 1.0), attn_output_gate=arch in ('qwen35', 'qwen35moe'), ssm=ssm,
full_attention_interval=kv.get(f'{arch}.full_attention_interval', 0))
model = Transformer(config)
nn.state.load_state_dict(model, state_dict, verbose=False, consume=True, realize=False) # NOTE: rope_freqs.weight (32,) is unused
# NOTE: without this contiguous, it unpacks the weights from the model every time. we shouldn't need this, but for now it's faster
@@ -357,18 +437,21 @@ class Transformer:
Tensor.realize(*params)
return model, kv
def get_start_pos(self, tokens:list[int]):
return sum(1 for _ in itertools.takewhile(lambda ab: ab[0] == ab[1], zip(tokens[:-1], self._cached_tokens)))
def get_start_pos(self, tokens:list[int]) -> int:
prefix_len = sum(1 for _ in itertools.takewhile(lambda ab: ab[0] == ab[1], zip(tokens[:-1], self._cached_tokens)))
return min(block._reusable_prefix_len(prefix_len, len(self._cached_tokens)) for block in self.blk)
def generate(self, tokens:list[int], chunk_size:int=32, temperature:float=0.0):
if self.has_recurrent_block: chunk_size = 1
v_start_pos = UOp.variable("start_pos", 0, self.max_context-1)
v_toks = UOp.variable("toks", 1, chunk_size)
# TODO: use UOp.variable for temperature once float variables are supported
temp = Tensor(temperature).contiguous()
# assign all input tokens once, then slice from start_pos for the model call
t = Tensor(tokens + [0] * (self.max_context - len(tokens)), dtype="int32").reshape(1, self.max_context)
# recompute start_pos from what's currently valid in the kv cache
# recompute start_pos from what's currently valid in the caches
start_pos = self.get_start_pos(tokens)
if start_pos < len(self._cached_tokens) and (resets := [r for b in self.blk for r in b._state_reset_ops()]): Tensor.realize(*resets)
out, prompt_len = None, len(tokens)
while len(tokens) < self.max_context:
sp, nt = v_start_pos.bind(start_pos), v_toks.bind(min(chunk_size, len(tokens) - start_pos))
@@ -390,6 +473,11 @@ models = {
"qwen3:1.7b": "https://huggingface.co/unsloth/Qwen3-1.7B-GGUF/resolve/main/Qwen3-1.7B-Q4_K_M.gguf",
"qwen3:8b": "https://huggingface.co/Qwen/Qwen3-8B-GGUF/resolve/main/Qwen3-8B-Q4_K_M.gguf",
"qwen3:30b-a3b": "https://huggingface.co/Qwen/Qwen3-30B-A3B-GGUF/resolve/main/Qwen3-30B-A3B-Q4_K_M.gguf",
"qwen3.5:0.8b": "https://huggingface.co/unsloth/Qwen3.5-0.8B-GGUF/resolve/main/Qwen3.5-0.8B-Q8_0.gguf",
"qwen3.5:4b": "https://huggingface.co/unsloth/Qwen3.5-4B-GGUF/resolve/main/Qwen3.5-4B-Q4_K_M.gguf",
"qwen3.5:9b": "https://huggingface.co/unsloth/Qwen3.5-9B-GGUF/resolve/main/Qwen3.5-9B-Q4_K_M.gguf",
"qwen3.5:27b": "https://huggingface.co/unsloth/Qwen3.5-27B-GGUF/resolve/main/Qwen3.5-27B-Q4_K_M.gguf",
"qwen3.5:35b-a3b": "https://huggingface.co/unsloth/Qwen3.5-35B-A3B-GGUF/resolve/main/Qwen3.5-35B-A3B-Q4_K_M.gguf",
"olmoe": "https://huggingface.co/allenai/OLMoE-1B-7B-0924-Instruct-GGUF/resolve/main/olmoe-1b-7b-0924-instruct-q4_k_m.gguf",
"moonlight": "https://huggingface.co/gabriellarson/Moonlight-16B-A3B-Instruct-GGUF/resolve/main/Moonlight-16B-A3B-Instruct-Q4_K_M.gguf",
}