#!/usr/bin/env python """ JIT Footguns: Documenting unexpected behavior changes when using @TinyJit Each test shows behavior that works without JIT but changes with JIT. Comments marked "should be X!" indicate the intuitively expected value. SILENT MISMATCHES (highest priority - wrong results, no error): class_method_shared_across_instances EASY could check if first arg is self and warn slice_assign_requires_realize MED assign graph not connected to read during JIT replay output_buffer_reuse MED performance tradeoff, could add option or better docs symbolic_pad_view_frozen MED pad view BIND values baked in at capture time python_constants_frozen HARD inherent to tracing JITs conditional_branches_frozen HARD inherent to tracing JITs ERRORS RAISED (lower priority - at least users know): item_bakes_in_values EASY raises JitError if .item()/.data() accessed during capture unrealized_const_input_error EASY raises JitError for unrealized const inputs non_tensor_outputs_error EASY raises JitError if return contains non-Tensor values positional_kwargs_cannot_mix EASY normalize positional args to kwargs using function signature duplicate_inputs_fail MED would need to handle aliasing in input_replace nested_jit_fails_on_second_call MED could fail on first call instead of second """ import unittest import numpy as np from tinygrad import Tensor, TinyJit, Device from tinygrad.engine.jit import JitError from tinygrad.helpers import JIT class TestJitFootguns(unittest.TestCase): def test_output_buffer_reuse(self): """Output tensors share buffer after capture - old references get overwritten.""" @TinyJit def f(x): return x.sum().realize() r1 = f(Tensor([1, 1])) # warmup r2 = f(Tensor([2, 2])) # capture r3 = f(Tensor([3, 3])) # jit exec self.assertEqual(r1.item(), 2) # warmup result independent self.assertEqual(r3.item(), 6) # latest is correct self.assertEqual(r2.item(), 6) # should be 4! (overwritten by r3) def test_output_buffer_workaround(self): """Use .clone().realize() to get independent copies.""" @TinyJit def f(x): return x.sum().realize() r1 = f(Tensor([1, 1])).clone().realize() r2 = f(Tensor([2, 2])).clone().realize() r3 = f(Tensor([3, 3])).clone().realize() self.assertEqual([r1.item(), r2.item(), r3.item()], [2, 4, 6]) def test_graph_input_output_aliasing(self): """Test that JIT handles input=output aliasing correctly, simulating LLM generate pattern. The LLM generate pattern: 1. First "session": multiple iterations where output becomes next input 2. Second "session": starts with a NEW input tensor (not the previous output) The bug: GraphRunner computes input_replace during _first_run. If at that time input buffer == output buffer (aliasing), it incorrectly includes the output position in input_replace. Later, when a DIFFERENT input is passed, the output position gets overwritten with the input, corrupting the computation. This requires multiple kernels to trigger because single-kernel JITs don't get graphed ("only one kernel doesn't graph"). """ if Device[Device.DEFAULT].graph is None or JIT != 1: self.skipTest("test requires JIT graph support") # Multiple operations to create multiple kernels that get batched into a GraphRunner @TinyJit def step(x): y = (x + 1).realize() # kernel 1 z = (y * 2).realize() # kernel 2 return z # Phase 1: warmup and capture a = Tensor([10]).contiguous().realize() step(a) # warmup (cnt=0) b = Tensor([20]).contiguous().realize() x = step(b) # capture (cnt=1), x = (20+1)*2 = 42 # Phase 2: first "session" - iterations where output becomes input (triggers _first_run with aliasing) for _ in range(3): x = step(x) # (42+1)*2=86, (86+1)*2=174, (174+1)*2=350 self.assertEqual(x.item(), 350) # Phase 3: second "session" - NEW input tensor (simulates new generate() call) # The bug: GraphRunner's input_replace incorrectly includes the output position # When new input y is passed, it overwrites the output buffer, using old value (350) instead of new (100) y = Tensor([100]).contiguous().realize() for _ in range(3): y = step(y) # should be (100+1)*2=202, (202+1)*2=406, (406+1)*2=814 self.assertEqual(y.item(), 814) # fails with 1406 if bug exists (uses 350 instead of 100) def test_multiple_outputs_same_intermediate(self): """Multiple outputs derived from the same intermediate - JIT copies aliased inputs to prevent hazard.""" @TinyJit def f(buf, frame): new_buf = buf[1:].cat(frame, dim=0) return new_buf.contiguous(), new_buf[:1].contiguous() buf = Tensor([[0], [1], [2]]).contiguous().realize() for i in range(4): frame = Tensor([[10+i]]).contiguous().realize() expected_first = buf[1:2].numpy().item() new_buf, first = f(buf, frame) self.assertEqual(first.numpy().item(), expected_first) buf = new_buf def test_intra_kernel_output_input_aliasing(self): """JIT must copy aliased input when output buffer is fed back as input (read-write race in same kernel).""" N = 1 << 20 f = TinyJit(lambda buf, new: buf[N//2:].cat(new), prune=True) buf = Tensor.zeros(N, dtype='int32').contiguous().realize() for i in range(10): buf = f(buf, Tensor(np.ones(N//2, dtype=np.int32)*(i+1))) np.testing.assert_array_equal(buf[:N//2].numpy(), np.full(N//2, i, dtype=np.int32)) def test_slice_assign_works_without_realize(self): """Slice assign then read from same buffer - pending assigns are side-realized.""" from tinygrad import Variable v_pos = Variable("pos", 0, 3) cache = Tensor.zeros(4, 4).contiguous().realize() @TinyJit def f(pos): cache[pos:pos+1, :].assign(Tensor.ones(1, 4)) return cache.sum().realize() for i in range(4): cache.assign(Tensor.zeros(4, 4)).realize() self.assertEqual(f(v_pos.bind(i)).item(), 4.0) def test_symbolic_pad_view_frozen(self): """Symbolic pad view has BIND values baked in at capture time. TODO: pad should be captured in jit.""" from tinygrad import Variable a = Tensor.rand(3, 10).realize() # broken: pad is a view, BIND values frozen at capture (i=2) @TinyJit def f_broken(a): return (a+1).pad((None, (0, 10-a.shape[1]))).realize() for i in range(1, 5): f_broken(a[:, :Variable("i", 1, 10).bind(i)]) self.assertEqual(int((f_broken(a[:, :Variable("i", 1, 10).bind(4)])[0] != 0).sum().item()), 2) # should be 4! # workaround: contiguous fuses pad into kernel @TinyJit def f_fixed(a): return (a+1).pad((None, (0, 10-a.shape[1]))).contiguous().realize() for i in range(1, 5): f_fixed(a[:, :Variable("i", 1, 10).bind(i)]) self.assertEqual(int((f_fixed(a[:, :Variable("i", 1, 10).bind(4)])[0] != 0).sum().item()), 4) def test_non_tensor_outputs_error(self): @TinyJit def f(x, mult): return (x * 2).realize(), mult * 10 with self.assertRaises(JitError): for i in range(3): f(Tensor([i]), i) def test_duplicate_inputs_fail(self): """JIT cannot handle the same tensor passed as multiple arguments.""" @TinyJit def f(a, b): return (a + b).realize() x = Tensor([1, 2, 3]) with self.assertRaises(JitError): f(x, x) def test_tensors_in_containers(self): @TinyJit def f(a, arr): return (a + arr[0]).realize() for i in range(4): a, b = Tensor([1, 1, 1]).realize(), Tensor([i, i, i]).realize() np.testing.assert_array_equal(f(a, [b]).numpy(), [1+i, 1+i, 1+i]) def test_nested_jit_fails_on_second_call(self): """Nested JIT works on first call but fails on second.""" @TinyJit def inner(t): return t + 1 @TinyJit def outer(t): return inner(t) * 3 self.assertEqual(outer(Tensor([1])).realize().item(), 6) # works! with self.assertRaises(RuntimeError): outer(Tensor([2])).realize() # fails def test_implicit_inputs_need_realize(self): """Closure tensors must be realized before JIT call.""" x = Tensor([0]) @TinyJit def f(): return (x * 2).realize() for i in range(5): x.assign(Tensor([i])).realize() # must realize! self.assertEqual(f().item(), i * 2) def test_views_with_different_offsets_fail(self): """JIT requires consistent tensor views across calls.""" @TinyJit def f(a): return (a + 1).realize() base = Tensor.randn(10, 10).realize() with self.assertRaises(JitError): for i in range(1, 5): f(base[:, i:i+2]) # different offset each time def test_shape_change_after_capture_fails(self): """Shapes are locked at capture time.""" @TinyJit def f(a, b): return (a + b).realize() f(Tensor.randn(10, 10), Tensor.randn(10, 10)) # warmup f(Tensor.randn(10, 10), Tensor.randn(10, 10)) # capture with self.assertRaises(JitError): f(Tensor.randn(20, 20), Tensor.randn(20, 20)) def test_python_constants_frozen(self): """Python variables inside JIT use capture-time values.""" mult = 1 @TinyJit def f(x): return (x * mult).realize() results = [] for i in range(5): mult = i + 1 results.append(f(Tensor([10])).item()) self.assertEqual(results[0], 10) # warmup, mult=1 self.assertEqual(results[1], 20) # capture, mult=2 self.assertEqual(results[2], 20) # should be 30! self.assertEqual(results[3], 20) # should be 40! def test_unrealized_const_input_error(self): """Const tensors have no buffer to replace, so JIT raises an error. Even explicit .realize() doesn't help.""" @TinyJit def f(a, b): return (a * b).realize() # unrealized const fails with self.assertRaises(JitError): f(Tensor([1, 2, 3]).realize(), Tensor(2)) # explicit .realize() on const still fails - const cannot be realized to have a buffer @TinyJit def g(a, b): return (a * b).realize() with self.assertRaises(JitError): g(Tensor([1, 2, 3]).realize(), Tensor(2).realize()) def test_conditional_branches_frozen(self): """Only the branch taken during capture runs thereafter.""" @TinyJit def f(x, use_square): if use_square: return (x * x).realize() return (x * 2).realize() f(Tensor([3]), True) # warmup f(Tensor([3]), False) # capture (False branch) result = f(Tensor([3]), True) # passing True but False branch runs self.assertEqual(result.item(), 6) # should be 9! def test_positional_kwargs_cannot_mix(self): """Must use same calling convention after capture.""" @TinyJit def f(a, b): return (a + b).realize() f(Tensor([1]), Tensor([2])) # warmup with positional f(Tensor([1]), Tensor([2])) # capture with positional with self.assertRaises(JitError): f(a=Tensor([3]), b=Tensor([4])) # kwargs fail def test_class_method_shared_across_instances(self): """JIT on instance methods is shared at class level.""" class Model: def __init__(self, scale): self.scale = Tensor([scale]) @TinyJit def forward(self, x): return (x * self.scale).realize() m1, m2 = Model(2), Model(3) m1.forward(Tensor([5])) # warmup m1.forward(Tensor([5])) # capture with m1.scale=2 self.assertEqual(m1.forward(Tensor([5])).item(), 10) self.assertEqual(m2.forward(Tensor([5])).item(), 10) # should be 15! def test_side_effects_only_during_capture(self): """Function body not executed during JIT replay.""" call_count = [0] @TinyJit def f(x): call_count[0] += 1 return (x * 2).realize() f(Tensor([1])) # warmup f(Tensor([2])) # capture self.assertEqual(call_count[0], 2) f(Tensor([3])) f(Tensor([4])) f(Tensor([5])) self.assertEqual(call_count[0], 2) # still 2, not 5! def test_nothing_realized_fails(self): """Must JIT at least one kernel.""" @TinyJit def f(a, b): return None with self.assertRaises(JitError): for _ in range(3): f(Tensor([1]), Tensor([2])) def test_item_creates_unrealized_return(self): """.item() in shape computation raises error during JIT capture.""" @TinyJit def f(x): return Tensor.zeros(x.sum().item()) f(Tensor([1, 1, 1])) # warmup with self.assertRaises(JitError): f(Tensor([1, 1, 1])) # capture - .item() raises def test_item_bakes_in_values(self): """.item() during JIT capture raises error (would bake in value).""" @TinyJit def f(x, mask): return x.masked_select(mask) f(Tensor([1, 2, 3, 4]), Tensor([True, False, True, False])) # warmup with self.assertRaises(JitError): f(Tensor([1, 2, 3, 4]), Tensor([True, False, True, False])) # capture - .item() raises def test_masked_select_static_size_jittable(self): @TinyJit def f(x, mask): return x.masked_select(mask, size=4, fill_value=-1).realize() for _ in range(3): np.testing.assert_equal(f(Tensor([1, 2, 3, 4]), Tensor([True, False, True, False])).numpy(), [1, 3, -1, -1]) np.testing.assert_equal(f(Tensor([5, 6, 7, 8]), Tensor([False, True, True, True])).numpy(), [6, 7, 8, -1]) np.testing.assert_equal(f(Tensor([9, 8, 7, 6]), Tensor([True, True, True, True])).numpy(), [9, 8, 7, 6]) np.testing.assert_equal(f(Tensor([1, 1, 1, 1]), Tensor([False, False, False, False])).numpy(), [-1, -1, -1, -1]) def test_nonzero_static_size_jittable(self): @TinyJit def f(x): return x.nonzero(size=3, fill_value=-1).realize() for _ in range(3): np.testing.assert_equal(f(Tensor([1, 0, 2, 0, 3])).numpy(), [[0], [2], [4]]) np.testing.assert_equal(f(Tensor([0, 0, 5, 0, 0])).numpy(), [[2], [-1], [-1]]) np.testing.assert_equal(f(Tensor([0, 0, 0, 0, 0])).numpy(), [[-1], [-1], [-1]]) def test_tolist_bakes_in_values(self): """.tolist() raises error during JIT capture (would bake in values).""" @TinyJit def f(x): return Tensor(x.tolist()) f(Tensor([1, 2, 3])) # warmup with self.assertRaises(JitError): f(Tensor([1, 2, 3])) # capture - .tolist() raises class TestJitCorrectBehavior(unittest.TestCase): """Behaviors that work correctly - documented for clarity.""" def test_random_regenerates(self): """Random tensors regenerate each call.""" @TinyJit def f(x): return (x + Tensor.rand(3)).realize() f(Tensor([0, 0, 0])) # warmup f(Tensor([0, 0, 0])) # capture results = {tuple(f(Tensor([0, 0, 0])).numpy().tolist()) for _ in range(5)} self.assertEqual(len(results), 5) def test_unrealized_return_auto_realized(self): """Unrealized return tensors are auto-realized.""" @TinyJit def f(a, b): return a + b # no explicit realize for _ in range(5): a, b = Tensor.randn(10), Tensor.randn(10) np.testing.assert_allclose(f(a, b).numpy(), a.numpy() + b.numpy(), atol=1e-5) def test_kwargs_order_doesnt_matter(self): """Kwargs are sorted by name, so order doesn't matter.""" @TinyJit def f(first, second): return (first / second).realize() for _ in range(3): a, b = Tensor.randn(10), Tensor.randn(10) + 1 np.testing.assert_allclose(f(second=b, first=a).numpy(), a.numpy() / b.numpy(), atol=1e-4) np.testing.assert_allclose(f(first=a, second=b).numpy(), a.numpy() / b.numpy(), atol=1e-4) def test_input_mutation_consistent(self): """Input mutation via assign works consistently.""" @TinyJit def f(x): x += 1 x.realize() return x a = Tensor([0]).contiguous().realize() for _ in range(5): f(a) self.assertEqual(a.item(), 5) if __name__ == '__main__': unittest.main()