diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index f0fb8bc1db..b35c4c38d5 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -69,8 +69,8 @@ def jit_lower(linear:UOp, held_bufs:set[UOp], input_uops:list[UOp]) -> UOp: # parametrize input buffers: map each input buffer UOp to a PARAM with the correct slot index linear = linear.substitute({u: UOp.param(i, u.dtype, u.shape, u.device) for i,u in enumerate(input_uops)}, walk=True) - linear = compile_linear(linear, beam=getenv("JITBEAM", BEAM.value)) linear = memory_plan_rewrite(linear, held_bufs) + linear = compile_linear(linear, beam=getenv("JITBEAM", BEAM.value)) if JIT < 2: linear = graph_split_rewrite(linear, max_batch_size=JIT_BATCH_SIZE.value) if VIZ: graph_rewrite(linear, PatternMatcher([]), name="View graphed linear") return linear