mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
lil image refactors + vectorize->stack
This commit is contained in:
@@ -20,8 +20,8 @@ def hand_spec_tc_cores():
|
||||
|
||||
gk = UOp.range(N // 8, 0, AxisType.REDUCE)
|
||||
|
||||
a_tc = UOp.vectorize(*[mat_idx(a, gx, gk, warp, i) for i in range(2)])
|
||||
b_tc = UOp.vectorize(*[mat_idx(b, gk, gy, warp, i) for i in range(2)])
|
||||
a_tc = UOp.stack(*[mat_idx(a, gx, gk, warp, i) for i in range(2)])
|
||||
b_tc = UOp.stack(*[mat_idx(b, gk, gy, warp, i) for i in range(2)])
|
||||
|
||||
acc = UOp.placeholder((2,), dtypes.float, slot=0, addrspace=AddrSpace.REG)
|
||||
acc = acc[0].set(0.0)
|
||||
@@ -30,7 +30,7 @@ def hand_spec_tc_cores():
|
||||
# TODO: make this simple
|
||||
wmma_arg = ('WMMA_8_8_8_float_float', (8, 8, 8), dtypes.float, dtypes.float, 'METAL', 32, (((3, 2),), ((3, 2),), ((3, 2),)), ())
|
||||
|
||||
acc_load = UOp.vectorize(acc.after(gk)[0], acc.after(gk)[1])
|
||||
acc_load = UOp.stack(acc.after(gk)[0], acc.after(gk)[1])
|
||||
out = UOp(Ops.WMMA, dtypes.float.vec(2), (a_tc, b_tc, acc_load), arg=wmma_arg)
|
||||
|
||||
end_loop = UOp.group(*[acc[i].store(out.gep(i)) for i in range(2)]).end(gk)
|
||||
|
||||
@@ -84,13 +84,13 @@ class Group:
|
||||
for width in self.ker.range(c.shape[-2], track=False):
|
||||
for inner in self.ker.range(a.shape[-2], axis_type=AxisType.REDUCE, track=False):
|
||||
if a_base_shape.cols == 16:
|
||||
a_in = UOp.vectorize(*[a[height, inner, i] for i in range(4)])
|
||||
b_in = UOp.vectorize(*[b[inner, width, i] for i in range(4)])
|
||||
a_in = UOp.stack(*[a[height, inner, i] for i in range(4)])
|
||||
b_in = UOp.stack(*[b[inner, width, i] for i in range(4)])
|
||||
elif a_base_shape.cols == 32:
|
||||
a_in = UOp.vectorize(*[a[height, inner, i] for i in range(8)])
|
||||
b_in = UOp.vectorize(*[b[inner, width, i] for i in range(8)])
|
||||
a_in = UOp.stack(*[a[height, inner, i] for i in range(8)])
|
||||
b_in = UOp.stack(*[b[inner, width, i] for i in range(8)])
|
||||
else: raise NotImplementedError(f"mma_AB not implemented for {a_base_shape.cols=}")
|
||||
d_in = UOp.vectorize(*[c[height, width, i] for i in range(4)])
|
||||
d_in = UOp.stack(*[c[height, width, i] for i in range(4)])
|
||||
|
||||
out = UOp(Ops.WMMA, dtypes.float32.vec(4), (a_in, b_in, d_in), arg=wmma_arg)
|
||||
c_i = [c[height, width, i].store(out.gep(i)) for i in range(4)]
|
||||
@@ -114,13 +114,13 @@ class Group:
|
||||
for width in self.ker.range(c.shape[-2], track=False):
|
||||
for inner in self.ker.range(a.shape[-2], axis_type=AxisType.REDUCE, track=False):
|
||||
if a_base_shape.cols == 16:
|
||||
a_in = UOp.vectorize(*[a[height, inner, i] for i in range(4)])
|
||||
b_in = UOp.vectorize(*[b[width, inner, i] for i in range(4)])
|
||||
a_in = UOp.stack(*[a[height, inner, i] for i in range(4)])
|
||||
b_in = UOp.stack(*[b[width, inner, i] for i in range(4)])
|
||||
elif a_base_shape.cols == 32:
|
||||
a_in = UOp.vectorize(*[a[height, inner, i] for i in range(8)])
|
||||
b_in = UOp.vectorize(*[b[width, inner, i] for i in range(8)])
|
||||
a_in = UOp.stack(*[a[height, inner, i] for i in range(8)])
|
||||
b_in = UOp.stack(*[b[width, inner, i] for i in range(8)])
|
||||
else: raise NotImplementedError(f"mma_ABt not implemented for {a_base_shape.cols=}")
|
||||
d_in = UOp.vectorize(*[c[height, width, i] for i in range(4)])
|
||||
d_in = UOp.stack(*[c[height, width, i] for i in range(4)])
|
||||
|
||||
out = UOp(Ops.WMMA, dtypes.float32.vec(4), (a_in, b_in, d_in), arg=wmma_arg)
|
||||
c_i = [c[height, width, i].store(out.gep(i)) for i in range(4)]
|
||||
@@ -144,13 +144,13 @@ class Group:
|
||||
for width in self.ker.range(c.shape[-2], track=False):
|
||||
for inner in self.ker.range(a.shape[-3], axis_type=AxisType.REDUCE, track=False):
|
||||
if a_base_shape.cols == 16:
|
||||
a_in = UOp.vectorize(*[a[inner, height, i] for i in range(4)])
|
||||
b_in = UOp.vectorize(*[b[inner, width, i] for i in range(4)])
|
||||
a_in = UOp.stack(*[a[inner, height, i] for i in range(4)])
|
||||
b_in = UOp.stack(*[b[inner, width, i] for i in range(4)])
|
||||
elif a_base_shape.cols == 32:
|
||||
a_in = UOp.vectorize(*[a[inner, height, i] for i in range(8)])
|
||||
b_in = UOp.vectorize(*[b[inner, width, i] for i in range(8)])
|
||||
a_in = UOp.stack(*[a[inner, height, i] for i in range(8)])
|
||||
b_in = UOp.stack(*[b[inner, width, i] for i in range(8)])
|
||||
else: raise NotImplementedError(f"mma_AtB not implemented for {a_base_shape.cols=}")
|
||||
d_in = UOp.vectorize(*[c[height, width, i] for i in range(4)])
|
||||
d_in = UOp.stack(*[c[height, width, i] for i in range(4)])
|
||||
|
||||
out = UOp(Ops.WMMA, dtypes.float32.vec(4), (a_in, b_in, d_in), arg=wmma_arg)
|
||||
c_i = [c[height, width, i].store(out.gep(i)) for i in range(4)]
|
||||
@@ -174,13 +174,13 @@ class Group:
|
||||
for width in self.ker.range(c.shape[-2], track=False):
|
||||
for inner in self.ker.range(a.shape[-3], axis_type=AxisType.REDUCE, track=False):
|
||||
if a_base_shape.cols == 16:
|
||||
a_in = UOp.vectorize(*[a[inner, height, i] for i in range(4)])
|
||||
b_in = UOp.vectorize(*[b[width, inner, i] for i in range(4)])
|
||||
a_in = UOp.stack(*[a[inner, height, i] for i in range(4)])
|
||||
b_in = UOp.stack(*[b[width, inner, i] for i in range(4)])
|
||||
elif a_base_shape.cols == 32:
|
||||
a_in = UOp.vectorize(*[a[inner, height, i] for i in range(8)])
|
||||
b_in = UOp.vectorize(*[b[width, inner, i] for i in range(8)])
|
||||
a_in = UOp.stack(*[a[inner, height, i] for i in range(8)])
|
||||
b_in = UOp.stack(*[b[width, inner, i] for i in range(8)])
|
||||
else: raise NotImplementedError(f"mma_AtBt not implemented for {a_base_shape.cols=}")
|
||||
d_in = UOp.vectorize(*[c[height, width, i] for i in range(4)])
|
||||
d_in = UOp.stack(*[c[height, width, i] for i in range(4)])
|
||||
|
||||
out = UOp(Ops.WMMA, dtypes.float32.vec(4), (a_in, b_in, d_in), arg=wmma_arg)
|
||||
c_i = [c[height, width, i].store(out.gep(i)) for i in range(4)]
|
||||
|
||||
@@ -68,7 +68,7 @@ def expand_index(buf:UOp, vec:UOp):
|
||||
# search for dims that drop the most valid statements
|
||||
best_drop, cands = -1, []
|
||||
for ch, cw in ImageDType.valid_dims(dt):
|
||||
if (dropped:=len(_drop_valid_stmts(valid, cidx:=uop_given_valid(valid, UOp.vectorize((x//4)%cw, x//(4*cw))), ch, cw))) > best_drop:
|
||||
if (dropped:=len(_drop_valid_stmts(valid, cidx:=uop_given_valid(valid, UOp.stack((x//4)%cw, x//(4*cw))), ch, cw))) > best_drop:
|
||||
best_drop, cands = dropped, [(ch, cw, cidx)]
|
||||
elif dropped == best_drop: cands.append((ch, cw, cidx))
|
||||
# and tiebreak with indexing complexity (ie. number of nodes)
|
||||
@@ -197,8 +197,9 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
|
||||
return UOp(Ops.VCAT, ls.dtype, tuple(ret)) if ls.op is Ops.LOAD else UOp.group(*ret)
|
||||
|
||||
def get_image_idx(idx:UOp, width:int):
|
||||
oidx = UOp(Ops.STACK, dtypes.weakint.vec(2), (((x:=idx.src[1].get_idx()) // 4) % width, (x // (4*width))))
|
||||
return idx.replace(src=(idx.src[0], oidx.valid(idx.src[1].get_valid())))
|
||||
x, valid = idx.src[1].get_idx(), idx.src[1].get_valid()
|
||||
idx_x, idx_y = (x // 4) % width, x // (4*width)
|
||||
return idx.replace(src=(idx.src[0], UOp.stack(idx_x, idx_y).valid(valid)))
|
||||
|
||||
def image_fixup(ls:UOp):
|
||||
# normal image load or store, with the CAST from expand_index
|
||||
@@ -385,7 +386,7 @@ def make_image(ls, buf, off):
|
||||
if (vcount:=buf.dtype.vcount) != 1: buf = buf.src[0]
|
||||
if buf.op == Ops.PARAM and not isinstance(dt:=buf.dtype, ImageDType) and (dims:=ImageDType.valid_dims(dt)):
|
||||
buf = buf.replace(dtype=(dtypes.imageh if dt.base == dtypes.half else dtypes.imagef)((*dims[0], 4)))
|
||||
if vcount != 1: buf = UOp.vectorize(*([buf] * vcount))
|
||||
if vcount != 1: buf = UOp.stack(*([buf] * vcount))
|
||||
if ls.op is Ops.LOAD: return ls.replace(src=(buf.index(off, ptr=True),), dtype=dtypes.float.vec(ls.dtype.vcount)).cast(dt.base)
|
||||
return buf.index(off, ptr=True).store(pm_imageh_store.rewrite(ls.src[1]) if dt.base == dtypes.half else ls.src[1])
|
||||
|
||||
|
||||
@@ -418,7 +418,7 @@ def f2f_clamp(val:UOp, dt:DType) -> UOp:
|
||||
|
||||
def f2f_load(x: UOp, fr:DType, to:DType) -> UOp:
|
||||
if (n:=x.dtype.count) == 1: return f2f(x.replace(dtype=f2f_dt[fr]), fr, to)
|
||||
return UOp.vectorize(*(f2f(x.replace(dtype=f2f_dt[fr], src=(reindex(x.src[0].src[0], i, 1),)), fr, to) for i in range(n)))
|
||||
return UOp.stack(*(f2f(x.replace(dtype=f2f_dt[fr], src=(reindex(x.src[0].src[0], i, 1),)), fr, to) for i in range(n)))
|
||||
|
||||
def f2f_store(st, idx, val, fr:DType, to:DType):
|
||||
if (n:=val.dtype.count) == 1: return st.replace(src=(idx, f2f(val.bitcast(f2f_dt[to]), to, fr)))
|
||||
|
||||
@@ -412,6 +412,8 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
return UOp(Ops.SINK, dtypes.void, tuple([x for x in srcs if x is not None]), **kwargs)
|
||||
def maketuple(*srcs:UOp): # pylint: disable=no-self-argument
|
||||
return UOp(Ops.TUPLE, dtypes.void, srcs)
|
||||
def stack(*srcs:UOp, **kwargs): # pylint: disable=no-self-argument
|
||||
return UOp(Ops.STACK, srcs[0].dtype.vec(len(srcs)), srcs, **kwargs)
|
||||
def gettuple(self, idx:int) -> UOp:
|
||||
in_tuple = self.src[0] if self.op is Ops.FUNCTION else self
|
||||
assert in_tuple.op is Ops.TUPLE, f"gettuple requires FUNCTION or TUPLE source, got {self.op}"
|
||||
@@ -419,8 +421,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
def group(*srcs:UOp|None): # pylint: disable=no-self-argument
|
||||
if len(srcs) == 1 and isinstance(srcs[0], UOp): return srcs[0]
|
||||
return UOp(Ops.GROUP, dtypes.void, tuple([x for x in srcs if x is not None]))
|
||||
def vectorize(self, *srcs, **kwargs):
|
||||
return UOp(Ops.STACK, self.dtype.vec(len(srcs)+1), (self,)+srcs, **kwargs)
|
||||
def index(self, *srcs:UOp|None, ptr=False, **kwargs):
|
||||
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype if ptr else self.dtype.base), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
|
||||
def __getitem__(self, idx):
|
||||
@@ -1136,8 +1136,8 @@ class UPat(OpMixin):
|
||||
|
||||
# copied from UOp
|
||||
def sink(self, *srcs:UPat|None, **kwargs): return UPat(Ops.SINK, dtypes.void, (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
|
||||
def index(self, idx:UPat, valid:UPat|None=None, **kwargs):
|
||||
return UPat(Ops.INDEX, self.match_dtype, (self,idx,valid) if valid is not None else (self,idx), **kwargs)
|
||||
def index(self, *srcs:UPat|None, **kwargs):
|
||||
return UPat(Ops.INDEX, self.match_dtype, (self,)+tuple(x for x in srcs if x is not None), **kwargs)
|
||||
def cast(self, dtype=None, **kwargs):
|
||||
if dtype is not None and self.match_dtype == (dtype,): return self
|
||||
return UPat(Ops.CAST, dtype, (self,), **kwargs)
|
||||
@@ -1533,7 +1533,7 @@ pm_lower_index_dtype = PatternMatcher([
|
||||
lambda n: n.replace(src=tuple(s.src[0] if s.op is Ops.CAST and s.dtype == dtypes.weakint else s for s in n.src))),
|
||||
# vectorized indexes (ie. images) must be int
|
||||
(UPat(Ops.INDEX, src=(UPat(), UPat(Ops.STACK, dtypes.long, name="vec")), allow_any_len=True, name="idx"),
|
||||
lambda idx,vec: idx.replace(src=(idx.src[0], UOp.vectorize(*(u.cast(dtypes.int) for u in vec.src)), *idx.src[2:])))
|
||||
lambda idx,vec: idx.replace(src=(idx.src[0], UOp.stack(*(u.cast(dtypes.int) for u in vec.src)), *idx.src[2:])))
|
||||
])
|
||||
def _index_to_concrete_int(u:UOp) -> UOp: return graph_rewrite(u.sink(), pm_lower_index_dtype).src[0]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user