diff --git a/extra/thunder/cuda/include/types/global/gl.cuh b/extra/thunder/cuda/include/types/global/gl.cuh index d7eceae2ee..cbbc3f3f9d 100644 --- a/extra/thunder/cuda/include/types/global/gl.cuh +++ b/extra/thunder/cuda/include/types/global/gl.cuh @@ -65,8 +65,8 @@ template struct descripto namespace detail { template struct descriptor_dict { - __host__ descriptor_dict() {} - template __host__ descriptor_dict(T _, int b, int d, int r, int c) {} + __host__ __device__ descriptor_dict() {} + template __host__ __device__ descriptor_dict(T _, int b, int d, int r, int c) {} __host__ __device__ descriptor_dict(const descriptor_dict &other) {} #ifdef KITTENS_HOPPER template __device__ const CUtensorMap* get() const { @@ -85,8 +85,8 @@ struct descriptor_dict<_T, Args...> { using DESC = kittens::tma::descriptor<_T>; // copy or initialize with a default value CUtensorMap tma_desc; descriptor_dict other_descs; - __host__ descriptor_dict() {} - __host__ descriptor_dict(typename DESC::T::dtype *data, int b, int d, int r, int c): other_descs(data, b, d, r, c) { + __host__ __device__ descriptor_dict() {} + __host__ __device__ descriptor_dict(typename DESC::T::dtype *data, int b, int d, int r, int c): other_descs(data, b, d, r, c) { kittens::detail::tma::create_tensor_map(&tma_desc, data, b, d, r, c); } __host__ __device__ inline descriptor_dict(const descriptor_dict &other) : @@ -135,7 +135,7 @@ struct gl { detail::descriptor_dict tma_descs; - __host__ inline gl(T *_data, + __host__ __device__ inline gl(T *_data, ducks::gl::make_arg_t _batch, ducks::gl::make_arg_t _depth, ducks::gl::make_arg_t _rows, diff --git a/extra/thunder/cuda/include/types/global/tma.cuh b/extra/thunder/cuda/include/types/global/tma.cuh index c52c266d80..4ffa9ba0bd 100644 --- a/extra/thunder/cuda/include/types/global/tma.cuh +++ b/extra/thunder/cuda/include/types/global/tma.cuh @@ -425,4 +425,4 @@ __host__ static inline CUtensorMap* allocate_and_create_tensor_map(const typenam } // namespace tma } // namespace detail -} // namespace kittens \ No newline at end of file +} // namespace kittens diff --git a/extra/thunder/cuda/matmul.cu b/extra/thunder/cuda/matmul.cu new file mode 100644 index 0000000000..29cab02292 --- /dev/null +++ b/extra/thunder/cuda/matmul.cu @@ -0,0 +1,45 @@ +// https://github.com/HazyResearch/ThunderKittens/blob/main/kernels/matmul/educational/level_04.cu +#include "kittens.cuh" +using namespace kittens; + +constexpr int g_N = 8192; +constexpr int BLOCK_SIZE = 32; +#define NUM_WORKERS (1) +#define NUM_THREADS (NUM_WORKERS*kittens::WARP_THREADS) + +using sub_tile = st_bf; +using tile_gl = gl; + +__global__ void kernel(bf16 *c_ptr, bf16 *a_ptr, bf16 *b_ptr) { + tile_gl g_C{c_ptr, nullptr, nullptr, nullptr, nullptr}; + tile_gl g_A{a_ptr, nullptr, nullptr, nullptr, nullptr}; + tile_gl g_B{b_ptr, nullptr, nullptr, nullptr, nullptr}; + + extern __shared__ alignment_dummy __shm[]; + shared_allocator al((int*)&__shm[0]); + st_bf &As = al.allocate>(); + st_bf &Bs = al.allocate>(); + + rt_bf A_reg; + rt_bf B_reg; + rt_bf B_reg_col; + rt_fl C_accum; + + int col = blockIdx.x; + int row = blockIdx.y; + + warp::zero(C_accum); + int num_tiles = (g_N + BLOCK_SIZE - 1) / BLOCK_SIZE; + for (int tile = 0; tile < num_tiles; ++tile) { + warp::load(As, g_A, {0, 0, row, tile}); + warp::load(Bs, g_B, {0, 0, tile, col}); + __syncthreads(); + warp::load(A_reg, As); + warp::load(B_reg, Bs); + warp::swap_layout(B_reg_col, B_reg); + __syncthreads(); + warp::mma_AB(C_accum, A_reg, B_reg_col, C_accum); + __syncthreads(); + } + warp::store(g_C, C_accum, {0, 0, row, col}); +} diff --git a/extra/thunder/cuda/matmul.py b/extra/thunder/cuda/matmul.py new file mode 100644 index 0000000000..fe3bd577e4 --- /dev/null +++ b/extra/thunder/cuda/matmul.py @@ -0,0 +1,37 @@ +import pathlib +from tinygrad import Device, Tensor +from tinygrad.helpers import Context +from tinygrad.runtime.support.compiler_cuda import pretty_ptx, NVCCCompiler + +if __name__ == "__main__": + code = (pathlib.Path(__file__).parent / "matmul.cu").read_text() + device = Device["CUDA"] + kitten_args = [f"-I{(pathlib.Path(__file__).parent / 'include').as_posix()}", "-std=c++20", "--expt-relaxed-constexpr", "-DKITTENS_HOPPER"] + lib = NVCCCompiler(device.compiler.arch, kitten_args).compile(code) + kernel_name = lib.decode().split(".globl\t")[1].split("\n")[0] + print("kernel name", kernel_name) + print(pretty_ptx(lib.decode())) + + prg = device.runtime(kernel_name, lib) + prg.smem = 10000 + + N = 8192 + a = Tensor.randn(N, N, device='CUDA', dtype="bfloat16") + b = Tensor.randn(N, N, device='CUDA', dtype="bfloat16") + c = Tensor.empty(N, N, device='CUDA', dtype="bfloat16") + Tensor.realize(a, b, c) + + BLOCK_SIZE = 32 + + gsz = (N // BLOCK_SIZE, N // BLOCK_SIZE, 1) + for _ in range(5): + et = prg(c.uop.buffer.ensure_allocated()._buf, a.uop.buffer._buf, b.uop.buffer._buf, + global_size=gsz, local_size=(32,1,1), wait=True) + print(f"{N*N*N*2/(et*1e9):2f} GFLOPS") + + for _ in range(5): + with Context(DEBUG=2): + ref = (a@b).realize() + + ref, c = ref.float(), c.float() + print((ref-c).mean().item(), (ref-c).max().item())