diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index d7ad56b0d1..01680996ca 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -42,9 +42,9 @@ class _LBScheduleItem: inputs: Tuple[LazyBuffer, ...] var_vals: Dict[Variable, int] -# recursively create a lazyop def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outbufs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], st:ShapeTracker, realizes:Dict[LazyBuffer, None], cache, assign_to:Optional[LazyBuffer]=None, assign_idx:Optional[int]=None) -> LazyOp: + """recursively create a lazyop""" if (buf, st) in cache: return cache[(buf, st)] if buf != buf.base: st = buf.st + st @@ -95,6 +95,7 @@ def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outbufs:Tuple[Laz return ret def _schedule_group(outs:Tuple[LazyBuffer, ...], realizes:Dict[LazyBuffer, None], reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> _LBScheduleItem: + """create a schedule item from a list of outputs""" inputs: List[LazyBuffer] = [] ast: List[LazyOp] = [] var_vals: Dict[Variable, int] = merge_dicts([out.st.var_vals.copy() for out in outs]) @@ -115,9 +116,9 @@ def _schedule_group(outs:Tuple[LazyBuffer, ...], realizes:Dict[LazyBuffer, None] # *** DAG creation: decide which LazyBuffers should realize *** -# recursively search the entire graph for all LazyBuffers, insert realizes after expands def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[LazyBuffer, None], simple_pads:Set[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], scheduled=False): + """recursively search the entire graph for all LazyBuffers, insert realizes after expands""" if buf in allbufs or buf.base.realized: return if GRAPH: log_lazybuffer(buf, scheduled) if buf.base != buf: @@ -148,9 +149,9 @@ def _is_padding_okay(buf:LazyBuffer, realizes:Dict[LazyBuffer, None]) -> bool: if buf.op in UNSAFE_PAD_OPS: return False return all(_is_padding_okay(x.base, realizes) for x in buf.srcs) -# recursively search the LazyBuffer for groupable children, realize the LazyBuffer if a child can't group def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer], group:Set[LazyBuffer]): + """recursively search the LazyBuffer for groupable children, realize the LazyBuffer if a child can't group""" if tr in realizes: # can only fuse contiguous # max one reduceop per kernel @@ -166,6 +167,7 @@ def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:Defa def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[DefaultDict[LazyBuffer, List[LazyBuffer]], DefaultDict[LazyBuffer, int], Dict[LazyBuffer, _LBScheduleItem]]: + """create a graph for realizing the outputs""" # start by just realizing the buffers passed in realizes: Dict[LazyBuffer, None] = {x.base: None for x in outs if not x.base.realized} allbufs: Dict[LazyBuffer, None] = {} @@ -181,7 +183,7 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[Defaul # find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child) reduce_for_op: Dict[LazyBuffer, LazyBuffer] = {} - for r in allbufs.keys(): + for r in allbufs: if r != r.base or r.op not in ReduceOps or r in realizes: continue group: Set[LazyBuffer] = set() @@ -191,11 +193,13 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[Defaul # TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs forced_realize = r in group if not forced_realize and len(group) > 1: + # create a multi output kernel if the LazyBufferss can cleanly group rc_parents, rc_children = deque(group), deque(group) while rc_parents and not forced_realize: # max one reduceop per kernel if (p:=rc_parents.pop()).op in ReduceOps: forced_realize = True else: rc_parents.extend(x.base for x in p.srcs if x.base.realized is None and x.base is not r) + # search descendants of the reduceop that can cleanly group realized_descendants: Set[LazyBuffer] = set() while rc_children and not forced_realize: if (c:=rc_children.pop()).op in ReduceOps or not c.st.contiguous or c.st.size != r.st.size or c in reduce_for_op: @@ -204,6 +208,7 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[Defaul if c in realizes and c not in group: realized_descendants.add(c) rc_children.extend(x for x in children[c] if x.realized is None and x.device == r.device) group.update(realized_descendants) + # can only fuse assign if no other assign_target is used in the kernel if not forced_realize and any(x.op is LoadOps.ASSIGN for x in group): parents = deque((r, *group)) while parents and not forced_realize: