diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index d404babdf7..0e6fa01d64 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -323,9 +323,7 @@ def group_realizes(ctx:ScheduleContext, realizes:Dict[UOp, UOp]) -> List[List[UO # **** Schedule creation and BFS toposort -def realize(ctx:Dict[UOp, UOp], b:UOp, to_store:UOp, base:UOp) -> None: - ctx[b] = to_store - return None +def realize(ctx:Dict[UOp, UOp], b:UOp, to_store:UOp, base:UOp) -> None: return ctx.update([(b, to_store)]) def realize_view(ctx:Dict[UOp, UOp], base:UOp, view:UOp, to_store:UOp, b:UOp) -> None: if to_store.op in {Ops.CONST, Ops.BIND}: return None @@ -339,12 +337,14 @@ def realize_view(ctx:Dict[UOp, UOp], base:UOp, view:UOp, to_store:UOp, b:UOp) -> # otherwise safety check pads return None if (all(v.mask is None for v in st.views) or can_pad(base, ctx, set())) else realize(ctx, b, to_store, base) -def fold_img_cast(ctx, xb:UOp, view:UOp, b:UOp, to_cast:UOp, **kwargs) -> Optional[UOp]: +def fold_img_cast(ctx:Dict[UOp, UOp], xb:UOp, view:UOp, b:UOp, to_cast:UOp, **kwargs) -> Optional[UOp]: if not isinstance(xb.dtype, ImageDType) or b not in ctx or xb not in ctx or uval(to_cast).op in GroupOp.Meta: return None del ctx[b] return to_cast.view(unwrap(view.st)) do_realize = PatternMatcher([ + # always realize sinked ops + (UPat(Ops.SINK, name="sink"), lambda ctx,sink: ctx.update((x.buf_uop, x) for x in sink.src)), # always realize meta ops (UPatSrc({Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta}), realize), # realize before expand or unsafe pad ops @@ -373,7 +373,6 @@ break_sched = PatternMatcher([ @track_rewrites(named=True) def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]: if len(outs:=dedup(x.base for x in outs if x.base.realized is None and x.base.op is not Ops.CONST)) == 0: return [], {} - for out in outs: out.forced_realize = True # create the big graph ctx = ScheduleContext() cache: Dict[LazyBuffer, UOp] = {}