mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 00:15:35 +08:00
don't change lazy state in schedule [pr] (#7867)
This commit is contained in:
@@ -323,9 +323,7 @@ def group_realizes(ctx:ScheduleContext, realizes:Dict[UOp, UOp]) -> List[List[UO
|
||||
|
||||
# **** Schedule creation and BFS toposort
|
||||
|
||||
def realize(ctx:Dict[UOp, UOp], b:UOp, to_store:UOp, base:UOp) -> None:
|
||||
ctx[b] = to_store
|
||||
return None
|
||||
def realize(ctx:Dict[UOp, UOp], b:UOp, to_store:UOp, base:UOp) -> None: return ctx.update([(b, to_store)])
|
||||
|
||||
def realize_view(ctx:Dict[UOp, UOp], base:UOp, view:UOp, to_store:UOp, b:UOp) -> None:
|
||||
if to_store.op in {Ops.CONST, Ops.BIND}: return None
|
||||
@@ -339,12 +337,14 @@ def realize_view(ctx:Dict[UOp, UOp], base:UOp, view:UOp, to_store:UOp, b:UOp) ->
|
||||
# otherwise safety check pads
|
||||
return None if (all(v.mask is None for v in st.views) or can_pad(base, ctx, set())) else realize(ctx, b, to_store, base)
|
||||
|
||||
def fold_img_cast(ctx, xb:UOp, view:UOp, b:UOp, to_cast:UOp, **kwargs) -> Optional[UOp]:
|
||||
def fold_img_cast(ctx:Dict[UOp, UOp], xb:UOp, view:UOp, b:UOp, to_cast:UOp, **kwargs) -> Optional[UOp]:
|
||||
if not isinstance(xb.dtype, ImageDType) or b not in ctx or xb not in ctx or uval(to_cast).op in GroupOp.Meta: return None
|
||||
del ctx[b]
|
||||
return to_cast.view(unwrap(view.st))
|
||||
|
||||
do_realize = PatternMatcher([
|
||||
# always realize sinked ops
|
||||
(UPat(Ops.SINK, name="sink"), lambda ctx,sink: ctx.update((x.buf_uop, x) for x in sink.src)),
|
||||
# always realize meta ops
|
||||
(UPatSrc({Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta}), realize),
|
||||
# realize before expand or unsafe pad ops
|
||||
@@ -373,7 +373,6 @@ break_sched = PatternMatcher([
|
||||
@track_rewrites(named=True)
|
||||
def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
|
||||
if len(outs:=dedup(x.base for x in outs if x.base.realized is None and x.base.op is not Ops.CONST)) == 0: return [], {}
|
||||
for out in outs: out.forced_realize = True
|
||||
# create the big graph
|
||||
ctx = ScheduleContext()
|
||||
cache: Dict[LazyBuffer, UOp] = {}
|
||||
|
||||
Reference in New Issue
Block a user