mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
move unneeded fields out of ScheduleContext [pr] (#8752)
This commit is contained in:
@@ -84,7 +84,6 @@ class ScheduleContext:
|
||||
contiguous: dict[UOp, UOp] = field(default_factory=dict) # this maps roots to places they are made contiguous
|
||||
children: defaultdict[UOp, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
|
||||
preloads: defaultdict[Buffer, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
|
||||
becomes_map: dict[UOp, UOp] = field(default_factory=dict)
|
||||
|
||||
# wrap tensor uops around a VIEW(BUFFER, <uop>)
|
||||
# this BUFFER preserves a link back to the uop on the tensor after the scheduler rewrites it.
|
||||
@@ -473,7 +472,6 @@ def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None:
|
||||
if (op:=uval(view)).op is Ops.ASSIGN: ctx.assigns.add(buf_uop)
|
||||
for x in op.base.src:
|
||||
if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None
|
||||
buf_uop.buffer.ref(1)
|
||||
create_ctx = PatternMatcher([(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf_uop"), UPat())), append_uop)])
|
||||
|
||||
# **** movement ops
|
||||
@@ -502,24 +500,27 @@ def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tu
|
||||
# group realizes into kernels
|
||||
store_groups = group_realizes(ctx)
|
||||
graph_rewrite(sink, break_sched, ctx)
|
||||
# preschedule realize groups
|
||||
# create schedule items + map buffers to realized tensors
|
||||
prescheduled: list[ScheduleItem] = []
|
||||
becomes_map: dict[UOp, UOp] = {}
|
||||
for store_uops in store_groups:
|
||||
small_sink = UOp.sink(*[ctx.realizes[u] for u in store_uops])
|
||||
if not all(x.op is Ops.STORE for x in small_sink.src): raise RuntimeError(f"expected all realized BUFFERs to get a STORE {sink}")
|
||||
prescheduled.append(schedule_uop(small_sink, ctx))
|
||||
# can only schedule once
|
||||
for buf_uop in store_uops:
|
||||
for tensor_uop in ctx.tensor_uops[buf_uop]: ctx.becomes_map[tensor_uop] = buf_uop.view(unwrap(tensor_uop.st))
|
||||
for tensor_uop in ctx.tensor_uops[buf_uop]: becomes_map[tensor_uop] = buf_uop.view(unwrap(tensor_uop.st))
|
||||
# increment refcount for this buffer
|
||||
buf_uop.buffer.ref(1)
|
||||
|
||||
# tensors can become an existing buffer or simplify to a const, no ScheduleItem needed
|
||||
for k,v in tensor_map.items():
|
||||
# NOOP
|
||||
if k.base is v.base: continue
|
||||
# NOTE: only the base tensors get a BUFFER UOp
|
||||
if v.is_realized and k is k.base: ctx.becomes_map[k] = v.view(unwrap(k.st))
|
||||
if v.is_realized and k is k.base: becomes_map[k] = v.view(unwrap(k.st))
|
||||
# otherwise if it simplified to a CONST the UOp just becomes that CONST
|
||||
elif v.op is Ops.CONST: ctx.becomes_map[k] = v
|
||||
elif v.op is Ops.CONST: becomes_map[k] = v
|
||||
|
||||
# add kernel children
|
||||
schedule_targets = {out:si for si in prescheduled for out in si.outputs}
|
||||
@@ -548,4 +549,4 @@ def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tu
|
||||
# confirm everything was scheduled correctly
|
||||
if len(schedule) != (groups:=len(prescheduled)): raise RuntimeError(f"cycle detected in graph, grouped {groups} but only scheduled {len(schedule)}")
|
||||
if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
|
||||
return schedule, ctx.var_vals, ctx.becomes_map
|
||||
return schedule, ctx.var_vals, becomes_map
|
||||
|
||||
Reference in New Issue
Block a user