From 575b40b93ac4b4672826980aaed43c82164c18dc Mon Sep 17 00:00:00 2001 From: Christopher Milan Date: Mon, 16 Mar 2026 20:16:33 -0700 Subject: [PATCH 1/4] determine image shapes before index devectorization (#15304) --- tinygrad/codegen/late/devectorizer.py | 35 +++++++++++++-------------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/tinygrad/codegen/late/devectorizer.py b/tinygrad/codegen/late/devectorizer.py index 93d65da486..072094ae95 100644 --- a/tinygrad/codegen/late/devectorizer.py +++ b/tinygrad/codegen/late/devectorizer.py @@ -62,6 +62,18 @@ load_store_indexing = PatternMatcher([ # ***** load/store grouping ***** def expand_index(buf:UOp, vec:UOp): + # determine optimal image shapes + if IMAGE == 1 and isinstance(dt:=buf.dtype, ImageDType): + x, valid = vec.get_idx().gep(0), vec.get_valid().gep(0) + # search for dims that drop the most valid statements + best_drop, cands = -1, [] + for ch, cw in ImageDType.valid_dims(dt): + if (dropped:=len(_drop_valid_stmts(valid, cidx:=uop_given_valid(valid, UOp.vectorize((x//4)%cw, x//(4*cw))), ch, cw))) > best_drop: + best_drop, cands = dropped, [(ch, cw, cidx)] + elif dropped == best_drop: cands.append((ch, cw, cidx)) + # and tiebreak with indexing complexity (ie. number of nodes) + h, w, _ = cands[0] if len(cands) == 1 else min(cands, key=lambda cand: len(cand[2].gep(1).simplify().backward_slice)) + buf = buf.replace(dtype=(dtypes.imageh if dt.itemsize == 2 else dtypes.imagef)((h, w, 4), w * 4 * dt.itemsize)) if getenv("UNSAFE_DISABLE_MASK", 0): vec = vec.get_idx() # generate the individual indexes return UOp(Ops.VECTORIZE, buf.dtype, tuple(buf.index(vec.gep(i), ptr=True) for i in range(vec.dtype.count))) @@ -184,34 +196,21 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp): if len(ret) <= 1: return None return UOp(Ops.VCAT, ls.dtype, tuple(ret)) if ls.op is Ops.LOAD else UOp.group(*ret) -def _do_image_fixup(dt:ImageDType, idx:UOp) -> tuple[UOp, UOp, int, int]: - buf = idx.src[0] - x, valid = idx.src[1].get_idx(), idx.src[1].get_valid() - h, w = dt.shape[0], dt.shape[1] - if IMAGE == 1: - # search for dims that drop the most valid statements - best_drop, cands = -1, [] - for ch, cw in ImageDType.valid_dims(dt): - if (dropped:=len(_drop_valid_stmts(valid, cidx:=uop_given_valid(valid, UOp.vectorize((x//4)%cw, x//(4*cw))), ch, cw))) > best_drop: - best_drop, cands = dropped, [(ch, cw, cidx)] - elif dropped == best_drop: cands.append((ch, cw, cidx)) - # and tiebreak with indexing complexity (ie. number of nodes) - h, w, _ = cands[0] if len(cands) == 1 else min(cands, key=lambda cand: len(cand[2].simplify().gep(1).backward_slice)) - buf = buf.replace(dtype=(dtypes.imageh if dt.itemsize == 2 else dtypes.imagef)((h, w, 4), w * 4 * dt.itemsize)) - oidx = UOp(Ops.VECTORIZE, dtypes.index.vec(2), ((x // 4) % w, (x // (4*w)))) - return x, idx.replace(src=(buf, oidx.valid(valid))), w, h +def get_image_idx(idx:UOp, width:int): + oidx = UOp(Ops.VECTORIZE, dtypes.index.vec(2), (((x:=idx.src[1].get_idx()) // 4) % width, (x // (4*width)))) + return idx.replace(src=(idx.src[0], oidx.valid(idx.src[1].get_valid()))) def image_fixup(ls:UOp): # normal image load or store, with the CAST from expand_index if ls.src[0].op is Ops.CAST and isinstance(image_dtype:=ls.src[0].src[0].dtype, ImageDType): assert ls.src[0].dtype.count == 4, "image must be casted to 4" - _, idx, _, _ = _do_image_fixup(image_dtype, ls.src[0].src[0]) + idx = get_image_idx(ls.src[0].src[0], image_dtype.shape[1]) return ls.replace(src=(idx,)+ls.src[1:]) # this is an unprocessed image without a cast, aka unfoldable image load. this doesn't work for stores if isinstance(image_dtype:=ls.src[0].dtype, ImageDType) and ls.src[0].src[1].get_idx().dtype != dtypes.index.vec(2): assert ls.op is Ops.LOAD, "if an image store isn't upcasted to 4, we can't store it" - x, idx, width, height = _do_image_fixup(image_dtype, ls.src[0]) + x, idx = ls.src[0].src[1].get_idx(), get_image_idx(ls.src[0], image_dtype.shape[1]) vec_load = ls.replace(dtype=ls.dtype.vec(4), src=(idx,)+ls.src[1:]) # image pixels have 4 channels (.xyzw), select channel based on x % 4 x_mod_4 = x % 4 From 1283b57b4e3e68e58ce4a643a628aec7d22f5841 Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 16 Mar 2026 23:55:59 -0400 Subject: [PATCH 2/4] update fix_store_after_hazard (#15309) actual gate is just not CONTIGUOUS, also don't need to check against full backward_slice --- tinygrad/schedule/rangeify.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 1ceeca13c7..fc2495ed48 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -8,7 +8,7 @@ from tinygrad.helpers import prod, all_same, getenv, dedup, all_int, DEBUG, SPLI from tinygrad.helpers import PCONTIG, FLOAT16, OPENPILOT_HACKS, argsort, partition, get_single_element from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify from tinygrad.codegen.opt import Opt -from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext, apply_movement_op +from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, IndexingContext, apply_movement_op from tinygrad.schedule.multi import multi_pm from tinygrad.schedule.allreduce import create_allreduce_function @@ -71,8 +71,11 @@ pm_mops = PatternMatcher([ def fix_store_after_hazard(after:UOp, target:UOp, src:UOp): # PERMUTE and FLIP reorder indices, SHRINK can have overlapping regions when dest is also shrunk unsafe = {Ops.PERMUTE, Ops.FLIP} | ({Ops.SHRINK} if target.op_in_backward_slice_with_self(Ops.SHRINK) else set()) - if any(s.op in unsafe and target.base in s.backward_slice for s in src.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS or s.op is Ops.AFTER)): - return after.replace(src=(after.src[0], target.store(src.contiguous()))) + base = target.base + reaches_base: dict[UOp, bool] = {} + for s in src.toposort(gate=lambda s: s.op is not Ops.CONTIGUOUS): + reaches_base[s] = s is base or any(reaches_base.get(c) for c in s.src) + if reaches_base[s] and s.op in unsafe: return after.replace(src=(after.src[0], target.store(src.contiguous()))) def normalize_store_after_target_chain(after:UOp, target:UOp, src:UOp): root_target = target From 856a839efc4de0a1df9887c4156c5a75ecfef1c8 Mon Sep 17 00:00:00 2001 From: b1tg <33436708+b1tg@users.noreply.github.com> Date: Tue, 17 Mar 2026 12:57:33 +0800 Subject: [PATCH 3/4] llm: fix qwen3 moe topk renormalization (#15201) --- test/unit/test_llm_moe.py | 21 +++++++++++++++++++++ tinygrad/apps/llm.py | 11 +++++++---- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/test/unit/test_llm_moe.py b/test/unit/test_llm_moe.py index 764ddd3857..f413570e27 100644 --- a/test/unit/test_llm_moe.py +++ b/test/unit/test_llm_moe.py @@ -49,5 +49,26 @@ class TestMoEFeedForward(unittest.TestCase): expected = 1 + (Tensor([1.0]).silu().item() + Tensor([3.0]).silu().item()) / 2 np.testing.assert_allclose(out.numpy(), expected, rtol=1e-2) + def test_moe_feed_forward_norm_topk_prob(self): + from tinygrad.apps.llm import TransformerBlock + dim, hidden, n_heads = 8, 16, 2 + num_experts, k = 4, 2 + + block = TransformerBlock(dim, hidden, n_heads, n_heads, norm_eps=1e-5, head_dim=dim//n_heads, + rope_theta=10000, max_context=16, num_experts=num_experts, num_experts_per_tok=k) + block.norm_topk_prob = True + + block.ffn_gate_exps.weight = Tensor.stack(*[Tensor.eye(hidden, dim) * (i + 1) for i in range(num_experts)]) + block.ffn_up_exps.weight = Tensor.stack(*[Tensor.eye(hidden, dim) for _ in range(num_experts)]) + block.ffn_down_exps.weight = Tensor.stack(*[Tensor.eye(dim, hidden) for _ in range(num_experts)]) + block.ffn_gate_inp.weight = Tensor([[0.1, 0, 0.1, 0]] * dim).T # equal top-2 experts, but only ~69% mass before renorm + block.ffn_norm.weight = Tensor.ones(dim) + + h = Tensor.ones(1, 1, dim) + out = block._feed_forward(h) + + expected = 1 + (Tensor([1.0]).silu().item() + Tensor([3.0]).silu().item()) / 2 + np.testing.assert_allclose(out.numpy()[0, 0, 0], expected, rtol=1e-2) + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/apps/llm.py b/tinygrad/apps/llm.py index 704eba4414..38a46f74cc 100644 --- a/tinygrad/apps/llm.py +++ b/tinygrad/apps/llm.py @@ -84,7 +84,7 @@ def apply_rope(x:Tensor, freqs_cis:Tensor) -> Tensor: class TransformerBlock: def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, head_dim:int, rope_theta:float, - max_context:int=0, qk_norm:int=0, num_experts:int=0, num_experts_per_tok:int=0): + max_context:int=0, qk_norm:int=0, num_experts:int=0, num_experts_per_tok:int=0, norm_topk_prob:bool=False): self.n_heads = n_heads self.n_kv_heads = n_kv_heads self.head_dim = head_dim @@ -107,6 +107,7 @@ class TransformerBlock: # --- feed-forward (MoE or dense) ------------------------------------- if num_experts > 0: + self.norm_topk_prob = norm_topk_prob self.num_experts_per_tok = num_experts_per_tok self.ffn_gate_inp = nn.Linear(dim, num_experts, bias=False) # router self.ffn_gate_exps = ExpertWeights(num_experts, dim, hidden_dim) @@ -150,6 +151,7 @@ class TransformerBlock: if hasattr(self, 'ffn_gate_exps'): x = h_norm.unsqueeze(2) # (B, T, 1, D) - add expert dim for broadcasting probs, sel = self.ffn_gate_inp(h_norm).softmax(-1).topk(self.num_experts_per_tok) # (B, T, k) each + if self.norm_topk_prob: probs = probs / probs.sum(axis=-1, keepdim=True) x_down = self.ffn_down_exps(sel, self.ffn_gate_exps(sel, x).silu() * self.ffn_up_exps(sel, x)) # (B, T, k, D) return h + (x_down * probs.unsqueeze(-1)).sum(axis=2) # (B, T, D) # TODO: remove the need for this contiguous @@ -166,9 +168,9 @@ class TransformerBlock: class Transformer: def __init__(self, *, num_blocks, dim, hidden_dim, n_heads, n_kv_heads, norm_eps, vocab_size, head_dim:int, rope_theta:float, - max_context:int=0, qk_norm:int=0, num_experts:int=0, num_experts_per_tok:int=0): + max_context:int=0, qk_norm:int=0, num_experts:int=0, num_experts_per_tok:int=0, norm_topk_prob:bool=False): self.blk = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, head_dim, rope_theta, max_context, qk_norm, - num_experts, num_experts_per_tok) for _ in range(num_blocks)] + num_experts, num_experts_per_tok, norm_topk_prob) for _ in range(num_blocks)] self.token_embd = nn.Embedding(vocab_size, dim) self.output_norm = nn.RMSNorm(dim, norm_eps) self.output = nn.Linear(dim, vocab_size, bias=False) @@ -215,7 +217,8 @@ class Transformer: head_dim=kv.get(f'{arch}.attention.key_length', kv[f'{arch}.embedding_length'] // n_heads), rope_theta=kv[f'{arch}.rope.freq_base'], max_context=max_context, qk_norm=int(state_dict['blk.0.attn_q_norm.weight'].shape[0]) if 'blk.0.attn_q_norm.weight' in state_dict else 0, - num_experts=kv.get(f'{arch}.expert_count', 0), num_experts_per_tok=kv.get(f'{arch}.expert_used_count', 0)) + num_experts=kv.get(f'{arch}.expert_count', 0), num_experts_per_tok=kv.get(f'{arch}.expert_used_count', 0), + norm_topk_prob=True if arch=='qwen3moe' else False) nn.state.load_state_dict(model, state_dict, verbose=False, consume=True, realize=False) # NOTE: rope_freqs.weight (32,) is unused # NOTE: without this contiguous, it unpacks the weights from the model every time. we shouldn't need this, but for now it's faster if realize: From 6b6d1814cab9ae1ee3af7e628a7bd37fa7f76f5f Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 17 Mar 2026 03:05:23 -0400 Subject: [PATCH 4/4] update no_vectorized_index [pr] (#15313) combine no_vectorized_index and no_vectorized_index_broadcast --- tinygrad/codegen/late/devectorizer.py | 31 +++++++++++---------------- 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/tinygrad/codegen/late/devectorizer.py b/tinygrad/codegen/late/devectorizer.py index 072094ae95..e8bc757f5e 100644 --- a/tinygrad/codegen/late/devectorizer.py +++ b/tinygrad/codegen/late/devectorizer.py @@ -251,32 +251,27 @@ def no_vectorized_alu(alu:UOp): def no_vectorized_buf(buf:UOp): return buf.replace(dtype=buf.ptrdtype.base.scalar().ptr(buf.ptrdtype.size*buf.ptrdtype.count, buf.ptrdtype.addrspace)).cast(buf.dtype) -def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp): +def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp, bcast:UOp|None=None): cnt = cast.dtype.count - assert idx.dtype.count == 1, f"idx dtype must be 1 {idx.dtype}" - return buf.broadcast(cnt).index(idx.broadcast(cnt)*cnt+UOp.const(dtypes.index.vec(cnt), tuple(range(cnt))), ptr=True) - -def no_vectorized_index_broadcast(buf:UOp, cast:UOp, bcast:UOp, idx:UOp): - cnt = cast.dtype.count - vcnt = cast.dtype.vcount - precnt = bcast.dtype.vcount - # TODO: I have no idea *why* this is. I just change things until the tests pass. No AI, old school. - if bcast.op is Ops.GEP: - gep_arg = tuple(flatten([range(precnt) for _ in range(vcnt)])) - sum_arg = tuple(flatten([[i+y for y in bcast.arg] for i in range(vcnt)])) + if bcast is not None and bcast.op is Ops.GEP: + # GEP selects specific lanes; bcast.arg[k] is the offset for lane k, iterate groups × selected lanes + pairs = [(k, g + bcast.arg[k]) for g, k in itertools.product(range(cast.dtype.vcount), range(len(bcast.arg)))] + elif bcast is not None: + # BROADCAST: cross product of components × lanes + pairs = [(j, c) for c, j in itertools.product(range(cnt), range(bcast.dtype.vcount))] else: - gep_arg = tuple(flatten([range(precnt) for _ in range(cnt)])) - sum_arg = tuple(flatten([[i]*precnt for i in range(cnt)])) - new_idx = idx.gep(gep_arg)*cnt + UOp.const(dtypes.index.vec(len(sum_arg)), sum_arg) - return buf.broadcast(cnt*precnt).index(new_idx, ptr=True) + # simple scalar index: one lane, all components + pairs = [(0, c) for c in range(cnt)] + idx_lanes, offsets = (tuple(x) for x in zip(*pairs)) + return buf.broadcast(len(pairs)).index(idx.gep(idx_lanes)*cnt + UOp.const(dtypes.index.vec(len(pairs)), offsets), ptr=True) devectorize_buf_and_index = PatternMatcher([ (UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), no_vectorized_buf), (UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").index(UPat.var("idx")), no_vectorized_index), (UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").broadcast(name="bcast").index(UPat.var("idx")), - no_vectorized_index_broadcast), + no_vectorized_index), (UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").gep(name="bcast").index(UPat.var("idx")), - no_vectorized_index_broadcast), + no_vectorized_index), ]) devectorize = PatternMatcher([