llm: fix nan in kvcache (#15552)

This commit is contained in:
nimlgen
2026-04-01 00:38:45 +03:00
committed by GitHub
parent 3af25ccdb4
commit 5181c8e23a
2 changed files with 24 additions and 4 deletions

View File

@@ -30,10 +30,10 @@ class TestTransformerGenerate(unittest.TestCase):
gen = model.generate(tokens)
next(gen)
# should only process tokens[7:] = [10, 11, 12] since first 7 are cached
# should process tokens[6:] = [42, 10, 11, 12] since first 6 have cached k/v
toks_shape = captured_inputs[0][0][-1]
self.assertEqual(toks_shape.val if isinstance(toks_shape, UOp) else toks_shape, 3)
self.assertEqual(captured_inputs[0][1], 7)
self.assertEqual(toks_shape.val if isinstance(toks_shape, UOp) else toks_shape, 4)
self.assertEqual(captured_inputs[0][1], 6)
def test_kv_cache_invalidation(self):
"""Test that generate invalidates the KV cache when tokens diverge from the cached prefix."""
@@ -106,6 +106,26 @@ class TestTransformerGenerate(unittest.TestCase):
# 4 tokens, chunk_size=4 -> 1 prefill chunk
self.assertEqual(get_prefill_flags(list(range(4)), 4), [True, False, False])
def test_kv_cache_resume_matches_fresh(self):
model = Transformer(TEST_CONFIG)
# generate 2 tokens, then abandon
prompt = list(range(1, 6))
gen = model.generate(list(prompt))
out1, out2 = next(gen), next(gen)
# resume with conversation history + new user tokens appended
extended = prompt + [out1, out2, 10, 11, 12]
gen = model.generate(list(extended))
resumed_out = [next(gen) for _ in range(3)]
# compare against fresh generation (no cache) of the same prompt
model._cached_tokens = []
gen = model.generate(list(extended))
fresh_out = [next(gen) for _ in range(3)]
self.assertEqual(fresh_out, resumed_out)
def test_temperature_zero_is_greedy(self):
"""Temperature 0 (or near 0) should produce deterministic output."""
model = Transformer(TEST_CONFIG)

View File

@@ -271,7 +271,7 @@ class Transformer:
# chunked prefill: keep processing until all prompt tokens are consumed
if start_pos < len(tokens): continue
tokens.append(int(out.item()))
self._cached_tokens = tokens[:]
self._cached_tokens = tokens[:-1]
yield tokens[-1]
models = {