mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
wmma: enable METAL half tensor cores and clean up cstyle (#3095)
* wmma: enable METAL half tensor cores and clean up cstyle * revert simple_matmul rand changes and break line in tensor * added metal fp16->fp32 tensor core
This commit is contained in:
@@ -2,14 +2,14 @@ import numpy as np
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad import dtypes, Tensor
|
||||
dtype_in = dtypes.half if getenv("HALF") else dtypes.float
|
||||
acc_dtype = dtypes.half if getenv("ACC_HALF") else None
|
||||
N = getenv("N", 4096)
|
||||
CNT = getenv("CNT", 10)
|
||||
a, b = Tensor.rand(N, N, dtype=dtype_in).realize(), Tensor.rand(N, N, dtype=dtype_in).realize()
|
||||
for i in range(CNT):
|
||||
if i > 0 and getenv("RAND", 0) != 0:
|
||||
a, b = Tensor.rand(N, N, dtype=dtype_in).realize(), Tensor.rand(N, N, dtype=dtype_in).realize()
|
||||
# NOTE: accumulate is in float32
|
||||
c = (a @ b).realize()
|
||||
c = a.matmul(b, acc_dtype=acc_dtype).realize()
|
||||
comp = a.numpy().astype(np.float32) @ b.numpy().astype(np.float32)
|
||||
nc = c.numpy()
|
||||
np.testing.assert_allclose(nc, comp, atol=1e-4, rtol=3e-2)
|
||||
|
||||
@@ -87,10 +87,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
if tc.arch is not None and tc.arch != os.uname().machine: continue
|
||||
a, b = Tensor.rand(tc.dims[0], tc.dims[2], dtype=tc.dtype_in), Tensor.rand(tc.dims[2], tc.dims[1], dtype=tc.dtype_in)
|
||||
np_a, np_b = a.numpy(), b.numpy()
|
||||
if tc.dtype_out != tc.dtype_in:
|
||||
r = (a.reshape(tc.dims[0], 1, tc.dims[2]) * b.permute(1,0).reshape(1, tc.dims[1], tc.dims[2])).cast(tc.dtype_out).sum(axis=2)
|
||||
else:
|
||||
r = a @ b
|
||||
r = a.matmul(b, acc_dtype=tc.dtype_out)
|
||||
realized_ast, _ = helper_realized_ast(r)
|
||||
k = Linearizer(realized_ast)
|
||||
k.apply_tensor_cores(1)
|
||||
|
||||
@@ -33,18 +33,18 @@ class TensorCore:
|
||||
upcast_dim: int # which TC dim to upcast
|
||||
thread_local_aliases: List[List[List[int]]] # a list of [threads_1, ..., threads_n, upcast_1(unrolled), upcast_2(upcast)] defining the alias (-1 is upcast, 1-n is warp threads) for each TC dim # noqa: E501
|
||||
thread_local_sizes: List[int] # in each thread, the number of elements stored in registers for each TC dim
|
||||
wmma_func: str # name of wmma function to call
|
||||
arch: Optional[str] = None
|
||||
def __str__(self): return f"tensor_core<{self.device}, {self.dims}, {self.dtype_in}, {self.dtype_out}>"
|
||||
|
||||
tensor_cores: Dict[str, List[TensorCore]] = {
|
||||
"METAL": [
|
||||
TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.float, dtype_out=dtypes.float, upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], arch="arm64"), # noqa: E501
|
||||
# TODO: enable half @ half -> half tensor core with correct dtypes in uop
|
||||
# TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.half, upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], arch="arm64"), # noqa: E501
|
||||
TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.float, dtype_out=dtypes.float, wmma_func="__metal_wmma<float2,simdgroup_float8x8,float2>", upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], arch="arm64"), # noqa: E501
|
||||
TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__metal_wmma<half2,simdgroup_float8x8,float2>", upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], arch="arm64"), # noqa: E501
|
||||
TensorCore(device="METAL", dims=[8,8,8], dtype_in=dtypes.half, dtype_out=dtypes.half, wmma_func="__metal_wmma<half2,simdgroup_half8x8,half2>", upcast_dim=0, threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[2,2,2], thread_local_aliases= [ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], arch="arm64"), # noqa: E501
|
||||
],
|
||||
"HIP": [
|
||||
TensorCore(device="HIP", dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, upcast_dim=1, threads=[(0,16),(1,2)], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501
|
||||
TensorCore(device="HIP", dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.half, upcast_dim=1, threads=[(0,16),(1,2)], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501
|
||||
TensorCore(device="HIP", dims=[16,16,16], dtype_in=dtypes.half, dtype_out=dtypes.float, wmma_func="__builtin_amdgcn_wmma_f32_16x16x16_f16_w32", upcast_dim=1, threads=[(0,16),(1,2)], thread_local_sizes=[16,16,8], thread_local_aliases=[ [[0],[0],[-1],[1]], [[0],[1],[-1],[0]], [[0],[1],[0],[2,-1]] ]), # noqa: E501
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@@ -265,7 +265,7 @@ class Linearizer(Kernel):
|
||||
# define accumulator
|
||||
acc = self.global_load(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(self.reduceop))
|
||||
|
||||
if self.tensor_core:
|
||||
if (tc:=self.tensor_core):
|
||||
def calc_tc_idxs(local_size: int, aliases: List[List[int]]):
|
||||
replace_idxs = []
|
||||
for alias in aliases:
|
||||
@@ -277,11 +277,11 @@ class Linearizer(Kernel):
|
||||
full_var_sz *= next_var.max+1
|
||||
replace_idxs.append(full_var)
|
||||
return replace_idxs
|
||||
replace_acc_idxs = calc_tc_idxs(self.tensor_core.thread_local_sizes[2], self.tensor_core.thread_local_aliases[2])
|
||||
for n in range(len(self.tensor_core.threads)):
|
||||
local_idxs[self.local_dims-len(self.tensor_core.threads)+n] = replace_acc_idxs[n] # replace locals
|
||||
for n in range(len(replace_acc_idxs)-len(self.tensor_core.threads)):
|
||||
upcast_idxs[n] = replace_acc_idxs[len(self.tensor_core.threads)+n] # replace upcasts
|
||||
replace_acc_idxs = calc_tc_idxs(tc.thread_local_sizes[2], tc.thread_local_aliases[2])
|
||||
for n in range(len(tc.threads)):
|
||||
local_idxs[self.local_dims-len(tc.threads)+n] = replace_acc_idxs[n] # replace locals
|
||||
for n in range(len(replace_acc_idxs)-len(tc.threads)):
|
||||
upcast_idxs[n] = replace_acc_idxs[len(tc.threads)+n] # replace upcasts
|
||||
|
||||
# reduce loop
|
||||
loop_ctx = render_loop(reduce_idxs)
|
||||
@@ -306,8 +306,8 @@ class Linearizer(Kernel):
|
||||
locals_to_store.append((localbuf_idx, buf_idxs, ll))
|
||||
|
||||
# copy in any global buffers
|
||||
if self.tensor_core:
|
||||
wmma_sz, dtype_in, dtype_out = self.tensor_core.thread_local_sizes, self.tensor_core.dtype_in, self.tensor_core.dtype_out
|
||||
if (tc:=self.tensor_core):
|
||||
wmma_sz = tc.thread_local_sizes
|
||||
# calculate the number of local accumulator reduces and render WMMAs: this is bad... this needs to come from someplace else
|
||||
nx, ny, nacc = (len(locals_to_store[0][2])//wmma_sz[0]), (len(locals_to_store[1][2])//wmma_sz[1]), (len(acc)//wmma_sz[2])
|
||||
acc_reds = math.isqrt((nx*ny)//nacc)
|
||||
@@ -315,12 +315,12 @@ class Linearizer(Kernel):
|
||||
for y in range(by):
|
||||
for x in range(bx):
|
||||
for j in range(acc_reds):
|
||||
ops = (self.uop(UOps.CAST, dtype_in.vec(wmma_sz[0]), tuple(locals_to_store[0][2][(x+(j*bx))*wmma_sz[0]:(x+(j*bx)+1)*wmma_sz[0]])),
|
||||
self.uop(UOps.CAST, dtype_in.vec(wmma_sz[1]), tuple(locals_to_store[1][2][(y+(j*by))*wmma_sz[1]:(y+(j*by)+1)*wmma_sz[1]])),
|
||||
self.uop(UOps.CAST, dtype_out.vec(wmma_sz[2]), tuple(op3:=acc[i:i+wmma_sz[2]])))
|
||||
ret = self.uop(UOps.WMMA, dtype_out.vec(wmma_sz[2]), ops, (self.opts.device, self.tensor_core.dtype_in, self.tensor_core.dtype_out,))
|
||||
for z in range(cast(DType, ret.dtype).sz):
|
||||
acc[i+z] = self.uop(UOps.PHI, dtypes.float, (op3[z], self.uop(UOps.GEP, dtypes.float, (ret,), z)) + loop_ctx)
|
||||
ops = (self.uop(UOps.CAST, tc.dtype_in.vec(wmma_sz[0]), tuple(locals_to_store[0][2][(x+(j*bx))*wmma_sz[0]:(x+(j*bx)+1)*wmma_sz[0]])),
|
||||
self.uop(UOps.CAST, tc.dtype_in.vec(wmma_sz[1]), tuple(locals_to_store[1][2][(y+(j*by))*wmma_sz[1]:(y+(j*by)+1)*wmma_sz[1]])),
|
||||
self.uop(UOps.CAST, tc.dtype_out.vec(wmma_sz[2]), tuple(op3:=acc[i:i+wmma_sz[2]])))
|
||||
ret = self.uop(UOps.WMMA, tc.dtype_out.vec(wmma_sz[2]), ops, tc.wmma_func)
|
||||
for z in range(wmma_sz[2]):
|
||||
acc[i+z] = self.uop(UOps.PHI, tc.dtype_out, (op3[z], self.uop(UOps.GEP, tc.dtype_out, (ret,), z)) + loop_ctx)
|
||||
i += wmma_sz[2]
|
||||
else:
|
||||
if locals_to_store:
|
||||
|
||||
@@ -4,7 +4,7 @@ from tinygrad.dtype import dtypes
|
||||
|
||||
# *** image Tensor function replacements ***
|
||||
|
||||
def image_dot(self, w):
|
||||
def image_dot(self, w, acc_dtype=None):
|
||||
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
|
||||
n1, n2 = len(self.shape), len(w.shape)
|
||||
assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
|
||||
@@ -17,9 +17,9 @@ def image_dot(self, w):
|
||||
cx = self.transpose(self.ndim-1, self.ndim-2).reshape((bs//groups, groups*cin, -1, 1))
|
||||
# groups*cout x cin x H, W
|
||||
cw = w.transpose(w.ndim-1, w.ndim-2).reshape((groups*cout, cin, 1, 1))
|
||||
return image_conv2d(cx, cw, groups=groups).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2)
|
||||
return image_conv2d(cx, cw, groups=groups, acc_dtype=acc_dtype).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2)
|
||||
|
||||
def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, padding=0):
|
||||
def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype=None):
|
||||
base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef
|
||||
|
||||
(bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape
|
||||
@@ -72,7 +72,7 @@ def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, paddin
|
||||
w = w.permute(0,4,2,5,1,3).reshape((1, 1, 1, *cout_expand, rcin_hi, rcin_lo, H, W))
|
||||
|
||||
# the conv!
|
||||
ret = (x*w).cast(base_image_type((bs*oy, ox*cout//4, 4)) if IMAGE >= 2 else dtypes.float32).sum((-4, -3, -2, -1))
|
||||
ret = (x*w).cast(base_image_type((bs*oy, ox*cout//4, 4)) if IMAGE >= 2 else dtypes.float32).sum((-4, -3, -2, -1), acc_dtype=acc_dtype)
|
||||
|
||||
# undo hack for non multiples of 4 on C.rcout
|
||||
if added_output_channels != 0:
|
||||
|
||||
@@ -138,10 +138,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st
|
||||
kk(lang.render_for(ssa(u,'ridx'), r[vin[0]], r[vin[1]]))
|
||||
depth += 1
|
||||
elif uop == UOps.WMMA:
|
||||
if args[0] == "METAL" and dtype == dtypes.float.vec(2): wmma_func = "__metal_wmma<float2,simdgroup_float8x8>"
|
||||
elif args[0] == "HIP" and dtype == dtypes.float.vec(8): wmma_func = "__builtin_amdgcn_wmma_f32_16x16x16_f16_w32"
|
||||
else: raise NotImplementedError(f"WMMA not implemented for {args}")
|
||||
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u, 'wmma')} = {wmma_func}({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});") # noqa: E501
|
||||
kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u, 'wmma')} = {args}({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});") # noqa: E501
|
||||
elif uop == UOps.ALU:
|
||||
# remove parens if ALU types are the same. TODO: can do more here
|
||||
if vin[0].uop == UOps.ALU and vin[0].arg == args and args in {BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL, BinaryOps.XOR}:
|
||||
@@ -218,10 +215,10 @@ class OpenCLLanguage(CStyleLanguage):
|
||||
OpenCLRenderer = functools.partial(uops_to_cstyle, OpenCLLanguage())
|
||||
|
||||
class MetalLanguage(CStyleLanguage):
|
||||
kernel_prefix = """#include <metal_stdlib>\nusing namespace metal;\ntemplate<typename T, typename S> T __metal_wmma(T m, T n, T o) {
|
||||
kernel_prefix = """#include <metal_stdlib>\nusing namespace metal;\ntemplate<typename T, typename S, typename U> U __metal_wmma(T m, T n, U o) {
|
||||
S a,b,c; a.thread_elements()[0] = m.x; a.thread_elements()[1] = m.y; b.thread_elements()[0] = n.x; b.thread_elements()[1] = n.y;
|
||||
c.thread_elements()[0] = o.x; c.thread_elements()[1] = o.y; simdgroup_multiply_accumulate(c, a, b, c);
|
||||
return T(c.thread_elements()[0], c.thread_elements()[1]);\n}\nkernel """
|
||||
return U(c.thread_elements()[0], c.thread_elements()[1]);\n}\nkernel """
|
||||
buffer_prefix = "device "
|
||||
smem_prefix = "threadgroup "
|
||||
arg_int_prefix = "constant int&"
|
||||
|
||||
@@ -522,10 +522,10 @@ class Tensor:
|
||||
ret = fxn.apply(self, new_shape=tuple([1 if i in axis_ else s for i,s in enumerate(self.shape)]))
|
||||
return ret if keepdim else ret.reshape(shape=shape)
|
||||
|
||||
def sum(self, axis=None, keepdim=False):
|
||||
acc_dtype = least_upper_dtype(self.dtype, dtypes.uint) if dtypes.is_unsigned(self.dtype) else \
|
||||
least_upper_dtype(self.dtype, dtypes.int) if (dtypes.is_int(self.dtype) or self.dtype==dtypes.bool) else \
|
||||
least_upper_dtype(self.dtype, dtypes.float)
|
||||
def sum(self, axis=None, keepdim=False, acc_dtype=None):
|
||||
if acc_dtype is None: acc_dtype = least_upper_dtype(self.dtype, dtypes.uint) if dtypes.is_unsigned(self.dtype) else \
|
||||
least_upper_dtype(self.dtype, dtypes.int) if (dtypes.is_int(self.dtype) or self.dtype==dtypes.bool) else \
|
||||
least_upper_dtype(self.dtype, dtypes.float)
|
||||
# cast back to float16 or bfloat16 to match torch / jax behavior, but we use float for acc
|
||||
output_dtype = self.dtype if self.dtype in (dtypes.float16, dtypes.bfloat16) else acc_dtype
|
||||
return self.cast(acc_dtype)._reduce(mlops.Sum, axis, keepdim).cast(output_dtype)
|
||||
@@ -680,15 +680,16 @@ class Tensor:
|
||||
|
||||
return (ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))]))).contiguous().contiguous_backward()
|
||||
|
||||
def dot(self, w:Tensor) -> Tensor:
|
||||
def dot(self, w:Tensor, acc_dtype=None) -> Tensor:
|
||||
n1, n2 = len(self.shape), len(w.shape)
|
||||
assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
|
||||
assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})" # noqa: E501
|
||||
x = self.reshape(*self.shape[0:-1], *[1]*min(n1-1, n2-1, 1), self.shape[-1])
|
||||
w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2))
|
||||
return (x*w).sum(-1).cast(least_upper_dtype(x.dtype, w.dtype))
|
||||
return (x*w).sum(-1, acc_dtype=acc_dtype).cast(least_upper_dtype(x.dtype, w.dtype))
|
||||
|
||||
def matmul(self, x:Tensor, reverse=False) -> Tensor: return x.dot(self) if reverse else self.dot(x)
|
||||
def matmul(self, x:Tensor, reverse=False, acc_dtype:Optional[DType]=None) -> Tensor:
|
||||
return x.dot(self, acc_dtype=acc_dtype) if reverse else self.dot(x, acc_dtype=acc_dtype)
|
||||
|
||||
def _cumsum(self, axis:int=0, _first_zero=False) -> Tensor:
|
||||
return self.transpose(axis,-1).pad2d((self.shape[axis]-int(not _first_zero),0))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1)
|
||||
|
||||
Reference in New Issue
Block a user