Merge branch 'master' into precompile_backward

This commit is contained in:
George Hotz
2026-03-17 15:17:41 +08:00
committed by GitHub
4 changed files with 64 additions and 43 deletions

View File

@@ -49,5 +49,26 @@ class TestMoEFeedForward(unittest.TestCase):
expected = 1 + (Tensor([1.0]).silu().item() + Tensor([3.0]).silu().item()) / 2
np.testing.assert_allclose(out.numpy(), expected, rtol=1e-2)
def test_moe_feed_forward_norm_topk_prob(self):
from tinygrad.apps.llm import TransformerBlock
dim, hidden, n_heads = 8, 16, 2
num_experts, k = 4, 2
block = TransformerBlock(dim, hidden, n_heads, n_heads, norm_eps=1e-5, head_dim=dim//n_heads,
rope_theta=10000, max_context=16, num_experts=num_experts, num_experts_per_tok=k)
block.norm_topk_prob = True
block.ffn_gate_exps.weight = Tensor.stack(*[Tensor.eye(hidden, dim) * (i + 1) for i in range(num_experts)])
block.ffn_up_exps.weight = Tensor.stack(*[Tensor.eye(hidden, dim) for _ in range(num_experts)])
block.ffn_down_exps.weight = Tensor.stack(*[Tensor.eye(dim, hidden) for _ in range(num_experts)])
block.ffn_gate_inp.weight = Tensor([[0.1, 0, 0.1, 0]] * dim).T # equal top-2 experts, but only ~69% mass before renorm
block.ffn_norm.weight = Tensor.ones(dim)
h = Tensor.ones(1, 1, dim)
out = block._feed_forward(h)
expected = 1 + (Tensor([1.0]).silu().item() + Tensor([3.0]).silu().item()) / 2
np.testing.assert_allclose(out.numpy()[0, 0, 0], expected, rtol=1e-2)
if __name__ == '__main__':
unittest.main()

View File

@@ -84,7 +84,7 @@ def apply_rope(x:Tensor, freqs_cis:Tensor) -> Tensor:
class TransformerBlock:
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, head_dim:int, rope_theta:float,
max_context:int=0, qk_norm:int=0, num_experts:int=0, num_experts_per_tok:int=0):
max_context:int=0, qk_norm:int=0, num_experts:int=0, num_experts_per_tok:int=0, norm_topk_prob:bool=False):
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.head_dim = head_dim
@@ -107,6 +107,7 @@ class TransformerBlock:
# --- feed-forward (MoE or dense) -------------------------------------
if num_experts > 0:
self.norm_topk_prob = norm_topk_prob
self.num_experts_per_tok = num_experts_per_tok
self.ffn_gate_inp = nn.Linear(dim, num_experts, bias=False) # router
self.ffn_gate_exps = ExpertWeights(num_experts, dim, hidden_dim)
@@ -150,6 +151,7 @@ class TransformerBlock:
if hasattr(self, 'ffn_gate_exps'):
x = h_norm.unsqueeze(2) # (B, T, 1, D) - add expert dim for broadcasting
probs, sel = self.ffn_gate_inp(h_norm).softmax(-1).topk(self.num_experts_per_tok) # (B, T, k) each
if self.norm_topk_prob: probs = probs / probs.sum(axis=-1, keepdim=True)
x_down = self.ffn_down_exps(sel, self.ffn_gate_exps(sel, x).silu() * self.ffn_up_exps(sel, x)) # (B, T, k, D)
return h + (x_down * probs.unsqueeze(-1)).sum(axis=2) # (B, T, D)
# TODO: remove the need for this contiguous
@@ -166,9 +168,9 @@ class TransformerBlock:
class Transformer:
def __init__(self, *, num_blocks, dim, hidden_dim, n_heads, n_kv_heads, norm_eps, vocab_size, head_dim:int, rope_theta:float,
max_context:int=0, qk_norm:int=0, num_experts:int=0, num_experts_per_tok:int=0):
max_context:int=0, qk_norm:int=0, num_experts:int=0, num_experts_per_tok:int=0, norm_topk_prob:bool=False):
self.blk = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, head_dim, rope_theta, max_context, qk_norm,
num_experts, num_experts_per_tok) for _ in range(num_blocks)]
num_experts, num_experts_per_tok, norm_topk_prob) for _ in range(num_blocks)]
self.token_embd = nn.Embedding(vocab_size, dim)
self.output_norm = nn.RMSNorm(dim, norm_eps)
self.output = nn.Linear(dim, vocab_size, bias=False)
@@ -215,7 +217,8 @@ class Transformer:
head_dim=kv.get(f'{arch}.attention.key_length', kv[f'{arch}.embedding_length'] // n_heads),
rope_theta=kv[f'{arch}.rope.freq_base'], max_context=max_context,
qk_norm=int(state_dict['blk.0.attn_q_norm.weight'].shape[0]) if 'blk.0.attn_q_norm.weight' in state_dict else 0,
num_experts=kv.get(f'{arch}.expert_count', 0), num_experts_per_tok=kv.get(f'{arch}.expert_used_count', 0))
num_experts=kv.get(f'{arch}.expert_count', 0), num_experts_per_tok=kv.get(f'{arch}.expert_used_count', 0),
norm_topk_prob=True if arch=='qwen3moe' else False)
nn.state.load_state_dict(model, state_dict, verbose=False, consume=True, realize=False) # NOTE: rope_freqs.weight (32,) is unused
# NOTE: without this contiguous, it unpacks the weights from the model every time. we shouldn't need this, but for now it's faster
if realize:

View File

@@ -62,6 +62,18 @@ load_store_indexing = PatternMatcher([
# ***** load/store grouping *****
def expand_index(buf:UOp, vec:UOp):
# determine optimal image shapes
if IMAGE == 1 and isinstance(dt:=buf.dtype, ImageDType):
x, valid = vec.get_idx().gep(0), vec.get_valid().gep(0)
# search for dims that drop the most valid statements
best_drop, cands = -1, []
for ch, cw in ImageDType.valid_dims(dt):
if (dropped:=len(_drop_valid_stmts(valid, cidx:=uop_given_valid(valid, UOp.vectorize((x//4)%cw, x//(4*cw))), ch, cw))) > best_drop:
best_drop, cands = dropped, [(ch, cw, cidx)]
elif dropped == best_drop: cands.append((ch, cw, cidx))
# and tiebreak with indexing complexity (ie. number of nodes)
h, w, _ = cands[0] if len(cands) == 1 else min(cands, key=lambda cand: len(cand[2].gep(1).simplify().backward_slice))
buf = buf.replace(dtype=(dtypes.imageh if dt.itemsize == 2 else dtypes.imagef)((h, w, 4), w * 4 * dt.itemsize))
if getenv("UNSAFE_DISABLE_MASK", 0): vec = vec.get_idx()
# generate the individual indexes
return UOp(Ops.VECTORIZE, buf.dtype, tuple(buf.index(vec.gep(i), ptr=True) for i in range(vec.dtype.count)))
@@ -184,34 +196,21 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
if len(ret) <= 1: return None
return UOp(Ops.VCAT, ls.dtype, tuple(ret)) if ls.op is Ops.LOAD else UOp.group(*ret)
def _do_image_fixup(dt:ImageDType, idx:UOp) -> tuple[UOp, UOp, int, int]:
buf = idx.src[0]
x, valid = idx.src[1].get_idx(), idx.src[1].get_valid()
h, w = dt.shape[0], dt.shape[1]
if IMAGE == 1:
# search for dims that drop the most valid statements
best_drop, cands = -1, []
for ch, cw in ImageDType.valid_dims(dt):
if (dropped:=len(_drop_valid_stmts(valid, cidx:=uop_given_valid(valid, UOp.vectorize((x//4)%cw, x//(4*cw))), ch, cw))) > best_drop:
best_drop, cands = dropped, [(ch, cw, cidx)]
elif dropped == best_drop: cands.append((ch, cw, cidx))
# and tiebreak with indexing complexity (ie. number of nodes)
h, w, _ = cands[0] if len(cands) == 1 else min(cands, key=lambda cand: len(cand[2].simplify().gep(1).backward_slice))
buf = buf.replace(dtype=(dtypes.imageh if dt.itemsize == 2 else dtypes.imagef)((h, w, 4), w * 4 * dt.itemsize))
oidx = UOp(Ops.VECTORIZE, dtypes.index.vec(2), ((x // 4) % w, (x // (4*w))))
return x, idx.replace(src=(buf, oidx.valid(valid))), w, h
def get_image_idx(idx:UOp, width:int):
oidx = UOp(Ops.VECTORIZE, dtypes.index.vec(2), (((x:=idx.src[1].get_idx()) // 4) % width, (x // (4*width))))
return idx.replace(src=(idx.src[0], oidx.valid(idx.src[1].get_valid())))
def image_fixup(ls:UOp):
# normal image load or store, with the CAST from expand_index
if ls.src[0].op is Ops.CAST and isinstance(image_dtype:=ls.src[0].src[0].dtype, ImageDType):
assert ls.src[0].dtype.count == 4, "image must be casted to 4"
_, idx, _, _ = _do_image_fixup(image_dtype, ls.src[0].src[0])
idx = get_image_idx(ls.src[0].src[0], image_dtype.shape[1])
return ls.replace(src=(idx,)+ls.src[1:])
# this is an unprocessed image without a cast, aka unfoldable image load. this doesn't work for stores
if isinstance(image_dtype:=ls.src[0].dtype, ImageDType) and ls.src[0].src[1].get_idx().dtype != dtypes.index.vec(2):
assert ls.op is Ops.LOAD, "if an image store isn't upcasted to 4, we can't store it"
x, idx, width, height = _do_image_fixup(image_dtype, ls.src[0])
x, idx = ls.src[0].src[1].get_idx(), get_image_idx(ls.src[0], image_dtype.shape[1])
vec_load = ls.replace(dtype=ls.dtype.vec(4), src=(idx,)+ls.src[1:])
# image pixels have 4 channels (.xyzw), select channel based on x % 4
x_mod_4 = x % 4
@@ -252,32 +251,27 @@ def no_vectorized_alu(alu:UOp):
def no_vectorized_buf(buf:UOp):
return buf.replace(dtype=buf.ptrdtype.base.scalar().ptr(buf.ptrdtype.size*buf.ptrdtype.count, buf.ptrdtype.addrspace)).cast(buf.dtype)
def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp):
def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp, bcast:UOp|None=None):
cnt = cast.dtype.count
assert idx.dtype.count == 1, f"idx dtype must be 1 {idx.dtype}"
return buf.broadcast(cnt).index(idx.broadcast(cnt)*cnt+UOp.const(dtypes.index.vec(cnt), tuple(range(cnt))), ptr=True)
def no_vectorized_index_broadcast(buf:UOp, cast:UOp, bcast:UOp, idx:UOp):
cnt = cast.dtype.count
vcnt = cast.dtype.vcount
precnt = bcast.dtype.vcount
# TODO: I have no idea *why* this is. I just change things until the tests pass. No AI, old school.
if bcast.op is Ops.GEP:
gep_arg = tuple(flatten([range(precnt) for _ in range(vcnt)]))
sum_arg = tuple(flatten([[i+y for y in bcast.arg] for i in range(vcnt)]))
if bcast is not None and bcast.op is Ops.GEP:
# GEP selects specific lanes; bcast.arg[k] is the offset for lane k, iterate groups × selected lanes
pairs = [(k, g + bcast.arg[k]) for g, k in itertools.product(range(cast.dtype.vcount), range(len(bcast.arg)))]
elif bcast is not None:
# BROADCAST: cross product of components × lanes
pairs = [(j, c) for c, j in itertools.product(range(cnt), range(bcast.dtype.vcount))]
else:
gep_arg = tuple(flatten([range(precnt) for _ in range(cnt)]))
sum_arg = tuple(flatten([[i]*precnt for i in range(cnt)]))
new_idx = idx.gep(gep_arg)*cnt + UOp.const(dtypes.index.vec(len(sum_arg)), sum_arg)
return buf.broadcast(cnt*precnt).index(new_idx, ptr=True)
# simple scalar index: one lane, all components
pairs = [(0, c) for c in range(cnt)]
idx_lanes, offsets = (tuple(x) for x in zip(*pairs))
return buf.broadcast(len(pairs)).index(idx.gep(idx_lanes)*cnt + UOp.const(dtypes.index.vec(len(pairs)), offsets), ptr=True)
devectorize_buf_and_index = PatternMatcher([
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), no_vectorized_buf),
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").index(UPat.var("idx")), no_vectorized_index),
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").broadcast(name="bcast").index(UPat.var("idx")),
no_vectorized_index_broadcast),
no_vectorized_index),
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").gep(name="bcast").index(UPat.var("idx")),
no_vectorized_index_broadcast),
no_vectorized_index),
])
devectorize = PatternMatcher([

View File

@@ -8,7 +8,7 @@ from tinygrad.helpers import prod, all_same, getenv, dedup, all_int, DEBUG, SPLI
from tinygrad.helpers import PCONTIG, FLOAT16, OPENPILOT_HACKS, argsort, partition, get_single_element
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify
from tinygrad.codegen.opt import Opt
from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext, apply_movement_op
from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, IndexingContext, apply_movement_op
from tinygrad.schedule.multi import multi_pm
from tinygrad.schedule.allreduce import create_allreduce_function
@@ -71,8 +71,11 @@ pm_mops = PatternMatcher([
def fix_store_after_hazard(after:UOp, target:UOp, src:UOp):
# PERMUTE and FLIP reorder indices, SHRINK can have overlapping regions when dest is also shrunk
unsafe = {Ops.PERMUTE, Ops.FLIP} | ({Ops.SHRINK} if target.op_in_backward_slice_with_self(Ops.SHRINK) else set())
if any(s.op in unsafe and target.base in s.backward_slice for s in src.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS or s.op is Ops.AFTER)):
return after.replace(src=(after.src[0], target.store(src.contiguous())))
base = target.base
reaches_base: dict[UOp, bool] = {}
for s in src.toposort(gate=lambda s: s.op is not Ops.CONTIGUOUS):
reaches_base[s] = s is base or any(reaches_base.get(c) for c in s.src)
if reaches_base[s] and s.op in unsafe: return after.replace(src=(after.src[0], target.store(src.contiguous())))
def normalize_store_after_target_chain(after:UOp, target:UOp, src:UOp):
root_target = target