mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
llm: fix nan in kvcache (#15552)
This commit is contained in:
@@ -30,10 +30,10 @@ class TestTransformerGenerate(unittest.TestCase):
|
||||
gen = model.generate(tokens)
|
||||
next(gen)
|
||||
|
||||
# should only process tokens[7:] = [10, 11, 12] since first 7 are cached
|
||||
# should process tokens[6:] = [42, 10, 11, 12] since first 6 have cached k/v
|
||||
toks_shape = captured_inputs[0][0][-1]
|
||||
self.assertEqual(toks_shape.val if isinstance(toks_shape, UOp) else toks_shape, 3)
|
||||
self.assertEqual(captured_inputs[0][1], 7)
|
||||
self.assertEqual(toks_shape.val if isinstance(toks_shape, UOp) else toks_shape, 4)
|
||||
self.assertEqual(captured_inputs[0][1], 6)
|
||||
|
||||
def test_kv_cache_invalidation(self):
|
||||
"""Test that generate invalidates the KV cache when tokens diverge from the cached prefix."""
|
||||
@@ -106,6 +106,26 @@ class TestTransformerGenerate(unittest.TestCase):
|
||||
# 4 tokens, chunk_size=4 -> 1 prefill chunk
|
||||
self.assertEqual(get_prefill_flags(list(range(4)), 4), [True, False, False])
|
||||
|
||||
def test_kv_cache_resume_matches_fresh(self):
|
||||
model = Transformer(TEST_CONFIG)
|
||||
|
||||
# generate 2 tokens, then abandon
|
||||
prompt = list(range(1, 6))
|
||||
gen = model.generate(list(prompt))
|
||||
out1, out2 = next(gen), next(gen)
|
||||
|
||||
# resume with conversation history + new user tokens appended
|
||||
extended = prompt + [out1, out2, 10, 11, 12]
|
||||
gen = model.generate(list(extended))
|
||||
resumed_out = [next(gen) for _ in range(3)]
|
||||
|
||||
# compare against fresh generation (no cache) of the same prompt
|
||||
model._cached_tokens = []
|
||||
gen = model.generate(list(extended))
|
||||
fresh_out = [next(gen) for _ in range(3)]
|
||||
|
||||
self.assertEqual(fresh_out, resumed_out)
|
||||
|
||||
def test_temperature_zero_is_greedy(self):
|
||||
"""Temperature 0 (or near 0) should produce deterministic output."""
|
||||
model = Transformer(TEST_CONFIG)
|
||||
|
||||
@@ -271,7 +271,7 @@ class Transformer:
|
||||
# chunked prefill: keep processing until all prompt tokens are consumed
|
||||
if start_pos < len(tokens): continue
|
||||
tokens.append(int(out.item()))
|
||||
self._cached_tokens = tokens[:]
|
||||
self._cached_tokens = tokens[:-1]
|
||||
yield tokens[-1]
|
||||
|
||||
models = {
|
||||
|
||||
Reference in New Issue
Block a user