mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-11 23:46:02 +08:00
jit: memplan before compile (#16560)
This commit is contained in:
@@ -69,8 +69,8 @@ def jit_lower(linear:UOp, held_bufs:set[UOp], input_uops:list[UOp]) -> UOp:
|
||||
|
||||
# parametrize input buffers: map each input buffer UOp to a PARAM with the correct slot index
|
||||
linear = linear.substitute({u: UOp.param(i, u.dtype, u.shape, u.device) for i,u in enumerate(input_uops)}, walk=True)
|
||||
linear = compile_linear(linear, beam=getenv("JITBEAM", BEAM.value))
|
||||
linear = memory_plan_rewrite(linear, held_bufs)
|
||||
linear = compile_linear(linear, beam=getenv("JITBEAM", BEAM.value))
|
||||
if JIT < 2: linear = graph_split_rewrite(linear, max_batch_size=JIT_BATCH_SIZE.value)
|
||||
if VIZ: graph_rewrite(linear, PatternMatcher([]), name="View graphed linear")
|
||||
return linear
|
||||
|
||||
Reference in New Issue
Block a user