mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
Merge branch 'master' into precompile_backward
This commit is contained in:
@@ -49,5 +49,26 @@ class TestMoEFeedForward(unittest.TestCase):
|
||||
expected = 1 + (Tensor([1.0]).silu().item() + Tensor([3.0]).silu().item()) / 2
|
||||
np.testing.assert_allclose(out.numpy(), expected, rtol=1e-2)
|
||||
|
||||
def test_moe_feed_forward_norm_topk_prob(self):
|
||||
from tinygrad.apps.llm import TransformerBlock
|
||||
dim, hidden, n_heads = 8, 16, 2
|
||||
num_experts, k = 4, 2
|
||||
|
||||
block = TransformerBlock(dim, hidden, n_heads, n_heads, norm_eps=1e-5, head_dim=dim//n_heads,
|
||||
rope_theta=10000, max_context=16, num_experts=num_experts, num_experts_per_tok=k)
|
||||
block.norm_topk_prob = True
|
||||
|
||||
block.ffn_gate_exps.weight = Tensor.stack(*[Tensor.eye(hidden, dim) * (i + 1) for i in range(num_experts)])
|
||||
block.ffn_up_exps.weight = Tensor.stack(*[Tensor.eye(hidden, dim) for _ in range(num_experts)])
|
||||
block.ffn_down_exps.weight = Tensor.stack(*[Tensor.eye(dim, hidden) for _ in range(num_experts)])
|
||||
block.ffn_gate_inp.weight = Tensor([[0.1, 0, 0.1, 0]] * dim).T # equal top-2 experts, but only ~69% mass before renorm
|
||||
block.ffn_norm.weight = Tensor.ones(dim)
|
||||
|
||||
h = Tensor.ones(1, 1, dim)
|
||||
out = block._feed_forward(h)
|
||||
|
||||
expected = 1 + (Tensor([1.0]).silu().item() + Tensor([3.0]).silu().item()) / 2
|
||||
np.testing.assert_allclose(out.numpy()[0, 0, 0], expected, rtol=1e-2)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -84,7 +84,7 @@ def apply_rope(x:Tensor, freqs_cis:Tensor) -> Tensor:
|
||||
|
||||
class TransformerBlock:
|
||||
def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, head_dim:int, rope_theta:float,
|
||||
max_context:int=0, qk_norm:int=0, num_experts:int=0, num_experts_per_tok:int=0):
|
||||
max_context:int=0, qk_norm:int=0, num_experts:int=0, num_experts_per_tok:int=0, norm_topk_prob:bool=False):
|
||||
self.n_heads = n_heads
|
||||
self.n_kv_heads = n_kv_heads
|
||||
self.head_dim = head_dim
|
||||
@@ -107,6 +107,7 @@ class TransformerBlock:
|
||||
|
||||
# --- feed-forward (MoE or dense) -------------------------------------
|
||||
if num_experts > 0:
|
||||
self.norm_topk_prob = norm_topk_prob
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.ffn_gate_inp = nn.Linear(dim, num_experts, bias=False) # router
|
||||
self.ffn_gate_exps = ExpertWeights(num_experts, dim, hidden_dim)
|
||||
@@ -150,6 +151,7 @@ class TransformerBlock:
|
||||
if hasattr(self, 'ffn_gate_exps'):
|
||||
x = h_norm.unsqueeze(2) # (B, T, 1, D) - add expert dim for broadcasting
|
||||
probs, sel = self.ffn_gate_inp(h_norm).softmax(-1).topk(self.num_experts_per_tok) # (B, T, k) each
|
||||
if self.norm_topk_prob: probs = probs / probs.sum(axis=-1, keepdim=True)
|
||||
x_down = self.ffn_down_exps(sel, self.ffn_gate_exps(sel, x).silu() * self.ffn_up_exps(sel, x)) # (B, T, k, D)
|
||||
return h + (x_down * probs.unsqueeze(-1)).sum(axis=2) # (B, T, D)
|
||||
# TODO: remove the need for this contiguous
|
||||
@@ -166,9 +168,9 @@ class TransformerBlock:
|
||||
|
||||
class Transformer:
|
||||
def __init__(self, *, num_blocks, dim, hidden_dim, n_heads, n_kv_heads, norm_eps, vocab_size, head_dim:int, rope_theta:float,
|
||||
max_context:int=0, qk_norm:int=0, num_experts:int=0, num_experts_per_tok:int=0):
|
||||
max_context:int=0, qk_norm:int=0, num_experts:int=0, num_experts_per_tok:int=0, norm_topk_prob:bool=False):
|
||||
self.blk = [TransformerBlock(dim, hidden_dim, n_heads, n_kv_heads, norm_eps, head_dim, rope_theta, max_context, qk_norm,
|
||||
num_experts, num_experts_per_tok) for _ in range(num_blocks)]
|
||||
num_experts, num_experts_per_tok, norm_topk_prob) for _ in range(num_blocks)]
|
||||
self.token_embd = nn.Embedding(vocab_size, dim)
|
||||
self.output_norm = nn.RMSNorm(dim, norm_eps)
|
||||
self.output = nn.Linear(dim, vocab_size, bias=False)
|
||||
@@ -215,7 +217,8 @@ class Transformer:
|
||||
head_dim=kv.get(f'{arch}.attention.key_length', kv[f'{arch}.embedding_length'] // n_heads),
|
||||
rope_theta=kv[f'{arch}.rope.freq_base'], max_context=max_context,
|
||||
qk_norm=int(state_dict['blk.0.attn_q_norm.weight'].shape[0]) if 'blk.0.attn_q_norm.weight' in state_dict else 0,
|
||||
num_experts=kv.get(f'{arch}.expert_count', 0), num_experts_per_tok=kv.get(f'{arch}.expert_used_count', 0))
|
||||
num_experts=kv.get(f'{arch}.expert_count', 0), num_experts_per_tok=kv.get(f'{arch}.expert_used_count', 0),
|
||||
norm_topk_prob=True if arch=='qwen3moe' else False)
|
||||
nn.state.load_state_dict(model, state_dict, verbose=False, consume=True, realize=False) # NOTE: rope_freqs.weight (32,) is unused
|
||||
# NOTE: without this contiguous, it unpacks the weights from the model every time. we shouldn't need this, but for now it's faster
|
||||
if realize:
|
||||
|
||||
@@ -62,6 +62,18 @@ load_store_indexing = PatternMatcher([
|
||||
# ***** load/store grouping *****
|
||||
|
||||
def expand_index(buf:UOp, vec:UOp):
|
||||
# determine optimal image shapes
|
||||
if IMAGE == 1 and isinstance(dt:=buf.dtype, ImageDType):
|
||||
x, valid = vec.get_idx().gep(0), vec.get_valid().gep(0)
|
||||
# search for dims that drop the most valid statements
|
||||
best_drop, cands = -1, []
|
||||
for ch, cw in ImageDType.valid_dims(dt):
|
||||
if (dropped:=len(_drop_valid_stmts(valid, cidx:=uop_given_valid(valid, UOp.vectorize((x//4)%cw, x//(4*cw))), ch, cw))) > best_drop:
|
||||
best_drop, cands = dropped, [(ch, cw, cidx)]
|
||||
elif dropped == best_drop: cands.append((ch, cw, cidx))
|
||||
# and tiebreak with indexing complexity (ie. number of nodes)
|
||||
h, w, _ = cands[0] if len(cands) == 1 else min(cands, key=lambda cand: len(cand[2].gep(1).simplify().backward_slice))
|
||||
buf = buf.replace(dtype=(dtypes.imageh if dt.itemsize == 2 else dtypes.imagef)((h, w, 4), w * 4 * dt.itemsize))
|
||||
if getenv("UNSAFE_DISABLE_MASK", 0): vec = vec.get_idx()
|
||||
# generate the individual indexes
|
||||
return UOp(Ops.VECTORIZE, buf.dtype, tuple(buf.index(vec.gep(i), ptr=True) for i in range(vec.dtype.count)))
|
||||
@@ -184,34 +196,21 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
|
||||
if len(ret) <= 1: return None
|
||||
return UOp(Ops.VCAT, ls.dtype, tuple(ret)) if ls.op is Ops.LOAD else UOp.group(*ret)
|
||||
|
||||
def _do_image_fixup(dt:ImageDType, idx:UOp) -> tuple[UOp, UOp, int, int]:
|
||||
buf = idx.src[0]
|
||||
x, valid = idx.src[1].get_idx(), idx.src[1].get_valid()
|
||||
h, w = dt.shape[0], dt.shape[1]
|
||||
if IMAGE == 1:
|
||||
# search for dims that drop the most valid statements
|
||||
best_drop, cands = -1, []
|
||||
for ch, cw in ImageDType.valid_dims(dt):
|
||||
if (dropped:=len(_drop_valid_stmts(valid, cidx:=uop_given_valid(valid, UOp.vectorize((x//4)%cw, x//(4*cw))), ch, cw))) > best_drop:
|
||||
best_drop, cands = dropped, [(ch, cw, cidx)]
|
||||
elif dropped == best_drop: cands.append((ch, cw, cidx))
|
||||
# and tiebreak with indexing complexity (ie. number of nodes)
|
||||
h, w, _ = cands[0] if len(cands) == 1 else min(cands, key=lambda cand: len(cand[2].simplify().gep(1).backward_slice))
|
||||
buf = buf.replace(dtype=(dtypes.imageh if dt.itemsize == 2 else dtypes.imagef)((h, w, 4), w * 4 * dt.itemsize))
|
||||
oidx = UOp(Ops.VECTORIZE, dtypes.index.vec(2), ((x // 4) % w, (x // (4*w))))
|
||||
return x, idx.replace(src=(buf, oidx.valid(valid))), w, h
|
||||
def get_image_idx(idx:UOp, width:int):
|
||||
oidx = UOp(Ops.VECTORIZE, dtypes.index.vec(2), (((x:=idx.src[1].get_idx()) // 4) % width, (x // (4*width))))
|
||||
return idx.replace(src=(idx.src[0], oidx.valid(idx.src[1].get_valid())))
|
||||
|
||||
def image_fixup(ls:UOp):
|
||||
# normal image load or store, with the CAST from expand_index
|
||||
if ls.src[0].op is Ops.CAST and isinstance(image_dtype:=ls.src[0].src[0].dtype, ImageDType):
|
||||
assert ls.src[0].dtype.count == 4, "image must be casted to 4"
|
||||
_, idx, _, _ = _do_image_fixup(image_dtype, ls.src[0].src[0])
|
||||
idx = get_image_idx(ls.src[0].src[0], image_dtype.shape[1])
|
||||
return ls.replace(src=(idx,)+ls.src[1:])
|
||||
|
||||
# this is an unprocessed image without a cast, aka unfoldable image load. this doesn't work for stores
|
||||
if isinstance(image_dtype:=ls.src[0].dtype, ImageDType) and ls.src[0].src[1].get_idx().dtype != dtypes.index.vec(2):
|
||||
assert ls.op is Ops.LOAD, "if an image store isn't upcasted to 4, we can't store it"
|
||||
x, idx, width, height = _do_image_fixup(image_dtype, ls.src[0])
|
||||
x, idx = ls.src[0].src[1].get_idx(), get_image_idx(ls.src[0], image_dtype.shape[1])
|
||||
vec_load = ls.replace(dtype=ls.dtype.vec(4), src=(idx,)+ls.src[1:])
|
||||
# image pixels have 4 channels (.xyzw), select channel based on x % 4
|
||||
x_mod_4 = x % 4
|
||||
@@ -252,32 +251,27 @@ def no_vectorized_alu(alu:UOp):
|
||||
def no_vectorized_buf(buf:UOp):
|
||||
return buf.replace(dtype=buf.ptrdtype.base.scalar().ptr(buf.ptrdtype.size*buf.ptrdtype.count, buf.ptrdtype.addrspace)).cast(buf.dtype)
|
||||
|
||||
def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp):
|
||||
def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp, bcast:UOp|None=None):
|
||||
cnt = cast.dtype.count
|
||||
assert idx.dtype.count == 1, f"idx dtype must be 1 {idx.dtype}"
|
||||
return buf.broadcast(cnt).index(idx.broadcast(cnt)*cnt+UOp.const(dtypes.index.vec(cnt), tuple(range(cnt))), ptr=True)
|
||||
|
||||
def no_vectorized_index_broadcast(buf:UOp, cast:UOp, bcast:UOp, idx:UOp):
|
||||
cnt = cast.dtype.count
|
||||
vcnt = cast.dtype.vcount
|
||||
precnt = bcast.dtype.vcount
|
||||
# TODO: I have no idea *why* this is. I just change things until the tests pass. No AI, old school.
|
||||
if bcast.op is Ops.GEP:
|
||||
gep_arg = tuple(flatten([range(precnt) for _ in range(vcnt)]))
|
||||
sum_arg = tuple(flatten([[i+y for y in bcast.arg] for i in range(vcnt)]))
|
||||
if bcast is not None and bcast.op is Ops.GEP:
|
||||
# GEP selects specific lanes; bcast.arg[k] is the offset for lane k, iterate groups × selected lanes
|
||||
pairs = [(k, g + bcast.arg[k]) for g, k in itertools.product(range(cast.dtype.vcount), range(len(bcast.arg)))]
|
||||
elif bcast is not None:
|
||||
# BROADCAST: cross product of components × lanes
|
||||
pairs = [(j, c) for c, j in itertools.product(range(cnt), range(bcast.dtype.vcount))]
|
||||
else:
|
||||
gep_arg = tuple(flatten([range(precnt) for _ in range(cnt)]))
|
||||
sum_arg = tuple(flatten([[i]*precnt for i in range(cnt)]))
|
||||
new_idx = idx.gep(gep_arg)*cnt + UOp.const(dtypes.index.vec(len(sum_arg)), sum_arg)
|
||||
return buf.broadcast(cnt*precnt).index(new_idx, ptr=True)
|
||||
# simple scalar index: one lane, all components
|
||||
pairs = [(0, c) for c in range(cnt)]
|
||||
idx_lanes, offsets = (tuple(x) for x in zip(*pairs))
|
||||
return buf.broadcast(len(pairs)).index(idx.gep(idx_lanes)*cnt + UOp.const(dtypes.index.vec(len(pairs)), offsets), ptr=True)
|
||||
|
||||
devectorize_buf_and_index = PatternMatcher([
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), no_vectorized_buf),
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").index(UPat.var("idx")), no_vectorized_index),
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").broadcast(name="bcast").index(UPat.var("idx")),
|
||||
no_vectorized_index_broadcast),
|
||||
no_vectorized_index),
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG)).or_after(name="buf").cast(name="cast").gep(name="bcast").index(UPat.var("idx")),
|
||||
no_vectorized_index_broadcast),
|
||||
no_vectorized_index),
|
||||
])
|
||||
|
||||
devectorize = PatternMatcher([
|
||||
|
||||
@@ -8,7 +8,7 @@ from tinygrad.helpers import prod, all_same, getenv, dedup, all_int, DEBUG, SPLI
|
||||
from tinygrad.helpers import PCONTIG, FLOAT16, OPENPILOT_HACKS, argsort, partition, get_single_element
|
||||
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify
|
||||
from tinygrad.codegen.opt import Opt
|
||||
from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext, apply_movement_op
|
||||
from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, IndexingContext, apply_movement_op
|
||||
from tinygrad.schedule.multi import multi_pm
|
||||
from tinygrad.schedule.allreduce import create_allreduce_function
|
||||
|
||||
@@ -71,8 +71,11 @@ pm_mops = PatternMatcher([
|
||||
def fix_store_after_hazard(after:UOp, target:UOp, src:UOp):
|
||||
# PERMUTE and FLIP reorder indices, SHRINK can have overlapping regions when dest is also shrunk
|
||||
unsafe = {Ops.PERMUTE, Ops.FLIP} | ({Ops.SHRINK} if target.op_in_backward_slice_with_self(Ops.SHRINK) else set())
|
||||
if any(s.op in unsafe and target.base in s.backward_slice for s in src.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS or s.op is Ops.AFTER)):
|
||||
return after.replace(src=(after.src[0], target.store(src.contiguous())))
|
||||
base = target.base
|
||||
reaches_base: dict[UOp, bool] = {}
|
||||
for s in src.toposort(gate=lambda s: s.op is not Ops.CONTIGUOUS):
|
||||
reaches_base[s] = s is base or any(reaches_base.get(c) for c in s.src)
|
||||
if reaches_base[s] and s.op in unsafe: return after.replace(src=(after.src[0], target.store(src.contiguous())))
|
||||
|
||||
def normalize_store_after_target_chain(after:UOp, target:UOp, src:UOp):
|
||||
root_target = target
|
||||
|
||||
Reference in New Issue
Block a user