mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 08:28:55 +08:00
Fix amx shape [run_process_replay] (#6524)
* fix amx shape (sz,sz,sz) -> (sz,sz,1) * revert check
This commit is contained in:
@@ -1040,7 +1040,8 @@ class TestLinearizer(unittest.TestCase):
|
||||
def test_tensor_cores(self):
|
||||
for tc in Device[Device.DEFAULT].renderer.tensor_cores:
|
||||
if (getenv("EMULATE_CUDA") or getenv("EMULATE_INTEL")) and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue
|
||||
helper_tc_allclose(tc.dims[0], tc.dims[1], tc.dims[2], tc.dtype_in, tc.dtype_out, axis=0, tc_opt=0)
|
||||
# for AMX, tc.dims[2] == 1 so reduceop is None thus tensor_cores are not triggered
|
||||
helper_tc_allclose(tc.dims[0], tc.dims[1], 2 if AMX else tc.dims[2], tc.dtype_in, tc.dtype_out, axis=0, tc_opt=0)
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
|
||||
def test_tensor_cores_padded(self):
|
||||
@@ -1061,7 +1062,8 @@ class TestLinearizer(unittest.TestCase):
|
||||
# check excessive padding doesn't trigger padded TC in TC_OPT=2
|
||||
helper_tc_ensure_uops_and_opts_count(tc.dims[0]//4, tc.dims[1], tc.dims[2], tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=False)
|
||||
helper_tc_ensure_uops_and_opts_count(tc.dims[0], tc.dims[1]//4, tc.dims[2], tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=False)
|
||||
helper_tc_ensure_uops_and_opts_count(tc.dims[0], tc.dims[1], tc.dims[2]//4, tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=False)
|
||||
if not AMX: # AMX tc.dims[2] == 1
|
||||
helper_tc_ensure_uops_and_opts_count(tc.dims[0], tc.dims[1], tc.dims[2]//4, tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=False)
|
||||
|
||||
# check correctness
|
||||
helper_tc_allclose(tc.dims[0]+pad, tc.dims[1]+pad, tc.dims[2]+pad, tc.dtype_in, tc.dtype_out, tc_opt=2)
|
||||
|
||||
@@ -330,7 +330,8 @@ class Kernel:
|
||||
if tc.dims[i] > sz: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[i], tc.dims[i]//sz), append_opt=False)
|
||||
for (tc_dim, tc_amt) in tc.threads: self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[tc_dim], tc_amt), append_opt=False)
|
||||
elif self.opts.device == "CLANG":
|
||||
for (i, sz) in tc.threads: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[i], sz), append_opt=False)
|
||||
for i, sz in enumerate([prod(x) for x in [[x[1] for x in tc.threads if x[0]==dim] for dim in range(2)]]): # upcast non-local'd N, M
|
||||
if tc.dims[i] > sz: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[i], tc.dims[i]//sz), append_opt=False)
|
||||
elif self.opts.device in {"CUDA", "NV"}:
|
||||
self.apply_opt(Opt(OptOps.UNROLL, tc_opts.axes[2]-self.first_reduce, 8), append_opt=False)
|
||||
self.apply_opt(Opt(OptOps.UNROLL, tc_opts.axes[2]-self.first_reduce, 2), append_opt=False)
|
||||
@@ -367,7 +368,7 @@ class Kernel:
|
||||
if extra_opts is not None:
|
||||
for opt in extra_opts: self.apply_opt(opt)
|
||||
else:
|
||||
if (self.opts.device == "CLANG" and AMX): return True # skip hand-coded TC opts if CLANG, upcasting will make kernel slower
|
||||
if (self.opts.device == "CLANG" and AMX): return True # skip hand-coded TC opts if AMX, upcasting will make kernel slower
|
||||
# hand-coded TC opts
|
||||
def late_upcast_tc(tc_dim: int):
|
||||
if tc_opts.axes_exist[tc_dim]:
|
||||
@@ -661,7 +662,6 @@ class Kernel:
|
||||
[y + (wd if x == 0 else tcd) for x,y in pattern_2] + list(range(tcd+len(tcd_expand), len(new_shape)))
|
||||
return st1.reshape(new_shape).simplify().permute(tuple(permaxis)).reshape(st1.shape).simplify()
|
||||
|
||||
threads = prod(t[1] for t in tc.threads)
|
||||
if self.opts.device in {"AMD", "HIP"}:
|
||||
reduce_axes, upcast_axes = [0], [[(0, 16)], [(0, 16)], [(1, 8)]]
|
||||
# https://gpuopen.com/learn/wmma_on_rdna3/
|
||||
@@ -672,8 +672,8 @@ class Kernel:
|
||||
fix_st1 = functools.partial(fix_st, (2,4,2,2), (8,2), (2,2,2,2), ((1,1), (0,1), (1,0), (0,3)), ((0,0), (0,2), (1,3), (1,2)))
|
||||
fix_st2 = functools.partial(fix_st, (2,4,2,2), (8,2), (2,2,2,2), ((0,0), (1,1), (1,2), (0,2), (1,0)), ((0,1), (0,3), (1,3)))
|
||||
elif self.opts.device == "CLANG":
|
||||
reduce_axes, upcast_axes = [], [[(1,tc.dims[0])],[(0,tc.dims[1])],[(1, tc.dims[2]), (0, tc.dims[2])]]
|
||||
threads, fix_st1, fix_st2 = threads // tc.dims[2], None, None
|
||||
reduce_axes, upcast_axes = [], [[(1,tc.dims[0])],[(0,tc.dims[1])],[(1, tc.dims[0]), (0, tc.dims[1])]]
|
||||
fix_st1, fix_st2 = None, None
|
||||
elif self.opts.device in {"CUDA", "NV"}:
|
||||
reduce_axes, upcast_axes = [0, 1], [[(0, 8)], [(2, 2), (3, 2)], [(2, 2), (3, 2)]]
|
||||
# https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float
|
||||
@@ -689,7 +689,7 @@ class Kernel:
|
||||
raise RuntimeError("unsupported device for tensor cores")
|
||||
|
||||
assert apply_to_st is None, "double tensor core? not supported"
|
||||
wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device, threads,
|
||||
wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device, prod(t[1] for t in tc.threads),
|
||||
tuple(tuple((self.first_upcast+ax, sz) for ax, sz in up) for up in upcast_axes),
|
||||
tuple(self.first_upcast+ax for ax in reduce_axes))
|
||||
if self.use_tensor_cores >= 2:
|
||||
|
||||
@@ -204,7 +204,7 @@ class ClangRenderer(CStyleLanguage):
|
||||
|
||||
if AMX:
|
||||
tc_types = [(dtype, amx_size//dtype.itemsize) for dtype, amx_size in zip([dtypes.float], [64])]
|
||||
tensor_cores = [TensorCore(dims=(sz,sz,sz), threads=[(0,sz),(1,sz)], dtype_in=dtype, dtype_out=dtype) for dtype, sz in tc_types]
|
||||
tensor_cores = [TensorCore(dims=(sz,sz,1), threads=[], dtype_in=dtype, dtype_out=dtype) for dtype, sz in tc_types]
|
||||
|
||||
def render_vector_prefix(self, dt:DType) -> str:
|
||||
return f"typedef {self.render_dtype(dt.scalar())} {self.render_dtype(dt)} __attribute__((aligned({(sz:=dt.itemsize)}),vector_size({sz})));"
|
||||
@@ -212,12 +212,12 @@ class ClangRenderer(CStyleLanguage):
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
||||
prefix, macros = [self.render_vector_prefix(dt) for dt in dedup(uop.dtype for uop in uops if uop.dtype.count>1)], []
|
||||
# https://github.com/corsix/amx
|
||||
for name, (N, M, K), dtype_in, _, _, _, _, _ in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]):
|
||||
for name, (N, M, _), dtype_in, _, _, _, _, _ in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]):
|
||||
macros = [
|
||||
'#define AMX_SET(imm5) __asm("nop\\nnop\\nnop\\n.word (0x201000+(%0<<5)+%1)" : : "i"(17), "i"(imm5) : "memory")',
|
||||
'#define AMX(op, gpr, btf) __asm(".word (0x201000+(%0 << 5)+0%1-((0%1>>4)*6))" : : "i"(op), "r"((unsigned long long)(gpr)+(btf)) : "memory")',
|
||||
]
|
||||
prefix += [f"""{(out := self.render_dtype(dtype_in.vec(K*K)))} __{name}({self.render_dtype(dtype_in.vec(N))} data1, {self.render_dtype(dtype_in.vec(M))} data2, {out} data0){{
|
||||
prefix += [f"""{(out := self.render_dtype(dtype_in.vec(N*N)))} __{name}({self.render_dtype(dtype_in.vec(N))} data1, {self.render_dtype(dtype_in.vec(M))} data2, {out} data0){{
|
||||
AMX_SET(0);\n for(int ridx0 = 0; ridx0 < 16; ridx0++){{ AMX(4, (int *)(&data0), 0ull<<62 | (ridx0*4ull)<<56 | ridx0*64ull); }}
|
||||
AMX(0, (int *)(&data2), 0ull<<62); AMX(1, (int *)(&data1), 0ull<<62); AMX(12, 0, 0ull);
|
||||
for(int ridx0 = 0; ridx0 < 16; ridx0++){{ AMX(5, (int *)(&data0), 0ull<<62 | (ridx0*4ull)<<56 | ridx0*64ull); }}\n AMX_SET(1);\n return data0;\n}}"""] # noqa: E501
|
||||
|
||||
Reference in New Issue
Block a user