#include "kittens.cuh" #include "utils.cpp" #ifndef ATTN_B constexpr int ATTN_B = 16; // batch size #endif #ifndef ATTN_H constexpr int ATTN_H = 64; // number of query heads #endif #ifndef ATTN_H_KV constexpr int ATTN_H_KV = 8; // number of key/value heads (for GQA) #endif constexpr int GROUP_SIZE = ATTN_H / ATTN_H_KV; // queries per KV head group #ifndef ATTN_N constexpr int ATTN_N = 1024; // sequence length #endif constexpr int ATTN_D = 128; // dimension constexpr int STEP_QO = 64; // block size for QO constexpr int BLOCK_SIZE_KV = 256; // block size for KV constexpr int SLICE_QO = 32; constexpr int DOT_SLICE_QO = 16; constexpr int WARP_SIZE_KV = 64; // warp size for KV constexpr bool causal = true; #define NUM_WARPS 4 #define NUM_THREADS (kittens::WARP_THREADS * NUM_WARPS) using G = kittens::group; using namespace kittens; using _gl_QdO = gl; using _gl_KV = gl; using _gl_dQ = gl; using _gl_dKV = gl; using _gl_Lvec = gl; template struct attn_bwd_combined_globals { _gl_QdO Q; _gl_KV K, V; _gl_QdO dOg; _gl_dQ dQg; _gl_dKV dKg, dVg; _gl_Lvec L_vec, delta_vec; dim3 grid() { return dim3(ATTN_H, (ATTN_N / BLOCK_SIZE_KV), ATTN_B); } dim3 block() { return dim3(NUM_THREADS); } size_t dynamic_shared_memory() { return MAX_SHARED_MEMORY; } }; template __launch_bounds__(NUM_THREADS, 1) __global__ void attend_bwd_combined_ker(bf16 *dQ_ptr, bf16 *dK_ptr, bf16 *dV_ptr, bf16 *dO_ptr, bf16 *Q_ptr, bf16 *K_ptr, bf16 *V_ptr, float *L_vec_ptr, float *delta_vec_ptr) { const int q_head_idx_fixed = blockIdx.x; // This is the query head index [0, ATTN_H) const int kv_head_idx = q_head_idx_fixed / GROUP_SIZE; const int q_head_in_group = q_head_idx_fixed % GROUP_SIZE; const int seq_idx = blockIdx.y; const int batch_idx = blockIdx.z; const int first_q_head = q_head_idx_fixed; const int warpid = kittens::warpid(); const int j = seq_idx * NUM_WARPS + warpid; // optimization on loops bounds const int total_steps_per_head = ATTN_N / STEP_QO; const int j_min = seq_idx * NUM_WARPS; const int k_start_min = j_min * WARP_SIZE_KV; // first Q step that can overlap this K_span: const int first_step = max(0, k_start_min / STEP_QO); const int num_steps_per_head = total_steps_per_head - first_step; const int num_steps = num_steps_per_head; const int k_pos = j * WARP_SIZE_KV; constexpr float L_SCALE_FACTOR = 1.44269504089f; constexpr float P_SCALE_FACTOR = (D == 128) ? 0.08838834764f*1.44269504089f : 0.125f*1.44269504089f; constexpr float dP_SCALE_FACTOR = (D == 128) ? 0.08838834764f : 0.125f; // Shared tiles extern __shared__ alignment_dummy __shm[]; shared_allocator al((int*)&__shm[0]); st_bf (&K_j_smem) = al.allocate>(); st_bf (&Q_i_smem)[2][2] = al.allocate, 2, 2>(); st_bf (&dO_i_smem)[2][2] = al.allocate, 2, 2>(); st_bf (&attn_i_smem) = al.allocate>(); sv_fl (&L_smem)[2] = al.allocate, 2>(); sv_fl (&delta_smem)[2] = al.allocate, 2>(); // Register tiles using Q_ranges = ducks::art::split_many_t>, 4>; // 16 registers - a[112:127] using dO_ranges = ducks::art::split_many_t>, 4>; // 16 registers - v[72:87] using dO_col_ranges = ducks::art::split_many_t>, 4>; // 16 registers - v[88:103] using K_ranges = ducks::art::split_many_t, ducks::art::range<62, 77>>, 4>; // 64 registers - a[0:47] & v[56:71] using V_ranges = ducks::art::split_many_t>, 4>; // 64 registers - a[48:111] using P_ranges = ducks::art::split_many_t>, 4>; // 16 registers - v[40:55] using dP_ranges = ducks::art::split_many_t>, 4>; // 16 registers - v[56:71] using P_bf16_ranges = ducks::art::split_many_t>, 2>; // 8 registers - v[116:123] using dP_bf16_ranges = ducks::art::split_many_t>, 2>; // 8 registers - v[56:63] using P_bf16_col_ranges = ducks::art::split_many_t>, 4>; // 8 registers using dP_bf16_col_ranges = ducks::art::split_many_t>, 4>; // 8 registers using dS_ranges = ducks::art::split_many_t>, 4>; // 32 registers - v[24:55] using dQ_ranges = ducks::art::split_many_t>, 4>; // 8 registers - v[108:115] ducks::art::clobber(); ducks::art::clobber(); ducks::art::clobber(); ducks::art::clobber(); ducks::art::clobber(); ducks::art::clobber(); ducks::art::clobber(); ducks::art::clobber(); ducks::art::clobber(); ducks::art::clobber(); ducks::art::clobber(); using dV_ranges = ducks::art::split_many_t>, 16>; // 128 registers v[128:255] using dK_ranges = ducks::art::split_many_t>, 16>; // 128 registers a[128:255] ducks::art::clobber(); ducks::art::clobber(); art Q_i; // 16 registers art dO_i; // 16 registers art Q_i_col; // 16 registers art dO_i_col; // 16 registers art K_j; // 64 registers art V_j; // 64 registers constexpr int L_i = 126; constexpr int delta_i = 127; constexpr int neg_inf_v = 29; // Move -inf to VGPR neg_inf_v kittens::macros::clobber_gpr(); kittens::macros::v_mov_b32(0xff800000); art P_ij; // 16 registers art dP_ij; // 16 registers art P_ij_bf16; // 8 registers art dP_ij_bf16; // 8 registers art> dP_ij_bf16_accum_row; // 8 registers art P_ij_bf16_col; // 8 registers art dP_ij_bf16_col; // 8 registers art K_j_col; // 64 registers // for dq art dP_ij_bf16_col_T; // 32 registers // for dq art dK_j_T; // 128 registers art dV_j_T; // 128 registers art dQ_i_T; // 8 registers // for dq art> dQ_i; // 8 registers // for dq // This is used for both dK_j_T and dV_j_T art> dV_j; // Construct gl objects with compile-time dims AFTER clobbers so compiler knows which VGPRs are taken _gl_dQ dQg{dQ_ptr, nullptr, nullptr, nullptr, nullptr}; _gl_dKV dKg{dK_ptr, nullptr, nullptr, nullptr, nullptr}; _gl_dKV dVg{dV_ptr, nullptr, nullptr, nullptr, nullptr}; _gl_QdO dOg{dO_ptr, nullptr, nullptr, nullptr, nullptr}; _gl_QdO Q{Q_ptr, nullptr, nullptr, nullptr, nullptr}; _gl_KV K{K_ptr, nullptr, nullptr, nullptr, nullptr}; _gl_KV V{V_ptr, nullptr, nullptr, nullptr, nullptr}; _gl_Lvec L_vec_gl{L_vec_ptr, nullptr, nullptr, nullptr, nullptr}; _gl_Lvec delta_vec_gl{delta_vec_ptr, nullptr, nullptr, nullptr, nullptr}; attn_bwd_combined_globals g{Q, K, V, dOg, dQg, dKg, dVg, L_vec_gl, delta_vec_gl}; // Swizzled offsets for Q and dO constexpr int bytes_per_thread = st_16x32_s::template bytes_per_thread(); constexpr int bytes_per_warp = bytes_per_thread * kittens::WARP_THREADS; constexpr int memcpy_per_tile = BLOCK_SIZE_KV * DOT_SLICE_QO * sizeof(bf16) / (bytes_per_thread * NUM_THREADS); static_assert(BLOCK_SIZE_KV * DOT_SLICE_QO * sizeof(bf16) >= bytes_per_warp, "shared tile must be at least 1024 bytes"); uint32_t swizzled_offsets_Q_dO[memcpy_per_tile]; G::prefill_swizzled_offsets<1, false>(Q_i_smem[0][0], g.Q, swizzled_offsets_Q_dO); int tic = 0, toc = 1; // Load K_j from HBM to shared memory G::load<1, false>(K_j_smem, g.K, {batch_idx, seq_idx, kv_head_idx, 0}); // Load V_j from HBM to registers load<1>(V_j, g.V, {batch_idx, 0, kv_head_idx, 0}, {0, j, 0, 0}); // Load Q, dO, L, delta for this specific query head load(L_smem[tic], g.L_vec, {batch_idx, first_q_head, 0, first_step}); load(delta_smem[tic], g.delta_vec, {batch_idx, first_q_head, 0, first_step}); G::load<1, false>(Q_i_smem[tic][0], g.Q, {batch_idx, first_step * 2 + 0, first_q_head, 0}, swizzled_offsets_Q_dO); G::load<1, false>(dO_i_smem[tic][0], g.dOg, {batch_idx, first_step * 2 + 0, first_q_head, 0}, swizzled_offsets_Q_dO); G::load<1, false>(Q_i_smem[tic][1], g.Q, {batch_idx, first_step * 2 + 1, first_q_head, 0}, swizzled_offsets_Q_dO); G::load<1, false>(dO_i_smem[tic][1], g.dOg, {batch_idx, first_step * 2 + 1, first_q_head, 0}, swizzled_offsets_Q_dO); __builtin_amdgcn_s_waitcnt(0); __builtin_amdgcn_s_barrier(); __builtin_amdgcn_sched_barrier(0); // Addresses const uint32_t K_j_addr = get_address(K_j, subtile_inplace(K_j_smem, {warpid, 0})); // Compute K_j_col_addr // uint32_t K_j_col_addr = get_address(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid})); const uint32_t K_j_col_addr = [&] { const int laneid = kittens::laneid(); const uint32_t src_ptr = reinterpret_cast(&subtile_inplace<256, 32>(K_j_smem, {0, warpid}).data[0]); const int row_offset = (laneid % 16) / 4 + (laneid / 16) * 4; const int col_offset = ((laneid % 4) * 4); const int lane_byte_offset = (row_offset * 16 + col_offset) * sizeof(bf16); const uint32_t addr = src_ptr + lane_byte_offset; return addr; }(); auto attn_i_smem_subtile = subtile_inplace(attn_i_smem, {warpid, 0}); const uint32_t dP_ij_bf16_accum_row_addr = get_address(attn_i_smem_subtile, dP_ij_bf16_accum_row); uint32_t Q_i_addr; uint32_t dO_i_addr; uint32_t dO_i_col_addr; uint32_t Q_i_col_addr; // Compute dP_ij_bf16_col_T_addr // const uint32_t dP_ij_bf16_col_T_addr = [&] { // const int laneid = kittens::laneid(); // const uint32_t src_ptr = reinterpret_cast(&attn_i_smem.data[0]); // const int row_offset = (laneid % 16) / 4 + (laneid / 16) * 4; // const int col_offset = ((laneid % 4) * 4); // const int lane_byte_offset = (row_offset * 16 + col_offset) * sizeof(bf16); // const int swizzled_lane_byte_offset = lane_byte_offset ^ ((lane_byte_offset >> 7) << 3); // const uint32_t addr = src_ptr + swizzled_lane_byte_offset; // return addr; // }(); uint32_t dP_ij_bf16_col_T_addr = get_address(dP_ij_bf16_col_T, attn_i_smem); if (num_steps > 1) { // Prologue { const int q_head_idx = (0) / num_steps_per_head + first_q_head; const int q_seq_idx = ((0) % num_steps_per_head) + first_step; const int q_pos = q_seq_idx * STEP_QO; const int next_q_head_idx = (0 + 1) / num_steps_per_head + first_q_head; const int next_q_seq_idx = ((0 + 1) % num_steps_per_head) + first_step; // dot slice 0 { load(L_smem[toc], g.L_vec, {batch_idx, next_q_head_idx, 0, next_q_seq_idx}); G::load<1, false>(Q_i_smem[toc][0], g.Q, {batch_idx, next_q_seq_idx * 2, next_q_head_idx, 0}); // Load Q_i from shared memory to registers // load(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0})); Q_i_addr = get_address(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0})); load<0, 0>(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_addr); load<0, 1>(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_addr); load<0, 2>(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_addr); load<0, 3>(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_addr); load(subvec_inplace(L_smem[tic], 0)); load(subvec_inplace(delta_smem[tic], 0)); // Load K_j from shared memory to registers // load(K_j, subtile_inplace(K_j_smem, {warpid, 0})); load<0, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<0, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<0, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<0, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<1, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<1, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<1, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<1, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); asm volatile("s_waitcnt lgkmcnt(0)"); __builtin_amdgcn_s_barrier(); // 10. S_ij = Q_i K_j^T * scale // 11. P_ij = exp2(S_ij - L_i) // 13. dP_ij = dO_i @ V_j^T // 14. dS_ij = P_ij o (dP_ij - delta_i) // mma_ABt(P_ij, Q_i, K_j); mma_ABt<0, 0, 0>(P_ij, Q_i, K_j); load<2, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<2, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 0, 1>(P_ij, Q_i, K_j, P_ij); mul(L_SCALE_FACTOR); mma_ABt<0, 0, 2>(P_ij, Q_i, K_j, P_ij); load<2, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<2, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 0, 3>(P_ij, Q_i, K_j, P_ij); mma_ABt<0, 1, 0>(P_ij, Q_i, K_j); load<3, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<3, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 1, 1>(P_ij, Q_i, K_j, P_ij); mma_ABt<0, 1, 2>(P_ij, Q_i, K_j, P_ij); load<3, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<3, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 1, 3>(P_ij, Q_i, K_j, P_ij); mul<0, 0>(P_ij, P_ij, P_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(6)"); mma_ABt<0, 2, 0>(P_ij, Q_i, K_j); // Load dO_i from shared memory to registers // load(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0})); dO_i_addr = get_address(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0})); load<0, 0>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); load<0, 1>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); mma_ABt<0, 2, 1>(P_ij, Q_i, K_j, P_ij); sub_row<0, 0, L_i>(P_ij, P_ij); asm volatile("s_waitcnt lgkmcnt(6)"); mma_ABt<0, 2, 2>(P_ij, Q_i, K_j, P_ij); load<0, 2>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); load<0, 3>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); mma_ABt<0, 2, 3>(P_ij, Q_i, K_j, P_ij); mul<0, 1>(P_ij, P_ij, P_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(6)"); mma_ABt<0, 3, 0>(P_ij, Q_i, K_j); // Load dO_i_col from shared memory to registers // load(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0})); // Compute dO_i_col_addr // dO_i_col_addr = get_address(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0})); dO_i_col_addr = [&] { const int laneid = kittens::laneid(); const uint32_t src_ptr = reinterpret_cast(&subtile_inplace(dO_i_smem[tic][0], {0, 0}).data[0]); const int row_offset = (laneid % 16) / 4 + (laneid / 32) * 8; const int col_offset = ((laneid % 4) * 4) + 16*((laneid % 32)/16); const int lane_byte_offset = (row_offset * 32 + col_offset) * sizeof(bf16); const int swizzled_lane_byte_offset = lane_byte_offset ^ ((lane_byte_offset >> 9) << 5); const uint32_t addr = src_ptr + swizzled_lane_byte_offset; return addr; }(); load<0, 0>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); load<0, 1>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); mma_ABt<0, 3, 1>(P_ij, Q_i, K_j, P_ij); sub_row<0, 1, L_i>(P_ij, P_ij); asm volatile("s_waitcnt lgkmcnt(8)"); mma_ABt<0, 3, 2>(P_ij, Q_i, K_j, P_ij); load<0, 2>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); load<0, 3>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); mma_ABt<0, 3, 3>(P_ij, Q_i, K_j, P_ij); // Dot slice 0 kittens::macros::v_mov_b32(0xff800000); if constexpr (causal) { // If the query position is less than the key position, set P_ij to -inf if (q_pos < k_pos) { mov(P_ij); // If the query position is equal to the key position, we need to apply a causal mask } else if (q_pos == k_pos) { // Apply the causal mask to [0, 0] and set [0, 1:4] to -inf make_causal<0, 0, neg_inf_v>(P_ij, P_ij); mov<0, 1, neg_inf_v>(P_ij); mov<0, 2, neg_inf_v>(P_ij); mov<0, 3, neg_inf_v>(P_ij); } } mul<0, 2>(P_ij, P_ij, P_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(8)"); // mma_ABt(dP_ij, dO_i, V_j); mma_ABt<0, 0, 0>(dP_ij, dO_i, V_j); sub_row<0, 2, L_i>(P_ij, P_ij); mma_ABt<0, 0, 1>(dP_ij, dO_i, V_j, dP_ij); exp2<0, 0>(P_ij, P_ij); mma_ABt<0, 0, 2>(dP_ij, dO_i, V_j, dP_ij); // Load Q_i_col from shared memory to registers // load(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0})); // Compute Q_i_col_addr // uint32_t Q_i_col_addr = get_address(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0})); Q_i_col_addr = [&] { const int laneid = kittens::laneid(); const uint32_t src_ptr = reinterpret_cast(&subtile_inplace(Q_i_smem[tic][0], {0, 0}).data[0]); const int row_offset = (laneid % 16) / 4 + (laneid / 32) * 8; const int col_offset = ((laneid % 4) * 4) + 16*((laneid % 32)/16); const int lane_byte_offset = (row_offset * 32 + col_offset) * sizeof(bf16); const int swizzled_lane_byte_offset = lane_byte_offset ^ ((lane_byte_offset >> 9) << 5); const int addr = src_ptr + swizzled_lane_byte_offset; return addr; }(); load<0, 0>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 0, 3>(dP_ij, dO_i, V_j, dP_ij); exp2<0, 1>(P_ij, P_ij); mma_ABt<0, 1, 0>(dP_ij, dO_i, V_j); load<0, 1>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 1, 1>(dP_ij, dO_i, V_j, dP_ij); mul<0, 3>(P_ij, P_ij, P_SCALE_FACTOR); mma_ABt<0, 1, 2>(dP_ij, dO_i, V_j, dP_ij); sub_row<0, 3, L_i>(P_ij, P_ij); mma_ABt<0, 1, 3>(dP_ij, dO_i, V_j, dP_ij); copy<0, 0>(P_ij_bf16, P_ij); mma_ABt<0, 2, 0>(dP_ij, dO_i, V_j); exp2<0, 2>(P_ij, P_ij); mma_ABt<0, 2, 1>(dP_ij, dO_i, V_j, dP_ij); copy<0, 1>(P_ij_bf16, P_ij); mma_ABt<0, 2, 2>(dP_ij, dO_i, V_j, dP_ij); load<0, 2>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 2, 3>(dP_ij, dO_i, V_j, dP_ij); exp2<0, 3>(P_ij, P_ij); mma_ABt<0, 3, 0>(dP_ij, dO_i, V_j); load<0, 3>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 3, 1>(dP_ij, dO_i, V_j, dP_ij); copy<0, 2>(P_ij_bf16, P_ij); copy<0, 3>(P_ij_bf16, P_ij); mma_ABt<0, 3, 2>(dP_ij, dO_i, V_j, dP_ij); swap_layout_inplace(P_ij_bf16_col, P_ij_bf16); mma_ABt<0, 3, 3>(dP_ij, dO_i, V_j, dP_ij); asm volatile("s_waitcnt lgkmcnt(8)"); // mma_AtB(dV_j_T, dO_i_col, P_ij_bf16_col); mma_AtB<0, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col); // Load K_j_col from shared memory to registers // load(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid})); load<0, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<0, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col); sub_row<0, 0, delta_i>(dP_ij, dP_ij); sub_row<0, 1, delta_i>(dP_ij, dP_ij); mma_AtB<1, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col); load<1, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<1, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<1, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col); mul<0, 0>(dP_ij, dP_ij, P_ij); mul<0, 1>(dP_ij, dP_ij, P_ij); copy<0, 0>(dP_ij_bf16, dP_ij); copy<0, 1>(dP_ij_bf16, dP_ij); sub_row<0, 2, delta_i>(dP_ij, dP_ij); mma_AtB<2, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col); load<2, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); // 12. dV_j += P_ij^T @ dO_i // 16. dK_j += dS_ij^T @ Q_i (128x64)=(128x16)x(16x64) // Store dP_ij_bf16_accum_row to shared memory // store(attn_i_smem_subtile, dP_ij_bf16_accum_row); store<0, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); store<1, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); mma_AtB<2, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col); sub_row<0, 3, delta_i>(dP_ij, dP_ij); mul<0, 2>(dP_ij, dP_ij, P_ij); mul<0, 3>(dP_ij, dP_ij, P_ij); copy<0, 2>(dP_ij_bf16, dP_ij); copy<0, 3>(dP_ij_bf16, dP_ij); mma_AtB<3, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col); // dot slice 1 load(subvec_inplace(L_smem[tic], 1)); load(subvec_inplace(delta_smem[tic], 1)); store<2, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); store<3, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); mma_AtB<3, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col); swap_layout_inplace(dP_ij_bf16_col, dP_ij_bf16); asm volatile("s_waitcnt lgkmcnt(12)"); // mma_AtB(dK_j_T, Q_i_col, dP_ij_bf16_col); mma_AtB<0, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col); load<2, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<3, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<3, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<4, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col); asm volatile("s_waitcnt lgkmcnt(8)"); __builtin_amdgcn_s_barrier(); mma_AtB<1, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col); // Load dP_ij_bf16_col_T from shared memory to registers // load(dP_ij_bf16_col_T, attn_i_smem); load<0, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<1, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<2, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<3, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); mma_AtB<1, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col); mul(L_SCALE_FACTOR); mma_AtB<2, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col); load<4, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<5, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<4, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<5, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<2, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col); mma_AtB<3, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col); load<6, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<7, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<5, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<3, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col); asm volatile("s_waitcnt vmcnt(0) lgkmcnt(6)"); __builtin_amdgcn_s_barrier(); // 15. dQ_i += dS_ij @ K_j (32x16)=(32x256)x(256x16) // mma_AtB(dQ_i_T, K_j_col, dP_ij_bf16_col_T); mma_AtB<0, 0, 0>(dQ_i_T, K_j_col, dP_ij_bf16_col_T); load<6, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<6, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 0, 1>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); G::load<1, false>(dO_i_smem[toc][0], g.dOg, {batch_idx, next_q_seq_idx * 2, next_q_head_idx, 0}, swizzled_offsets_Q_dO); mma_AtB<0, 0, 2>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<7, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<7, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 0, 3>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load(delta_smem[toc], g.delta_vec, {batch_idx, next_q_head_idx, 0, next_q_seq_idx}); mma_AtB<0, 0, 4>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); // Load Q_i from shared memory to registers // load(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0})); Q_i_addr = get_address(Q_i, subtile_inplace(Q_i_smem[tic][0], {1, 0})); load<0, 0>(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_addr); load<0, 1>(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_addr); mma_AtB<0, 0, 5>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); asm volatile("s_waitcnt lgkmcnt(4)"); __builtin_amdgcn_s_barrier(); mma_AtB<0, 0, 6>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<0, 2>(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_addr); load<0, 3>(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_addr); mma_AtB<0, 0, 7>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mma_AtB<1, 0, 0>(dQ_i_T, K_j_col, dP_ij_bf16_col_T); // Load K_j from shared memory to registers // load(K_j, subtile_inplace(K_j_smem, {warpid, 0})); load<0, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<0, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 1>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mma_AtB<1, 0, 2>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<0, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<0, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 3>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mul<0, 0, 0>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<0, 0, 1>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mma_AtB<1, 0, 4>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<1, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<1, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 5>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mul<0, 0, 2>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<0, 0, 3>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(10)"); mma_AtB<1, 0, 6>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<1, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<1, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 7>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); asm volatile("s_waitcnt lgkmcnt(2)"); } // dot slice 1 { // 10. S_ij = Q_i K_j^T * scale // 11. P_ij = exp2(S_ij - L_i) // 13. dP_ij = dO_i @ V_j^T // 14. dS_ij = P_ij o (dP_ij - delta_i) // mma_ABt(P_ij, Q_i, K_j); mma_ABt<0, 0, 0>(P_ij, Q_i, K_j); load<2, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<2, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 0, 1>(P_ij, Q_i, K_j, P_ij); mma_ABt<0, 0, 2>(P_ij, Q_i, K_j, P_ij); load<2, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<2, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 0, 3>(P_ij, Q_i, K_j, P_ij); mul<1, 0, 0>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<1, 0, 1>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mma_ABt<0, 1, 0>(P_ij, Q_i, K_j); load<3, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<3, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 1, 1>(P_ij, Q_i, K_j, P_ij); mul<1, 0, 2>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<1, 0, 3>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mma_ABt<0, 1, 2>(P_ij, Q_i, K_j, P_ij); load<3, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<3, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 1, 3>(P_ij, Q_i, K_j, P_ij); mul<0, 0>(P_ij, P_ij, P_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(6)"); mma_ABt<0, 2, 0>(P_ij, Q_i, K_j); // Load dO_i from shared memory to registers // load(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0})); dO_i_addr = get_address(dO_i, subtile_inplace(dO_i_smem[tic][0], {1, 0})); load<0, 0>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); load<0, 1>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); mma_ABt<0, 2, 1>(P_ij, Q_i, K_j, P_ij); sub_row<0, 0, L_i>(P_ij, P_ij); asm volatile("s_waitcnt lgkmcnt(6)"); mma_ABt<0, 2, 2>(P_ij, Q_i, K_j, P_ij); load<0, 2>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); load<0, 3>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); mma_ABt<0, 2, 3>(P_ij, Q_i, K_j, P_ij); mul<0, 1>(P_ij, P_ij, P_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(6)"); mma_ABt<0, 3, 0>(P_ij, Q_i, K_j); // Load dO_i_col from shared memory to registers // load(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0})); // Compute dO_i_col_addr // uint32_t dO_i_col_addr = get_address(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0})); dO_i_col_addr = [&] { const int laneid = kittens::laneid(); const uint32_t src_ptr = reinterpret_cast(&subtile_inplace(dO_i_smem[tic][0], {1, 0}).data[0]); const int row_offset = (laneid % 16) / 4 + (laneid / 32) * 8; const int col_offset = ((laneid % 4) * 4) + 16*((laneid % 32)/16); const int lane_byte_offset = (row_offset * 32 + col_offset) * sizeof(bf16); const int swizzled_lane_byte_offset = lane_byte_offset ^ ((lane_byte_offset >> 9) << 5); const uint32_t addr = src_ptr + swizzled_lane_byte_offset; return addr; }(); load<0, 0>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); load<0, 1>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); mma_ABt<0, 3, 1>(P_ij, Q_i, K_j, P_ij); sub_row<0, 1, L_i>(P_ij, P_ij); asm volatile("s_waitcnt lgkmcnt(8)"); mma_ABt<0, 3, 2>(P_ij, Q_i, K_j, P_ij); load<0, 2>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); load<0, 3>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); mma_ABt<0, 3, 3>(P_ij, Q_i, K_j, P_ij); // Dot slice 1 kittens::macros::v_mov_b32(0xff800000); if constexpr (causal) { // If the query position is less than the key position, set P_ij to -inf if (q_pos < k_pos) { mov(P_ij); // If the query position is equal to the key position, we need to apply a causal mask } else if (q_pos == k_pos) { // Apply the causal mask to [0, 1] and set [0, 2:4] to -inf make_causal<0, 1, neg_inf_v>(P_ij, P_ij); mov<0, 2, neg_inf_v>(P_ij); mov<0, 3, neg_inf_v>(P_ij); } } mul<0, 2>(P_ij, P_ij, P_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(8)"); // mma_ABt(dP_ij, dO_i, V_j); mma_ABt<0, 0, 0>(dP_ij, dO_i, V_j); sub_row<0, 2, L_i>(P_ij, P_ij); mma_ABt<0, 0, 1>(dP_ij, dO_i, V_j, dP_ij); exp2<0, 0>(P_ij, P_ij); mma_ABt<0, 0, 2>(dP_ij, dO_i, V_j, dP_ij); // Load Q_i_col from shared memory to registers // load(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0})); // Compute Q_i_col_addr // uint32_t Q_i_col_addr = get_address(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0})); Q_i_col_addr = [&] { const int laneid = kittens::laneid(); const uint32_t src_ptr = reinterpret_cast(&subtile_inplace(Q_i_smem[tic][0], {1, 0}).data[0]); const int row_offset = (laneid % 16) / 4 + (laneid / 32) * 8; const int col_offset = ((laneid % 4) * 4) + 16*((laneid % 32)/16); const int lane_byte_offset = (row_offset * 32 + col_offset) * sizeof(bf16); const int swizzled_lane_byte_offset = lane_byte_offset ^ ((lane_byte_offset >> 9) << 5); const int addr = src_ptr + swizzled_lane_byte_offset; return addr; }(); load<0, 0>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 0, 3>(dP_ij, dO_i, V_j, dP_ij); exp2<0, 1>(P_ij, P_ij); mma_ABt<0, 1, 0>(dP_ij, dO_i, V_j); load<0, 1>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 1, 1>(dP_ij, dO_i, V_j, dP_ij); mul<0, 3>(P_ij, P_ij, P_SCALE_FACTOR); mma_ABt<0, 1, 2>(dP_ij, dO_i, V_j, dP_ij); sub_row<0, 3, L_i>(P_ij, P_ij); mma_ABt<0, 1, 3>(dP_ij, dO_i, V_j, dP_ij); copy<0, 0>(P_ij_bf16, P_ij); mma_ABt<0, 2, 0>(dP_ij, dO_i, V_j); exp2<0, 2>(P_ij, P_ij); mma_ABt<0, 2, 1>(dP_ij, dO_i, V_j, dP_ij); copy<0, 1>(P_ij_bf16, P_ij); mma_ABt<0, 2, 2>(dP_ij, dO_i, V_j, dP_ij); load<0, 2>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 2, 3>(dP_ij, dO_i, V_j, dP_ij); exp2<0, 3>(P_ij, P_ij); mma_ABt<0, 3, 0>(dP_ij, dO_i, V_j); load<0, 3>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 3, 1>(dP_ij, dO_i, V_j, dP_ij); copy<0, 2>(P_ij_bf16, P_ij); copy<0, 3>(P_ij_bf16, P_ij); mma_ABt<0, 3, 2>(dP_ij, dO_i, V_j, dP_ij); swap_layout_inplace(P_ij_bf16_col, P_ij_bf16); mma_ABt<0, 3, 3>(dP_ij, dO_i, V_j, dP_ij); asm volatile("s_waitcnt lgkmcnt(8)"); // mma_AtB(dV_j_T, dO_i_col, P_ij_bf16_col); mma_AtB<0, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); // Load K_j_col from shared memory to registers // load(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid})); load<0, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<0, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); sub_row<0, 0, delta_i>(dP_ij, dP_ij); sub_row<0, 1, delta_i>(dP_ij, dP_ij); mma_AtB<1, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); load<1, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<1, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<1, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); mul<0, 0>(dP_ij, dP_ij, P_ij); mul<0, 1>(dP_ij, dP_ij, P_ij); copy<0, 0>(dP_ij_bf16, dP_ij); copy<0, 1>(dP_ij_bf16, dP_ij); sub_row<0, 2, delta_i>(dP_ij, dP_ij); mma_AtB<2, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); load<2, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); // 12. dV_j += P_ij^T @ dO_i // 16. dK_j += dS_ij^T @ Q_i (128x64)=(128x16)x(16x64) // Store dP_ij_bf16_accum_row to shared memory // store(attn_i_smem_subtile, dP_ij_bf16_accum_row); store<0, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); store<1, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); mma_AtB<2, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); sub_row<0, 3, delta_i>(dP_ij, dP_ij); mul<0, 2>(dP_ij, dP_ij, P_ij); mul<0, 3>(dP_ij, dP_ij, P_ij); copy<0, 2>(dP_ij_bf16, dP_ij); copy<0, 3>(dP_ij_bf16, dP_ij); mma_AtB<3, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); // dot slice 2 load(subvec_inplace(L_smem[tic], 2)); load(subvec_inplace(delta_smem[tic], 2)); store<2, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); store<3, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); mma_AtB<3, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); swap_layout_inplace(dP_ij_bf16_col, dP_ij_bf16); asm volatile("s_waitcnt lgkmcnt(12)"); // mma_AtB(dK_j_T, Q_i_col, dP_ij_bf16_col); mma_AtB<0, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); load<2, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<3, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<3, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<4, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); asm volatile("s_waitcnt lgkmcnt(8)"); __builtin_amdgcn_s_barrier(); mma_AtB<1, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); // Load dP_ij_bf16_col_T from shared memory to registers // load(dP_ij_bf16_col_T, attn_i_smem); load<0, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<1, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<2, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<3, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); mma_AtB<1, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); mul(L_SCALE_FACTOR); atomic_pk_add_bf16_with_warpid<2, 0, 0>(g.dQg, dQ_i, {batch_idx, q_head_idx, q_seq_idx * 4, 0}, warpid); mma_AtB<2, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); load<4, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<5, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<4, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<5, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<2, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); atomic_pk_add_bf16_with_warpid<2, 0, 1>(g.dQg, dQ_i, {batch_idx, q_head_idx, q_seq_idx * 4, 0}, warpid); mma_AtB<3, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); load<6, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<7, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<5, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<3, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); asm volatile("s_waitcnt vmcnt(4) lgkmcnt(6)"); __builtin_amdgcn_s_barrier(); // 15. dQ_i += dS_ij @ K_j (32x16)=(32x256)x(256x16) // mma_AtB(dQ_i_T, K_j_col, dP_ij_bf16_col_T); mma_AtB<0, 0, 0>(dQ_i_T, K_j_col, dP_ij_bf16_col_T); load<6, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<6, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 0, 1>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); G::load<1, false>(Q_i_smem[toc][1], g.Q, {batch_idx, next_q_seq_idx * 2 + 1, next_q_head_idx, 0}, swizzled_offsets_Q_dO); mma_AtB<0, 0, 2>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<7, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<7, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 0, 3>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mma_AtB<0, 0, 4>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); // Load Q_i from shared memory to registers // load(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0})); Q_i_addr = get_address(Q_i, subtile_inplace(Q_i_smem[tic][1], {0, 0})); load<0, 0>(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_addr); load<0, 1>(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_addr); mma_AtB<0, 0, 5>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); asm volatile("s_waitcnt lgkmcnt(4)"); __builtin_amdgcn_s_barrier(); mma_AtB<0, 0, 6>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<0, 2>(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_addr); load<0, 3>(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_addr); mma_AtB<0, 0, 7>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mma_AtB<1, 0, 0>(dQ_i_T, K_j_col, dP_ij_bf16_col_T); // Load K_j from shared memory to registers // load(K_j, subtile_inplace(K_j_smem, {warpid, 0})); load<0, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<0, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 1>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mma_AtB<1, 0, 2>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<0, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<0, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 3>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mul<0, 0, 0>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<0, 0, 1>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mma_AtB<1, 0, 4>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<1, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<1, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 5>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mul<0, 0, 2>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<0, 0, 3>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(10)"); mma_AtB<1, 0, 6>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<1, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<1, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 7>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); asm volatile("s_waitcnt lgkmcnt(2)"); } // dot slice 2 { // 10. S_ij = Q_i K_j^T * scale // 11. P_ij = exp2(S_ij - L_i) // 13. dP_ij = dO_i @ V_j^T // 14. dS_ij = P_ij o (dP_ij - delta_i) // mma_ABt(P_ij, Q_i, K_j); mma_ABt<0, 0, 0>(P_ij, Q_i, K_j); load<2, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<2, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 0, 1>(P_ij, Q_i, K_j, P_ij); mma_ABt<0, 0, 2>(P_ij, Q_i, K_j, P_ij); load<2, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<2, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 0, 3>(P_ij, Q_i, K_j, P_ij); mul<1, 0, 0>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<1, 0, 1>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mma_ABt<0, 1, 0>(P_ij, Q_i, K_j); load<3, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<3, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 1, 1>(P_ij, Q_i, K_j, P_ij); mul<1, 0, 2>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<1, 0, 3>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mma_ABt<0, 1, 2>(P_ij, Q_i, K_j, P_ij); load<3, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<3, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 1, 3>(P_ij, Q_i, K_j, P_ij); mul<0, 0>(P_ij, P_ij, P_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(6)"); mma_ABt<0, 2, 0>(P_ij, Q_i, K_j); // Load dO_i from shared memory to registers // load(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0})); dO_i_addr = get_address(dO_i, subtile_inplace(dO_i_smem[tic][1], {0, 0})); load<0, 0>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); load<0, 1>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); mma_ABt<0, 2, 1>(P_ij, Q_i, K_j, P_ij); sub_row<0, 0, L_i>(P_ij, P_ij); asm volatile("s_waitcnt lgkmcnt(6)"); mma_ABt<0, 2, 2>(P_ij, Q_i, K_j, P_ij); load<0, 2>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); load<0, 3>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); mma_ABt<0, 2, 3>(P_ij, Q_i, K_j, P_ij); mul<0, 1>(P_ij, P_ij, P_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(6)"); mma_ABt<0, 3, 0>(P_ij, Q_i, K_j); // Load dO_i_col from shared memory to registers // load(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0})); // Compute dO_i_col_addr // uint32_t dO_i_col_addr = get_address(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0})); dO_i_col_addr = [&] { const int laneid = kittens::laneid(); const uint32_t src_ptr = reinterpret_cast(&subtile_inplace(dO_i_smem[tic][1], {0, 0}).data[0]); const int row_offset = (laneid % 16) / 4 + (laneid / 32) * 8; const int col_offset = ((laneid % 4) * 4) + 16*((laneid % 32)/16); const int lane_byte_offset = (row_offset * 32 + col_offset) * sizeof(bf16); const int swizzled_lane_byte_offset = lane_byte_offset ^ ((lane_byte_offset >> 9) << 5); const uint32_t addr = src_ptr + swizzled_lane_byte_offset; return addr; }(); load<0, 0>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); load<0, 1>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); mma_ABt<0, 3, 1>(P_ij, Q_i, K_j, P_ij); sub_row<0, 1, L_i>(P_ij, P_ij); asm volatile("s_waitcnt lgkmcnt(8)"); mma_ABt<0, 3, 2>(P_ij, Q_i, K_j, P_ij); load<0, 2>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); load<0, 3>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); mma_ABt<0, 3, 3>(P_ij, Q_i, K_j, P_ij); // Dot slice 2 kittens::macros::v_mov_b32(0xff800000); if constexpr (causal) { // If the query position is less than the key position, set P_ij to -inf if (q_pos < k_pos) { mov(P_ij); // If the query position is equal to the key position, we need to apply a causal mask } else if (q_pos == k_pos) { // Apply the causal mask to [0, 2] and set [0, 3:4] to -inf make_causal<0, 2, neg_inf_v>(P_ij, P_ij); mov<0, 3, neg_inf_v>(P_ij); } } mul<0, 2>(P_ij, P_ij, P_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(8)"); // mma_ABt(dP_ij, dO_i, V_j); mma_ABt<0, 0, 0>(dP_ij, dO_i, V_j); sub_row<0, 2, L_i>(P_ij, P_ij); mma_ABt<0, 0, 1>(dP_ij, dO_i, V_j, dP_ij); exp2<0, 0>(P_ij, P_ij); mma_ABt<0, 0, 2>(dP_ij, dO_i, V_j, dP_ij); // Load Q_i_col from shared memory to registers // load(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0})); // Compute Q_i_col_addr // uint32_t Q_i_col_addr = get_address(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0})); Q_i_col_addr = [&] { const int laneid = kittens::laneid(); const uint32_t src_ptr = reinterpret_cast(&subtile_inplace(Q_i_smem[tic][1], {0, 0}).data[0]); const int row_offset = (laneid % 16) / 4 + (laneid / 32) * 8; const int col_offset = ((laneid % 4) * 4) + 16*((laneid % 32)/16); const int lane_byte_offset = (row_offset * 32 + col_offset) * sizeof(bf16); const int swizzled_lane_byte_offset = lane_byte_offset ^ ((lane_byte_offset >> 9) << 5); const int addr = src_ptr + swizzled_lane_byte_offset; return addr; }(); load<0, 0>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 0, 3>(dP_ij, dO_i, V_j, dP_ij); exp2<0, 1>(P_ij, P_ij); mma_ABt<0, 1, 0>(dP_ij, dO_i, V_j); load<0, 1>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 1, 1>(dP_ij, dO_i, V_j, dP_ij); mul<0, 3>(P_ij, P_ij, P_SCALE_FACTOR); mma_ABt<0, 1, 2>(dP_ij, dO_i, V_j, dP_ij); sub_row<0, 3, L_i>(P_ij, P_ij); mma_ABt<0, 1, 3>(dP_ij, dO_i, V_j, dP_ij); copy<0, 0>(P_ij_bf16, P_ij); mma_ABt<0, 2, 0>(dP_ij, dO_i, V_j); exp2<0, 2>(P_ij, P_ij); mma_ABt<0, 2, 1>(dP_ij, dO_i, V_j, dP_ij); copy<0, 1>(P_ij_bf16, P_ij); mma_ABt<0, 2, 2>(dP_ij, dO_i, V_j, dP_ij); load<0, 2>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 2, 3>(dP_ij, dO_i, V_j, dP_ij); exp2<0, 3>(P_ij, P_ij); mma_ABt<0, 3, 0>(dP_ij, dO_i, V_j); load<0, 3>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 3, 1>(dP_ij, dO_i, V_j, dP_ij); copy<0, 2>(P_ij_bf16, P_ij); copy<0, 3>(P_ij_bf16, P_ij); mma_ABt<0, 3, 2>(dP_ij, dO_i, V_j, dP_ij); swap_layout_inplace(P_ij_bf16_col, P_ij_bf16); mma_ABt<0, 3, 3>(dP_ij, dO_i, V_j, dP_ij); asm volatile("s_waitcnt lgkmcnt(8)"); // mma_AtB(dV_j_T, dO_i_col, P_ij_bf16_col); mma_AtB<0, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); // Load K_j_col from shared memory to registers // load(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid})); load<0, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<0, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); sub_row<0, 0, delta_i>(dP_ij, dP_ij); sub_row<0, 1, delta_i>(dP_ij, dP_ij); mma_AtB<1, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); load<1, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<1, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<1, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); mul<0, 0>(dP_ij, dP_ij, P_ij); mul<0, 1>(dP_ij, dP_ij, P_ij); copy<0, 0>(dP_ij_bf16, dP_ij); copy<0, 1>(dP_ij_bf16, dP_ij); sub_row<0, 2, delta_i>(dP_ij, dP_ij); mma_AtB<2, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); load<2, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); // 12. dV_j += P_ij^T @ dO_i // 16. dK_j += dS_ij^T @ Q_i (128x64)=(128x16)x(16x64) // Store dP_ij_bf16_accum_row to shared memory // store(attn_i_smem_subtile, dP_ij_bf16_accum_row); store<0, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); store<1, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); mma_AtB<2, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); sub_row<0, 3, delta_i>(dP_ij, dP_ij); mul<0, 2>(dP_ij, dP_ij, P_ij); mul<0, 3>(dP_ij, dP_ij, P_ij); copy<0, 2>(dP_ij_bf16, dP_ij); copy<0, 3>(dP_ij_bf16, dP_ij); mma_AtB<3, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); // dot slice 3 load(subvec_inplace(L_smem[tic], 3)); load(subvec_inplace(delta_smem[tic], 3)); store<2, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); store<3, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); mma_AtB<3, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); swap_layout_inplace(dP_ij_bf16_col, dP_ij_bf16); asm volatile("s_waitcnt lgkmcnt(12)"); // mma_AtB(dK_j_T, Q_i_col, dP_ij_bf16_col); mma_AtB<0, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); load<2, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<3, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<3, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<4, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); asm volatile("s_waitcnt lgkmcnt(8)"); __builtin_amdgcn_s_barrier(); mma_AtB<1, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); // Load dP_ij_bf16_col_T from shared memory to registers // load(dP_ij_bf16_col_T, attn_i_smem); load<0, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<1, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<2, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<3, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); mma_AtB<1, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); mul(L_SCALE_FACTOR); atomic_pk_add_bf16_with_warpid<2, 0, 0>(g.dQg, dQ_i, {batch_idx, q_head_idx, q_seq_idx * 4 + 1, 0}, warpid); mma_AtB<2, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); load<4, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<5, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<4, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<5, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<2, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); atomic_pk_add_bf16_with_warpid<2, 0, 1>(g.dQg, dQ_i, {batch_idx, q_head_idx, q_seq_idx * 4 + 1, 0}, warpid); mma_AtB<3, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); load<6, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<7, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<5, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<3, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); asm volatile("s_waitcnt vmcnt(4) lgkmcnt(6)"); __builtin_amdgcn_s_barrier(); // 15. dQ_i += dS_ij @ K_j (32x16)=(32x256)x(256x16) // mma_AtB(dQ_i_T, K_j_col, dP_ij_bf16_col_T); mma_AtB<0, 0, 0>(dQ_i_T, K_j_col, dP_ij_bf16_col_T); load<6, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<6, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 0, 1>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); G::load<1, false>(dO_i_smem[toc][1], g.dOg, {batch_idx, next_q_seq_idx * 2 + 1, next_q_head_idx, 0}, swizzled_offsets_Q_dO); mma_AtB<0, 0, 2>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<7, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<7, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 0, 3>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mma_AtB<0, 0, 4>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); // Load Q_i from shared memory to registers // load(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0})); Q_i_addr = get_address(Q_i, subtile_inplace(Q_i_smem[tic][1], {1, 0})); load<0, 0>(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_addr); load<0, 1>(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_addr); mma_AtB<0, 0, 5>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); asm volatile("s_waitcnt lgkmcnt(4)"); __builtin_amdgcn_s_barrier(); mma_AtB<0, 0, 6>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<0, 2>(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_addr); load<0, 3>(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_addr); mma_AtB<0, 0, 7>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mma_AtB<1, 0, 0>(dQ_i_T, K_j_col, dP_ij_bf16_col_T); // Load K_j from shared memory to registers // load(K_j, subtile_inplace(K_j_smem, {warpid, 0})); load<0, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<0, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 1>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mma_AtB<1, 0, 2>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<0, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<0, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 3>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mul<0, 0, 0>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<0, 0, 1>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mma_AtB<1, 0, 4>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<1, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<1, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 5>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mul<0, 0, 2>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<0, 0, 3>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(10)"); mma_AtB<1, 0, 6>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<1, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<1, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 7>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); asm volatile("s_waitcnt lgkmcnt(2)"); } // dot slice 3 { // 10. S_ij = Q_i K_j^T * scale // 11. P_ij = exp2(S_ij - L_i) // 13. dP_ij = dO_i @ V_j^T // 14. dS_ij = P_ij o (dP_ij - delta_i) // mma_ABt(P_ij, Q_i, K_j); mma_ABt<0, 0, 0>(P_ij, Q_i, K_j); load<2, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<2, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 0, 1>(P_ij, Q_i, K_j, P_ij); mma_ABt<0, 0, 2>(P_ij, Q_i, K_j, P_ij); load<2, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<2, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 0, 3>(P_ij, Q_i, K_j, P_ij); mul<1, 0, 0>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<1, 0, 1>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mma_ABt<0, 1, 0>(P_ij, Q_i, K_j); load<3, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<3, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 1, 1>(P_ij, Q_i, K_j, P_ij); mul<1, 0, 2>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<1, 0, 3>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mma_ABt<0, 1, 2>(P_ij, Q_i, K_j, P_ij); load<3, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<3, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 1, 3>(P_ij, Q_i, K_j, P_ij); mul<0, 0>(P_ij, P_ij, P_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(6)"); mma_ABt<0, 2, 0>(P_ij, Q_i, K_j); // Load dO_i from shared memory to registers // load(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0})); dO_i_addr = get_address(dO_i, subtile_inplace(dO_i_smem[tic][1], {1, 0})); load<0, 0>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); load<0, 1>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); mma_ABt<0, 2, 1>(P_ij, Q_i, K_j, P_ij); sub_row<0, 0, L_i>(P_ij, P_ij); asm volatile("s_waitcnt lgkmcnt(6)"); mma_ABt<0, 2, 2>(P_ij, Q_i, K_j, P_ij); load<0, 2>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); load<0, 3>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); mma_ABt<0, 2, 3>(P_ij, Q_i, K_j, P_ij); mul<0, 1>(P_ij, P_ij, P_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(6)"); mma_ABt<0, 3, 0>(P_ij, Q_i, K_j); // Load dO_i_col from shared memory to registers // load(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0})); // Compute dO_i_col_addr // uint32_t dO_i_col_addr = get_address(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0})); dO_i_col_addr = [&] { const int laneid = kittens::laneid(); const uint32_t src_ptr = reinterpret_cast(&subtile_inplace(dO_i_smem[tic][1], {1, 0}).data[0]); const int row_offset = (laneid % 16) / 4 + (laneid / 32) * 8; const int col_offset = ((laneid % 4) * 4) + 16*((laneid % 32)/16); const int lane_byte_offset = (row_offset * 32 + col_offset) * sizeof(bf16); const int swizzled_lane_byte_offset = lane_byte_offset ^ ((lane_byte_offset >> 9) << 5); const uint32_t addr = src_ptr + swizzled_lane_byte_offset; return addr; }(); load<0, 0>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); load<0, 1>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); mma_ABt<0, 3, 1>(P_ij, Q_i, K_j, P_ij); sub_row<0, 1, L_i>(P_ij, P_ij); asm volatile("s_waitcnt lgkmcnt(8)"); mma_ABt<0, 3, 2>(P_ij, Q_i, K_j, P_ij); load<0, 2>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); load<0, 3>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); mma_ABt<0, 3, 3>(P_ij, Q_i, K_j, P_ij); // Dot slice 3 kittens::macros::v_mov_b32(0xff800000); if constexpr (causal) { // If the query position is less than the key position, set P_ij to -inf if (q_pos < k_pos) { mov(P_ij); // If the query position is equal to the key position, we need to apply a causal mask } else if (q_pos == k_pos) { // Apply the causal mask to [0, 3] make_causal<0, 3, neg_inf_v>(P_ij, P_ij); } } mul<0, 2>(P_ij, P_ij, P_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(8)"); // mma_ABt(dP_ij, dO_i, V_j); mma_ABt<0, 0, 0>(dP_ij, dO_i, V_j); sub_row<0, 2, L_i>(P_ij, P_ij); mma_ABt<0, 0, 1>(dP_ij, dO_i, V_j, dP_ij); exp2<0, 0>(P_ij, P_ij); mma_ABt<0, 0, 2>(dP_ij, dO_i, V_j, dP_ij); // Load Q_i_col from shared memory to registers // load(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0})); // Compute Q_i_col_addr // uint32_t Q_i_col_addr = get_address(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0})); Q_i_col_addr = [&] { const int laneid = kittens::laneid(); const uint32_t src_ptr = reinterpret_cast(&subtile_inplace(Q_i_smem[tic][1], {1, 0}).data[0]); const int row_offset = (laneid % 16) / 4 + (laneid / 32) * 8; const int col_offset = ((laneid % 4) * 4) + 16*((laneid % 32)/16); const int lane_byte_offset = (row_offset * 32 + col_offset) * sizeof(bf16); const int swizzled_lane_byte_offset = lane_byte_offset ^ ((lane_byte_offset >> 9) << 5); const int addr = src_ptr + swizzled_lane_byte_offset; return addr; }(); load<0, 0>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 0, 3>(dP_ij, dO_i, V_j, dP_ij); exp2<0, 1>(P_ij, P_ij); mma_ABt<0, 1, 0>(dP_ij, dO_i, V_j); load<0, 1>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 1, 1>(dP_ij, dO_i, V_j, dP_ij); mul<0, 3>(P_ij, P_ij, P_SCALE_FACTOR); mma_ABt<0, 1, 2>(dP_ij, dO_i, V_j, dP_ij); sub_row<0, 3, L_i>(P_ij, P_ij); mma_ABt<0, 1, 3>(dP_ij, dO_i, V_j, dP_ij); copy<0, 0>(P_ij_bf16, P_ij); mma_ABt<0, 2, 0>(dP_ij, dO_i, V_j); exp2<0, 2>(P_ij, P_ij); mma_ABt<0, 2, 1>(dP_ij, dO_i, V_j, dP_ij); copy<0, 1>(P_ij_bf16, P_ij); mma_ABt<0, 2, 2>(dP_ij, dO_i, V_j, dP_ij); load<0, 2>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 2, 3>(dP_ij, dO_i, V_j, dP_ij); exp2<0, 3>(P_ij, P_ij); mma_ABt<0, 3, 0>(dP_ij, dO_i, V_j); load<0, 3>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 3, 1>(dP_ij, dO_i, V_j, dP_ij); copy<0, 2>(P_ij_bf16, P_ij); copy<0, 3>(P_ij_bf16, P_ij); mma_ABt<0, 3, 2>(dP_ij, dO_i, V_j, dP_ij); swap_layout_inplace(P_ij_bf16_col, P_ij_bf16); mma_ABt<0, 3, 3>(dP_ij, dO_i, V_j, dP_ij); asm volatile("s_waitcnt lgkmcnt(8)"); // mma_AtB(dV_j_T, dO_i_col, P_ij_bf16_col); mma_AtB<0, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); // Load K_j_col from shared memory to registers // load(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid})); load<0, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<0, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); sub_row<0, 0, delta_i>(dP_ij, dP_ij); sub_row<0, 1, delta_i>(dP_ij, dP_ij); mma_AtB<1, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); load<1, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<1, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<1, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); mul<0, 0>(dP_ij, dP_ij, P_ij); mul<0, 1>(dP_ij, dP_ij, P_ij); copy<0, 0>(dP_ij_bf16, dP_ij); copy<0, 1>(dP_ij_bf16, dP_ij); sub_row<0, 2, delta_i>(dP_ij, dP_ij); mma_AtB<2, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); load<2, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); // 12. dV_j += P_ij^T @ dO_i // 16. dK_j += dS_ij^T @ Q_i (128x64)=(128x16)x(16x64) // Store dP_ij_bf16_accum_row to shared memory // store(attn_i_smem_subtile, dP_ij_bf16_accum_row); store<0, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); store<1, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); mma_AtB<2, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); sub_row<0, 3, delta_i>(dP_ij, dP_ij); mul<0, 2>(dP_ij, dP_ij, P_ij); mul<0, 3>(dP_ij, dP_ij, P_ij); copy<0, 2>(dP_ij_bf16, dP_ij); copy<0, 3>(dP_ij_bf16, dP_ij); mma_AtB<3, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); // dot slice 0 - next iteration load(subvec_inplace(L_smem[toc], 0)); load(subvec_inplace(delta_smem[toc], 0)); store<2, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); store<3, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); mma_AtB<3, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); swap_layout_inplace(dP_ij_bf16_col, dP_ij_bf16); asm volatile("s_waitcnt lgkmcnt(12)"); // mma_AtB(dK_j_T, Q_i_col, dP_ij_bf16_col); mma_AtB<0, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); load<2, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<3, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<3, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<4, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); asm volatile("s_waitcnt lgkmcnt(8)"); __builtin_amdgcn_s_barrier(); mma_AtB<1, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); // Load dP_ij_bf16_col_T from shared memory to registers // load(dP_ij_bf16_col_T, attn_i_smem); load<0, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<1, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<2, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<3, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); mma_AtB<1, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); mul(L_SCALE_FACTOR); atomic_pk_add_bf16_with_warpid<2, 0, 0>(g.dQg, dQ_i, {batch_idx, q_head_idx, q_seq_idx * 4 + 2, 0}, warpid); mma_AtB<2, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); load<4, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<5, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<4, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<5, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<2, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); atomic_pk_add_bf16_with_warpid<2, 0, 1>(g.dQg, dQ_i, {batch_idx, q_head_idx, q_seq_idx * 4 + 2, 0}, warpid); mma_AtB<3, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); load<6, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<7, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<5, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<3, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); asm volatile("s_waitcnt vmcnt(4) lgkmcnt(6)"); __builtin_amdgcn_s_barrier(); // 15. dQ_i += dS_ij @ K_j (32x16)=(32x256)x(256x16) // mma_AtB(dQ_i_T, K_j_col, dP_ij_bf16_col_T); mma_AtB<0, 0, 0>(dQ_i_T, K_j_col, dP_ij_bf16_col_T); load<6, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<6, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 0, 1>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mma_AtB<0, 0, 2>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<7, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<7, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 0, 3>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mma_AtB<0, 0, 4>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); // Load Q_i from shared memory to registers // load(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0})); Q_i_addr = get_address(Q_i, subtile_inplace(Q_i_smem[toc][0], {0, 0})); load<0, 0>(Q_i, subtile_inplace(Q_i_smem[toc][0], {0, 0}), Q_i_addr); load<0, 1>(Q_i, subtile_inplace(Q_i_smem[toc][0], {0, 0}), Q_i_addr); mma_AtB<0, 0, 5>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); asm volatile("s_waitcnt lgkmcnt(4)"); __builtin_amdgcn_s_barrier(); mma_AtB<0, 0, 6>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<0, 2>(Q_i, subtile_inplace(Q_i_smem[toc][0], {0, 0}), Q_i_addr); load<0, 3>(Q_i, subtile_inplace(Q_i_smem[toc][0], {0, 0}), Q_i_addr); mma_AtB<0, 0, 7>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mma_AtB<1, 0, 0>(dQ_i_T, K_j_col, dP_ij_bf16_col_T); // Load K_j from shared memory to registers // load(K_j, subtile_inplace(K_j_smem, {warpid, 0})); load<0, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<0, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 1>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mma_AtB<1, 0, 2>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<0, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<0, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 3>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mul<0, 0, 0>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<0, 0, 1>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mma_AtB<1, 0, 4>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<1, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<1, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 5>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mul<0, 0, 2>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<0, 0, 3>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(10)"); mma_AtB<1, 0, 6>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<1, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<1, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 7>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); asm volatile("s_waitcnt lgkmcnt(2)"); } tic ^= 1; toc ^= 1; } // 9. for 1 <= i <= T_r (1024 / 32 = 32) for (int i = 1; i < num_steps - 1; ++i, tic ^= 1, toc ^= 1) { const int last_q_head_idx = (i - 1) / num_steps_per_head + first_q_head; const int last_q_seq_idx = ((i - 1) % num_steps_per_head) + first_step; const int q_head_idx = i / num_steps_per_head + first_q_head; const int q_seq_idx = (i % num_steps_per_head) + first_step; const int q_pos = q_seq_idx * STEP_QO; const int next_q_head_idx = (i + 1) / num_steps_per_head + first_q_head; const int next_q_seq_idx = ((i + 1) % num_steps_per_head) + first_step; // dot slice 0 { // 10. S_ij = Q_i K_j^T * scale // 11. P_ij = exp2(S_ij - L_i) // 13. dP_ij = dO_i @ V_j^T // 14. dS_ij = P_ij o (dP_ij - delta_i) // mma_ABt(P_ij, Q_i, K_j); G::load<1, false>(Q_i_smem[toc][0], g.Q, {batch_idx, next_q_seq_idx * 2, next_q_head_idx, 0}, swizzled_offsets_Q_dO); mma_ABt<0, 0, 0>(P_ij, Q_i, K_j); load<2, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<2, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 0, 1>(P_ij, Q_i, K_j, P_ij); load(L_smem[toc], g.L_vec, {batch_idx, next_q_head_idx, 0, next_q_seq_idx}); mma_ABt<0, 0, 2>(P_ij, Q_i, K_j, P_ij); load<2, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<2, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 0, 3>(P_ij, Q_i, K_j, P_ij); mul<1, 0, 0>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<1, 0, 1>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mma_ABt<0, 1, 0>(P_ij, Q_i, K_j); load<3, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<3, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 1, 1>(P_ij, Q_i, K_j, P_ij); mul<1, 0, 2>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<1, 0, 3>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mma_ABt<0, 1, 2>(P_ij, Q_i, K_j, P_ij); load<3, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<3, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 1, 3>(P_ij, Q_i, K_j, P_ij); mul<0, 0>(P_ij, P_ij, P_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(6)"); mma_ABt<0, 2, 0>(P_ij, Q_i, K_j); // Load dO_i from shared memory to registers // load(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0})); dO_i_addr = get_address(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0})); load<0, 0>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); load<0, 1>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); mma_ABt<0, 2, 1>(P_ij, Q_i, K_j, P_ij); sub_row<0, 0, L_i>(P_ij, P_ij); asm volatile("s_waitcnt lgkmcnt(6)"); mma_ABt<0, 2, 2>(P_ij, Q_i, K_j, P_ij); load<0, 2>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); load<0, 3>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); mma_ABt<0, 2, 3>(P_ij, Q_i, K_j, P_ij); mul<0, 1>(P_ij, P_ij, P_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(6)"); mma_ABt<0, 3, 0>(P_ij, Q_i, K_j); // Load dO_i_col from shared memory to registers // load(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0})); // Compute dO_i_col_addr // uint32_t dO_i_col_addr = get_address(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0})); dO_i_col_addr = [&] { const int laneid = kittens::laneid(); const uint32_t src_ptr = reinterpret_cast(&subtile_inplace(dO_i_smem[tic][0], {0, 0}).data[0]); const int row_offset = (laneid % 16) / 4 + (laneid / 32) * 8; const int col_offset = ((laneid % 4) * 4) + 16*((laneid % 32)/16); const int lane_byte_offset = (row_offset * 32 + col_offset) * sizeof(bf16); const int swizzled_lane_byte_offset = lane_byte_offset ^ ((lane_byte_offset >> 9) << 5); const uint32_t addr = src_ptr + swizzled_lane_byte_offset; return addr; }(); load<0, 0>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); load<0, 1>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); mma_ABt<0, 3, 1>(P_ij, Q_i, K_j, P_ij); sub_row<0, 1, L_i>(P_ij, P_ij); asm volatile("s_waitcnt lgkmcnt(8)"); mma_ABt<0, 3, 2>(P_ij, Q_i, K_j, P_ij); load<0, 2>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); load<0, 3>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); mma_ABt<0, 3, 3>(P_ij, Q_i, K_j, P_ij); // Dot slice 0 kittens::macros::v_mov_b32(0xff800000); if constexpr (causal) { // If the query position is less than the key position, set P_ij to -inf if (q_pos < k_pos) { mov(P_ij); // If the query position is equal to the key position, we need to apply a causal mask } else if (q_pos == k_pos) { // Apply the causal mask to [0, 0] and set [0, 1:4] to -inf make_causal<0, 0, neg_inf_v>(P_ij, P_ij); mov<0, 1, neg_inf_v>(P_ij); mov<0, 2, neg_inf_v>(P_ij); mov<0, 3, neg_inf_v>(P_ij); } } mul<0, 2>(P_ij, P_ij, P_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(8)"); // mma_ABt(dP_ij, dO_i, V_j); mma_ABt<0, 0, 0>(dP_ij, dO_i, V_j); sub_row<0, 2, L_i>(P_ij, P_ij); mma_ABt<0, 0, 1>(dP_ij, dO_i, V_j, dP_ij); exp2<0, 0>(P_ij, P_ij); mma_ABt<0, 0, 2>(dP_ij, dO_i, V_j, dP_ij); // Load Q_i_col from shared memory to registers // load(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0})); // Compute Q_i_col_addr // uint32_t Q_i_col_addr = get_address(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0})); Q_i_col_addr = [&] { const int laneid = kittens::laneid(); const uint32_t src_ptr = reinterpret_cast(&subtile_inplace(Q_i_smem[tic][0], {0, 0}).data[0]); const int row_offset = (laneid % 16) / 4 + (laneid / 32) * 8; const int col_offset = ((laneid % 4) * 4) + 16*((laneid % 32)/16); const int lane_byte_offset = (row_offset * 32 + col_offset) * sizeof(bf16); const int swizzled_lane_byte_offset = lane_byte_offset ^ ((lane_byte_offset >> 9) << 5); const int addr = src_ptr + swizzled_lane_byte_offset; return addr; }(); load<0, 0>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 0, 3>(dP_ij, dO_i, V_j, dP_ij); exp2<0, 1>(P_ij, P_ij); mma_ABt<0, 1, 0>(dP_ij, dO_i, V_j); load<0, 1>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 1, 1>(dP_ij, dO_i, V_j, dP_ij); mul<0, 3>(P_ij, P_ij, P_SCALE_FACTOR); mma_ABt<0, 1, 2>(dP_ij, dO_i, V_j, dP_ij); sub_row<0, 3, L_i>(P_ij, P_ij); mma_ABt<0, 1, 3>(dP_ij, dO_i, V_j, dP_ij); copy<0, 0>(P_ij_bf16, P_ij); mma_ABt<0, 2, 0>(dP_ij, dO_i, V_j); exp2<0, 2>(P_ij, P_ij); mma_ABt<0, 2, 1>(dP_ij, dO_i, V_j, dP_ij); copy<0, 1>(P_ij_bf16, P_ij); mma_ABt<0, 2, 2>(dP_ij, dO_i, V_j, dP_ij); load<0, 2>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 2, 3>(dP_ij, dO_i, V_j, dP_ij); exp2<0, 3>(P_ij, P_ij); mma_ABt<0, 3, 0>(dP_ij, dO_i, V_j); load<0, 3>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 3, 1>(dP_ij, dO_i, V_j, dP_ij); copy<0, 2>(P_ij_bf16, P_ij); copy<0, 3>(P_ij_bf16, P_ij); mma_ABt<0, 3, 2>(dP_ij, dO_i, V_j, dP_ij); swap_layout_inplace(P_ij_bf16_col, P_ij_bf16); mma_ABt<0, 3, 3>(dP_ij, dO_i, V_j, dP_ij); asm volatile("s_waitcnt lgkmcnt(8)"); // mma_AtB(dV_j_T, dO_i_col, P_ij_bf16_col); mma_AtB<0, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); // Load K_j_col from shared memory to registers // load(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid})); load<0, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<0, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); sub_row<0, 0, delta_i>(dP_ij, dP_ij); sub_row<0, 1, delta_i>(dP_ij, dP_ij); mma_AtB<1, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); load<1, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<1, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<1, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); mul<0, 0>(dP_ij, dP_ij, P_ij); mul<0, 1>(dP_ij, dP_ij, P_ij); copy<0, 0>(dP_ij_bf16, dP_ij); copy<0, 1>(dP_ij_bf16, dP_ij); sub_row<0, 2, delta_i>(dP_ij, dP_ij); mma_AtB<2, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); load<2, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); // 12. dV_j += P_ij^T @ dO_i // 16. dK_j += dS_ij^T @ Q_i (128x64)=(128x16)x(16x64) // Store dP_ij_bf16_accum_row to shared memory // store(attn_i_smem_subtile, dP_ij_bf16_accum_row); store<0, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); store<1, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); mma_AtB<2, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); sub_row<0, 3, delta_i>(dP_ij, dP_ij); mul<0, 2>(dP_ij, dP_ij, P_ij); mul<0, 3>(dP_ij, dP_ij, P_ij); copy<0, 2>(dP_ij_bf16, dP_ij); copy<0, 3>(dP_ij_bf16, dP_ij); mma_AtB<3, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); // dot slice 1 load(subvec_inplace(L_smem[tic], 1)); load(subvec_inplace(delta_smem[tic], 1)); store<2, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); store<3, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); mma_AtB<3, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); swap_layout_inplace(dP_ij_bf16_col, dP_ij_bf16); asm volatile("s_waitcnt lgkmcnt(12)"); // mma_AtB(dK_j_T, Q_i_col, dP_ij_bf16_col); mma_AtB<0, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); load<2, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<3, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<3, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<4, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); asm volatile("s_waitcnt lgkmcnt(8)"); __builtin_amdgcn_s_barrier(); mma_AtB<1, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); // Load dP_ij_bf16_col_T from shared memory to registers // load(dP_ij_bf16_col_T, attn_i_smem); load<0, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<1, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<2, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<3, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); mma_AtB<1, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); mul(L_SCALE_FACTOR); atomic_pk_add_bf16_with_warpid<2, 0, 0>(g.dQg, dQ_i, {batch_idx, last_q_head_idx, last_q_seq_idx * 4 + 3, 0}, warpid); mma_AtB<2, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); load<4, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<5, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<4, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<5, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<2, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); atomic_pk_add_bf16_with_warpid<2, 0, 1>(g.dQg, dQ_i, {batch_idx, last_q_head_idx, last_q_seq_idx * 4 + 3, 0}, warpid); mma_AtB<3, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); load<6, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<7, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<5, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<3, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); asm volatile("s_waitcnt vmcnt(4) lgkmcnt(6)"); __builtin_amdgcn_s_barrier(); // 15. dQ_i += dS_ij @ K_j (32x16)=(32x256)x(256x16) // mma_AtB(dQ_i_T, K_j_col, dP_ij_bf16_col_T); mma_AtB<0, 0, 0>(dQ_i_T, K_j_col, dP_ij_bf16_col_T); load<6, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<6, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 0, 1>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); G::load<1, false>(dO_i_smem[toc][0], g.dOg, {batch_idx, next_q_seq_idx * 2, next_q_head_idx, 0}, swizzled_offsets_Q_dO); mma_AtB<0, 0, 2>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<7, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<7, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 0, 3>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load(delta_smem[toc], g.delta_vec, {batch_idx, next_q_head_idx, 0, next_q_seq_idx}); mma_AtB<0, 0, 4>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); // Load Q_i from shared memory to registers // load(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0})); Q_i_addr = get_address(Q_i, subtile_inplace(Q_i_smem[tic][0], {1, 0})); load<0, 0>(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_addr); load<0, 1>(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_addr); mma_AtB<0, 0, 5>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); asm volatile("s_waitcnt lgkmcnt(4)"); __builtin_amdgcn_s_barrier(); mma_AtB<0, 0, 6>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<0, 2>(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_addr); load<0, 3>(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_addr); mma_AtB<0, 0, 7>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mma_AtB<1, 0, 0>(dQ_i_T, K_j_col, dP_ij_bf16_col_T); // Load K_j from shared memory to registers // load(K_j, subtile_inplace(K_j_smem, {warpid, 0})); load<0, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<0, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 1>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mma_AtB<1, 0, 2>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<0, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<0, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 3>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mul<0, 0, 0>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<0, 0, 1>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mma_AtB<1, 0, 4>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<1, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<1, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 5>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mul<0, 0, 2>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<0, 0, 3>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(10)"); mma_AtB<1, 0, 6>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<1, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<1, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 7>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); asm volatile("s_waitcnt lgkmcnt(2)"); } // dot slice 1 { // 10. S_ij = Q_i K_j^T * scale // 11. P_ij = exp2(S_ij - L_i) // 13. dP_ij = dO_i @ V_j^T // 14. dS_ij = P_ij o (dP_ij - delta_i) // mma_ABt(P_ij, Q_i, K_j); mma_ABt<0, 0, 0>(P_ij, Q_i, K_j); load<2, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<2, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 0, 1>(P_ij, Q_i, K_j, P_ij); mma_ABt<0, 0, 2>(P_ij, Q_i, K_j, P_ij); load<2, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<2, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 0, 3>(P_ij, Q_i, K_j, P_ij); mul<1, 0, 0>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<1, 0, 1>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mma_ABt<0, 1, 0>(P_ij, Q_i, K_j); load<3, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<3, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 1, 1>(P_ij, Q_i, K_j, P_ij); mul<1, 0, 2>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<1, 0, 3>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(6)"); mma_ABt<0, 1, 2>(P_ij, Q_i, K_j, P_ij); load<3, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<3, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 1, 3>(P_ij, Q_i, K_j, P_ij); mul<0, 0>(P_ij, P_ij, P_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(6)"); mma_ABt<0, 2, 0>(P_ij, Q_i, K_j); // Load dO_i from shared memory to registers // load(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0})); dO_i_addr = get_address(dO_i, subtile_inplace(dO_i_smem[tic][0], {1, 0})); load<0, 0>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); load<0, 1>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); mma_ABt<0, 2, 1>(P_ij, Q_i, K_j, P_ij); sub_row<0, 0, L_i>(P_ij, P_ij); asm volatile("s_waitcnt lgkmcnt(6)"); mma_ABt<0, 2, 2>(P_ij, Q_i, K_j, P_ij); load<0, 2>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); load<0, 3>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); mma_ABt<0, 2, 3>(P_ij, Q_i, K_j, P_ij); mul<0, 1>(P_ij, P_ij, P_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(6)"); mma_ABt<0, 3, 0>(P_ij, Q_i, K_j); // Load dO_i_col from shared memory to registers // load(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0})); // Compute dO_i_col_addr // uint32_t dO_i_col_addr = get_address(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0})); dO_i_col_addr = [&] { const int laneid = kittens::laneid(); const uint32_t src_ptr = reinterpret_cast(&subtile_inplace(dO_i_smem[tic][0], {1, 0}).data[0]); const int row_offset = (laneid % 16) / 4 + (laneid / 32) * 8; const int col_offset = ((laneid % 4) * 4) + 16*((laneid % 32)/16); const int lane_byte_offset = (row_offset * 32 + col_offset) * sizeof(bf16); const int swizzled_lane_byte_offset = lane_byte_offset ^ ((lane_byte_offset >> 9) << 5); const uint32_t addr = src_ptr + swizzled_lane_byte_offset; return addr; }(); load<0, 0>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); load<0, 1>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); mma_ABt<0, 3, 1>(P_ij, Q_i, K_j, P_ij); sub_row<0, 1, L_i>(P_ij, P_ij); asm volatile("s_waitcnt lgkmcnt(8)"); mma_ABt<0, 3, 2>(P_ij, Q_i, K_j, P_ij); load<0, 2>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); load<0, 3>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); mma_ABt<0, 3, 3>(P_ij, Q_i, K_j, P_ij); // Dot slice 1 kittens::macros::v_mov_b32(0xff800000); if constexpr (causal) { // If the query position is less than the key position, set P_ij to -inf if (q_pos < k_pos) { mov(P_ij); // If the query position is equal to the key position, we need to apply a causal mask } else if (q_pos == k_pos) { // Apply the causal mask to [0, 1] and set [0, 2:4] to -inf make_causal<0, 1, neg_inf_v>(P_ij, P_ij); mov<0, 2, neg_inf_v>(P_ij); mov<0, 3, neg_inf_v>(P_ij); } } mul<0, 2>(P_ij, P_ij, P_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(8)"); // mma_ABt(dP_ij, dO_i, V_j); mma_ABt<0, 0, 0>(dP_ij, dO_i, V_j); sub_row<0, 2, L_i>(P_ij, P_ij); mma_ABt<0, 0, 1>(dP_ij, dO_i, V_j, dP_ij); exp2<0, 0>(P_ij, P_ij); mma_ABt<0, 0, 2>(dP_ij, dO_i, V_j, dP_ij); // Load Q_i_col from shared memory to registers // load(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0})); // Compute Q_i_col_addr // uint32_t Q_i_col_addr = get_address(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0})); Q_i_col_addr = [&] { const int laneid = kittens::laneid(); const uint32_t src_ptr = reinterpret_cast(&subtile_inplace(Q_i_smem[tic][0], {1, 0}).data[0]); const int row_offset = (laneid % 16) / 4 + (laneid / 32) * 8; const int col_offset = ((laneid % 4) * 4) + 16*((laneid % 32)/16); const int lane_byte_offset = (row_offset * 32 + col_offset) * sizeof(bf16); const int swizzled_lane_byte_offset = lane_byte_offset ^ ((lane_byte_offset >> 9) << 5); const int addr = src_ptr + swizzled_lane_byte_offset; return addr; }(); load<0, 0>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 0, 3>(dP_ij, dO_i, V_j, dP_ij); exp2<0, 1>(P_ij, P_ij); mma_ABt<0, 1, 0>(dP_ij, dO_i, V_j); load<0, 1>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 1, 1>(dP_ij, dO_i, V_j, dP_ij); mul<0, 3>(P_ij, P_ij, P_SCALE_FACTOR); mma_ABt<0, 1, 2>(dP_ij, dO_i, V_j, dP_ij); sub_row<0, 3, L_i>(P_ij, P_ij); mma_ABt<0, 1, 3>(dP_ij, dO_i, V_j, dP_ij); copy<0, 0>(P_ij_bf16, P_ij); mma_ABt<0, 2, 0>(dP_ij, dO_i, V_j); exp2<0, 2>(P_ij, P_ij); mma_ABt<0, 2, 1>(dP_ij, dO_i, V_j, dP_ij); copy<0, 1>(P_ij_bf16, P_ij); mma_ABt<0, 2, 2>(dP_ij, dO_i, V_j, dP_ij); load<0, 2>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 2, 3>(dP_ij, dO_i, V_j, dP_ij); exp2<0, 3>(P_ij, P_ij); mma_ABt<0, 3, 0>(dP_ij, dO_i, V_j); load<0, 3>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 3, 1>(dP_ij, dO_i, V_j, dP_ij); copy<0, 2>(P_ij_bf16, P_ij); copy<0, 3>(P_ij_bf16, P_ij); mma_ABt<0, 3, 2>(dP_ij, dO_i, V_j, dP_ij); swap_layout_inplace(P_ij_bf16_col, P_ij_bf16); mma_ABt<0, 3, 3>(dP_ij, dO_i, V_j, dP_ij); asm volatile("s_waitcnt lgkmcnt(8)"); // mma_AtB(dV_j_T, dO_i_col, P_ij_bf16_col); mma_AtB<0, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); // Load K_j_col from shared memory to registers // load(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid})); load<0, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<0, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); sub_row<0, 0, delta_i>(dP_ij, dP_ij); sub_row<0, 1, delta_i>(dP_ij, dP_ij); mma_AtB<1, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); load<1, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<1, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<1, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); mul<0, 0>(dP_ij, dP_ij, P_ij); mul<0, 1>(dP_ij, dP_ij, P_ij); copy<0, 0>(dP_ij_bf16, dP_ij); copy<0, 1>(dP_ij_bf16, dP_ij); sub_row<0, 2, delta_i>(dP_ij, dP_ij); mma_AtB<2, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); load<2, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); // 12. dV_j += P_ij^T @ dO_i // 16. dK_j += dS_ij^T @ Q_i (128x64)=(128x16)x(16x64) // Store dP_ij_bf16_accum_row to shared memory // store(attn_i_smem_subtile, dP_ij_bf16_accum_row); store<0, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); store<1, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); mma_AtB<2, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); sub_row<0, 3, delta_i>(dP_ij, dP_ij); mul<0, 2>(dP_ij, dP_ij, P_ij); mul<0, 3>(dP_ij, dP_ij, P_ij); copy<0, 2>(dP_ij_bf16, dP_ij); copy<0, 3>(dP_ij_bf16, dP_ij); mma_AtB<3, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); // dot slice 2 load(subvec_inplace(L_smem[tic], 2)); load(subvec_inplace(delta_smem[tic], 2)); store<2, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); store<3, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); mma_AtB<3, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); swap_layout_inplace(dP_ij_bf16_col, dP_ij_bf16); asm volatile("s_waitcnt lgkmcnt(12)"); // mma_AtB(dK_j_T, Q_i_col, dP_ij_bf16_col); mma_AtB<0, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); load<2, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<3, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<3, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<4, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); asm volatile("s_waitcnt lgkmcnt(8)"); __builtin_amdgcn_s_barrier(); mma_AtB<1, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); // Load dP_ij_bf16_col_T from shared memory to registers // load(dP_ij_bf16_col_T, attn_i_smem); load<0, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<1, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<2, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<3, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); mma_AtB<1, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); mul(L_SCALE_FACTOR); atomic_pk_add_bf16_with_warpid<2, 0, 0>(g.dQg, dQ_i, {batch_idx, q_head_idx, q_seq_idx * 4 + 0, 0}, warpid); mma_AtB<2, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); load<4, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<5, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<4, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<5, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<2, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); atomic_pk_add_bf16_with_warpid<2, 0, 1>(g.dQg, dQ_i, {batch_idx, q_head_idx, q_seq_idx * 4 + 0, 0}, warpid); mma_AtB<3, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); load<6, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<7, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<5, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<3, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); asm volatile("s_waitcnt vmcnt(4) lgkmcnt(6)"); __builtin_amdgcn_s_barrier(); // 15. dQ_i += dS_ij @ K_j (32x16)=(32x256)x(256x16) // mma_AtB(dQ_i_T, K_j_col, dP_ij_bf16_col_T); mma_AtB<0, 0, 0>(dQ_i_T, K_j_col, dP_ij_bf16_col_T); load<6, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<6, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 0, 1>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); G::load<1, false>(Q_i_smem[toc][1], g.Q, {batch_idx, next_q_seq_idx * 2 + 1, next_q_head_idx, 0}, swizzled_offsets_Q_dO); mma_AtB<0, 0, 2>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<7, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<7, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 0, 3>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mma_AtB<0, 0, 4>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); // Load Q_i from shared memory to registers // load(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0})); Q_i_addr = get_address(Q_i, subtile_inplace(Q_i_smem[tic][1], {0, 0})); load<0, 0>(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_addr); load<0, 1>(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_addr); mma_AtB<0, 0, 5>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); asm volatile("s_waitcnt lgkmcnt(4)"); __builtin_amdgcn_s_barrier(); mma_AtB<0, 0, 6>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<0, 2>(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_addr); load<0, 3>(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_addr); mma_AtB<0, 0, 7>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mma_AtB<1, 0, 0>(dQ_i_T, K_j_col, dP_ij_bf16_col_T); // Load K_j from shared memory to registers // load(K_j, subtile_inplace(K_j_smem, {warpid, 0})); load<0, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<0, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 1>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mma_AtB<1, 0, 2>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<0, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<0, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 3>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mul<0, 0, 0>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<0, 0, 1>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mma_AtB<1, 0, 4>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<1, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<1, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 5>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mul<0, 0, 2>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<0, 0, 3>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(10)"); mma_AtB<1, 0, 6>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<1, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<1, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 7>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); asm volatile("s_waitcnt lgkmcnt(2)"); } // dot slice 2 { // 10. S_ij = Q_i K_j^T * scale // 11. P_ij = exp2(S_ij - L_i) // 13. dP_ij = dO_i @ V_j^T // 14. dS_ij = P_ij o (dP_ij - delta_i) // mma_ABt(P_ij, Q_i, K_j); mma_ABt<0, 0, 0>(P_ij, Q_i, K_j); load<2, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<2, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 0, 1>(P_ij, Q_i, K_j, P_ij); mma_ABt<0, 0, 2>(P_ij, Q_i, K_j, P_ij); load<2, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<2, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 0, 3>(P_ij, Q_i, K_j, P_ij); mul<1, 0, 0>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<1, 0, 1>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mma_ABt<0, 1, 0>(P_ij, Q_i, K_j); load<3, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<3, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 1, 1>(P_ij, Q_i, K_j, P_ij); mul<1, 0, 2>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<1, 0, 3>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mma_ABt<0, 1, 2>(P_ij, Q_i, K_j, P_ij); load<3, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<3, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 1, 3>(P_ij, Q_i, K_j, P_ij); mul<0, 0>(P_ij, P_ij, P_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(6)"); mma_ABt<0, 2, 0>(P_ij, Q_i, K_j); // Load dO_i from shared memory to registers // load(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0})); dO_i_addr = get_address(dO_i, subtile_inplace(dO_i_smem[tic][1], {0, 0})); load<0, 0>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); load<0, 1>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); mma_ABt<0, 2, 1>(P_ij, Q_i, K_j, P_ij); sub_row<0, 0, L_i>(P_ij, P_ij); asm volatile("s_waitcnt lgkmcnt(6)"); mma_ABt<0, 2, 2>(P_ij, Q_i, K_j, P_ij); load<0, 2>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); load<0, 3>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); mma_ABt<0, 2, 3>(P_ij, Q_i, K_j, P_ij); mul<0, 1>(P_ij, P_ij, P_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(6)"); mma_ABt<0, 3, 0>(P_ij, Q_i, K_j); // Load dO_i_col from shared memory to registers // load(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0})); // Compute dO_i_col_addr // uint32_t dO_i_col_addr = get_address(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0})); dO_i_col_addr = [&] { const int laneid = kittens::laneid(); const uint32_t src_ptr = reinterpret_cast(&subtile_inplace(dO_i_smem[tic][1], {0, 0}).data[0]); const int row_offset = (laneid % 16) / 4 + (laneid / 32) * 8; const int col_offset = ((laneid % 4) * 4) + 16*((laneid % 32)/16); const int lane_byte_offset = (row_offset * 32 + col_offset) * sizeof(bf16); const int swizzled_lane_byte_offset = lane_byte_offset ^ ((lane_byte_offset >> 9) << 5); const uint32_t addr = src_ptr + swizzled_lane_byte_offset; return addr; }(); load<0, 0>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); load<0, 1>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); mma_ABt<0, 3, 1>(P_ij, Q_i, K_j, P_ij); sub_row<0, 1, L_i>(P_ij, P_ij); asm volatile("s_waitcnt lgkmcnt(8)"); mma_ABt<0, 3, 2>(P_ij, Q_i, K_j, P_ij); load<0, 2>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); load<0, 3>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); mma_ABt<0, 3, 3>(P_ij, Q_i, K_j, P_ij); // Dot slice 2 kittens::macros::v_mov_b32(0xff800000); if constexpr (causal) { // If the query position is less than the key position, set P_ij to -inf if (q_pos < k_pos) { mov(P_ij); // If the query position is equal to the key position, we need to apply a causal mask } else if (q_pos == k_pos) { // Apply the causal mask to [0, 2] and set [0, 3:4] to -inf make_causal<0, 2, neg_inf_v>(P_ij, P_ij); mov<0, 3, neg_inf_v>(P_ij); } } mul<0, 2>(P_ij, P_ij, P_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(8)"); // mma_ABt(dP_ij, dO_i, V_j); mma_ABt<0, 0, 0>(dP_ij, dO_i, V_j); sub_row<0, 2, L_i>(P_ij, P_ij); mma_ABt<0, 0, 1>(dP_ij, dO_i, V_j, dP_ij); exp2<0, 0>(P_ij, P_ij); mma_ABt<0, 0, 2>(dP_ij, dO_i, V_j, dP_ij); // Load Q_i_col from shared memory to registers // load(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0})); // Compute Q_i_col_addr // uint32_t Q_i_col_addr = get_address(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0})); Q_i_col_addr = [&] { const int laneid = kittens::laneid(); const uint32_t src_ptr = reinterpret_cast(&subtile_inplace(Q_i_smem[tic][1], {0, 0}).data[0]); const int row_offset = (laneid % 16) / 4 + (laneid / 32) * 8; const int col_offset = ((laneid % 4) * 4) + 16*((laneid % 32)/16); const int lane_byte_offset = (row_offset * 32 + col_offset) * sizeof(bf16); const int swizzled_lane_byte_offset = lane_byte_offset ^ ((lane_byte_offset >> 9) << 5); const int addr = src_ptr + swizzled_lane_byte_offset; return addr; }(); load<0, 0>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 0, 3>(dP_ij, dO_i, V_j, dP_ij); exp2<0, 1>(P_ij, P_ij); mma_ABt<0, 1, 0>(dP_ij, dO_i, V_j); load<0, 1>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 1, 1>(dP_ij, dO_i, V_j, dP_ij); mul<0, 3>(P_ij, P_ij, P_SCALE_FACTOR); mma_ABt<0, 1, 2>(dP_ij, dO_i, V_j, dP_ij); sub_row<0, 3, L_i>(P_ij, P_ij); mma_ABt<0, 1, 3>(dP_ij, dO_i, V_j, dP_ij); copy<0, 0>(P_ij_bf16, P_ij); mma_ABt<0, 2, 0>(dP_ij, dO_i, V_j); exp2<0, 2>(P_ij, P_ij); mma_ABt<0, 2, 1>(dP_ij, dO_i, V_j, dP_ij); copy<0, 1>(P_ij_bf16, P_ij); mma_ABt<0, 2, 2>(dP_ij, dO_i, V_j, dP_ij); load<0, 2>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 2, 3>(dP_ij, dO_i, V_j, dP_ij); exp2<0, 3>(P_ij, P_ij); mma_ABt<0, 3, 0>(dP_ij, dO_i, V_j); load<0, 3>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 3, 1>(dP_ij, dO_i, V_j, dP_ij); copy<0, 2>(P_ij_bf16, P_ij); copy<0, 3>(P_ij_bf16, P_ij); mma_ABt<0, 3, 2>(dP_ij, dO_i, V_j, dP_ij); swap_layout_inplace(P_ij_bf16_col, P_ij_bf16); mma_ABt<0, 3, 3>(dP_ij, dO_i, V_j, dP_ij); asm volatile("s_waitcnt lgkmcnt(8)"); // mma_AtB(dV_j_T, dO_i_col, P_ij_bf16_col); mma_AtB<0, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); // Load K_j_col from shared memory to registers // load(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid})); load<0, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<0, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); sub_row<0, 0, delta_i>(dP_ij, dP_ij); sub_row<0, 1, delta_i>(dP_ij, dP_ij); mma_AtB<1, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); load<1, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<1, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<1, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); mul<0, 0>(dP_ij, dP_ij, P_ij); mul<0, 1>(dP_ij, dP_ij, P_ij); copy<0, 0>(dP_ij_bf16, dP_ij); copy<0, 1>(dP_ij_bf16, dP_ij); sub_row<0, 2, delta_i>(dP_ij, dP_ij); mma_AtB<2, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); load<2, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); // 12. dV_j += P_ij^T @ dO_i // 16. dK_j += dS_ij^T @ Q_i (128x64)=(128x16)x(16x64) // Store dP_ij_bf16_accum_row to shared memory // store(attn_i_smem_subtile, dP_ij_bf16_accum_row); store<0, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); store<1, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); mma_AtB<2, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); sub_row<0, 3, delta_i>(dP_ij, dP_ij); mul<0, 2>(dP_ij, dP_ij, P_ij); mul<0, 3>(dP_ij, dP_ij, P_ij); copy<0, 2>(dP_ij_bf16, dP_ij); copy<0, 3>(dP_ij_bf16, dP_ij); mma_AtB<3, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); // dot slice 3 load(subvec_inplace(L_smem[tic], 3)); load(subvec_inplace(delta_smem[tic], 3)); store<2, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); store<3, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); mma_AtB<3, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); swap_layout_inplace(dP_ij_bf16_col, dP_ij_bf16); asm volatile("s_waitcnt lgkmcnt(12)"); // mma_AtB(dK_j_T, Q_i_col, dP_ij_bf16_col); mma_AtB<0, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); load<2, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<3, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<3, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<4, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); asm volatile("s_waitcnt lgkmcnt(8)"); __builtin_amdgcn_s_barrier(); mma_AtB<1, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); // Load dP_ij_bf16_col_T from shared memory to registers // load(dP_ij_bf16_col_T, attn_i_smem); load<0, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<1, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<2, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<3, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); mma_AtB<1, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); mul(L_SCALE_FACTOR); atomic_pk_add_bf16_with_warpid<2, 0, 0>(g.dQg, dQ_i, {batch_idx, q_head_idx, q_seq_idx * 4 + 1, 0}, warpid); mma_AtB<2, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); load<4, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<5, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<4, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<5, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<2, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); atomic_pk_add_bf16_with_warpid<2, 0, 1>(g.dQg, dQ_i, {batch_idx, q_head_idx, q_seq_idx * 4 + 1, 0}, warpid); mma_AtB<3, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); load<6, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<7, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<5, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<3, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); asm volatile("s_waitcnt vmcnt(4) lgkmcnt(6)"); __builtin_amdgcn_s_barrier(); // 15. dQ_i += dS_ij @ K_j (32x16)=(32x256)x(256x16) // mma_AtB(dQ_i_T, K_j_col, dP_ij_bf16_col_T); mma_AtB<0, 0, 0>(dQ_i_T, K_j_col, dP_ij_bf16_col_T); load<6, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<6, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 0, 1>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); G::load<1, false>(dO_i_smem[toc][1], g.dOg, {batch_idx, next_q_seq_idx * 2 + 1, next_q_head_idx, 0}, swizzled_offsets_Q_dO); mma_AtB<0, 0, 2>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<7, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<7, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 0, 3>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mma_AtB<0, 0, 4>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); // Load Q_i from shared memory to registers // load(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0})); Q_i_addr = get_address(Q_i, subtile_inplace(Q_i_smem[tic][1], {1, 0})); load<0, 0>(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_addr); load<0, 1>(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_addr); mma_AtB<0, 0, 5>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); asm volatile("s_waitcnt lgkmcnt(4)"); __builtin_amdgcn_s_barrier(); mma_AtB<0, 0, 6>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<0, 2>(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_addr); load<0, 3>(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_addr); mma_AtB<0, 0, 7>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mma_AtB<1, 0, 0>(dQ_i_T, K_j_col, dP_ij_bf16_col_T); // Load K_j from shared memory to registers // load(K_j, subtile_inplace(K_j_smem, {warpid, 0})); load<0, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<0, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 1>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mma_AtB<1, 0, 2>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<0, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<0, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 3>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mul<0, 0, 0>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<0, 0, 1>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mma_AtB<1, 0, 4>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<1, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<1, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 5>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mul<0, 0, 2>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<0, 0, 3>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(10)"); mma_AtB<1, 0, 6>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<1, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<1, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 7>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); asm volatile("s_waitcnt lgkmcnt(2)"); } // dot slice 3 { // 10. S_ij = Q_i K_j^T * scale // 11. P_ij = exp2(S_ij - L_i) // 13. dP_ij = dO_i @ V_j^T // 14. dS_ij = P_ij o (dP_ij - delta_i) // mma_ABt(P_ij, Q_i, K_j); mma_ABt<0, 0, 0>(P_ij, Q_i, K_j); load<2, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<2, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 0, 1>(P_ij, Q_i, K_j, P_ij); mma_ABt<0, 0, 2>(P_ij, Q_i, K_j, P_ij); load<2, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<2, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 0, 3>(P_ij, Q_i, K_j, P_ij); mul<1, 0, 0>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<1, 0, 1>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mma_ABt<0, 1, 0>(P_ij, Q_i, K_j); load<3, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<3, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 1, 1>(P_ij, Q_i, K_j, P_ij); mul<1, 0, 2>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<1, 0, 3>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mma_ABt<0, 1, 2>(P_ij, Q_i, K_j, P_ij); load<3, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<3, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 1, 3>(P_ij, Q_i, K_j, P_ij); mul<0, 0>(P_ij, P_ij, P_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(6)"); mma_ABt<0, 2, 0>(P_ij, Q_i, K_j); // Load dO_i from shared memory to registers // load(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0})); dO_i_addr = get_address(dO_i, subtile_inplace(dO_i_smem[tic][1], {1, 0})); load<0, 0>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); load<0, 1>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); mma_ABt<0, 2, 1>(P_ij, Q_i, K_j, P_ij); sub_row<0, 0, L_i>(P_ij, P_ij); asm volatile("s_waitcnt lgkmcnt(6)"); mma_ABt<0, 2, 2>(P_ij, Q_i, K_j, P_ij); load<0, 2>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); load<0, 3>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); mma_ABt<0, 2, 3>(P_ij, Q_i, K_j, P_ij); mul<0, 1>(P_ij, P_ij, P_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(6)"); mma_ABt<0, 3, 0>(P_ij, Q_i, K_j); // Load dO_i_col from shared memory to registers // load(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0})); // Compute dO_i_col_addr // uint32_t dO_i_col_addr = get_address(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0})); dO_i_col_addr = [&] { const int laneid = kittens::laneid(); const uint32_t src_ptr = reinterpret_cast(&subtile_inplace(dO_i_smem[tic][1], {1, 0}).data[0]); const int row_offset = (laneid % 16) / 4 + (laneid / 32) * 8; const int col_offset = ((laneid % 4) * 4) + 16*((laneid % 32)/16); const int lane_byte_offset = (row_offset * 32 + col_offset) * sizeof(bf16); const int swizzled_lane_byte_offset = lane_byte_offset ^ ((lane_byte_offset >> 9) << 5); const uint32_t addr = src_ptr + swizzled_lane_byte_offset; return addr; }(); load<0, 0>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); load<0, 1>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); mma_ABt<0, 3, 1>(P_ij, Q_i, K_j, P_ij); sub_row<0, 1, L_i>(P_ij, P_ij); asm volatile("s_waitcnt lgkmcnt(8)"); mma_ABt<0, 3, 2>(P_ij, Q_i, K_j, P_ij); load<0, 2>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); load<0, 3>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); mma_ABt<0, 3, 3>(P_ij, Q_i, K_j, P_ij); // Dot slice 3 kittens::macros::v_mov_b32(0xff800000); if constexpr (causal) { // If the query position is less than the key position, set P_ij to -inf if (q_pos < k_pos) { mov(P_ij); // If the query position is equal to the key position, we need to apply a causal mask } else if (q_pos == k_pos) { // Apply the causal mask to [0, 3] make_causal<0, 3, neg_inf_v>(P_ij, P_ij); } } mul<0, 2>(P_ij, P_ij, P_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(8)"); // mma_ABt(dP_ij, dO_i, V_j); mma_ABt<0, 0, 0>(dP_ij, dO_i, V_j); sub_row<0, 2, L_i>(P_ij, P_ij); mma_ABt<0, 0, 1>(dP_ij, dO_i, V_j, dP_ij); exp2<0, 0>(P_ij, P_ij); mma_ABt<0, 0, 2>(dP_ij, dO_i, V_j, dP_ij); // Load Q_i_col from shared memory to registers // load(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0})); // Compute Q_i_col_addr // uint32_t Q_i_col_addr = get_address(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0})); Q_i_col_addr = [&] { const int laneid = kittens::laneid(); const uint32_t src_ptr = reinterpret_cast(&subtile_inplace(Q_i_smem[tic][1], {1, 0}).data[0]); const int row_offset = (laneid % 16) / 4 + (laneid / 32) * 8; const int col_offset = ((laneid % 4) * 4) + 16*((laneid % 32)/16); const int lane_byte_offset = (row_offset * 32 + col_offset) * sizeof(bf16); const int swizzled_lane_byte_offset = lane_byte_offset ^ ((lane_byte_offset >> 9) << 5); const int addr = src_ptr + swizzled_lane_byte_offset; return addr; }(); load<0, 0>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 0, 3>(dP_ij, dO_i, V_j, dP_ij); exp2<0, 1>(P_ij, P_ij); mma_ABt<0, 1, 0>(dP_ij, dO_i, V_j); load<0, 1>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 1, 1>(dP_ij, dO_i, V_j, dP_ij); mul<0, 3>(P_ij, P_ij, P_SCALE_FACTOR); mma_ABt<0, 1, 2>(dP_ij, dO_i, V_j, dP_ij); sub_row<0, 3, L_i>(P_ij, P_ij); mma_ABt<0, 1, 3>(dP_ij, dO_i, V_j, dP_ij); copy<0, 0>(P_ij_bf16, P_ij); mma_ABt<0, 2, 0>(dP_ij, dO_i, V_j); exp2<0, 2>(P_ij, P_ij); mma_ABt<0, 2, 1>(dP_ij, dO_i, V_j, dP_ij); copy<0, 1>(P_ij_bf16, P_ij); mma_ABt<0, 2, 2>(dP_ij, dO_i, V_j, dP_ij); load<0, 2>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 2, 3>(dP_ij, dO_i, V_j, dP_ij); exp2<0, 3>(P_ij, P_ij); mma_ABt<0, 3, 0>(dP_ij, dO_i, V_j); load<0, 3>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 3, 1>(dP_ij, dO_i, V_j, dP_ij); copy<0, 2>(P_ij_bf16, P_ij); copy<0, 3>(P_ij_bf16, P_ij); mma_ABt<0, 3, 2>(dP_ij, dO_i, V_j, dP_ij); swap_layout_inplace(P_ij_bf16_col, P_ij_bf16); mma_ABt<0, 3, 3>(dP_ij, dO_i, V_j, dP_ij); asm volatile("s_waitcnt lgkmcnt(8)"); // mma_AtB(dV_j_T, dO_i_col, P_ij_bf16_col); mma_AtB<0, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); // Load K_j_col from shared memory to registers // load(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid})); load<0, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<0, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); sub_row<0, 0, delta_i>(dP_ij, dP_ij); sub_row<0, 1, delta_i>(dP_ij, dP_ij); mma_AtB<1, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); load<1, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<1, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<1, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); mul<0, 0>(dP_ij, dP_ij, P_ij); mul<0, 1>(dP_ij, dP_ij, P_ij); copy<0, 0>(dP_ij_bf16, dP_ij); copy<0, 1>(dP_ij_bf16, dP_ij); sub_row<0, 2, delta_i>(dP_ij, dP_ij); mma_AtB<2, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); load<2, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); // 12. dV_j += P_ij^T @ dO_i // 16. dK_j += dS_ij^T @ Q_i (128x64)=(128x16)x(16x64) // Store dP_ij_bf16_accum_row to shared memory // store(attn_i_smem_subtile, dP_ij_bf16_accum_row); store<0, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); store<1, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); mma_AtB<2, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); sub_row<0, 3, delta_i>(dP_ij, dP_ij); mul<0, 2>(dP_ij, dP_ij, P_ij); mul<0, 3>(dP_ij, dP_ij, P_ij); copy<0, 2>(dP_ij_bf16, dP_ij); copy<0, 3>(dP_ij_bf16, dP_ij); mma_AtB<3, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); // dot slice 0 - next iteration load(subvec_inplace(L_smem[toc], 0)); load(subvec_inplace(delta_smem[toc], 0)); store<2, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); store<3, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); mma_AtB<3, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); swap_layout_inplace(dP_ij_bf16_col, dP_ij_bf16); asm volatile("s_waitcnt lgkmcnt(12)"); // mma_AtB(dK_j_T, Q_i_col, dP_ij_bf16_col); mma_AtB<0, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); load<2, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<3, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<3, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<4, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); asm volatile("s_waitcnt lgkmcnt(8)"); __builtin_amdgcn_s_barrier(); mma_AtB<1, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); // Load dP_ij_bf16_col_T from shared memory to registers // load(dP_ij_bf16_col_T, attn_i_smem); load<0, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<1, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<2, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<3, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); mma_AtB<1, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); mul(L_SCALE_FACTOR); atomic_pk_add_bf16_with_warpid<2, 0, 0>(g.dQg, dQ_i, {batch_idx, q_head_idx, q_seq_idx * 4 + 2, 0}, warpid); mma_AtB<2, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); load<4, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<5, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<4, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<5, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<2, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); atomic_pk_add_bf16_with_warpid<2, 0, 1>(g.dQg, dQ_i, {batch_idx, q_head_idx, q_seq_idx * 4 + 2, 0}, warpid); mma_AtB<3, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); load<6, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<7, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<5, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<3, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); asm volatile("s_waitcnt vmcnt(4) lgkmcnt(6)"); __builtin_amdgcn_s_barrier(); // 15. dQ_i += dS_ij @ K_j (32x16)=(32x256)x(256x16) // mma_AtB(dQ_i_T, K_j_col, dP_ij_bf16_col_T); mma_AtB<0, 0, 0>(dQ_i_T, K_j_col, dP_ij_bf16_col_T); load<6, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<6, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 0, 1>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mma_AtB<0, 0, 2>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<7, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<7, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 0, 3>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mma_AtB<0, 0, 4>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); // Load Q_i from shared memory to registers // load(Q_i, subtile_inplace(Q_i_smem[toc][0], {0, 0})); Q_i_addr = get_address(Q_i, subtile_inplace(Q_i_smem[toc][0], {0, 0})); load<0, 0>(Q_i, subtile_inplace(Q_i_smem[toc][0], {0, 0}), Q_i_addr); load<0, 1>(Q_i, subtile_inplace(Q_i_smem[toc][0], {0, 0}), Q_i_addr); mma_AtB<0, 0, 5>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); asm volatile("s_waitcnt lgkmcnt(4)"); __builtin_amdgcn_s_barrier(); mma_AtB<0, 0, 6>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<0, 2>(Q_i, subtile_inplace(Q_i_smem[toc][0], {0, 0}), Q_i_addr); load<0, 3>(Q_i, subtile_inplace(Q_i_smem[toc][0], {0, 0}), Q_i_addr); mma_AtB<0, 0, 7>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mma_AtB<1, 0, 0>(dQ_i_T, K_j_col, dP_ij_bf16_col_T); // Load K_j from shared memory to registers // load(K_j, subtile_inplace(K_j_smem, {warpid, 0})); load<0, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<0, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 1>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mma_AtB<1, 0, 2>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<0, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<0, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 3>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mul<0, 0, 0>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<0, 0, 1>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mma_AtB<1, 0, 4>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<1, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<1, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 5>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mul<0, 0, 2>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<0, 0, 3>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(10)"); mma_AtB<1, 0, 6>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<1, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<1, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 7>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); asm volatile("s_waitcnt lgkmcnt(2)"); } } } const int last_q_head_idx = (num_steps - 2) / num_steps_per_head + first_q_head; const int last_q_seq_idx = ((num_steps - 2) % num_steps_per_head) + first_step; const int q_head_idx = (num_steps - 1) / num_steps_per_head + first_q_head; const int q_seq_idx = ((num_steps - 1) % num_steps_per_head) + first_step; const int q_pos = q_seq_idx * STEP_QO; // Epilogue { // dot slice 0 { // 10. S_ij = Q_i K_j^T * scale // 11. P_ij = exp2(S_ij - L_i) // 13. dP_ij = dO_i @ V_j^T // 14. dS_ij = P_ij o (dP_ij - delta_i) // mma_ABt(P_ij, Q_i, K_j); mma_ABt<0, 0, 0>(P_ij, Q_i, K_j); load<2, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<2, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 0, 1>(P_ij, Q_i, K_j, P_ij); mma_ABt<0, 0, 2>(P_ij, Q_i, K_j, P_ij); load<2, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<2, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 0, 3>(P_ij, Q_i, K_j, P_ij); mul<1, 0, 0>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<1, 0, 1>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mma_ABt<0, 1, 0>(P_ij, Q_i, K_j); load<3, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<3, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 1, 1>(P_ij, Q_i, K_j, P_ij); mul<1, 0, 2>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<1, 0, 3>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mma_ABt<0, 1, 2>(P_ij, Q_i, K_j, P_ij); load<3, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<3, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 1, 3>(P_ij, Q_i, K_j, P_ij); mul<0, 0>(P_ij, P_ij, P_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(6)"); mma_ABt<0, 2, 0>(P_ij, Q_i, K_j); // Load dO_i from shared memory to registers // load(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0})); dO_i_addr = get_address(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0})); load<0, 0>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); load<0, 1>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); mma_ABt<0, 2, 1>(P_ij, Q_i, K_j, P_ij); sub_row<0, 0, L_i>(P_ij, P_ij); asm volatile("s_waitcnt lgkmcnt(6)"); mma_ABt<0, 2, 2>(P_ij, Q_i, K_j, P_ij); load<0, 2>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); load<0, 3>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); mma_ABt<0, 2, 3>(P_ij, Q_i, K_j, P_ij); mul<0, 1>(P_ij, P_ij, P_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(6)"); mma_ABt<0, 3, 0>(P_ij, Q_i, K_j); // Load dO_i_col from shared memory to registers // load(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0})); // Compute dO_i_col_addr // uint32_t dO_i_col_addr = get_address(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0})); dO_i_col_addr = [&] { const int laneid = kittens::laneid(); const uint32_t src_ptr = reinterpret_cast(&subtile_inplace(dO_i_smem[tic][0], {0, 0}).data[0]); const int row_offset = (laneid % 16) / 4 + (laneid / 32) * 8; const int col_offset = ((laneid % 4) * 4) + 16*((laneid % 32)/16); const int lane_byte_offset = (row_offset * 32 + col_offset) * sizeof(bf16); const int swizzled_lane_byte_offset = lane_byte_offset ^ ((lane_byte_offset >> 9) << 5); const uint32_t addr = src_ptr + swizzled_lane_byte_offset; return addr; }(); load<0, 0>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); load<0, 1>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); mma_ABt<0, 3, 1>(P_ij, Q_i, K_j, P_ij); sub_row<0, 1, L_i>(P_ij, P_ij); asm volatile("s_waitcnt lgkmcnt(8)"); mma_ABt<0, 3, 2>(P_ij, Q_i, K_j, P_ij); load<0, 2>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); load<0, 3>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); mma_ABt<0, 3, 3>(P_ij, Q_i, K_j, P_ij); // Dot slice 0 kittens::macros::v_mov_b32(0xff800000); if constexpr (causal) { // If the query position is less than the key position, set P_ij to -inf if (q_pos < k_pos) { mov(P_ij); // If the query position is equal to the key position, we need to apply a causal mask } else if (q_pos == k_pos) { // Apply the causal mask to [0, 0] and set [0, 1:4] to -inf make_causal<0, 0, neg_inf_v>(P_ij, P_ij); mov<0, 1, neg_inf_v>(P_ij); mov<0, 2, neg_inf_v>(P_ij); mov<0, 3, neg_inf_v>(P_ij); } } mul<0, 2>(P_ij, P_ij, P_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(8)"); // mma_ABt(dP_ij, dO_i, V_j); mma_ABt<0, 0, 0>(dP_ij, dO_i, V_j); sub_row<0, 2, L_i>(P_ij, P_ij); mma_ABt<0, 0, 1>(dP_ij, dO_i, V_j, dP_ij); exp2<0, 0>(P_ij, P_ij); mma_ABt<0, 0, 2>(dP_ij, dO_i, V_j, dP_ij); // Load Q_i_col from shared memory to registers // load(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0})); // Compute Q_i_col_addr // uint32_t Q_i_col_addr = get_address(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0})); Q_i_col_addr = [&] { const int laneid = kittens::laneid(); const uint32_t src_ptr = reinterpret_cast(&subtile_inplace(Q_i_smem[tic][0], {0, 0}).data[0]); const int row_offset = (laneid % 16) / 4 + (laneid / 32) * 8; const int col_offset = ((laneid % 4) * 4) + 16*((laneid % 32)/16); const int lane_byte_offset = (row_offset * 32 + col_offset) * sizeof(bf16); const int swizzled_lane_byte_offset = lane_byte_offset ^ ((lane_byte_offset >> 9) << 5); const int addr = src_ptr + swizzled_lane_byte_offset; return addr; }(); load<0, 0>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 0, 3>(dP_ij, dO_i, V_j, dP_ij); exp2<0, 1>(P_ij, P_ij); mma_ABt<0, 1, 0>(dP_ij, dO_i, V_j); load<0, 1>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 1, 1>(dP_ij, dO_i, V_j, dP_ij); mul<0, 3>(P_ij, P_ij, P_SCALE_FACTOR); mma_ABt<0, 1, 2>(dP_ij, dO_i, V_j, dP_ij); sub_row<0, 3, L_i>(P_ij, P_ij); mma_ABt<0, 1, 3>(dP_ij, dO_i, V_j, dP_ij); copy<0, 0>(P_ij_bf16, P_ij); mma_ABt<0, 2, 0>(dP_ij, dO_i, V_j); exp2<0, 2>(P_ij, P_ij); mma_ABt<0, 2, 1>(dP_ij, dO_i, V_j, dP_ij); copy<0, 1>(P_ij_bf16, P_ij); mma_ABt<0, 2, 2>(dP_ij, dO_i, V_j, dP_ij); load<0, 2>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 2, 3>(dP_ij, dO_i, V_j, dP_ij); exp2<0, 3>(P_ij, P_ij); mma_ABt<0, 3, 0>(dP_ij, dO_i, V_j); load<0, 3>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 3, 1>(dP_ij, dO_i, V_j, dP_ij); copy<0, 2>(P_ij_bf16, P_ij); copy<0, 3>(P_ij_bf16, P_ij); mma_ABt<0, 3, 2>(dP_ij, dO_i, V_j, dP_ij); swap_layout_inplace(P_ij_bf16_col, P_ij_bf16); mma_ABt<0, 3, 3>(dP_ij, dO_i, V_j, dP_ij); asm volatile("s_waitcnt lgkmcnt(8)"); // mma_AtB(dV_j_T, dO_i_col, P_ij_bf16_col); mma_AtB<0, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); // Load K_j_col from shared memory to registers // load(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid})); load<0, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<0, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); sub_row<0, 0, delta_i>(dP_ij, dP_ij); sub_row<0, 1, delta_i>(dP_ij, dP_ij); mma_AtB<1, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); load<1, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<1, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<1, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); mul<0, 0>(dP_ij, dP_ij, P_ij); mul<0, 1>(dP_ij, dP_ij, P_ij); copy<0, 0>(dP_ij_bf16, dP_ij); copy<0, 1>(dP_ij_bf16, dP_ij); sub_row<0, 2, delta_i>(dP_ij, dP_ij); mma_AtB<2, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); load<2, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); // 12. dV_j += P_ij^T @ dO_i // 16. dK_j += dS_ij^T @ Q_i (128x64)=(128x16)x(16x64) // Store dP_ij_bf16_accum_row to shared memory // store(attn_i_smem_subtile, dP_ij_bf16_accum_row); store<0, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); store<1, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); mma_AtB<2, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); sub_row<0, 3, delta_i>(dP_ij, dP_ij); mul<0, 2>(dP_ij, dP_ij, P_ij); mul<0, 3>(dP_ij, dP_ij, P_ij); copy<0, 2>(dP_ij_bf16, dP_ij); copy<0, 3>(dP_ij_bf16, dP_ij); mma_AtB<3, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); // dot slice 1 load(subvec_inplace(L_smem[tic], 1)); load(subvec_inplace(delta_smem[tic], 1)); store<2, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); store<3, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); mma_AtB<3, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); swap_layout_inplace(dP_ij_bf16_col, dP_ij_bf16); asm volatile("s_waitcnt lgkmcnt(12)"); // mma_AtB(dK_j_T, Q_i_col, dP_ij_bf16_col); mma_AtB<0, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); load<2, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<3, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<3, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<4, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); asm volatile("s_waitcnt lgkmcnt(8)"); __builtin_amdgcn_s_barrier(); mma_AtB<1, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); // Load dP_ij_bf16_col_T from shared memory to registers // load(dP_ij_bf16_col_T, attn_i_smem); load<0, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<1, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<2, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<3, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); mma_AtB<1, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); mul(L_SCALE_FACTOR); if (num_steps > 1) { atomic_pk_add_bf16_with_warpid<2, 0, 0>(g.dQg, dQ_i, {batch_idx, last_q_head_idx, last_q_seq_idx * 4 + 3, 0}, warpid); } mma_AtB<2, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); load<4, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<5, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<4, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<5, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<2, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); if (num_steps > 1) { atomic_pk_add_bf16_with_warpid<2, 0, 1>(g.dQg, dQ_i, {batch_idx, last_q_head_idx, last_q_seq_idx * 4 + 3, 0}, warpid); } mma_AtB<3, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); load<6, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<7, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<5, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<3, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); asm volatile("s_waitcnt lgkmcnt(6)"); __builtin_amdgcn_s_barrier(); // 15. dQ_i += dS_ij @ K_j (32x16)=(32x256)x(256x16) // mma_AtB(dQ_i_T, K_j_col, dP_ij_bf16_col_T); mma_AtB<0, 0, 0>(dQ_i_T, K_j_col, dP_ij_bf16_col_T); load<6, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<6, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 0, 1>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mma_AtB<0, 0, 2>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<7, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<7, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 0, 3>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mma_AtB<0, 0, 4>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); // Load Q_i from shared memory to registers // load(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0})); Q_i_addr = get_address(Q_i, subtile_inplace(Q_i_smem[tic][0], {1, 0})); load<0, 0>(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_addr); load<0, 1>(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_addr); mma_AtB<0, 0, 5>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); asm volatile("s_waitcnt lgkmcnt(4)"); __builtin_amdgcn_s_barrier(); mma_AtB<0, 0, 6>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<0, 2>(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_addr); load<0, 3>(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_addr); mma_AtB<0, 0, 7>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mma_AtB<1, 0, 0>(dQ_i_T, K_j_col, dP_ij_bf16_col_T); // Load K_j from shared memory to registers // load(K_j, subtile_inplace(K_j_smem, {warpid, 0})); load<0, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<0, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 1>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mma_AtB<1, 0, 2>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<0, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<0, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 3>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mul<0, 0, 0>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<0, 0, 1>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mma_AtB<1, 0, 4>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<1, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<1, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 5>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mul<0, 0, 2>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<0, 0, 3>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(10)"); mma_AtB<1, 0, 6>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<1, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<1, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 7>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); asm volatile("s_waitcnt lgkmcnt(2)"); } // dot slice 1 { // 10. S_ij = Q_i K_j^T * scale // 11. P_ij = exp2(S_ij - L_i) // 13. dP_ij = dO_i @ V_j^T // 14. dS_ij = P_ij o (dP_ij - delta_i) // mma_ABt(P_ij, Q_i, K_j); mma_ABt<0, 0, 0>(P_ij, Q_i, K_j); load<2, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<2, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 0, 1>(P_ij, Q_i, K_j, P_ij); mma_ABt<0, 0, 2>(P_ij, Q_i, K_j, P_ij); load<2, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<2, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 0, 3>(P_ij, Q_i, K_j, P_ij); mul<1, 0, 0>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<1, 0, 1>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mma_ABt<0, 1, 0>(P_ij, Q_i, K_j); load<3, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<3, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 1, 1>(P_ij, Q_i, K_j, P_ij); mul<1, 0, 2>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<1, 0, 3>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mma_ABt<0, 1, 2>(P_ij, Q_i, K_j, P_ij); load<3, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<3, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 1, 3>(P_ij, Q_i, K_j, P_ij); mul<0, 0>(P_ij, P_ij, P_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(6)"); mma_ABt<0, 2, 0>(P_ij, Q_i, K_j); // Load dO_i from shared memory to registers // load(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0})); dO_i_addr = get_address(dO_i, subtile_inplace(dO_i_smem[tic][0], {1, 0})); load<0, 0>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); load<0, 1>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); mma_ABt<0, 2, 1>(P_ij, Q_i, K_j, P_ij); sub_row<0, 0, L_i>(P_ij, P_ij); asm volatile("s_waitcnt lgkmcnt(6)"); mma_ABt<0, 2, 2>(P_ij, Q_i, K_j, P_ij); load<0, 2>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); load<0, 3>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); mma_ABt<0, 2, 3>(P_ij, Q_i, K_j, P_ij); mul<0, 1>(P_ij, P_ij, P_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(6)"); mma_ABt<0, 3, 0>(P_ij, Q_i, K_j); // Load dO_i_col from shared memory to registers // load(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0})); // Compute dO_i_col_addr // uint32_t dO_i_col_addr = get_address(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0})); dO_i_col_addr = [&] { const int laneid = kittens::laneid(); const uint32_t src_ptr = reinterpret_cast(&subtile_inplace(dO_i_smem[tic][0], {1, 0}).data[0]); const int row_offset = (laneid % 16) / 4 + (laneid / 32) * 8; const int col_offset = ((laneid % 4) * 4) + 16*((laneid % 32)/16); const int lane_byte_offset = (row_offset * 32 + col_offset) * sizeof(bf16); const int swizzled_lane_byte_offset = lane_byte_offset ^ ((lane_byte_offset >> 9) << 5); const uint32_t addr = src_ptr + swizzled_lane_byte_offset; return addr; }(); load<0, 0>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); load<0, 1>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); mma_ABt<0, 3, 1>(P_ij, Q_i, K_j, P_ij); sub_row<0, 1, L_i>(P_ij, P_ij); asm volatile("s_waitcnt lgkmcnt(8)"); mma_ABt<0, 3, 2>(P_ij, Q_i, K_j, P_ij); load<0, 2>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); load<0, 3>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); mma_ABt<0, 3, 3>(P_ij, Q_i, K_j, P_ij); // Dot slice 1 kittens::macros::v_mov_b32(0xff800000); if constexpr (causal) { // If the query position is less than the key position, set P_ij to -inf if (q_pos < k_pos) { mov(P_ij); // If the query position is equal to the key position, we need to apply a causal mask } else if (q_pos == k_pos) { // Apply the causal mask to [0, 1] and set [0, 2:4] to -inf make_causal<0, 1, neg_inf_v>(P_ij, P_ij); mov<0, 2, neg_inf_v>(P_ij); mov<0, 3, neg_inf_v>(P_ij); } } mul<0, 2>(P_ij, P_ij, P_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(8)"); // mma_ABt(dP_ij, dO_i, V_j); mma_ABt<0, 0, 0>(dP_ij, dO_i, V_j); sub_row<0, 2, L_i>(P_ij, P_ij); mma_ABt<0, 0, 1>(dP_ij, dO_i, V_j, dP_ij); exp2<0, 0>(P_ij, P_ij); mma_ABt<0, 0, 2>(dP_ij, dO_i, V_j, dP_ij); // Load Q_i_col from shared memory to registers // load(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0})); // Compute Q_i_col_addr // uint32_t Q_i_col_addr = get_address(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0})); Q_i_col_addr = [&] { const int laneid = kittens::laneid(); const uint32_t src_ptr = reinterpret_cast(&subtile_inplace(Q_i_smem[tic][0], {1, 0}).data[0]); const int row_offset = (laneid % 16) / 4 + (laneid / 32) * 8; const int col_offset = ((laneid % 4) * 4) + 16*((laneid % 32)/16); const int lane_byte_offset = (row_offset * 32 + col_offset) * sizeof(bf16); const int swizzled_lane_byte_offset = lane_byte_offset ^ ((lane_byte_offset >> 9) << 5); const int addr = src_ptr + swizzled_lane_byte_offset; return addr; }(); load<0, 0>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 0, 3>(dP_ij, dO_i, V_j, dP_ij); exp2<0, 1>(P_ij, P_ij); mma_ABt<0, 1, 0>(dP_ij, dO_i, V_j); load<0, 1>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 1, 1>(dP_ij, dO_i, V_j, dP_ij); mul<0, 3>(P_ij, P_ij, P_SCALE_FACTOR); mma_ABt<0, 1, 2>(dP_ij, dO_i, V_j, dP_ij); sub_row<0, 3, L_i>(P_ij, P_ij); mma_ABt<0, 1, 3>(dP_ij, dO_i, V_j, dP_ij); copy<0, 0>(P_ij_bf16, P_ij); mma_ABt<0, 2, 0>(dP_ij, dO_i, V_j); exp2<0, 2>(P_ij, P_ij); mma_ABt<0, 2, 1>(dP_ij, dO_i, V_j, dP_ij); copy<0, 1>(P_ij_bf16, P_ij); mma_ABt<0, 2, 2>(dP_ij, dO_i, V_j, dP_ij); load<0, 2>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 2, 3>(dP_ij, dO_i, V_j, dP_ij); exp2<0, 3>(P_ij, P_ij); mma_ABt<0, 3, 0>(dP_ij, dO_i, V_j); load<0, 3>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 3, 1>(dP_ij, dO_i, V_j, dP_ij); copy<0, 2>(P_ij_bf16, P_ij); copy<0, 3>(P_ij_bf16, P_ij); mma_ABt<0, 3, 2>(dP_ij, dO_i, V_j, dP_ij); swap_layout_inplace(P_ij_bf16_col, P_ij_bf16); mma_ABt<0, 3, 3>(dP_ij, dO_i, V_j, dP_ij); asm volatile("s_waitcnt lgkmcnt(8)"); // mma_AtB(dV_j_T, dO_i_col, P_ij_bf16_col); mma_AtB<0, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); // Load K_j_col from shared memory to registers // load(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid})); load<0, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<0, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); sub_row<0, 0, delta_i>(dP_ij, dP_ij); sub_row<0, 1, delta_i>(dP_ij, dP_ij); mma_AtB<1, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); load<1, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<1, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<1, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); mul<0, 0>(dP_ij, dP_ij, P_ij); mul<0, 1>(dP_ij, dP_ij, P_ij); copy<0, 0>(dP_ij_bf16, dP_ij); copy<0, 1>(dP_ij_bf16, dP_ij); sub_row<0, 2, delta_i>(dP_ij, dP_ij); mma_AtB<2, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); load<2, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); // 12. dV_j += P_ij^T @ dO_i // 16. dK_j += dS_ij^T @ Q_i (128x64)=(128x16)x(16x64) // Store dP_ij_bf16_accum_row to shared memory // store(attn_i_smem_subtile, dP_ij_bf16_accum_row); store<0, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); store<1, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); mma_AtB<2, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); sub_row<0, 3, delta_i>(dP_ij, dP_ij); mul<0, 2>(dP_ij, dP_ij, P_ij); mul<0, 3>(dP_ij, dP_ij, P_ij); copy<0, 2>(dP_ij_bf16, dP_ij); copy<0, 3>(dP_ij_bf16, dP_ij); mma_AtB<3, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); // dot slice 2 load(subvec_inplace(L_smem[tic], 2)); load(subvec_inplace(delta_smem[tic], 2)); store<2, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); store<3, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); mma_AtB<3, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); swap_layout_inplace(dP_ij_bf16_col, dP_ij_bf16); asm volatile("s_waitcnt lgkmcnt(12)"); // mma_AtB(dK_j_T, Q_i_col, dP_ij_bf16_col); mma_AtB<0, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); load<2, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<3, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<3, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<4, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); asm volatile("s_waitcnt lgkmcnt(8)"); __builtin_amdgcn_s_barrier(); mma_AtB<1, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); // Load dP_ij_bf16_col_T from shared memory to registers // load(dP_ij_bf16_col_T, attn_i_smem); load<0, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<1, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<2, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<3, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); mma_AtB<1, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); mul(L_SCALE_FACTOR); atomic_pk_add_bf16_with_warpid<2, 0, 0>(g.dQg, dQ_i, {batch_idx, q_head_idx, q_seq_idx * 4, 0}, warpid); mma_AtB<2, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); load<4, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<5, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<4, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<5, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<2, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); atomic_pk_add_bf16_with_warpid<2, 0, 1>(g.dQg, dQ_i, {batch_idx, q_head_idx, q_seq_idx * 4, 0}, warpid); mma_AtB<3, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); load<6, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<7, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<5, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<3, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); asm volatile("s_waitcnt lgkmcnt(6)"); __builtin_amdgcn_s_barrier(); // 15. dQ_i += dS_ij @ K_j (32x16)=(32x256)x(256x16) // mma_AtB(dQ_i_T, K_j_col, dP_ij_bf16_col_T); mma_AtB<0, 0, 0>(dQ_i_T, K_j_col, dP_ij_bf16_col_T); load<6, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<6, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 0, 1>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mma_AtB<0, 0, 2>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<7, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<7, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 0, 3>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mma_AtB<0, 0, 4>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); // Load Q_i from shared memory to registers // load(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0})); Q_i_addr = get_address(Q_i, subtile_inplace(Q_i_smem[tic][1], {0, 0})); load<0, 0>(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_addr); load<0, 1>(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_addr); mma_AtB<0, 0, 5>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); asm volatile("s_waitcnt lgkmcnt(4)"); __builtin_amdgcn_s_barrier(); mma_AtB<0, 0, 6>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<0, 2>(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_addr); load<0, 3>(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_addr); mma_AtB<0, 0, 7>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mma_AtB<1, 0, 0>(dQ_i_T, K_j_col, dP_ij_bf16_col_T); // Load K_j from shared memory to registers // load(K_j, subtile_inplace(K_j_smem, {warpid, 0})); load<0, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<0, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 1>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mma_AtB<1, 0, 2>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<0, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<0, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 3>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mul<0, 0, 0>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<0, 0, 1>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mma_AtB<1, 0, 4>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<1, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<1, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 5>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mul<0, 0, 2>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<0, 0, 3>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(10)"); mma_AtB<1, 0, 6>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<1, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<1, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 7>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); asm volatile("s_waitcnt lgkmcnt(2)"); } // dot slice 2 { // 10. S_ij = Q_i K_j^T * scale // 11. P_ij = exp2(S_ij - L_i) // 13. dP_ij = dO_i @ V_j^T // 14. dS_ij = P_ij o (dP_ij - delta_i) // mma_ABt(P_ij, Q_i, K_j); mma_ABt<0, 0, 0>(P_ij, Q_i, K_j); load<2, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<2, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 0, 1>(P_ij, Q_i, K_j, P_ij); mma_ABt<0, 0, 2>(P_ij, Q_i, K_j, P_ij); load<2, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<2, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 0, 3>(P_ij, Q_i, K_j, P_ij); mul<1, 0, 0>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<1, 0, 1>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mma_ABt<0, 1, 0>(P_ij, Q_i, K_j); load<3, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<3, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 1, 1>(P_ij, Q_i, K_j, P_ij); mul<1, 0, 2>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<1, 0, 3>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mma_ABt<0, 1, 2>(P_ij, Q_i, K_j, P_ij); load<3, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<3, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 1, 3>(P_ij, Q_i, K_j, P_ij); mul<0, 0>(P_ij, P_ij, P_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(6)"); mma_ABt<0, 2, 0>(P_ij, Q_i, K_j); // Load dO_i from shared memory to registers // load(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0})); dO_i_addr = get_address(dO_i, subtile_inplace(dO_i_smem[tic][1], {0, 0})); load<0, 0>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); load<0, 1>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); mma_ABt<0, 2, 1>(P_ij, Q_i, K_j, P_ij); sub_row<0, 0, L_i>(P_ij, P_ij); asm volatile("s_waitcnt lgkmcnt(6)"); mma_ABt<0, 2, 2>(P_ij, Q_i, K_j, P_ij); load<0, 2>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); load<0, 3>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); mma_ABt<0, 2, 3>(P_ij, Q_i, K_j, P_ij); mul<0, 1>(P_ij, P_ij, P_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(6)"); mma_ABt<0, 3, 0>(P_ij, Q_i, K_j); // Load dO_i_col from shared memory to registers // load(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0})); // Compute dO_i_col_addr // uint32_t dO_i_col_addr = get_address(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0})); dO_i_col_addr = [&] { const int laneid = kittens::laneid(); const uint32_t src_ptr = reinterpret_cast(&subtile_inplace(dO_i_smem[tic][1], {0, 0}).data[0]); const int row_offset = (laneid % 16) / 4 + (laneid / 32) * 8; const int col_offset = ((laneid % 4) * 4) + 16*((laneid % 32)/16); const int lane_byte_offset = (row_offset * 32 + col_offset) * sizeof(bf16); const int swizzled_lane_byte_offset = lane_byte_offset ^ ((lane_byte_offset >> 9) << 5); const uint32_t addr = src_ptr + swizzled_lane_byte_offset; return addr; }(); load<0, 0>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); load<0, 1>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); mma_ABt<0, 3, 1>(P_ij, Q_i, K_j, P_ij); sub_row<0, 1, L_i>(P_ij, P_ij); asm volatile("s_waitcnt lgkmcnt(8)"); mma_ABt<0, 3, 2>(P_ij, Q_i, K_j, P_ij); load<0, 2>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); load<0, 3>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); mma_ABt<0, 3, 3>(P_ij, Q_i, K_j, P_ij); // Dot slice 2 kittens::macros::v_mov_b32(0xff800000); if constexpr (causal) { // If the query position is less than the key position, set P_ij to -inf if (q_pos < k_pos) { mov(P_ij); // If the query position is equal to the key position, we need to apply a causal mask } else if (q_pos == k_pos) { // Apply the causal mask to [0, 2] and set [0, 3:4] to -inf make_causal<0, 2, neg_inf_v>(P_ij, P_ij); mov<0, 3, neg_inf_v>(P_ij); } } mul<0, 2>(P_ij, P_ij, P_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(8)"); // mma_ABt(dP_ij, dO_i, V_j); mma_ABt<0, 0, 0>(dP_ij, dO_i, V_j); sub_row<0, 2, L_i>(P_ij, P_ij); mma_ABt<0, 0, 1>(dP_ij, dO_i, V_j, dP_ij); exp2<0, 0>(P_ij, P_ij); mma_ABt<0, 0, 2>(dP_ij, dO_i, V_j, dP_ij); // Load Q_i_col from shared memory to registers // load(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0})); // Compute Q_i_col_addr // uint32_t Q_i_col_addr = get_address(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0})); Q_i_col_addr = [&] { const int laneid = kittens::laneid(); const uint32_t src_ptr = reinterpret_cast(&subtile_inplace(Q_i_smem[tic][1], {0, 0}).data[0]); const int row_offset = (laneid % 16) / 4 + (laneid / 32) * 8; const int col_offset = ((laneid % 4) * 4) + 16*((laneid % 32)/16); const int lane_byte_offset = (row_offset * 32 + col_offset) * sizeof(bf16); const int swizzled_lane_byte_offset = lane_byte_offset ^ ((lane_byte_offset >> 9) << 5); const int addr = src_ptr + swizzled_lane_byte_offset; return addr; }(); load<0, 0>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 0, 3>(dP_ij, dO_i, V_j, dP_ij); exp2<0, 1>(P_ij, P_ij); mma_ABt<0, 1, 0>(dP_ij, dO_i, V_j); load<0, 1>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 1, 1>(dP_ij, dO_i, V_j, dP_ij); mul<0, 3>(P_ij, P_ij, P_SCALE_FACTOR); mma_ABt<0, 1, 2>(dP_ij, dO_i, V_j, dP_ij); sub_row<0, 3, L_i>(P_ij, P_ij); mma_ABt<0, 1, 3>(dP_ij, dO_i, V_j, dP_ij); copy<0, 0>(P_ij_bf16, P_ij); mma_ABt<0, 2, 0>(dP_ij, dO_i, V_j); exp2<0, 2>(P_ij, P_ij); mma_ABt<0, 2, 1>(dP_ij, dO_i, V_j, dP_ij); copy<0, 1>(P_ij_bf16, P_ij); mma_ABt<0, 2, 2>(dP_ij, dO_i, V_j, dP_ij); load<0, 2>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 2, 3>(dP_ij, dO_i, V_j, dP_ij); exp2<0, 3>(P_ij, P_ij); mma_ABt<0, 3, 0>(dP_ij, dO_i, V_j); load<0, 3>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 3, 1>(dP_ij, dO_i, V_j, dP_ij); copy<0, 2>(P_ij_bf16, P_ij); copy<0, 3>(P_ij_bf16, P_ij); mma_ABt<0, 3, 2>(dP_ij, dO_i, V_j, dP_ij); swap_layout_inplace(P_ij_bf16_col, P_ij_bf16); mma_ABt<0, 3, 3>(dP_ij, dO_i, V_j, dP_ij); asm volatile("s_waitcnt lgkmcnt(8)"); // mma_AtB(dV_j_T, dO_i_col, P_ij_bf16_col); mma_AtB<0, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); // Load K_j_col from shared memory to registers // load(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid})); load<0, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<0, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); sub_row<0, 0, delta_i>(dP_ij, dP_ij); sub_row<0, 1, delta_i>(dP_ij, dP_ij); mma_AtB<1, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); load<1, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<1, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<1, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); mul<0, 0>(dP_ij, dP_ij, P_ij); mul<0, 1>(dP_ij, dP_ij, P_ij); copy<0, 0>(dP_ij_bf16, dP_ij); copy<0, 1>(dP_ij_bf16, dP_ij); sub_row<0, 2, delta_i>(dP_ij, dP_ij); mma_AtB<2, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); load<2, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); // 12. dV_j += P_ij^T @ dO_i // 16. dK_j += dS_ij^T @ Q_i (128x64)=(128x16)x(16x64) // Store dP_ij_bf16_accum_row to shared memory // store(attn_i_smem_subtile, dP_ij_bf16_accum_row); store<0, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); store<1, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); mma_AtB<2, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); sub_row<0, 3, delta_i>(dP_ij, dP_ij); mul<0, 2>(dP_ij, dP_ij, P_ij); mul<0, 3>(dP_ij, dP_ij, P_ij); copy<0, 2>(dP_ij_bf16, dP_ij); copy<0, 3>(dP_ij_bf16, dP_ij); mma_AtB<3, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); // dot slice 3 load(subvec_inplace(L_smem[tic], 3)); load(subvec_inplace(delta_smem[tic], 3)); store<2, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); store<3, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); mma_AtB<3, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); swap_layout_inplace(dP_ij_bf16_col, dP_ij_bf16); asm volatile("s_waitcnt lgkmcnt(12)"); // mma_AtB(dK_j_T, Q_i_col, dP_ij_bf16_col); mma_AtB<0, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); load<2, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<3, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<3, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<4, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); asm volatile("s_waitcnt lgkmcnt(8)"); __builtin_amdgcn_s_barrier(); mma_AtB<1, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); // Load dP_ij_bf16_col_T from shared memory to registers // load(dP_ij_bf16_col_T, attn_i_smem); load<0, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<1, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<2, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<3, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); mma_AtB<1, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); mul(L_SCALE_FACTOR); atomic_pk_add_bf16_with_warpid<2, 0, 0>(g.dQg, dQ_i, {batch_idx, q_head_idx, q_seq_idx * 4 + 1, 0}, warpid); mma_AtB<2, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); load<4, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<5, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<4, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<5, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<2, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); atomic_pk_add_bf16_with_warpid<2, 0, 1>(g.dQg, dQ_i, {batch_idx, q_head_idx, q_seq_idx * 4 + 1, 0}, warpid); mma_AtB<3, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); load<6, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<7, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<5, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<3, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); asm volatile("s_waitcnt lgkmcnt(6)"); __builtin_amdgcn_s_barrier(); // 15. dQ_i += dS_ij @ K_j (32x16)=(32x256)x(256x16) // mma_AtB(dQ_i_T, K_j_col, dP_ij_bf16_col_T); mma_AtB<0, 0, 0>(dQ_i_T, K_j_col, dP_ij_bf16_col_T); load<6, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<6, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 0, 1>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mma_AtB<0, 0, 2>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<7, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<7, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 0, 3>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mma_AtB<0, 0, 4>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); // Load Q_i from shared memory to registers // load(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0})); Q_i_addr = get_address(Q_i, subtile_inplace(Q_i_smem[tic][1], {1, 0})); load<0, 0>(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_addr); load<0, 1>(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_addr); mma_AtB<0, 0, 5>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); asm volatile("s_waitcnt lgkmcnt(4)"); __builtin_amdgcn_s_barrier(); mma_AtB<0, 0, 6>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<0, 2>(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_addr); load<0, 3>(Q_i, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_addr); mma_AtB<0, 0, 7>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mma_AtB<1, 0, 0>(dQ_i_T, K_j_col, dP_ij_bf16_col_T); // Load K_j from shared memory to registers // load(K_j, subtile_inplace(K_j_smem, {warpid, 0})); load<0, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<0, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 1>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mma_AtB<1, 0, 2>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<0, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<0, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 3>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mul<0, 0, 0>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<0, 0, 1>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mma_AtB<1, 0, 4>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<1, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<1, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 5>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mul<0, 0, 2>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<0, 0, 3>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(10)"); mma_AtB<1, 0, 6>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<1, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<1, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_AtB<1, 0, 7>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); asm volatile("s_waitcnt lgkmcnt(2)"); } // dot slice 3 { // 10. S_ij = Q_i K_j^T * scale // 11. P_ij = exp2(S_ij - L_i) // 13. dP_ij = dO_i @ V_j^T // 14. dS_ij = P_ij o (dP_ij - delta_i) // mma_ABt(P_ij, Q_i, K_j); mma_ABt<0, 0, 0>(P_ij, Q_i, K_j); load<2, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<2, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 0, 1>(P_ij, Q_i, K_j, P_ij); mma_ABt<0, 0, 2>(P_ij, Q_i, K_j, P_ij); load<2, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<2, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 0, 3>(P_ij, Q_i, K_j, P_ij); mul<1, 0, 0>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<1, 0, 1>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mma_ABt<0, 1, 0>(P_ij, Q_i, K_j); load<3, 0>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<3, 1>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 1, 1>(P_ij, Q_i, K_j, P_ij); mul<1, 0, 2>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mul<1, 0, 3>(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); mma_ABt<0, 1, 2>(P_ij, Q_i, K_j, P_ij); load<3, 2>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); load<3, 3>(K_j, subtile_inplace(K_j_smem, {warpid, 0}), K_j_addr); mma_ABt<0, 1, 3>(P_ij, Q_i, K_j, P_ij); mul<0, 0>(P_ij, P_ij, P_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(6)"); mma_ABt<0, 2, 0>(P_ij, Q_i, K_j); // Load dO_i from shared memory to registers // load(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0})); dO_i_addr = get_address(dO_i, subtile_inplace(dO_i_smem[tic][1], {1, 0})); load<0, 0>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); load<0, 1>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); mma_ABt<0, 2, 1>(P_ij, Q_i, K_j, P_ij); sub_row<0, 0, L_i>(P_ij, P_ij); asm volatile("s_waitcnt lgkmcnt(6)"); mma_ABt<0, 2, 2>(P_ij, Q_i, K_j, P_ij); load<0, 2>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); load<0, 3>(dO_i, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_addr); mma_ABt<0, 2, 3>(P_ij, Q_i, K_j, P_ij); mul<0, 1>(P_ij, P_ij, P_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(6)"); mma_ABt<0, 3, 0>(P_ij, Q_i, K_j); // Load dO_i_col from shared memory to registers // load(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0})); // Compute dO_i_col_addr // uint32_t dO_i_col_addr = get_address(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0})); dO_i_col_addr = [&] { const int laneid = kittens::laneid(); const uint32_t src_ptr = reinterpret_cast(&subtile_inplace(dO_i_smem[tic][1], {1, 0}).data[0]); const int row_offset = (laneid % 16) / 4 + (laneid / 32) * 8; const int col_offset = ((laneid % 4) * 4) + 16*((laneid % 32)/16); const int lane_byte_offset = (row_offset * 32 + col_offset) * sizeof(bf16); const int swizzled_lane_byte_offset = lane_byte_offset ^ ((lane_byte_offset >> 9) << 5); const uint32_t addr = src_ptr + swizzled_lane_byte_offset; return addr; }(); load<0, 0>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); load<0, 1>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); mma_ABt<0, 3, 1>(P_ij, Q_i, K_j, P_ij); sub_row<0, 1, L_i>(P_ij, P_ij); asm volatile("s_waitcnt lgkmcnt(8)"); mma_ABt<0, 3, 2>(P_ij, Q_i, K_j, P_ij); load<0, 2>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); load<0, 3>(dO_i_col, subtile_inplace(dO_i_smem[tic][0], {0, 0}), dO_i_col_addr); mma_ABt<0, 3, 3>(P_ij, Q_i, K_j, P_ij); // Dot slice 3 kittens::macros::v_mov_b32(0xff800000); if constexpr (causal) { // If the query position is less than the key position, set P_ij to -inf if (q_pos < k_pos) { mov(P_ij); // If the query position is equal to the key position, we need to apply a causal mask } else if (q_pos == k_pos) { // Apply the causal mask to [0, 3] make_causal<0, 3, neg_inf_v>(P_ij, P_ij); } } mul<0, 2>(P_ij, P_ij, P_SCALE_FACTOR); asm volatile("s_waitcnt lgkmcnt(8)"); // mma_ABt(dP_ij, dO_i, V_j); mma_ABt<0, 0, 0>(dP_ij, dO_i, V_j); sub_row<0, 2, L_i>(P_ij, P_ij); mma_ABt<0, 0, 1>(dP_ij, dO_i, V_j, dP_ij); exp2<0, 0>(P_ij, P_ij); mma_ABt<0, 0, 2>(dP_ij, dO_i, V_j, dP_ij); // Load Q_i_col from shared memory to registers // load(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0})); // Compute Q_i_col_addr // uint32_t Q_i_col_addr = get_address(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0})); Q_i_col_addr = [&] { const int laneid = kittens::laneid(); const uint32_t src_ptr = reinterpret_cast(&subtile_inplace(Q_i_smem[tic][1], {1, 0}).data[0]); const int row_offset = (laneid % 16) / 4 + (laneid / 32) * 8; const int col_offset = ((laneid % 4) * 4) + 16*((laneid % 32)/16); const int lane_byte_offset = (row_offset * 32 + col_offset) * sizeof(bf16); const int swizzled_lane_byte_offset = lane_byte_offset ^ ((lane_byte_offset >> 9) << 5); const int addr = src_ptr + swizzled_lane_byte_offset; return addr; }(); load<0, 0>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 0, 3>(dP_ij, dO_i, V_j, dP_ij); exp2<0, 1>(P_ij, P_ij); mma_ABt<0, 1, 0>(dP_ij, dO_i, V_j); load<0, 1>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 1, 1>(dP_ij, dO_i, V_j, dP_ij); mul<0, 3>(P_ij, P_ij, P_SCALE_FACTOR); mma_ABt<0, 1, 2>(dP_ij, dO_i, V_j, dP_ij); sub_row<0, 3, L_i>(P_ij, P_ij); mma_ABt<0, 1, 3>(dP_ij, dO_i, V_j, dP_ij); copy<0, 0>(P_ij_bf16, P_ij); mma_ABt<0, 2, 0>(dP_ij, dO_i, V_j); exp2<0, 2>(P_ij, P_ij); mma_ABt<0, 2, 1>(dP_ij, dO_i, V_j, dP_ij); copy<0, 1>(P_ij_bf16, P_ij); mma_ABt<0, 2, 2>(dP_ij, dO_i, V_j, dP_ij); load<0, 2>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 2, 3>(dP_ij, dO_i, V_j, dP_ij); exp2<0, 3>(P_ij, P_ij); mma_ABt<0, 3, 0>(dP_ij, dO_i, V_j); load<0, 3>(Q_i_col, subtile_inplace(Q_i_smem[tic][0], {0, 0}), Q_i_col_addr); mma_ABt<0, 3, 1>(dP_ij, dO_i, V_j, dP_ij); copy<0, 2>(P_ij_bf16, P_ij); copy<0, 3>(P_ij_bf16, P_ij); mma_ABt<0, 3, 2>(dP_ij, dO_i, V_j, dP_ij); swap_layout_inplace(P_ij_bf16_col, P_ij_bf16); mma_ABt<0, 3, 3>(dP_ij, dO_i, V_j, dP_ij); asm volatile("s_waitcnt lgkmcnt(8)"); // mma_AtB(dV_j_T, dO_i_col, P_ij_bf16_col); mma_AtB<0, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); // Load K_j_col from shared memory to registers // load(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid})); load<0, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<0, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); sub_row<0, 0, delta_i>(dP_ij, dP_ij); sub_row<0, 1, delta_i>(dP_ij, dP_ij); mma_AtB<1, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); load<1, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<1, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<1, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); mul<0, 0>(dP_ij, dP_ij, P_ij); mul<0, 1>(dP_ij, dP_ij, P_ij); copy<0, 0>(dP_ij_bf16, dP_ij); copy<0, 1>(dP_ij_bf16, dP_ij); sub_row<0, 2, delta_i>(dP_ij, dP_ij); mma_AtB<2, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); load<2, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); // 12. dV_j += P_ij^T @ dO_i // 16. dK_j += dS_ij^T @ Q_i (128x64)=(128x16)x(16x64) // Store dP_ij_bf16_accum_row to shared memory // store(attn_i_smem_subtile, dP_ij_bf16_accum_row); store<0, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); store<1, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); mma_AtB<2, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); sub_row<0, 3, delta_i>(dP_ij, dP_ij); mul<0, 2>(dP_ij, dP_ij, P_ij); mul<0, 3>(dP_ij, dP_ij, P_ij); copy<0, 2>(dP_ij_bf16, dP_ij); copy<0, 3>(dP_ij_bf16, dP_ij); mma_AtB<3, 0, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); // dot slice 0 - next iteration store<2, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); store<3, 0>(attn_i_smem_subtile, dP_ij_bf16_accum_row, dP_ij_bf16_accum_row_addr); mma_AtB<3, 1, 0>(dV_j_T, dO_i_col, P_ij_bf16_col, dV_j_T); swap_layout_inplace(dP_ij_bf16_col, dP_ij_bf16); asm volatile("s_waitcnt lgkmcnt(12)"); // mma_AtB(dK_j_T, Q_i_col, dP_ij_bf16_col); mma_AtB<0, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); load<2, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<3, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<3, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<4, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); asm volatile("s_waitcnt lgkmcnt(8)"); __builtin_amdgcn_s_barrier(); mma_AtB<1, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); // Load dP_ij_bf16_col_T from shared memory to registers // load(dP_ij_bf16_col_T, attn_i_smem); load<0, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<1, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<2, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<3, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); mma_AtB<1, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); atomic_pk_add_bf16_with_warpid<2, 0, 0>(g.dQg, dQ_i, {batch_idx, q_head_idx, q_seq_idx * 4 + 2, 0}, warpid); mma_AtB<2, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); load<4, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<5, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<4, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<5, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<2, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); atomic_pk_add_bf16_with_warpid<2, 0, 1>(g.dQg, dQ_i, {batch_idx, q_head_idx, q_seq_idx * 4 + 2, 0}, warpid); mma_AtB<3, 0, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); load<6, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<7, 0>(dP_ij_bf16_col_T, attn_i_smem, dP_ij_bf16_col_T_addr); load<5, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<3, 1, 0>(dK_j_T, Q_i_col, dP_ij_bf16_col, dK_j_T); asm volatile("s_waitcnt lgkmcnt(6)"); __builtin_amdgcn_s_barrier(); // 15. dQ_i += dS_ij @ K_j (32x16)=(32x256)x(256x16) // mma_AtB(dQ_i_T, K_j_col, dP_ij_bf16_col_T); mma_AtB<0, 0, 0>(dQ_i_T, K_j_col, dP_ij_bf16_col_T); load<6, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<6, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 0, 1>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mma_AtB<0, 0, 2>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); load<7, 0>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); load<7, 1>(K_j_col, subtile_inplace<256, 32>(K_j_smem, {0, warpid}), K_j_col_addr); mma_AtB<0, 0, 3>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mma_AtB<0, 0, 4>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); // ds_read_b128 a[112:115] // ds_read_b128 a[116:119] mma_AtB<0, 0, 5>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); asm volatile("s_waitcnt lgkmcnt(4)"); __builtin_amdgcn_s_barrier(); mma_AtB<0, 0, 6>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); // ds_read_b128 a[120:123] // ds_read_b128 a[124:127] mma_AtB<0, 0, 7>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mma_AtB<1, 0, 0>(dQ_i_T, K_j_col, dP_ij_bf16_col_T); // ds_read_b128 a[0:3] // ds_read_b128 a[4:7] mma_AtB<1, 0, 1>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mma_AtB<1, 0, 2>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); // ds_read_b128 a[8:11] // ds_read_b128 a[12:15] mma_AtB<1, 0, 3>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); mma_AtB<1, 0, 4>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); // ds_read_b128 a[16:19] // ds_read_b128 a[20:23] mma_AtB<1, 0, 5>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); asm volatile("s_waitcnt lgkmcnt(10)"); mma_AtB<1, 0, 6>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); // ds_read_b128 a[24:27] // ds_read_b128 a[28:31] mma_AtB<1, 0, 7>(dQ_i_T, K_j_col, dP_ij_bf16_col_T, dQ_i_T); asm volatile("s_waitcnt lgkmcnt(2)"); } } store<1>(g.dVg, dV_j, {batch_idx * GROUP_SIZE + q_head_in_group, 0, kv_head_idx, 0}, {0, j, 0, 0}); __builtin_amdgcn_s_waitcnt(0); __builtin_amdgcn_s_barrier(); // We first copy dV_j_T from accumulator GPRs to vector GPRs and then perform the store accvgpr_read(dV_j_T, dK_j_T); mul(dV_j_T, dV_j_T, dP_SCALE_FACTOR); store<1>(g.dKg, dV_j, {batch_idx * GROUP_SIZE + q_head_in_group, 0, kv_head_idx, 0}, {0, j, 0, 0}); // Write out final dQ_i slice mul(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR); atomic_pk_add_bf16_with_warpid<2>(g.dQg, dQ_i, {batch_idx, q_head_idx, q_seq_idx * 4 + 3, 0}, warpid); } template __global__ void attend_bwd_combined_ker(bf16*, bf16*, bf16*, bf16*, bf16*, bf16*, bf16*, float*, float*);