move unneeded fields out of ScheduleContext [pr] (#8752)

This commit is contained in:
qazal
2025-01-26 03:36:15 -05:00
committed by GitHub
parent 1b4618e257
commit 06b58aa7ec

View File

@@ -84,7 +84,6 @@ class ScheduleContext:
contiguous: dict[UOp, UOp] = field(default_factory=dict) # this maps roots to places they are made contiguous
children: defaultdict[UOp, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
preloads: defaultdict[Buffer, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
becomes_map: dict[UOp, UOp] = field(default_factory=dict)
# wrap tensor uops around a VIEW(BUFFER, <uop>)
# this BUFFER preserves a link back to the uop on the tensor after the scheduler rewrites it.
@@ -473,7 +472,6 @@ def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None:
if (op:=uval(view)).op is Ops.ASSIGN: ctx.assigns.add(buf_uop)
for x in op.base.src:
if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None
buf_uop.buffer.ref(1)
create_ctx = PatternMatcher([(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf_uop"), UPat())), append_uop)])
# **** movement ops
@@ -502,24 +500,27 @@ def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tu
# group realizes into kernels
store_groups = group_realizes(ctx)
graph_rewrite(sink, break_sched, ctx)
# preschedule realize groups
# create schedule items + map buffers to realized tensors
prescheduled: list[ScheduleItem] = []
becomes_map: dict[UOp, UOp] = {}
for store_uops in store_groups:
small_sink = UOp.sink(*[ctx.realizes[u] for u in store_uops])
if not all(x.op is Ops.STORE for x in small_sink.src): raise RuntimeError(f"expected all realized BUFFERs to get a STORE {sink}")
prescheduled.append(schedule_uop(small_sink, ctx))
# can only schedule once
for buf_uop in store_uops:
for tensor_uop in ctx.tensor_uops[buf_uop]: ctx.becomes_map[tensor_uop] = buf_uop.view(unwrap(tensor_uop.st))
for tensor_uop in ctx.tensor_uops[buf_uop]: becomes_map[tensor_uop] = buf_uop.view(unwrap(tensor_uop.st))
# increment refcount for this buffer
buf_uop.buffer.ref(1)
# tensors can become an existing buffer or simplify to a const, no ScheduleItem needed
for k,v in tensor_map.items():
# NOOP
if k.base is v.base: continue
# NOTE: only the base tensors get a BUFFER UOp
if v.is_realized and k is k.base: ctx.becomes_map[k] = v.view(unwrap(k.st))
if v.is_realized and k is k.base: becomes_map[k] = v.view(unwrap(k.st))
# otherwise if it simplified to a CONST the UOp just becomes that CONST
elif v.op is Ops.CONST: ctx.becomes_map[k] = v
elif v.op is Ops.CONST: becomes_map[k] = v
# add kernel children
schedule_targets = {out:si for si in prescheduled for out in si.outputs}
@@ -548,4 +549,4 @@ def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tu
# confirm everything was scheduled correctly
if len(schedule) != (groups:=len(prescheduled)): raise RuntimeError(f"cycle detected in graph, grouped {groups} but only scheduled {len(schedule)}")
if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
return schedule, ctx.var_vals, ctx.becomes_map
return schedule, ctx.var_vals, becomes_map