diff --git a/extra/gemm/metal_uop_matmul.py b/extra/gemm/metal_uop_matmul.py index a2d619b45e..a8b795324a 100644 --- a/extra/gemm/metal_uop_matmul.py +++ b/extra/gemm/metal_uop_matmul.py @@ -20,8 +20,8 @@ def hand_spec_tc_cores(): gk = UOp.range(N // 8, 0, AxisType.REDUCE) - a_tc = UOp.vectorize(*[mat_idx(a, gx, gk, warp, i) for i in range(2)]) - b_tc = UOp.vectorize(*[mat_idx(b, gk, gy, warp, i) for i in range(2)]) + a_tc = UOp.stack(*[mat_idx(a, gx, gk, warp, i) for i in range(2)]) + b_tc = UOp.stack(*[mat_idx(b, gk, gy, warp, i) for i in range(2)]) acc = UOp.placeholder((2,), dtypes.float, slot=0, addrspace=AddrSpace.REG) acc = acc[0].set(0.0) @@ -30,7 +30,7 @@ def hand_spec_tc_cores(): # TODO: make this simple wmma_arg = ('WMMA_8_8_8_float_float', (8, 8, 8), dtypes.float, dtypes.float, 'METAL', 32, (((3, 2),), ((3, 2),), ((3, 2),)), ()) - acc_load = UOp.vectorize(acc.after(gk)[0], acc.after(gk)[1]) + acc_load = UOp.stack(acc.after(gk)[0], acc.after(gk)[1]) out = UOp(Ops.WMMA, dtypes.float.vec(2), (a_tc, b_tc, acc_load), arg=wmma_arg) end_loop = UOp.group(*[acc[i].store(out.gep(i)) for i in range(2)]).end(gk) diff --git a/extra/thunder/tiny/tk/group.py b/extra/thunder/tiny/tk/group.py index d592a44f9d..d528e61b7a 100644 --- a/extra/thunder/tiny/tk/group.py +++ b/extra/thunder/tiny/tk/group.py @@ -84,13 +84,13 @@ class Group: for width in self.ker.range(c.shape[-2], track=False): for inner in self.ker.range(a.shape[-2], axis_type=AxisType.REDUCE, track=False): if a_base_shape.cols == 16: - a_in = UOp.vectorize(*[a[height, inner, i] for i in range(4)]) - b_in = UOp.vectorize(*[b[inner, width, i] for i in range(4)]) + a_in = UOp.stack(*[a[height, inner, i] for i in range(4)]) + b_in = UOp.stack(*[b[inner, width, i] for i in range(4)]) elif a_base_shape.cols == 32: - a_in = UOp.vectorize(*[a[height, inner, i] for i in range(8)]) - b_in = UOp.vectorize(*[b[inner, width, i] for i in range(8)]) + a_in = UOp.stack(*[a[height, inner, i] for i in range(8)]) + b_in = UOp.stack(*[b[inner, width, i] for i in range(8)]) else: raise NotImplementedError(f"mma_AB not implemented for {a_base_shape.cols=}") - d_in = UOp.vectorize(*[c[height, width, i] for i in range(4)]) + d_in = UOp.stack(*[c[height, width, i] for i in range(4)]) out = UOp(Ops.WMMA, dtypes.float32.vec(4), (a_in, b_in, d_in), arg=wmma_arg) c_i = [c[height, width, i].store(out.gep(i)) for i in range(4)] @@ -114,13 +114,13 @@ class Group: for width in self.ker.range(c.shape[-2], track=False): for inner in self.ker.range(a.shape[-2], axis_type=AxisType.REDUCE, track=False): if a_base_shape.cols == 16: - a_in = UOp.vectorize(*[a[height, inner, i] for i in range(4)]) - b_in = UOp.vectorize(*[b[width, inner, i] for i in range(4)]) + a_in = UOp.stack(*[a[height, inner, i] for i in range(4)]) + b_in = UOp.stack(*[b[width, inner, i] for i in range(4)]) elif a_base_shape.cols == 32: - a_in = UOp.vectorize(*[a[height, inner, i] for i in range(8)]) - b_in = UOp.vectorize(*[b[width, inner, i] for i in range(8)]) + a_in = UOp.stack(*[a[height, inner, i] for i in range(8)]) + b_in = UOp.stack(*[b[width, inner, i] for i in range(8)]) else: raise NotImplementedError(f"mma_ABt not implemented for {a_base_shape.cols=}") - d_in = UOp.vectorize(*[c[height, width, i] for i in range(4)]) + d_in = UOp.stack(*[c[height, width, i] for i in range(4)]) out = UOp(Ops.WMMA, dtypes.float32.vec(4), (a_in, b_in, d_in), arg=wmma_arg) c_i = [c[height, width, i].store(out.gep(i)) for i in range(4)] @@ -144,13 +144,13 @@ class Group: for width in self.ker.range(c.shape[-2], track=False): for inner in self.ker.range(a.shape[-3], axis_type=AxisType.REDUCE, track=False): if a_base_shape.cols == 16: - a_in = UOp.vectorize(*[a[inner, height, i] for i in range(4)]) - b_in = UOp.vectorize(*[b[inner, width, i] for i in range(4)]) + a_in = UOp.stack(*[a[inner, height, i] for i in range(4)]) + b_in = UOp.stack(*[b[inner, width, i] for i in range(4)]) elif a_base_shape.cols == 32: - a_in = UOp.vectorize(*[a[inner, height, i] for i in range(8)]) - b_in = UOp.vectorize(*[b[inner, width, i] for i in range(8)]) + a_in = UOp.stack(*[a[inner, height, i] for i in range(8)]) + b_in = UOp.stack(*[b[inner, width, i] for i in range(8)]) else: raise NotImplementedError(f"mma_AtB not implemented for {a_base_shape.cols=}") - d_in = UOp.vectorize(*[c[height, width, i] for i in range(4)]) + d_in = UOp.stack(*[c[height, width, i] for i in range(4)]) out = UOp(Ops.WMMA, dtypes.float32.vec(4), (a_in, b_in, d_in), arg=wmma_arg) c_i = [c[height, width, i].store(out.gep(i)) for i in range(4)] @@ -174,13 +174,13 @@ class Group: for width in self.ker.range(c.shape[-2], track=False): for inner in self.ker.range(a.shape[-3], axis_type=AxisType.REDUCE, track=False): if a_base_shape.cols == 16: - a_in = UOp.vectorize(*[a[inner, height, i] for i in range(4)]) - b_in = UOp.vectorize(*[b[width, inner, i] for i in range(4)]) + a_in = UOp.stack(*[a[inner, height, i] for i in range(4)]) + b_in = UOp.stack(*[b[width, inner, i] for i in range(4)]) elif a_base_shape.cols == 32: - a_in = UOp.vectorize(*[a[inner, height, i] for i in range(8)]) - b_in = UOp.vectorize(*[b[width, inner, i] for i in range(8)]) + a_in = UOp.stack(*[a[inner, height, i] for i in range(8)]) + b_in = UOp.stack(*[b[width, inner, i] for i in range(8)]) else: raise NotImplementedError(f"mma_AtBt not implemented for {a_base_shape.cols=}") - d_in = UOp.vectorize(*[c[height, width, i] for i in range(4)]) + d_in = UOp.stack(*[c[height, width, i] for i in range(4)]) out = UOp(Ops.WMMA, dtypes.float32.vec(4), (a_in, b_in, d_in), arg=wmma_arg) c_i = [c[height, width, i].store(out.gep(i)) for i in range(4)] diff --git a/tinygrad/codegen/late/devectorizer.py b/tinygrad/codegen/late/devectorizer.py index dd4c4556e8..89d1afee26 100644 --- a/tinygrad/codegen/late/devectorizer.py +++ b/tinygrad/codegen/late/devectorizer.py @@ -68,7 +68,7 @@ def expand_index(buf:UOp, vec:UOp): # search for dims that drop the most valid statements best_drop, cands = -1, [] for ch, cw in ImageDType.valid_dims(dt): - if (dropped:=len(_drop_valid_stmts(valid, cidx:=uop_given_valid(valid, UOp.vectorize((x//4)%cw, x//(4*cw))), ch, cw))) > best_drop: + if (dropped:=len(_drop_valid_stmts(valid, cidx:=uop_given_valid(valid, UOp.stack((x//4)%cw, x//(4*cw))), ch, cw))) > best_drop: best_drop, cands = dropped, [(ch, cw, cidx)] elif dropped == best_drop: cands.append((ch, cw, cidx)) # and tiebreak with indexing complexity (ie. number of nodes) @@ -197,8 +197,9 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp): return UOp(Ops.VCAT, ls.dtype, tuple(ret)) if ls.op is Ops.LOAD else UOp.group(*ret) def get_image_idx(idx:UOp, width:int): - oidx = UOp(Ops.STACK, dtypes.weakint.vec(2), (((x:=idx.src[1].get_idx()) // 4) % width, (x // (4*width)))) - return idx.replace(src=(idx.src[0], oidx.valid(idx.src[1].get_valid()))) + x, valid = idx.src[1].get_idx(), idx.src[1].get_valid() + idx_x, idx_y = (x // 4) % width, x // (4*width) + return idx.replace(src=(idx.src[0], UOp.stack(idx_x, idx_y).valid(valid))) def image_fixup(ls:UOp): # normal image load or store, with the CAST from expand_index @@ -385,7 +386,7 @@ def make_image(ls, buf, off): if (vcount:=buf.dtype.vcount) != 1: buf = buf.src[0] if buf.op == Ops.PARAM and not isinstance(dt:=buf.dtype, ImageDType) and (dims:=ImageDType.valid_dims(dt)): buf = buf.replace(dtype=(dtypes.imageh if dt.base == dtypes.half else dtypes.imagef)((*dims[0], 4))) - if vcount != 1: buf = UOp.vectorize(*([buf] * vcount)) + if vcount != 1: buf = UOp.stack(*([buf] * vcount)) if ls.op is Ops.LOAD: return ls.replace(src=(buf.index(off, ptr=True),), dtype=dtypes.float.vec(ls.dtype.vcount)).cast(dt.base) return buf.index(off, ptr=True).store(pm_imageh_store.rewrite(ls.src[1]) if dt.base == dtypes.half else ls.src[1]) diff --git a/tinygrad/uop/decompositions.py b/tinygrad/uop/decompositions.py index 086bd270d1..dedd95d102 100644 --- a/tinygrad/uop/decompositions.py +++ b/tinygrad/uop/decompositions.py @@ -418,7 +418,7 @@ def f2f_clamp(val:UOp, dt:DType) -> UOp: def f2f_load(x: UOp, fr:DType, to:DType) -> UOp: if (n:=x.dtype.count) == 1: return f2f(x.replace(dtype=f2f_dt[fr]), fr, to) - return UOp.vectorize(*(f2f(x.replace(dtype=f2f_dt[fr], src=(reindex(x.src[0].src[0], i, 1),)), fr, to) for i in range(n))) + return UOp.stack(*(f2f(x.replace(dtype=f2f_dt[fr], src=(reindex(x.src[0].src[0], i, 1),)), fr, to) for i in range(n))) def f2f_store(st, idx, val, fr:DType, to:DType): if (n:=val.dtype.count) == 1: return st.replace(src=(idx, f2f(val.bitcast(f2f_dt[to]), to, fr))) diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index a468d9a74e..a452ac3f9f 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -412,6 +412,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass): return UOp(Ops.SINK, dtypes.void, tuple([x for x in srcs if x is not None]), **kwargs) def maketuple(*srcs:UOp): # pylint: disable=no-self-argument return UOp(Ops.TUPLE, dtypes.void, srcs) + def stack(*srcs:UOp, **kwargs): # pylint: disable=no-self-argument + return UOp(Ops.STACK, srcs[0].dtype.vec(len(srcs)), srcs, **kwargs) def gettuple(self, idx:int) -> UOp: in_tuple = self.src[0] if self.op is Ops.FUNCTION else self assert in_tuple.op is Ops.TUPLE, f"gettuple requires FUNCTION or TUPLE source, got {self.op}" @@ -419,8 +421,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass): def group(*srcs:UOp|None): # pylint: disable=no-self-argument if len(srcs) == 1 and isinstance(srcs[0], UOp): return srcs[0] return UOp(Ops.GROUP, dtypes.void, tuple([x for x in srcs if x is not None])) - def vectorize(self, *srcs, **kwargs): - return UOp(Ops.STACK, self.dtype.vec(len(srcs)+1), (self,)+srcs, **kwargs) def index(self, *srcs:UOp|None, ptr=False, **kwargs): return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs) def __getitem__(self, idx): @@ -1136,8 +1136,8 @@ class UPat(OpMixin): # copied from UOp def sink(self, *srcs:UPat|None, **kwargs): return UPat(Ops.SINK, dtypes.void, (self,)+tuple([x for x in srcs if x is not None]), **kwargs) - def index(self, idx:UPat, valid:UPat|None=None, **kwargs): - return UPat(Ops.INDEX, self.match_dtype, (self,idx,valid) if valid is not None else (self,idx), **kwargs) + def index(self, *srcs:UPat|None, **kwargs): + return UPat(Ops.INDEX, self.match_dtype, (self,)+tuple(x for x in srcs if x is not None), **kwargs) def cast(self, dtype=None, **kwargs): if dtype is not None and self.match_dtype == (dtype,): return self return UPat(Ops.CAST, dtype, (self,), **kwargs) @@ -1533,7 +1533,7 @@ pm_lower_index_dtype = PatternMatcher([ lambda n: n.replace(src=tuple(s.src[0] if s.op is Ops.CAST and s.dtype == dtypes.weakint else s for s in n.src))), # vectorized indexes (ie. images) must be int (UPat(Ops.INDEX, src=(UPat(), UPat(Ops.STACK, dtypes.long, name="vec")), allow_any_len=True, name="idx"), - lambda idx,vec: idx.replace(src=(idx.src[0], UOp.vectorize(*(u.cast(dtypes.int) for u in vec.src)), *idx.src[2:]))) + lambda idx,vec: idx.replace(src=(idx.src[0], UOp.stack(*(u.cast(dtypes.int) for u in vec.src)), *idx.src[2:]))) ]) def _index_to_concrete_int(u:UOp) -> UOp: return graph_rewrite(u.sink(), pm_lower_index_dtype).src[0]