diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index 8a1c940a67..112d85f04f 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -12,13 +12,9 @@ from tinygrad.ops import BinaryOps import functools def render(self) -> str: - graph = UOpGraph() # NOTE: we need STORE so the ALU op has children glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=(0,True)) - def recursive_add(x): - graph.add(x.uop, x.dtype, x.vin, x.arg) - for c in x.vin: recursive_add(c) - recursive_add(UOp(UOps.STORE, None, (glbl,UOp.const(dtypes.int, 0),self))) + graph = UOpGraph([UOp(UOps.STORE, None, (glbl, UOp.const(dtypes.int, 0), self))]) graph.linearize() from tinygrad.renderer.cstyle import CStyleLanguage class TestRenderer(CStyleLanguage): diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 9e00f3091e..ef9e0a0647 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -4,7 +4,7 @@ import itertools, math, functools from collections import defaultdict from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, ConstType -from tinygrad.helpers import colored, DEBUG, dedup, diskcache_put, prod, getenv, to_function_name +from tinygrad.helpers import colored, DEBUG, dedup, diskcache_put, prod, getenv, to_function_name, flatten from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps, get_lazyop_info from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode, create_lt_node @@ -49,7 +49,7 @@ class Linearizer(Kernel): # NOTE: the consts have to be cached for deduping of downstream uops to work def const(self, b:ConstType|Variable, dtype:DType=dtypes.int32) -> UOp: - return self.uops.add(UOps.DEFINE_VAR, dtype, (), b) if isinstance(b, Variable) else UOp.const(dtype, b) + return UOp(UOps.DEFINE_VAR, dtype, (), b) if isinstance(b, Variable) else UOp.const(dtype, b) def get_reduce_acc(self, reduceop:LazyOp): if reduceop.op is ReduceOps.SUM: return 0.0 if dtypes.is_float(reduceop.dtype) else 0 @@ -99,7 +99,7 @@ class Linearizer(Kernel): key = f"{'' if acc is None else self.reduceops.index(acc)}{localtype}{'CONST'+str(this_const) if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}" # noqa: E501 if key not in self.load_cache: if acc is not None: - self.load_cache[key] = self.uops.add(UOps.DEFINE_ACC, localtype, loop_ctx, (self.get_reduce_acc(acc), i, acc_count)) + self.load_cache[key] = UOp(UOps.DEFINE_ACC, localtype, loop_ctx, (self.get_reduce_acc(acc), i, acc_count)) acc_count += 1 elif this_const is not None: self.load_cache[key] = self.const(this_const, localtype) @@ -110,16 +110,16 @@ class Linearizer(Kernel): buf_uop = self.buf_uops[i] assert buf_uop is not None, f"buffer {i} wasn't UOped" image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid) - rendered_idx = self.uops.add(UOps.CAST, dtypes.int.vec(2), tuple(x.render(self.render_ops, self) for x in image_idx)) + rendered_idx = UOp(UOps.CAST, dtypes.int.vec(2), tuple(x.render(self.render_ops, self) for x in image_idx)) valid_tuple = (valid.render(self.render_ops, self), self.const(invalid_value, buf.dtype.base.vec(4))) if valid.min == 0 else tuple() - self.load_cache[key] = self.uops.add(UOps.LOAD, buf.dtype.base.vec(4), + self.load_cache[key] = UOp(UOps.LOAD, buf.dtype.base.vec(4), (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ())) if localtype == localtype.scalar(): idx_small = idx%4 res = idx_small.render(self.render_ops, self) - out = self.uops.add(UOps.GEP, localtype, (self.load_cache[key],), idx_small.max) + out = UOp(UOps.GEP, localtype, (self.load_cache[key],), idx_small.max) for ix in range(idx_small.max, idx_small.min, -1): - rvv = self.uops.add(UOps.GEP, localtype, (self.load_cache[key],), ix-1) + rvv = UOp(UOps.GEP, localtype, (self.load_cache[key],), ix-1) sel = UOp.alu(BinaryOps.CMPLT, res, self.const(ix)) out = UOp.alu(TernaryOps.WHERE, sel, rvv, out) self.load_cache[key] = out @@ -128,8 +128,8 @@ class Linearizer(Kernel): assert buf_uop is not None, f"buffer {i} wasn't UOped" rendered_idx = idx.render(self.render_ops, self) valid_tuple = (valid.render(self.render_ops, self), self.const(invalid_value, localtype)) if valid.min == 0 else tuple() - self.load_cache[key] = self.uops.add(UOps.LOAD, localtype, (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ())) - ret.append(self.uops.add(UOps.GEP, localtype.scalar(), (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key]) + self.load_cache[key] = UOp(UOps.LOAD, localtype, (buf_uop, rendered_idx) + valid_tuple + ((barrier,) if barrier else ())) + ret.append(UOp(UOps.GEP, localtype.scalar(), (self.load_cache[key],), rep_idx[dim]) if dim is not None else self.load_cache[key]) return ret def global_store(self, i:int, idxs:List[Node], store:List[UOp]) -> List[UOp]: @@ -152,7 +152,7 @@ class Linearizer(Kernel): amt = len(grouped) idx, valid = self.sts[i].expr_idxs(k) assert idx == ((idx//amt)*amt), "float4 stores are always aligned" - store_offset_new[k] = self.uops.add(UOps.CAST, buf.dtype.vec(amt), tuple(grouped)) + store_offset_new[k] = UOp(UOps.CAST, buf.dtype.vec(amt), tuple(grouped)) store_offset = store_offset_new stores = [] @@ -160,17 +160,17 @@ class Linearizer(Kernel): idx, valid = self.sts[i].expr_idxs(_idx) if isinstance(buf.dtype, ImageDType): image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid) - rendered_idx = self.uops.add(UOps.CAST, dtypes.int.vec(2), \ + rendered_idx = UOp(UOps.CAST, dtypes.int.vec(2), \ tuple(x.render(self.render_ops, self) for x in image_idx)) else: rendered_idx = idx.render(self.render_ops, self) - if valid.min == 1: stores.append(self.uops.add(UOps.STORE, None, (buf_uop, rendered_idx, var))) - else: stores.append(self.uops.add(UOps.STORE, None, (buf_uop, rendered_idx, var, valid.render(self.render_ops, self)))) + if valid.min == 1: stores.append(UOp(UOps.STORE, None, (buf_uop, rendered_idx, var))) + else: stores.append(UOp(UOps.STORE, None, (buf_uop, rendered_idx, var, valid.render(self.render_ops, self)))) return stores # render loop def render_loop(self, xx:List[Variable], depth:int) -> Tuple[UOp, ...]: - new_loops = {x.expr:self.uops.add(UOps.RANGE, dtypes.int32, ( + new_loops = {x.expr:UOp(UOps.RANGE, dtypes.int32, ( self.const(x.min) if isinstance(x.min, int) else cast(Node, x.min).render(self.render_ops, self), self.const(x.max+1) if isinstance(x.max, int) else cast(Node, x.max+1).render(self.render_ops, self)), arg=(depth,i)) for i,x in enumerate(xx) if not isinstance(x, NumNode) and x.expr is not None} # noqa: E501 self.loop_uops.update(new_loops) @@ -240,17 +240,17 @@ class Linearizer(Kernel): return strides upcasts, dev = [upcast_strides(x) for x in [locals_to_store[0][0], locals_to_store[1][0], 0]], self.opts.device # cast initial accs - wmmas = [self.uops.add(UOps.CAST, (dt3:=tc.dtype_out.vec(wmma_sz[2])), tuple(accs[reduceop][x:x+wmma_sz[2]])) + wmmas = [UOp(UOps.CAST, (dt3:=tc.dtype_out.vec(wmma_sz[2])), tuple(accs[reduceop][x:x+wmma_sz[2]])) for x in range(0, len(accs[reduceop]), wmma_sz[2])] for it in [x[::-1] for x in itertools.product(*[x for x in [range(sz) for _,sz in upcasts[0]][::-1]])]: offs = [x*y for (x,y) in zip([sum([prod(x) for x in zip(it, [stride for stride,_ in y])]) for y in upcasts], wmma_sz)] - ops = (self.uops.add(UOps.CAST, tc.dtype_in.vec(wmma_sz[0]), tuple(locals_to_store[0][2][offs[0]:offs[0]+wmma_sz[0]])), - self.uops.add(UOps.CAST, tc.dtype_in.vec(wmma_sz[1]), tuple(locals_to_store[1][2][offs[1]:offs[1]+wmma_sz[1]])), + ops = (UOp(UOps.CAST, tc.dtype_in.vec(wmma_sz[0]), tuple(locals_to_store[0][2][offs[0]:offs[0]+wmma_sz[0]])), + UOp(UOps.CAST, tc.dtype_in.vec(wmma_sz[1]), tuple(locals_to_store[1][2][offs[1]:offs[1]+wmma_sz[1]])), wmmas[(wmma_idx:=offs[2]//wmma_sz[2])]) # TODO: don't need to DEFINE_ACC, pass to WMMA in op3, or PHI accs that are not valid - wmmas[wmma_idx] = self.uops.add(UOps.WMMA, dt3, ops, (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, tuple(wmma_sz), dev)) + wmmas[wmma_idx] = UOp(UOps.WMMA, dt3, ops, (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, tuple(wmma_sz), dev)) # phi the last wmmas back to accs - accs[reduceop] = [self.uops.add(UOps.PHI, tc.dtype_out, (acc, self.uops.add(UOps.GEP, tc.dtype_out, (wmmas[z//wmma_sz[2]],), z%wmma_sz[2]))) + accs[reduceop] = [UOp(UOps.PHI, tc.dtype_out, (acc, UOp(UOps.GEP, tc.dtype_out, (wmmas[z//wmma_sz[2]],), z%wmma_sz[2]))) for z, acc in enumerate(accs[reduceop])] else: assert not locals_to_store, "storing locals isn't supported here" @@ -269,12 +269,12 @@ class Linearizer(Kernel): if self.group_for_reduces: fake_global_idxs = [x*0 for x in global_idxs] stores = self.global_store(out_buf, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, accs[reduceop]) # store accumulators - barrier = self.uops.add(UOps.BARRIER, None, tuple(stores)) + barrier = UOp(UOps.BARRIER, None, tuple(stores)) if self.opts.has_local: fake_idxs = [NumNode(0)]*len(self.sts[-1].shape) fake_idxs[self.global_dims+self.local_dims:self.global_dims+len(local_idxs)] = local_idxs[self.local_dims:] if_cond: UOp = create_lt_node(self.sts[-1].expr_idxs(fake_idxs)[0], 1).render(self.render_ops, self) - barrier = self.uops.add(UOps.IF, None, (if_cond, barrier)) + barrier = UOp(UOps.IF, None, (if_cond, barrier)) # create new late reduce local loops and replace local_idxs that have been used end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce and i not in self.upcast_in_mid_reduce_axes else 0) for i in range(0, self.first_reduce+self.group_for_reduces)] # noqa: E501 @@ -333,23 +333,22 @@ class Linearizer(Kernel): if self.opts.global_max and self.opts.local_max: self.limit_dims_to_max(self.opts.global_max, self.opts.local_max) # uops - self.uops:UOpGraph = UOpGraph() self.buf_uops: List[Optional[UOp]] = [None]*len(self.bufs) self.loop_uops: Dict[str, UOp] = {} # add global buffers for i,buf in enumerate(self.bufs): if isinstance(buf, MemBuffer): - self.buf_uops[i] = self.uops.add(UOps.DEFINE_GLOBAL, + self.buf_uops[i] = UOp(UOps.DEFINE_GLOBAL, buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (), (buf.idx, any(buf.idx == x.idx for x in self.outbufs))) # add var vals for i,var in enumerate(self.vars): assert var.expr is not None - self.loop_uops[var.expr] = self.uops.add(UOps.DEFINE_VAR, dtypes.int32, (), var) + self.loop_uops[var.expr] = UOp(UOps.DEFINE_VAR, dtypes.int32, (), var) # define local buffers for aliases in self.local_alias.values(): - for lb in aliases.values(): self.buf_uops[self.bufs.index(lb)] = self.uops.add(UOps.DEFINE_LOCAL, PtrDType(lb.dtype), + for lb in aliases.values(): self.buf_uops[self.bufs.index(lb)] = UOp(UOps.DEFINE_LOCAL, PtrDType(lb.dtype), (), (lb.name, self.sts[self.bufs.index(lb)].size)) # add a local buffer for multistage reduce. # TODO: use local alias if self.group_for_reduces: @@ -358,7 +357,7 @@ class Linearizer(Kernel): self.sts.append(ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+self.group_for_reduces]) + [1] * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)]))) # noqa: E501 temp_dtype = self.get_base_dtype(cast(LazyOp, self.reduceop).dtype) self.bufs.append(LocalBuffer(name:=f"temp{i if len(self.reduceops) > 1 else ''}", buf_size:=self.sts[-1].size, temp_dtype)) - self.buf_uops.append(self.uops.add(UOps.DEFINE_LOCAL, PtrDType(temp_dtype), (), (name, buf_size))) + self.buf_uops.append(UOp(UOps.DEFINE_LOCAL, PtrDType(temp_dtype), (), (name, buf_size))) # kernel name (before late upcast) self.name = ("r" if self.reduceop else ("C" if all(x.op in BufferOps for x in self.lazyops) else "E")) + \ @@ -381,11 +380,11 @@ class Linearizer(Kernel): self.local_size: Optional[List[int]] = None if self.dont_use_locals: self.global_size = [x.max+1 for x in loop_global_idxs][::-1] - self.loop_uops.update({x.expr:self.uops.add(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr.replace("gidx", "idx"), x.max+1)) for i,x in enumerate(loop_global_idxs)}) # noqa: E501 + self.loop_uops.update({x.expr:UOp(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr.replace("gidx", "idx"), x.max+1)) for i,x in enumerate(loop_global_idxs)}) # noqa: E501 elif self.opts.has_local: self.global_size, self.local_size = [x.max+1 for x in loop_global_idxs][::-1], [x.max+1 for x in loop_local_idxs] - self.loop_uops.update({x.expr:self.uops.add(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)}) # noqa: E501 - self.loop_uops.update({x.expr:self.uops.add(UOps.SPECIAL, dtypes.int32, (), (i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)}) + self.loop_uops.update({x.expr:UOp(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)}) # noqa: E501 + self.loop_uops.update({x.expr:UOp(UOps.SPECIAL, dtypes.int32, (), (i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)}) else: self.render_loop(loop_global_idxs+loop_local_idxs, 1) if self.global_size is not None: self.global_size += [1]*(3-len(self.global_size)) @@ -403,8 +402,10 @@ class Linearizer(Kernel): # render reduceops by depth for reduceop in self.reduceops: self.render_block((reduceop, ), global_idxs, local_idxs, upcast_idxs, full_upcast_idxs, alias_buf_idxs, loaded_buffers, accs) + stores = self.render_block(self.ast, global_idxs, local_idxs, upcast_idxs, full_upcast_idxs, alias_buf_idxs, loaded_buffers, accs) - self.render_block(self.ast, global_idxs, local_idxs, upcast_idxs, full_upcast_idxs, alias_buf_idxs, loaded_buffers, accs) + # only the final stores are needed to define the full UOps graph + self.uops:UOpGraph = UOpGraph(flatten(stores)) # maybe graph the uops if DEBUG >= 5: self.uops.print() @@ -428,7 +429,7 @@ class Linearizer(Kernel): # TODO: delete render_reduceop and move the logic for group_for_reduces to Block local_idxs[:], upcast_idxs[:] = self.render_reduceop((r:=reduceops[0]),accs,loaded_buffers,global_idxs,local_idxs,upcast_idxs, full_upcast_idxs, reduce_idxs,fake_reduce_idxs,alias_buf_idxs[r]) - return accs[r] + return [accs[r]] # load latebufs loaded_buffers.update({b:self.global_load(i, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs) \ @@ -442,7 +443,7 @@ class Linearizer(Kernel): if x in cache: return cache[x] if x.op in BufferOps: return loaded_buffers[x.arg] if x.op in [UnaryOps.CAST, UnaryOps.BITCAST]: - return [self.uops.add(UOps.BITCAST if x.op is UnaryOps.BITCAST else UOps.CAST, + return [UOp(UOps.BITCAST if x.op is UnaryOps.BITCAST else UOps.CAST, self.get_base_dtype(x.arg), (u,)) for u in self.ast_parse(x.src[0], accs, offs, loaded_buffers)] if x.op in ReduceOps and reduce_acc is None: assert offs is None, "not available if we aren't doing reduce" @@ -459,7 +460,7 @@ class Linearizer(Kernel): ret.append(acc[off]) for off in range(len(acc)): if input_acc[off] != acc[off]: - acc[off] = self.uops.add(UOps.PHI, input_acc[off].dtype, (input_acc[off], acc[off])) + acc[off] = UOp(UOps.PHI, input_acc[off].dtype, (input_acc[off], acc[off])) else: ret = [UOp.alu(x.op, *vin) for vin in zip(*values)] cache[x] = ret return ret diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index f3e83b76a7..1c6a2fb9a4 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -249,13 +249,39 @@ constant_folder = PatternMatcher([ # *** uop graph *** class UOpGraph: - def __init__(self): + def __init__(self, add_nodes:Optional[List[UOp]]=None): self.nodes: Dict[Tuple, UOp] = {} self._uops: Optional[List[UOp]] = None + if add_nodes is not None: self.multiadd(add_nodes) def __iter__(self) -> Iterator[UOp]: return iter(self.uops) def __getitem__(self, index) -> UOp: return self.uops[index] + def multiadd(self, unprocessed_nodes:List[UOp]): + # add nodes to graph in reverse BFS order + # TODO: i feel like this is written in a few places, possible to library it? + in_degree: DefaultDict[UOp, int] = defaultdict(int) + children: DefaultDict[UOp, List[UOp]] = defaultdict(list) + all_nodes: Dict[UOp, None] = dict() + while len(unprocessed_nodes): + n = unprocessed_nodes.pop(0) + if n in all_nodes: continue + all_nodes[n] = None + for x in n.vin: + in_degree[x] += 1 + children[x].append(n) + unprocessed_nodes += list(n.vin) + queue = [x for x in all_nodes if in_degree[x] == 0] + replace_nodes: Dict[UOp, UOp] = {} + while len(queue): + n = queue.pop(0) + if n in replace_nodes: continue + replace_nodes[n] = self.add(n.uop, n.dtype, tuple(replace_nodes.get(x, x) for x in n.vin), n.arg) + for x in children[n]: + in_degree[x] -= 1 + if in_degree[x] == 0: + queue.append(x) + def vars(self) -> List[Variable]: return [x.arg for x in self.uops if x.uop is UOps.DEFINE_VAR] def globals(self) -> List[Tuple[int, bool]]: return [x.arg for x in self.uops if x.uop is UOps.DEFINE_GLOBAL]