jit: memplan before compile (#16560)

This commit is contained in:
nimlgen
2026-06-10 15:05:15 +03:00
committed by GitHub
parent 34481830f1
commit 2c9d2c0d31

View File

@@ -69,8 +69,8 @@ def jit_lower(linear:UOp, held_bufs:set[UOp], input_uops:list[UOp]) -> UOp:
# parametrize input buffers: map each input buffer UOp to a PARAM with the correct slot index
linear = linear.substitute({u: UOp.param(i, u.dtype, u.shape, u.device) for i,u in enumerate(input_uops)}, walk=True)
linear = compile_linear(linear, beam=getenv("JITBEAM", BEAM.value))
linear = memory_plan_rewrite(linear, held_bufs)
linear = compile_linear(linear, beam=getenv("JITBEAM", BEAM.value))
if JIT < 2: linear = graph_split_rewrite(linear, max_batch_size=JIT_BATCH_SIZE.value)
if VIZ: graph_rewrite(linear, PatternMatcher([]), name="View graphed linear")
return linear