#include "kittens.cuh" #ifndef ATTN_B constexpr int ATTN_B = 16; // batch size #endif #ifndef ATTN_H constexpr int ATTN_H = 64; // number of query heads #endif #ifndef ATTN_H_KV constexpr int ATTN_H_KV = 8; // number of key/value heads (for GQA) #endif constexpr int GROUP_SIZE = ATTN_H / ATTN_H_KV; // queries per KV head group #ifndef ATTN_N constexpr int ATTN_N = 1024; // sequence length #endif constexpr int ATTN_D = 128; // dimension constexpr int STEP_QO = 64; // block size for QO constexpr int BLOCK_SIZE_KV = 256; // block size for KV constexpr int SLICE_QO = 32; constexpr int DOT_SLICE_QO = 16; constexpr int WARP_SIZE_KV = 64; // warp size for KV #define NUM_WARPS 4 #define NUM_THREADS (kittens::WARP_THREADS * NUM_WARPS) using G = kittens::group; using namespace kittens; template using qo_tile = rt; template using kv_tile = rt; template using qo_tile_T_dq = rt; template using qo_tile_dq = rt; template using kv_tile_T = rt; template using attn_tile = rt; template using attn_tile_T = rt; template using attn_tile_T_dq = rt; template using kv_tile_dq = rt; template struct attn_prep_globals { gl Og; gl dOg; gl delta; dim3 grid() { return dim3(ATTN_B, ATTN_H, ATTN_N / (DOT_SLICE_QO * NUM_WARPS)); } dim3 block() { return dim3(NUM_THREADS); } size_t dynamic_shared_memory() { return MAX_SHARED_MEMORY; } }; template __launch_bounds__(NUM_THREADS, 1) __global__ void attend_prep_ker(float *delta_ptr, bf16 *dq_ptr, bf16 *O_ptr, bf16 *dO_ptr) { gl delta{delta_ptr, ATTN_B, ATTN_H, 1, ATTN_N}; gl dQg{dq_ptr, ATTN_B, ATTN_H, ATTN_N, ATTN_D}; gl Og{O_ptr, ATTN_B, ATTN_N, ATTN_H, ATTN_D}; gl dOg{dO_ptr, ATTN_B, ATTN_N, ATTN_H, ATTN_D}; attn_prep_globals g{Og, dOg, delta}; const int batch_idx = blockIdx.x; const int head_idx = blockIdx.y; const int seq_idx = blockIdx.z; const int warpid = kittens::warpid(); qo_tile dO, O; qo_tile dO_float, O_float; typename qo_tile::col_vec delta_vec; load<1>(dO, g.dOg, {batch_idx, seq_idx * NUM_WARPS + warpid, head_idx, 0}); load<1>(O, g.Og, {batch_idx, seq_idx * NUM_WARPS + warpid, head_idx, 0}); copy(O_float, O); copy(dO_float, dO); // Δ_i = row_sum(dO ⊙ O) mul(dO_float, dO_float, O_float); row_sum(delta_vec, dO_float); store(g.delta, delta_vec, {batch_idx, head_idx, 0, seq_idx * NUM_WARPS + warpid}); // Zero out dq qo_tile dQ_zero; zero(dQ_zero); store<2>(dQg, dQ_zero, {batch_idx, head_idx, seq_idx * NUM_WARPS + warpid, 0}); } template __global__ void attend_prep_ker(float *delta_ptr, bf16 *dq_ptr, bf16 *O_ptr, bf16 *dO_ptr);