#include "kittens.cuh" #ifndef ATTN_B constexpr int ATTN_B = 16; // batch size #endif #ifndef ATTN_H constexpr int ATTN_H = 64; // number of query heads #endif #ifndef ATTN_N constexpr int ATTN_N = 1024; // sequence length #endif constexpr int ATTN_D = 128; // dimension constexpr int DOT_SLICE_QO = 16; #define NUM_WARPS 4 #define NUM_THREADS (kittens::WARP_THREADS * NUM_WARPS) using namespace kittens; template using qo_tile = rt; template> __device__ inline static void load_shuffled(RT &dst, const GL &src, const COORD &idx) { using T2 = RT::dtype; using U = typename GL::dtype; using U2 = base_types::packing::packed_type; static_assert(std::is_same_v, "load_shuffled is only supported for bf16"); U *src_ptr = (U*)&src[(idx.template unit_coord())]; const int row_stride = src.template stride(); int laneid = kittens::laneid(); int tile_row_stride = row_stride * dst.base_tile_rows; int tile_stride = dst.base_tile_rows * dst.base_tile_cols; uint32_t buffer_size = src.batch() * src.depth() * src.rows() * src.cols() * sizeof(U); std::uintptr_t as_int = reinterpret_cast(src_ptr); std::uint64_t as_u64 = static_cast(as_int); // widen if host is 32-bit buffer_resource br = make_buffer_resource(as_u64, buffer_size, 0x00020000); #pragma unroll for(int i = 0; i < dst.height; i++) { #pragma unroll for(int j = 0; j < dst.width; j++) { U2* tmp; float4 loaded = std::bit_cast(llvm_amdgcn_raw_buffer_load_b128( std::bit_cast(br), (i * tile_row_stride + j * tile_stride + laneid * 8) * sizeof(U), 0, 0 )); tmp = reinterpret_cast(&loaded); #pragma unroll for(int k = 0; k < dst.packed_per_thread; k++) { dst.tiles[i][j].data[k] = base_types::convertor::convert(tmp[k]); } } } } template> __device__ inline static void store_shuffled(const GL &dst, const RT &src, const COORD &idx) { using T2 = RT::dtype; using U = typename GL::dtype; using U2 = base_types::packing::packed_type; U *dst_ptr = (U*)&dst[(idx.template unit_coord())]; const int row_stride = dst.template stride(); int laneid = kittens::laneid(); const int row_offset = (laneid % 4) * 4; const int col_offset = ((laneid / 32) * 16) + (((laneid % 32) / 16) * 2) + (((laneid % 16) / 4) * 4); uint32_t buffer_size = dst.batch() * dst.depth() * dst.rows() * dst.cols() * sizeof(U); std::uintptr_t as_int = reinterpret_cast(dst_ptr); std::uint64_t as_u64 = static_cast(as_int); // widen if host is 32-bit buffer_resource br = make_buffer_resource(as_u64, buffer_size, 0x00020000); #pragma unroll for(int i = 0; i < src.height; i++) { int row = src.base_tile_rows * i + row_offset; #pragma unroll for(int j = 0; j < src.width; j++) { int col = src.base_tile_cols * j + col_offset; const uint32_t val_0 = *reinterpret_cast(&src.tiles[i][j].data[0]); const uint32_t val_1 = *reinterpret_cast(&src.tiles[i][j].data[1]); const uint32_t val_2 = *reinterpret_cast(&src.tiles[i][j].data[2]); const uint32_t val_3 = *reinterpret_cast(&src.tiles[i][j].data[3]); uint32_t offset_0 = (row * row_stride + col) * sizeof(U); uint32_t offset_1 = ((row + 1) * row_stride + col) * sizeof(U); uint32_t offset_2 = ((row + 2) * row_stride + col) * sizeof(U); uint32_t offset_3 = ((row + 3) * row_stride + col) * sizeof(U); llvm_amdgcn_raw_buffer_store_b32( val_0, std::bit_cast(br), offset_0, 0, 0 ); llvm_amdgcn_raw_buffer_store_b32( val_1, std::bit_cast(br), offset_1, 0, 0 ); llvm_amdgcn_raw_buffer_store_b32( val_2, std::bit_cast(br), offset_2, 0, 0 ); llvm_amdgcn_raw_buffer_store_b32( val_3, std::bit_cast(br), offset_3, 0, 0 ); } } } // Transpose dQ from (B, H, N, D) to (B, N, H, D) using shuffled load/store // to handle the warp-level layout from atomic_pk_add_bf16_with_warpid template __launch_bounds__(NUM_THREADS, 1) __global__ void attend_dq_shuffle_ker(bf16 *dQ_out_ptr, bf16 *dQ_in_ptr) { gl dQg_in{dQ_in_ptr, ATTN_B, ATTN_H, ATTN_N, ATTN_D}; gl dQg_out{dQ_out_ptr, ATTN_B, ATTN_N, ATTN_H, ATTN_D}; const int batch_idx = blockIdx.x; const int q_head_idx = blockIdx.y; const int seq_idx = blockIdx.z; const int warpid = kittens::warpid(); qo_tile dQg; load_shuffled<2>(dQg, dQg_in, {batch_idx, q_head_idx, seq_idx * NUM_WARPS + warpid, 0}); store_shuffled<1>(dQg_out, dQg, {batch_idx, seq_idx * NUM_WARPS + warpid, q_head_idx, 0}); } template __global__ void attend_dq_shuffle_ker(bf16 *dQ_out_ptr, bf16 *dQ_in_ptr);