From 5181c8e23a19fc45a0042e2a9adfc20d9834bf1e Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Wed, 1 Apr 2026 00:38:45 +0300 Subject: [PATCH] llm: fix nan in kvcache (#15552) --- test/unit/test_llm_server.py | 26 +++++++++++++++++++++++--- tinygrad/apps/llm.py | 2 +- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/test/unit/test_llm_server.py b/test/unit/test_llm_server.py index b6441af7b0..593f4dec5e 100644 --- a/test/unit/test_llm_server.py +++ b/test/unit/test_llm_server.py @@ -30,10 +30,10 @@ class TestTransformerGenerate(unittest.TestCase): gen = model.generate(tokens) next(gen) - # should only process tokens[7:] = [10, 11, 12] since first 7 are cached + # should process tokens[6:] = [42, 10, 11, 12] since first 6 have cached k/v toks_shape = captured_inputs[0][0][-1] - self.assertEqual(toks_shape.val if isinstance(toks_shape, UOp) else toks_shape, 3) - self.assertEqual(captured_inputs[0][1], 7) + self.assertEqual(toks_shape.val if isinstance(toks_shape, UOp) else toks_shape, 4) + self.assertEqual(captured_inputs[0][1], 6) def test_kv_cache_invalidation(self): """Test that generate invalidates the KV cache when tokens diverge from the cached prefix.""" @@ -106,6 +106,26 @@ class TestTransformerGenerate(unittest.TestCase): # 4 tokens, chunk_size=4 -> 1 prefill chunk self.assertEqual(get_prefill_flags(list(range(4)), 4), [True, False, False]) + def test_kv_cache_resume_matches_fresh(self): + model = Transformer(TEST_CONFIG) + + # generate 2 tokens, then abandon + prompt = list(range(1, 6)) + gen = model.generate(list(prompt)) + out1, out2 = next(gen), next(gen) + + # resume with conversation history + new user tokens appended + extended = prompt + [out1, out2, 10, 11, 12] + gen = model.generate(list(extended)) + resumed_out = [next(gen) for _ in range(3)] + + # compare against fresh generation (no cache) of the same prompt + model._cached_tokens = [] + gen = model.generate(list(extended)) + fresh_out = [next(gen) for _ in range(3)] + + self.assertEqual(fresh_out, resumed_out) + def test_temperature_zero_is_greedy(self): """Temperature 0 (or near 0) should produce deterministic output.""" model = Transformer(TEST_CONFIG) diff --git a/tinygrad/apps/llm.py b/tinygrad/apps/llm.py index 2e31a2566b..e1480b2e8c 100644 --- a/tinygrad/apps/llm.py +++ b/tinygrad/apps/llm.py @@ -271,7 +271,7 @@ class Transformer: # chunked prefill: keep processing until all prompt tokens are consumed if start_pos < len(tokens): continue tokens.append(int(out.item())) - self._cached_tokens = tokens[:] + self._cached_tokens = tokens[:-1] yield tokens[-1] models = {