diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 055bfb27b4..07c1d92fb5 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -1040,7 +1040,8 @@ class TestLinearizer(unittest.TestCase): def test_tensor_cores(self): for tc in Device[Device.DEFAULT].renderer.tensor_cores: if (getenv("EMULATE_CUDA") or getenv("EMULATE_INTEL")) and (tc.dtype_in == dtypes.bfloat16 or tc.dtype_out == dtypes.bfloat16): continue - helper_tc_allclose(tc.dims[0], tc.dims[1], tc.dims[2], tc.dtype_in, tc.dtype_out, axis=0, tc_opt=0) + # for AMX, tc.dims[2] == 1 so reduceop is None thus tensor_cores are not triggered + helper_tc_allclose(tc.dims[0], tc.dims[1], 2 if AMX else tc.dims[2], tc.dtype_in, tc.dtype_out, axis=0, tc_opt=0) @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") def test_tensor_cores_padded(self): @@ -1061,7 +1062,8 @@ class TestLinearizer(unittest.TestCase): # check excessive padding doesn't trigger padded TC in TC_OPT=2 helper_tc_ensure_uops_and_opts_count(tc.dims[0]//4, tc.dims[1], tc.dims[2], tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=False) helper_tc_ensure_uops_and_opts_count(tc.dims[0], tc.dims[1]//4, tc.dims[2], tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=False) - helper_tc_ensure_uops_and_opts_count(tc.dims[0], tc.dims[1], tc.dims[2]//4, tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=False) + if not AMX: # AMX tc.dims[2] == 1 + helper_tc_ensure_uops_and_opts_count(tc.dims[0], tc.dims[1], tc.dims[2]//4, tc.dtype_in, tc.dtype_out, tc_opt=2, ensure_triggered=False) # check correctness helper_tc_allclose(tc.dims[0]+pad, tc.dims[1]+pad, tc.dims[2]+pad, tc.dtype_in, tc.dtype_out, tc_opt=2) diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 0488924d3a..12c3459e22 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -330,7 +330,8 @@ class Kernel: if tc.dims[i] > sz: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[i], tc.dims[i]//sz), append_opt=False) for (tc_dim, tc_amt) in tc.threads: self.apply_opt(Opt(OptOps.LOCAL, tc_opts.axes[tc_dim], tc_amt), append_opt=False) elif self.opts.device == "CLANG": - for (i, sz) in tc.threads: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[i], sz), append_opt=False) + for i, sz in enumerate([prod(x) for x in [[x[1] for x in tc.threads if x[0]==dim] for dim in range(2)]]): # upcast non-local'd N, M + if tc.dims[i] > sz: self.apply_opt(Opt(OptOps.UPCAST, tc_opts.axes[i], tc.dims[i]//sz), append_opt=False) elif self.opts.device in {"CUDA", "NV"}: self.apply_opt(Opt(OptOps.UNROLL, tc_opts.axes[2]-self.first_reduce, 8), append_opt=False) self.apply_opt(Opt(OptOps.UNROLL, tc_opts.axes[2]-self.first_reduce, 2), append_opt=False) @@ -367,7 +368,7 @@ class Kernel: if extra_opts is not None: for opt in extra_opts: self.apply_opt(opt) else: - if (self.opts.device == "CLANG" and AMX): return True # skip hand-coded TC opts if CLANG, upcasting will make kernel slower + if (self.opts.device == "CLANG" and AMX): return True # skip hand-coded TC opts if AMX, upcasting will make kernel slower # hand-coded TC opts def late_upcast_tc(tc_dim: int): if tc_opts.axes_exist[tc_dim]: @@ -661,7 +662,6 @@ class Kernel: [y + (wd if x == 0 else tcd) for x,y in pattern_2] + list(range(tcd+len(tcd_expand), len(new_shape))) return st1.reshape(new_shape).simplify().permute(tuple(permaxis)).reshape(st1.shape).simplify() - threads = prod(t[1] for t in tc.threads) if self.opts.device in {"AMD", "HIP"}: reduce_axes, upcast_axes = [0], [[(0, 16)], [(0, 16)], [(1, 8)]] # https://gpuopen.com/learn/wmma_on_rdna3/ @@ -672,8 +672,8 @@ class Kernel: fix_st1 = functools.partial(fix_st, (2,4,2,2), (8,2), (2,2,2,2), ((1,1), (0,1), (1,0), (0,3)), ((0,0), (0,2), (1,3), (1,2))) fix_st2 = functools.partial(fix_st, (2,4,2,2), (8,2), (2,2,2,2), ((0,0), (1,1), (1,2), (0,2), (1,0)), ((0,1), (0,3), (1,3))) elif self.opts.device == "CLANG": - reduce_axes, upcast_axes = [], [[(1,tc.dims[0])],[(0,tc.dims[1])],[(1, tc.dims[2]), (0, tc.dims[2])]] - threads, fix_st1, fix_st2 = threads // tc.dims[2], None, None + reduce_axes, upcast_axes = [], [[(1,tc.dims[0])],[(0,tc.dims[1])],[(1, tc.dims[0]), (0, tc.dims[1])]] + fix_st1, fix_st2 = None, None elif self.opts.device in {"CUDA", "NV"}: reduce_axes, upcast_axes = [0, 1], [[(0, 8)], [(2, 2), (3, 2)], [(2, 2), (3, 2)]] # https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float @@ -689,7 +689,7 @@ class Kernel: raise RuntimeError("unsupported device for tensor cores") assert apply_to_st is None, "double tensor core? not supported" - wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device, threads, + wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device, prod(t[1] for t in tc.threads), tuple(tuple((self.first_upcast+ax, sz) for ax, sz in up) for up in upcast_axes), tuple(self.first_upcast+ax for ax in reduce_axes)) if self.use_tensor_cores >= 2: diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index e18f5aa836..5ead70a5e5 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -204,7 +204,7 @@ class ClangRenderer(CStyleLanguage): if AMX: tc_types = [(dtype, amx_size//dtype.itemsize) for dtype, amx_size in zip([dtypes.float], [64])] - tensor_cores = [TensorCore(dims=(sz,sz,sz), threads=[(0,sz),(1,sz)], dtype_in=dtype, dtype_out=dtype) for dtype, sz in tc_types] + tensor_cores = [TensorCore(dims=(sz,sz,1), threads=[], dtype_in=dtype, dtype_out=dtype) for dtype, sz in tc_types] def render_vector_prefix(self, dt:DType) -> str: return f"typedef {self.render_dtype(dt.scalar())} {self.render_dtype(dt)} __attribute__((aligned({(sz:=dt.itemsize)}),vector_size({sz})));" @@ -212,12 +212,12 @@ class ClangRenderer(CStyleLanguage): def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str: prefix, macros = [self.render_vector_prefix(dt) for dt in dedup(uop.dtype for uop in uops if uop.dtype.count>1)], [] # https://github.com/corsix/amx - for name, (N, M, K), dtype_in, _, _, _, _, _ in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]): + for name, (N, M, _), dtype_in, _, _, _, _, _ in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]): macros = [ '#define AMX_SET(imm5) __asm("nop\\nnop\\nnop\\n.word (0x201000+(%0<<5)+%1)" : : "i"(17), "i"(imm5) : "memory")', '#define AMX(op, gpr, btf) __asm(".word (0x201000+(%0 << 5)+0%1-((0%1>>4)*6))" : : "i"(op), "r"((unsigned long long)(gpr)+(btf)) : "memory")', ] - prefix += [f"""{(out := self.render_dtype(dtype_in.vec(K*K)))} __{name}({self.render_dtype(dtype_in.vec(N))} data1, {self.render_dtype(dtype_in.vec(M))} data2, {out} data0){{ + prefix += [f"""{(out := self.render_dtype(dtype_in.vec(N*N)))} __{name}({self.render_dtype(dtype_in.vec(N))} data1, {self.render_dtype(dtype_in.vec(M))} data2, {out} data0){{ AMX_SET(0);\n for(int ridx0 = 0; ridx0 < 16; ridx0++){{ AMX(4, (int *)(&data0), 0ull<<62 | (ridx0*4ull)<<56 | ridx0*64ull); }} AMX(0, (int *)(&data2), 0ull<<62); AMX(1, (int *)(&data1), 0ull<<62); AMX(12, 0, 0ull); for(int ridx0 = 0; ridx0 < 16; ridx0++){{ AMX(5, (int *)(&data0), 0ull<<62 | (ridx0*4ull)<<56 | ridx0*64ull); }}\n AMX_SET(1);\n return data0;\n}}"""] # noqa: E501