From 06b58aa7ecb4d86f70f2b70fa0d828cf0ba2f270 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sun, 26 Jan 2025 03:36:15 -0500 Subject: [PATCH] move unneeded fields out of ScheduleContext [pr] (#8752) --- tinygrad/engine/schedule.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 801ffdcfd9..c1aa937e36 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -84,7 +84,6 @@ class ScheduleContext: contiguous: dict[UOp, UOp] = field(default_factory=dict) # this maps roots to places they are made contiguous children: defaultdict[UOp, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict)) preloads: defaultdict[Buffer, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict)) - becomes_map: dict[UOp, UOp] = field(default_factory=dict) # wrap tensor uops around a VIEW(BUFFER, ) # this BUFFER preserves a link back to the uop on the tensor after the scheduler rewrites it. @@ -473,7 +472,6 @@ def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None: if (op:=uval(view)).op is Ops.ASSIGN: ctx.assigns.add(buf_uop) for x in op.base.src: if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None - buf_uop.buffer.ref(1) create_ctx = PatternMatcher([(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf_uop"), UPat())), append_uop)]) # **** movement ops @@ -502,24 +500,27 @@ def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tu # group realizes into kernels store_groups = group_realizes(ctx) graph_rewrite(sink, break_sched, ctx) - # preschedule realize groups + # create schedule items + map buffers to realized tensors prescheduled: list[ScheduleItem] = [] + becomes_map: dict[UOp, UOp] = {} for store_uops in store_groups: small_sink = UOp.sink(*[ctx.realizes[u] for u in store_uops]) if not all(x.op is Ops.STORE for x in small_sink.src): raise RuntimeError(f"expected all realized BUFFERs to get a STORE {sink}") prescheduled.append(schedule_uop(small_sink, ctx)) # can only schedule once for buf_uop in store_uops: - for tensor_uop in ctx.tensor_uops[buf_uop]: ctx.becomes_map[tensor_uop] = buf_uop.view(unwrap(tensor_uop.st)) + for tensor_uop in ctx.tensor_uops[buf_uop]: becomes_map[tensor_uop] = buf_uop.view(unwrap(tensor_uop.st)) + # increment refcount for this buffer + buf_uop.buffer.ref(1) # tensors can become an existing buffer or simplify to a const, no ScheduleItem needed for k,v in tensor_map.items(): # NOOP if k.base is v.base: continue # NOTE: only the base tensors get a BUFFER UOp - if v.is_realized and k is k.base: ctx.becomes_map[k] = v.view(unwrap(k.st)) + if v.is_realized and k is k.base: becomes_map[k] = v.view(unwrap(k.st)) # otherwise if it simplified to a CONST the UOp just becomes that CONST - elif v.op is Ops.CONST: ctx.becomes_map[k] = v + elif v.op is Ops.CONST: becomes_map[k] = v # add kernel children schedule_targets = {out:si for si in prescheduled for out in si.outputs} @@ -548,4 +549,4 @@ def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tu # confirm everything was scheduled correctly if len(schedule) != (groups:=len(prescheduled)): raise RuntimeError(f"cycle detected in graph, grouped {groups} but only scheduled {len(schedule)}") if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels") - return schedule, ctx.var_vals, ctx.becomes_map + return schedule, ctx.var_vals, becomes_map