Files
tinygrad/extra/gemm/cdna_asm_gemm.py
2026-06-05 15:39:41 -07:00

2890 lines
154 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import atexit, functools, pathlib
from tinygrad import Tensor, Device, dtypes
from tinygrad.dtype import AddrSpace
from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType
from tinygrad.renderer import Estimates
from tinygrad.helpers import getenv, all_same, DEBUG
from tinygrad.runtime.support.compiler_amd import HIPCCCompiler
from tinygrad.runtime.autogen.amd.cdna.ins import *
from examples.mlperf.models.flat_llama import FP8_DTYPE, FP8_GRAD_DTYPE, quantize_fp8
# ** CDNA4 assembly gemm
WORKGROUP_SIZE = 256
# M0 is encoded with 124 (NULL in RDNA) in CDNA
M0 = NULL
TILE_M, TILE_N, TILE_K, NUM_WG = 256, 256, 64, 256
def _magicgu_mulhi(d:int, vmax:int) -> tuple[int,int]:
"""Compute magic number and shift for mul_hi-based unsigned division by d, valid for all 32-bit n.
Adapted from magicgu in tinygrad.uop.decompositions (Hacker's Delight, Chapter 10) but targeting the mul_hi encoding:
- If shift bit 31 is clear: result = mul_hi(n, magic) >> shift
- If shift bit 31 is set: result = (mul_hi(n, magic) + n) >> (shift & 0x7FFFFFFF) (wrapping 32-bit add)
"""
if d == 1: return 0, (1 << 31) # (mul_hi(n, 0) + n) >> 0 = n
nc = (1 << 32) // d * d - 1
for s in range(32, 65):
if 2**s > nc * (d - 1 - (2**s - 1) % d):
m = (2**s + d - 1 - (2**s - 1) % d) // d
shift = s - 32
if m < (1 << 32): return m, shift
if m < (1 << 33):
m_enc = m - (1 << 32)
if ((((vmax * m_enc) >> 32) + vmax) & 0xFFFFFFFF) >> shift == vmax // d: return m_enc, shift | (1 << 31)
raise AssertionError(f"cannot compute magic for d={d}, vmax={vmax}")
def compute_gemm_args(M:int, N:int, K:int, batch:int) -> tuple[int, int, int, int, int]:
assert M % TILE_M == 0 and N % TILE_N == 0 and K % TILE_K == 0, f"shape ({M},{N},{K}) not a multiple of ({TILE_M},{TILE_N},{TILE_K})"
iters = K // TILE_K
total = (M // TILE_M) * (N // TILE_N) * iters
magic, shift = _magicgu_mulhi(iters, total * batch)
return NUM_WG, iters, total, magic, shift
class Kernel:
def __init__(self): self.instructions, self.labels, self.label_at_pos, self.pos = [], {}, {}, 0
def label(self, name):
self.labels[name] = self.pos
self.label_at_pos[self.pos] = name
def emit(self, inst, target=None):
self.instructions.append(inst)
inst._target, inst._pos = target, self.pos
self.pos += inst.size()
return inst
def waitcnt(self, lgkm=None, vm=None):
vmcnt, lgkmcnt, expcnt = vm if vm is not None else 63, lgkm if lgkm is not None else 15, 7
waitcnt = (vmcnt & 0xF) | ((expcnt & 0x7) << 4) | ((lgkmcnt & 0xF) << 8) | (((vmcnt >> 4) & 0x3) << 14)
self.emit(s_waitcnt(waitcnt))
def finalize(self):
"""Patch branch offsets and return the finalized instruction list."""
for inst in self.instructions:
if inst._target is None: continue
inst.simm16 = (self.labels[inst._target] - inst._pos - inst.size()) // 4
return self.instructions
def buffer_load_x4_m0_stride(k, vaddr_start, srd, stride, count, *args):
"""Emit count buffer_load_dwordx4 with M0 incrementing by stride between each."""
for i in range(count):
k.emit(buffer_load_dwordx4(v[0:3], v[vaddr_start + i], s[srd:srd+3], *args))
if i < count - 1: k.emit(s_add_u32(M0, M0, stride))
def mfma_64(k, mfma, a_base, b_base):
"""Emit 64 MFMA instructions in an 8x8 grid: 8 A-tile rows x 8 B-tile columns."""
for a in range(8):
for b in range(8):
acc = (a*8 + b) * 4
k.emit(mfma(v[acc:acc+3], v[a_base+a*4:a_base+a*4+3], v[b_base+b*4:b_base+b*4+3], v[acc:acc+3], 0, 0, 1, 0, 0, 0, 0, 0, 1))
def ds_read_ab_16(k, b_base=114):
"""Read 8 A-tile and 8 B-tile b128 blocks from LDS."""
for i in range(8):
args = (v[82+i*4:85+i*4], v[16]) if i == 0 else (v[82+i*4:85+i*4], v[16], v[0], v[0], 0, 0, i*2)
k.emit(ds_read_b128(*args))
for i in range(8):
base = b_base + i*4
off1, off2 = 128 * (i % 2), i // 2
if i == 0: k.emit(ds_read_b128(v[base:base+3], v[17]))
elif off1 and off2: k.emit(ds_read_b128(v[base:base+3], v[17], v[0], v[0], 0, off1, off2))
elif off1: k.emit(ds_read_b128(v[base:base+3], v[17], v[0], v[0], 0, 128))
else: k.emit(ds_read_b128(v[base:base+3], v[17], v[0], v[0], 0, 0, off2))
def zero_out_mask_32(k, base):
"""Conditionally zero 32 registers based on column bounds check."""
k.emit(v_and_b32_e32(v[181], 63, v[180]))
k.emit(v_lshrrev_b32_e32(v[181], 4, v[181]))
k.emit(v_lshlrev_b32_e32(v[181], 3, v[181]))
k.emit(v_add_u32_e64(v[182], v[181], 0))
for phase in range(2):
k.emit(v_cmp_ge_i32_e64(s[88:89], v[182], s[8]))
for sub in range(2):
for i in range(8):
r = base + phase*2 + sub + i*4
k.emit(v_cndmask_b32_e64(v[r], v[r], 0, s[88:89]))
if phase == 0:
k.emit(v_add_u32_e64(v[182], v[182], 4))
def buffer_load_d16_32(k, dst_lo, dst_hi, idx_start, srd_start, *tail):
"""Load 32 half-precision values via buffer_load_short_d16 pairs and OR combine."""
for row in range(8):
for j in range(4):
k.emit(buffer_load_short_d16(v[dst_lo + row*4 + j], v[idx_start + row], s[srd_start:srd_start+3], 0, j*4, 1, 0, 0, *tail))
k.emit(buffer_load_short_d16_hi(v[dst_hi + row*4 + j], v[idx_start + row], s[srd_start:srd_start+3], 0, j*4+2, 1, 0, 0, *tail))
k.waitcnt(vm=0)
for i in range(32):
k.emit(v_or_b32_e32(v[dst_lo + i], v[dst_lo + i], v[dst_hi + i]))
def shift_mask(k, regs):
"""Shift-and-mask 4-register groups for TailLoop column bounds."""
for r in regs:
k.emit(v_lshlrev_b64(v[184:185], s[87], v[r:r+1]))
k.emit(v_lshlrev_b64(v[186:187], s[87], v[r+2:r+3]))
k.emit(v_add_u32_e64(v[182], v[181], 4))
k.emit(v_cmp_ge_i32_e64(s[88:89], v[182], s[8]))
k.emit(v_cndmask_b32_e64(v[r], v[r], v[184], s[88:89]))
k.emit(v_cmp_ge_i32_e64(s[88:89], v[182], s[8]))
k.emit(v_cndmask_b32_e64(v[r+1], v[r+1], v[185], s[88:89]))
k.emit(v_add_u32_e64(v[182], v[182], 4))
k.emit(v_cmp_ge_i32_e64(s[88:89], v[182], s[8]))
k.emit(v_cndmask_b32_e64(v[r+2], v[r+2], v[186], s[88:89]))
k.emit(v_cmp_ge_i32_e64(s[88:89], v[182], s[8]))
k.emit(v_cndmask_b32_e64(v[r+3], v[r+3], v[187], s[88:89]))
def perm_b32_32(k, dst_base, src_base=82):
"""Emit 32 v_perm_b32_e64 to pack LDS b128 reads into B-tile operands."""
for word_off in range(4):
for perm_idx, perm in enumerate([s[85], s[86]]):
for grp in range(4):
dst = dst_base + word_off*8 + perm_idx*4 + grp
k.emit(v_perm_b32_e64(v[dst], v[src_base + grp*8 + 4 + word_off], v[src_base + grp*8 + word_off], perm))
def v_divmod(k, divisor, dividend):
"""Integer divmod via RCP with EXEC-mask correction. Quotient in v[18], remainder in v[19]."""
k.emit(v_cvt_f32_u32_e32(v[18], divisor))
k.emit(v_rcp_iflag_f32_e32(v[18], v[18]))
k.emit(v_cvt_f32_u32_e32(v[19], dividend))
k.emit(v_mul_f32_e32(v[18], v[18], v[19]))
k.emit(v_cvt_u32_f32_e32(v[18], v[18]))
k.emit(v_mul_u32_u24_e64(v[19], v[18], divisor))
k.emit(v_sub_u32_e32(v[19], dividend, v[19]))
k.emit(v_cmpx_eq_u32_e64(EXEC, v[19], divisor))
k.emit(v_add_u32_e32(v[18], 1, v[18]))
k.emit(v_mov_b32_e32(v[19], 0))
k.emit(s_mov_b64(EXEC, -1))
k.emit(v_cmpx_gt_u32_e64(EXEC, v[19], divisor))
k.emit(v_sub_u32_e64(v[18], v[18], 1))
k.emit(v_mul_u32_u24_e64(v[19], v[18], divisor))
k.emit(v_sub_u32_e32(v[19], dividend, v[19]))
k.emit(s_mov_b64(EXEC, -1))
def v_div(k, divisor, dividend):
"""Integer division via RCP with EXEC-mask correction (quotient only). Quotient in v[18]."""
k.emit(v_cvt_f32_u32_e32(v[18], divisor))
k.emit(v_rcp_iflag_f32_e32(v[18], v[18]))
k.emit(v_cvt_f32_u32_e32(v[19], dividend))
k.emit(v_mul_f32_e32(v[18], v[18], v[19]))
k.emit(v_cvt_u32_f32_e32(v[18], v[18]))
k.emit(v_mul_u32_u24_e64(v[19], v[18], divisor))
k.emit(v_sub_u32_e32(v[19], dividend, v[19]))
k.emit(v_cmpx_eq_u32_e64(EXEC, v[19], divisor))
k.emit(v_add_u32_e32(v[18], 1, v[18]))
k.emit(s_mov_b64(EXEC, -1))
k.emit(v_cmpx_gt_u32_e64(EXEC, v[19], divisor))
k.emit(v_sub_u32_e64(v[18], v[18], 1))
k.emit(s_mov_b64(EXEC, -1))
def v_ceildiv(k, divisor_v, dividend_v, tmp_v):
"""Ceil integer division via RCP with VCC correction. Quotient in v[18]."""
k.emit(v_cvt_f32_u32_e32(v[18], divisor_v))
k.emit(v_rcp_iflag_f32_e32(v[18], v[18]))
k.emit(v_cvt_f32_u32_e32(tmp_v, dividend_v))
k.emit(v_mul_f32_e32(v[18], v[18], tmp_v))
k.emit(v_cvt_u32_f32_e32(v[18], v[18]))
k.emit(v_mul_u32_u24_e32(tmp_v, v[18], divisor_v))
k.emit(v_sub_u32_e32(tmp_v, dividend_v, tmp_v))
k.emit(v_cmp_ne_u32_e64(VCC, tmp_v, 0))
k.emit(v_addc_co_u32(v[18], VCC, v[18], 0, VCC))
def gw_m_addr_elem(k, col_off, lds_v, addr_v, bias_v=None, scale_v=None, barrier=False):
"""Compute address for one edge-tile element. col_off=0 uses v[18], else v[22]=v[18]+col_off."""
col = v[18] if col_off == 0 else v[22]
if col_off > 0:
k.emit(v_add_co_u32(v[22], VCC, v[18], col_off))
k.emit(v_cmp_lt_u32_e64(s[78:79], col, s[20]))
k.emit(v_cmp_lt_u32_e64(s[82:83], v[19], s[21]))
k.emit(s_and_b64(s[82:83], s[78:79], s[82:83]))
k.emit(s_mul_i32(s[78], 256, s[2]))
k.emit(v_sub_u32_e64(v[lds_v], col, s[78]))
k.emit(v_lshlrev_b32_e32(v[lds_v], 2, v[lds_v]))
if barrier:
k.emit(s_waitcnt(49279))
k.emit(s_barrier())
if bias_v is not None:
k.emit(ds_read_b32(v[bias_v], v[lds_v]))
k.emit(ds_read_b32(v[scale_v], v[lds_v], v[0], v[0], 0, 0, 4))
k.emit(v_add_lshl_u32_e64(v[addr_v], v[21], col, 1))
k.emit(v_cndmask_b32_e64(v[addr_v], v[30], v[addr_v], s[82:83]))
def gw_n_addr_row(k, lds_v, addr_v, ds_base=None, barrier=False):
"""N-edge address computation for one row. Optional ds_reads (4×b128) and barrier."""
k.emit(v_cmp_lt_u32_e64(s[78:79], v[18], s[20]))
k.emit(v_cmp_lt_u32_e64(s[82:83], v[19], s[21]))
k.emit(s_and_b64(s[82:83], s[78:79], s[82:83]))
k.emit(s_mul_i32(s[78], 256, s[2]))
k.emit(v_sub_u32_e64(v[lds_v], v[18], s[78]))
k.emit(v_lshlrev_b32_e32(v[lds_v], 2, v[lds_v]))
if barrier:
k.waitcnt(lgkm=0)
k.emit(s_barrier())
if ds_base is not None:
k.emit(ds_read_b128(v[ds_base:ds_base+3], v[lds_v]))
k.emit(ds_read_b128(v[ds_base+4:ds_base+7], v[lds_v], v[0], v[0], 0, 16))
k.emit(ds_read_b128(v[ds_base+8:ds_base+11], v[lds_v], v[0], v[0], 0, 0, 4))
k.emit(ds_read_b128(v[ds_base+12:ds_base+15], v[lds_v], v[0], v[0], 0, 16, 4))
k.emit(v_add_lshl_u32_e64(v[addr_v], v[21], v[18], 1))
k.emit(v_cndmask_b32_e64(v[addr_v], v[30], v[addr_v], s[82:83]))
def gw_m_row_inc(k):
"""Increment row pointer for edge-tile global write."""
k.emit(v_add_co_u32(v[19], VCC, v[19], 1))
k.emit(v_add_u32_e64(v[20], v[20], s[38]))
k.emit(v_add_u32_e64(v[21], v[21], s[36]))
def gw_m_element(k, v_cvt_pk, data_v, scale_v, bias_v, addr_v):
"""Scale, bias, convert, and store a single f32 element as f16/bf16."""
k.emit(v_mul_f32_e32(v[data_v], v[scale_v], v[data_v]))
k.emit(v_add_f32_e32(v[22], v[bias_v], v[data_v]))
k.emit(v_mov_b32_e32(v[data_v], v[22]))
k.emit(v_cvt_pk(v[data_v], v[data_v], v[data_v]))
k.emit(buffer_store_short(v[data_v], v[addr_v], s[12:15], 0, 0, 1, 0, 0, 0, 0, 0, 1))
def gw_convert_and_store(k, v_cvt_pk, base, addr_v, sb=88, stride=True):
"""Scale, bias, pack, and store 8 f32 accumulator values as 4 packed values. sb=bias/scale base reg."""
for i in range(4):
k.emit(v_pk_mul_f32(v[base+i*2:base+i*2+1], v[sb+8+i*2:sb+8+i*2+1], v[base+i*2:base+i*2+1]))
for i in range(4):
k.emit(v_pk_add_f32(v[22+i*2:22+i*2+1], v[sb+i*2:sb+i*2+1], v[base+i*2:base+i*2+1]))
for i in range(4):
k.emit(v_mov_b64_e32(v[base+i*2:base+i*2+1], v[22+i*2:22+i*2+1]))
for i in range(4):
k.emit(v_cvt_pk(v[base+i], v[base+i*2], v[base+i*2+1]))
if stride:
k.emit(s_lshl_b32(s[68], s[36], 1))
k.emit(s_add_u32(s[12], s[12], s[68]))
k.emit(s_addc_u32(s[13], s[13], 0))
k.emit(buffer_store_dwordx4(v[base:base+3], v[addr_v], s[12:15], 0, 0, 1, 0, 0, 0, 0, 0, 1))
def shift_vector_components(k, glvw):
"""Emit accvgpr shuffle for ShiftVectorComponents with given GLVW width."""
shift = (8 - glvw) * 4
for byte_off in range(4):
for block in range(8):
base_dst = block * 32 + byte_off
base_src = base_dst + shift
for i in range(glvw):
k.emit(v_accvgpr_read(v[25 + i], v[base_src + i * 4]))
k.emit(s_nop(1))
for i in range(glvw):
k.emit(v_accvgpr_write(v[base_dst + i * 4], v[25 + i]))
def build_kernel(batch, M, N, K, dtype):
numWG, iters, total, magic, shift = compute_gemm_args(M, N, K, batch)
total *= batch
v_mfma_16x16x32 = {dtypes.half:v_mfma_f32_16x16x32_f16, dtypes.bfloat16:v_mfma_f32_16x16x32_bf16}[dtype]
v_cvt_pk = {dtypes.half:v_cvt_pk_f16_f32, dtypes.bfloat16:v_cvt_pk_bf16_f32}[dtype]
v_cvt = {dtypes.half:v_cvt_f32_f16_e32, dtypes.bfloat16:v_cvt_f32_bf16_e32}[dtype]
k = Kernel()
# load D, A, B pointers
k.emit(s_load_dwordx2(s[24:25], s[0:1], s[0], 0, 0, 0, 0, 1))
k.emit(s_load_dwordx2(s[30:31], s[0:1], s[0], 8, 0, 0, 0, 1))
k.emit(s_load_dwordx2(s[28:29], s[0:1], s[0], 16, 0, 0, 0, 1))
k.waitcnt(lgkm=0)
# params as constants
k.emit(s_mov_b32(s[69], numWG))
k.emit(s_mov_b32(s[20], N))
k.emit(s_mov_b32(s[21], batch * M))
k.emit(s_mov_b32(s[22], 1))
k.emit(s_mov_b32(s[23], K))
k.emit(s_mov_b32(s[36], N))
k.emit(s_mov_b32(s[37], 0))
k.emit(s_mov_b32(s[40], N))
k.emit(s_mov_b32(s[41], 0))
k.emit(s_mov_b32(s[42], K))
k.emit(s_mov_b32(s[43], 0))
k.emit(s_mov_b32(s[46], iters))
k.emit(s_mov_b32(s[47], magic))
k.emit(s_mov_b32(s[48], shift))
k.emit(s_mov_b32(s[49], total))
k.emit(s_mov_b32(s[62], 0))
k.emit(s_mov_b32(s[68], 0))
# kernel size is 256x256
k.emit(s_mov_b32(s[51], 256)); k.emit(s_mov_b32(s[52], 256))
k.emit(s_mov_b32(s[38], s[36]))
k.emit(s_mov_b32(s[39], s[37]))
k.emit(s_mov_b64(s[26:27], s[24:25]))
k.emit(s_and_b32(s[6], s[68], 4294901760))
k.emit(s_lshr_b32(s[6], s[6], 16))
k.emit(s_mov_b32(s[63], 0))
k.emit(s_setprio(3))
k.emit(s_mov_b32(M0, 133120))
k.emit(v_mov_b32_e32(v[180], v[0]))
# XCCG=256
# labels are named based on function:
# PGR = Prefetch Global Read (the global→LDS pipeline stage)
# SK = Stream-K (work partitioning by K-iterations, not tiles)
# WGM = WorkGroup Mapping (tile assignment scheme for cache locality)
# GLVW = Global Load Vector Width (edge tile width handling)
# BM0 = Block M offset 0 (register block position)
# OrdNLL = Ordered No-Load-Loop (final iteration without prefetch loads)
k.emit(s_mov_b32(s[75], 256))
v_divmod(k, s[75], s[2]) # v[18]=quotient, v[19]=remainder
k.emit(v_readfirstlane_b32_e32(v[71], v[18]))
k.emit(v_readfirstlane_b32_e32(v[72], v[19]))
k.emit(s_mul_i32(s[71], s[71], s[75]))
k.emit(s_lshr_b32(s[72], s[72], 1))
k.emit(s_add_u32(s[71], s[71], s[72]))
v_div(k, s[75], s[69]) # v[18]=quotient
k.emit(v_readfirstlane_b32_e32(v[72], v[18]))
k.emit(s_mul_i32(s[72], s[72], s[75]))
k.emit(s_sub_u32(s[73], s[69], s[72]))
k.emit(s_cmp_gt_u32(s[2], s[72]))
k.emit(s_cselect_b32(s[72], s[73], s[75]))
k.emit(s_lshr_b32(s[72], s[72], 1))
k.emit(s_bfm_b32(s[73], 1, 0))
k.emit(s_and_b32(s[73], s[2], s[73]))
k.emit(s_mul_i32(s[72], s[72], s[73]))
k.emit(s_add_u32(s[2], s[71], s[72]))
k.label('skip_WGMXCC')
k.emit(v_mov_b32_e32(v[20], 256))
k.emit(v_mov_b32_e32(v[19], s[20]))
v_ceildiv(k, v[20], v[19], v[21]) # ceil(N / 256) → v[18]
k.emit(v_mov_b32_e32(v[20], 256))
k.emit(v_mov_b32_e32(v[19], s[21]))
k.emit(v_readfirstlane_b32_e32(v[10], v[18]))
v_ceildiv(k, v[20], v[19], v[21]) # ceil(batch*M / 256) → v[18]
k.emit(s_nop())
k.emit(v_readfirstlane_b32_e32(v[11], v[18]))
k.waitcnt(lgkm=0)
k.emit(s_mov_b32(s[85], 84148480))
k.emit(s_mov_b32(s[86], 117834498))
k.emit(s_sub_u32(s[28], s[28], 16))
k.emit(s_subb_u32(s[29], s[29], 0))
k.emit(s_sub_u32(s[30], s[30], 16))
k.emit(s_subb_u32(s[31], s[31], 0))
k.label('AlphaNonZero')
k.emit(s_mov_b32(s[57], s[2]))
k.emit(s_mul_i32(s[58], s[57], s[46]))
k.emit(s_mov_b32(s[59], s[49]))
k.emit(s_mul_i32(s[87], s[52], s[46]))
k.emit(s_cmp_lt_u32(s[87], s[49]))
k.emit(s_cbranch_scc1(), target='SK_InitDone')
k.emit(s_mul_i32(s[87], s[52], s[46]))
k.emit(s_mul_i32(s[88], s[46], s[51]))
k.emit(s_sub_u32(s[87], s[87], s[88]))
k.emit(s_mul_i32(s[58], s[57], s[46]))
k.emit(s_add_u32(s[58], s[58], s[87]))
k.emit(s_add_u32(s[59], s[58], s[46]))
k.emit(s_add_u32(s[89], s[46], 1))
k.emit(s_mul_i32(s[88], s[57], s[89]))
k.emit(s_add_u32(s[89], s[88], s[89]))
k.emit(s_cmp_lt_u32(s[57], s[87]))
k.emit(s_cselect_b32(s[58], s[88], s[58]))
k.emit(s_cselect_b32(s[59], s[89], s[59]))
k.emit(s_mul_i32(s[87], s[52], s[46]))
k.emit(s_min_u32(s[59], s[59], s[87]))
k.label('SK_InitDone')
k.emit(s_cmp_ge_u32(s[58], s[49]))
k.emit(s_cbranch_scc1(), target='KernelEnd')
k.label('PersistentLoopStart')
k.emit(v_xor_b32_e32(v[18], v[178], v[16]))
k.emit(v_min_i32_e32(v[16], v[16], v[18]))
k.emit(v_xor_b32_e32(v[18], v[179], v[17]))
k.emit(v_min_i32_e32(v[17], v[17], v[18]))
k.emit(s_mul_hi_u32(s[89], s[58], s[47]))
k.emit(s_lshr_b32(s[90], s[48], 31))
k.emit(s_mul_i32(s[88], s[58], s[90]))
k.emit(s_add_u32(s[88], s[88], s[89]))
k.emit(s_and_b32(s[90], s[48], 2147483647))
k.emit(s_lshr_b32(s[88], s[88], s[90]))
k.emit(s_mul_i32(s[89], s[88], s[46]))
k.emit(s_add_u32(s[90], s[89], s[46]))
k.emit(s_sub_u32(s[60], s[58], s[89]))
k.emit(s_min_u32(s[61], s[59], s[90]))
k.emit(s_sub_u32(s[61], s[61], s[89]))
k.emit(s_mul_i32(s[91], s[52], s[46]))
k.emit(s_sub_u32(s[91], s[49], s[91]))
k.emit(s_mul_i32(s[89], s[51], s[46]))
k.emit(s_add_u32(s[89], s[89], s[58]))
k.emit(s_cmp_lt_u32(s[89], s[91]))
k.emit(s_cbranch_scc1(), target='NoBranch_8G3ZEUE1ZDJOP9IU')
k.emit(s_mov_b32(s[89], s[90]))
k.emit(s_cmp_le_u32(s[91], s[58]))
k.emit(s_cbranch_scc1(), target='NoBranch_8G3ZEUE1ZDJOP9IU')
k.emit(s_mul_i32(s[87], s[52], s[46]))
k.emit(s_mul_i32(s[92], s[46], s[51]))
k.emit(s_sub_u32(s[87], s[87], s[92]))
k.emit(s_mul_i32(s[58], s[57], s[46]))
k.emit(s_add_u32(s[58], s[58], s[87]))
k.emit(s_add_u32(s[59], s[58], s[46]))
k.emit(s_add_u32(s[93], s[46], 1))
k.emit(s_mul_i32(s[92], s[57], s[93]))
k.emit(s_add_u32(s[93], s[92], s[93]))
k.emit(s_cmp_lt_u32(s[57], s[87]))
k.emit(s_cselect_b32(s[58], s[92], s[58]))
k.emit(s_cselect_b32(s[59], s[93], s[59]))
k.emit(s_add_u32(s[89], s[58], s[91]))
k.emit(s_add_u32(s[59], s[59], s[91]))
k.emit(s_min_u32(s[59], s[59], s[49]))
k.emit(s_cmp_ge_u32(s[58], s[49]))
k.emit(s_cbranch_scc1(), target='KernelEnd')
k.label('NoBranch_8G3ZEUE1ZDJOP9IU')
k.emit(s_mov_b32(s[58], s[89]))
k.emit(s_mul_i32(s[89], s[10], s[11]))
v_divmod(k, s[89], s[88]) # batch tile index → quotient=v[18], remainder=v[19]
k.emit(v_readfirstlane_b32_e32(v[4], v[18]))
k.emit(v_readfirstlane_b32_e32(v[90], v[19]))
v_divmod(k, s[10], s[90]) # row tile index → quotient=v[18], remainder=v[19]
k.emit(v_readfirstlane_b32_e32(v[3], v[18]))
k.emit(v_readfirstlane_b32_e32(v[2], v[19]))
k.label('SKAlphaCheck')
k.emit(s_mov_b32(s[91], 16))
v_div(k, s[91], s[3]) # s[3] / 16 → v[18]
k.emit(v_readfirstlane_b32_e32(v[87], v[18]))
k.emit(s_mul_i32(s[90], s[87], s[91]))
k.emit(s_sub_u32(s[90], s[3], s[90]))
k.emit(s_mul_i32(s[90], s[90], s[10]))
k.emit(s_add_u32(s[90], s[90], s[2]))
v_div(k, s[91], s[11]) # s[11] / 16 → v[18]
k.emit(v_readfirstlane_b32_e32(v[88], v[18]))
k.emit(s_mul_i32(s[89], s[91], s[88]))
k.emit(s_sub_u32(s[89], s[11], s[89]))
k.emit(s_cmp_eq_u32(s[89], 0))
k.emit(s_cmov_b32(s[89], s[91]))
k.emit(s_cmp_ge_u32(s[87], s[88]))
k.emit(s_cselect_b32(s[88], s[89], s[91]))
v_divmod(k, s[88], s[90]) # WGM tile divmod → v[18]=quotient, v[19]=remainder
k.emit(v_readfirstlane_b32_e32(v[2], v[18]))
k.emit(v_readfirstlane_b32_e32(v[3], v[19]))
k.emit(s_mul_i32(s[3], s[2], s[88]))
k.emit(s_sub_u32(s[3], s[90], s[3]))
k.emit(s_mul_i32(s[87], s[87], s[91]))
k.emit(s_add_u32(s[3], s[3], s[87]))
k.label('WGM')
k.emit(v_and_b32_e32(v[19], 63, v[180]))
k.emit(v_and_b32_e32(v[18], 15, v[19]))
k.emit(v_lshlrev_b32_e32(v[18], 3, v[18]))
k.emit(v_lshrrev_b32_e32(v[19], 4, v[19]))
k.emit(v_lshl_add_u32_e64(v[18], v[19], 11, v[18]))
k.emit(v_lshrrev_b32_e32(v[22], 6, v[180]))
k.emit(v_and_b32_e32(v[22], 1, v[22]))
k.emit(v_lshl_add_u32_e64(v[18], v[22], 7, v[18]))
k.emit(v_and_b32_e32(v[20], 63, v[180]))
k.emit(v_and_b32_e32(v[19], 15, v[20]))
k.emit(v_lshlrev_b32_e32(v[19], 6, v[19]))
k.emit(v_lshlrev_b32_e32(v[19], 3, v[19]))
k.emit(v_lshrrev_b32_e32(v[20], 4, v[20]))
k.emit(v_lshl_add_u32_e64(v[19], v[20], 3, v[19]))
k.emit(v_lshrrev_b32_e32(v[21], 7, v[180]))
k.emit(v_and_b32_e32(v[21], 1, v[21]))
k.emit(v_lshl_add_u32_e64(v[19], v[21], 13, v[19]))
k.emit(v_lshrrev_b32_e32(v[20], 6, v[180]))
k.emit(v_lshrrev_b32_e32(v[20], 2, v[20]))
k.emit(s_mov_b32(s[87], 16384))
k.emit(v_mul_lo_u32(v[20], s[87], v[20]))
k.emit(v_add_lshl_u32_e64(v[16], v[20], v[18], 1))
k.emit(v_lshrrev_b32_e32(v[18], 6, v[180]))
k.emit(v_lshrrev_b32_e32(v[18], 2, v[18]))
k.emit(s_mov_b32(s[87], 64))
k.emit(v_mul_lo_u32(v[18], s[87], v[18]))
k.emit(v_add_lshl_u32_e64(v[17], v[18], v[19], 1))
k.emit(v_lshrrev_b32_e32(v[20], 10, v[17]))
k.emit(v_lshl_add_u32_e64(v[17], v[20], 5, v[17]))
k.emit(v_add_co_u32_e32(v[17], 32768, v[17]))
k.emit(v_add_u32_e32(v[178], 66560, v[16]))
k.emit(v_xor_b32_e32(v[178], v[178], v[16]))
k.emit(v_add_u32_e32(v[179], 66560, v[17]))
k.emit(v_xor_b32_e32(v[179], v[179], v[17]))
k.emit(v_lshrrev_b32_e32(v[19], 5, v[180]))
k.emit(v_and_b32_e32(v[18], 31, v[180]))
k.emit(v_lshlrev_b32_e32(v[18], 3, v[18]))
k.emit(v_mov_b32_e32(v[22], v[19]))
k.emit(v_lshrrev_b32_e32(v[20], 3, v[180]))
k.emit(v_and_b32_e32(v[21], 7, v[180]))
k.emit(v_lshlrev_b32_e32(v[21], 3, v[21]))
k.emit(v_mov_b32_e32(v[23], v[21]))
k.emit(v_mul_u32_u24_e32(v[24], 256, v[22]))
k.emit(v_add_lshl_u32_e64(v[24], v[18], v[24], 1))
k.emit(s_nop())
k.emit(v_readfirstlane_b32_e32(v[53], v[24]))
k.emit(s_nop())
k.emit(s_add_u32(s[55], s[53], 66560))
k.emit(s_xor_b32(s[55], s[55], s[53]))
k.emit(v_mul_u32_u24_e32(v[24], 64, v[20]))
k.emit(v_add_lshl_u32_e64(v[24], v[23], v[24], 1))
k.emit(v_lshrrev_b32_e32(v[26], 10, v[24]))
k.emit(v_lshl_add_u32_e64(v[24], v[26], 5, v[24]))
k.emit(v_add_co_u32_e32(v[24], 32768, v[24]))
k.emit(s_nop())
k.emit(v_readfirstlane_b32_e32(v[54], v[24]))
k.emit(s_nop())
k.emit(s_add_u32(s[56], s[54], 66560))
k.emit(s_xor_b32(s[56], s[56], s[54]))
k.emit(v_mov_b32_e32(v[24], v[18]))
# v[25:32] = B row indices with stride 32
k.emit(v_mov_b32_e32(v[25], v[20]))
for i in range(7): k.emit(v_add_co_u32_e32(v[26+i], 32, v[25+i]))
# v[33:40] = A col indices with stride 8
k.emit(v_mov_b32_e32(v[33], v[19]))
for i in range(7): k.emit(v_add_co_u32_e32(v[34+i], 8, v[33+i]))
k.emit(v_mov_b32_e32(v[41], v[21]))
k.emit(s_mul_i32(s[87], s[2], 256))
k.emit(s_sub_u32(s[87], s[20], s[87]))
k.emit(s_sub_u32(s[87], s[87], 8))
k.emit(v_mov_b32_e32(v[42], s[87]))
k.emit(v_min_i32_e32(v[24], v[42], v[24]))
# compute 8 A tile global load addresses: v[i] = (col_offset + stride_A * row_idx + 8) << 1
k.emit(v_mul_lo_u32(v[42], s[40], v[33]))
k.emit(v_add_co_u32_e32(v[0], v[24], v[42]))
k.emit(v_add_u32_e32(v[0], 8))
k.emit(v_lshlrev_b32_e32(v[0], 1))
for i in range(1, 8):
k.emit(v_mul_lo_u32(v[42], s[40], v[33+i]))
k.emit(v_add_co_u32_e32(v[i], v[24], v[42]))
k.emit(v_add_u32_e32(v[i], 8, v[i]))
k.emit(v_lshlrev_b32_e32(v[i], 1, v[i]))
# compute 8 B tile global load addresses: v[8+i] = (row_offset + stride_B * col_idx + 8) << 1
for i in range(8):
k.emit(v_mul_lo_u32(v[33], s[42], v[25+i]))
k.emit(v_add_co_u32_e32(v[8+i], v[41], v[33]))
k.emit(v_add_u32_e32(v[8+i], 8, v[8+i]))
k.emit(v_lshlrev_b32_e32(v[8+i], 1, v[8+i]))
k.emit(s_mul_hi_u32(s[91], s[2], 256))
k.emit(s_mul_i32(s[90], s[2], 256))
k.emit(s_mul_i32(s[88], s[60], 64))
k.emit(s_mul_hi_u32(s[89], s[88], s[40]))
k.emit(s_mul_i32(s[88], s[88], s[40]))
k.emit(s_add_u32(s[90], s[90], s[88]))
k.emit(s_addc_u32(s[91], s[91], s[89]))
k.emit(s_mov_b64(s[62:63], 1))
k.emit(s_sub_u32(s[88], s[20], 1))
k.emit(s_mul_hi_u32(s[89], 1, s[88]))
k.emit(s_mul_i32(s[88], 1, s[88]))
k.emit(s_add_u32(s[62], s[62], s[88]))
k.emit(s_addc_u32(s[63], s[63], s[89]))
k.emit(s_sub_u32(s[88], s[23], 1))
k.emit(s_mul_hi_u32(s[89], s[40], s[88]))
k.emit(s_mul_i32(s[88], s[40], s[88]))
k.emit(s_add_u32(s[62], s[62], s[88]))
k.emit(s_addc_u32(s[63], s[63], s[89]))
k.emit(s_sub_u32(s[62], s[62], s[90]))
k.emit(s_subb_u32(s[63], s[63], s[91]))
k.emit(s_lshl_b64(s[62:63], s[62:63], 1))
k.emit(s_add_u32(s[62], s[62], 16))
k.emit(s_addc_u32(s[63], s[63], 0))
k.emit(s_cmp_eq_u32(s[63], 0))
k.emit(s_cselect_b32(s[70], s[62], -1))
k.emit(s_mul_hi_u32(s[89], s[41], s[4]))
k.emit(s_mul_i32(s[88], s[41], s[4]))
k.emit(s_add_u32(s[90], s[90], s[88]))
k.emit(s_addc_u32(s[91], s[91], s[89]))
k.emit(s_lshl_b64(s[90:91], s[90:91], 1))
k.emit(s_add_u32(s[68], s[28], s[90]))
k.emit(s_addc_u32(s[69], s[29], s[91]))
k.emit(s_mov_b32(s[71], 131072))
k.emit(s_mul_hi_u32(s[91], s[3], 256))
k.emit(s_mul_i32(s[90], s[3], 256))
k.emit(s_mul_hi_u32(s[91], s[90], s[42]))
k.emit(s_mul_i32(s[90], s[90], s[42]))
k.emit(s_mul_i32(s[88], s[60], 64))
k.emit(s_mul_hi_u32(s[89], s[88], 1))
k.emit(s_mul_i32(s[88], s[88], 1))
k.emit(s_add_u32(s[90], s[90], s[88]))
k.emit(s_addc_u32(s[91], s[91], s[89]))
k.emit(s_mov_b64(s[76:77], 1))
k.emit(s_sub_u32(s[88], s[23], 1))
k.emit(s_mul_hi_u32(s[89], 1, s[88]))
k.emit(s_mul_i32(s[88], 1, s[88]))
k.emit(s_add_u32(s[76], s[76], s[88]))
k.emit(s_addc_u32(s[77], s[77], s[89]))
k.emit(s_sub_u32(s[88], s[21], 1))
k.emit(s_mul_hi_u32(s[89], s[42], s[88]))
k.emit(s_mul_i32(s[88], s[42], s[88]))
k.emit(s_add_u32(s[76], s[76], s[88]))
k.emit(s_addc_u32(s[77], s[77], s[89]))
k.emit(s_sub_u32(s[76], s[76], s[90]))
k.emit(s_subb_u32(s[77], s[77], s[91]))
k.emit(s_lshl_b64(s[76:77], s[76:77], 1))
k.emit(s_add_u32(s[76], s[76], 16))
k.emit(s_addc_u32(s[77], s[77], 0))
k.emit(s_cmp_eq_u32(s[77], 0))
k.emit(s_cselect_b32(s[74], s[76], -1))
k.emit(s_mul_hi_u32(s[89], s[43], s[4]))
k.emit(s_mul_i32(s[88], s[43], s[4]))
k.emit(s_add_u32(s[90], s[90], s[88]))
k.emit(s_addc_u32(s[91], s[91], s[89]))
k.emit(s_lshl_b64(s[90:91], s[90:91], 1))
k.emit(s_add_u32(s[72], s[30], s[90]))
k.emit(s_addc_u32(s[73], s[31], s[91]))
k.emit(s_mov_b32(s[75], 131072))
k.emit(s_mul_i32(s[83], 128, s[40]))
k.emit(s_mov_b32(s[84], 128))
k.emit(s_sub_u32(s[8], s[61], s[60]))
k.label('SKAlphaCheck2')
k.emit(s_and_b32(s[89], 63, s[23]))
k.emit(s_cmp_eq_u32(s[89], 0))
k.emit(s_cselect_b32(s[88], 0, 1))
k.emit(s_cmp_eq_u32(s[61], s[46]))
k.emit(s_cselect_b32(s[88], s[88], 0))
k.emit(s_sub_u32(s[8], s[8], s[88]))
k.emit(s_mov_b32(s[9], s[8]))
k.emit(s_and_b32(s[90], s[6], 7936))
k.emit(s_lshr_b32(s[90], s[90], 8))
k.emit(s_and_b32(s[91], s[6], 57344))
k.emit(s_and_b32(s[6], s[6], 255))
k.emit(s_mov_b32(s[88], s[6]))
k.label('beginStaggerUIter')
k.emit(s_lshl_b32(s[89], s[88], s[90]))
k.emit(s_cmp_ge_u32(s[9], s[89]))
k.emit(s_cbranch_scc1(), target='endStaggerUIter')
k.emit(s_lshr_b32(s[88], s[88], 1))
k.emit(s_branch(), target='beginStaggerUIter')
k.label('endStaggerUIter')
k.emit(s_sub_u32(s[89], s[88], 1))
k.emit(s_cmp_ge_u32(s[88], 1))
k.emit(s_cselect_b32(s[78], s[89], 0))
k.emit(s_cmp_eq_u32(s[91], 0))
k.emit(s_cbranch_scc1(), target='StaggerUMapping_1')
k.emit(s_mov_b32(s[88], s[2]))
k.emit(s_branch(), target='staggerInputEnd')
k.label('StaggerUMapping_1')
k.emit(s_cmp_eq_u32(s[91], 8192))
k.emit(s_cbranch_scc1(), target='StaggerUMapping_2')
k.emit(s_mov_b32(s[88], s[3]))
k.emit(s_branch(), target='staggerInputEnd')
k.label('StaggerUMapping_2')
k.emit(s_cmp_eq_u32(s[91], 16384))
k.emit(s_cbranch_scc1(), target='StaggerUMapping_3')
k.emit(s_mov_b32(s[88], -1))
k.emit(s_branch(), target='staggerInputEnd')
k.label('StaggerUMapping_3')
k.emit(s_cmp_eq_u32(s[91], 24576))
k.emit(s_cbranch_scc1(), target='StaggerUMapping_4')
k.emit(s_mul_i32(s[89], s[10], s[3]))
k.emit(s_add_u32(s[88], s[88], s[89]))
k.emit(s_add_u32(s[88], s[88], s[2]))
k.emit(s_branch(), target='staggerInputEnd')
k.label('StaggerUMapping_4')
k.emit(s_cmp_eq_u32(s[91], 32768))
k.emit(s_cbranch_scc1(), target='staggerInputEnd')
k.emit(s_mov_b32(s[88], -1))
k.emit(s_branch(), target='staggerInputEnd')
k.label('staggerInputEnd')
k.emit(s_and_b32(s[78], s[78], s[88]))
k.emit(s_lshl_b32(s[78], s[78], s[90]))
k.emit(s_cmp_gt_u32(s[60], 0))
k.emit(s_cmov_b32(s[78], 0))
k.emit(s_cmp_lt_u32(s[61], s[46]))
k.emit(s_cmov_b32(s[78], 0))
k.emit(s_mul_hi_i32(s[89], s[78], s[83]))
k.emit(s_mul_i32(s[88], s[78], s[83]))
k.emit(s_mul_hi_i32(s[80], s[8], s[83]))
k.emit(s_mul_i32(s[79], s[8], s[83]))
k.emit(s_sub_u32(s[79], s[83], s[79]))
k.emit(s_subb_u32(s[80], 0, s[80]))
k.emit(s_add_u32(s[68], s[68], s[88]))
k.emit(s_addc_u32(s[69], s[69], s[89]))
k.emit(s_sub_u32(s[62], s[62], s[88]))
k.emit(s_subb_u32(s[63], s[63], s[89]))
k.emit(s_cmp_eq_u32(s[63], 0))
k.emit(s_cselect_b32(s[70], s[62], -1))
k.emit(s_mul_hi_i32(s[89], s[78], s[84]))
k.emit(s_mul_i32(s[88], s[78], s[84]))
k.emit(s_mul_hi_i32(s[82], s[8], s[84]))
k.emit(s_mul_i32(s[81], s[8], s[84]))
k.emit(s_sub_u32(s[81], s[84], s[81]))
k.emit(s_subb_u32(s[82], 0, s[82]))
k.emit(s_add_u32(s[72], s[72], s[88]))
k.emit(s_addc_u32(s[73], s[73], s[89]))
k.emit(s_sub_u32(s[76], s[76], s[88]))
k.emit(s_subb_u32(s[77], s[77], s[89]))
k.emit(s_cmp_eq_u32(s[77], 0))
k.emit(s_cselect_b32(s[74], s[76], -1))
k.emit(s_add_u32(s[78], s[78], 2))
k.emit(s_cmp_eq_u32(s[8], 0))
k.emit(s_setprio())
k.emit(s_cbranch_scc1(), target='ShadowInitStart')
k.emit(s_mov_b32(M0, s[53]))
k.waitcnt(lgkm=0)
k.emit(s_barrier())
buffer_load_x4_m0_stride(k, 0, 68, 4096, 8, 0, 0, 1, 0, 0, 1, 0, 1)
k.emit(s_mov_b32(M0, 133120))
k.emit(s_mov_b32(M0, s[54]))
buffer_load_x4_m0_stride(k, 8, 72, 4224, 8, 0, 0, 1, 0, 0, 0, 1, 1, 1)
k.emit(s_mov_b32(M0, 133120))
k.emit(s_add_u32(s[90], s[8], 1))
k.emit(s_cmp_eq_u32(s[78], s[90]))
k.emit(s_cselect_b32(s[88], s[79], s[83]))
k.emit(s_cselect_b32(s[89], s[80], 0))
k.emit(s_add_u32(s[68], s[68], s[88]))
k.emit(s_addc_u32(s[69], s[69], s[89]))
k.emit(s_sub_u32(s[62], s[62], s[88]))
k.emit(s_subb_u32(s[63], s[63], s[89]))
k.emit(s_cmp_eq_u32(s[63], 0))
k.emit(s_cselect_b32(s[70], s[62], -1))
k.emit(s_add_u32(s[90], s[8], 1))
k.emit(s_cmp_eq_u32(s[78], s[90]))
k.emit(s_cselect_b32(s[88], s[81], s[84]))
k.emit(s_cselect_b32(s[89], s[82], 0))
k.emit(s_add_u32(s[72], s[72], s[88]))
k.emit(s_addc_u32(s[73], s[73], s[89]))
k.emit(s_sub_u32(s[76], s[76], s[88]))
k.emit(s_subb_u32(s[77], s[77], s[89]))
k.emit(s_cmp_eq_u32(s[77], 0))
k.emit(s_cselect_b32(s[74], s[76], -1))
k.label('ShadowInitStart')
k.emit(s_mov_b64(s[12:13], s[24:25]))
k.emit(s_mov_b32(s[14], 2147483648))
k.emit(s_mov_b32(s[15], 131072))
k.emit(s_mov_b64(s[16:17], s[24:25]))
k.emit(s_mov_b32(s[18], 2147483648))
k.emit(s_mov_b32(s[19], 131072))
k.emit(s_mov_b32(s[87], 1))
k.emit(s_mov_b32(s[88], 1))
k.emit(s_mul_i32(s[92], 256, s[3]))
k.emit(s_mul_hi_u32(s[91], s[92], s[38]))
k.emit(s_mul_i32(s[90], s[92], s[38]))
k.emit(s_lshl_b64(s[90:91], s[90:91], s[87]))
k.emit(s_add_u32(s[16], s[26], s[90]))
k.emit(s_addc_u32(s[17], s[27], s[91]))
k.emit(s_mul_hi_u32(s[91], s[92], s[36]))
k.emit(s_mul_i32(s[90], s[92], s[36]))
k.emit(s_lshl_b64(s[90:91], s[90:91], s[88]))
k.emit(s_add_u32(s[12], s[24], s[90]))
k.emit(s_addc_u32(s[13], s[25], s[91]))
k.emit(s_mul_hi_u32(s[91], s[4], s[39]))
k.emit(s_mul_i32(s[90], s[4], s[39]))
k.emit(s_lshl_b64(s[90:91], s[90:91], s[87]))
k.emit(s_add_u32(s[16], s[16], s[90]))
k.emit(s_addc_u32(s[17], s[17], s[91]))
k.emit(s_mul_hi_u32(s[91], s[4], s[37]))
k.emit(s_mul_i32(s[90], s[4], s[37]))
k.emit(s_lshl_b64(s[90:91], s[90:91], s[88]))
k.emit(s_add_u32(s[12], s[12], s[90]))
k.emit(s_addc_u32(s[13], s[13], s[91]))
k.emit(v_mov_b64_e32(v[182:183], 0))
# zero 16 accumulators
for i in range(16): k.emit(v_accvgpr_write(v[i], 0))
# zero all 256 accvgprs via mfma (16 regs per call, 15 calls for v[16:255])
for i in range(15): k.emit(v_mfma_i32_32x32x16_i8(v[16+i*16:31+i*16], v[182:183], v[182:183], v[0:15], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_cmp_eq_u32(s[8], 0))
k.emit(s_cbranch_scc1(), target='toPGR1end_OrdNLL')
k.waitcnt(vm=0)
k.emit(s_barrier())
k.emit(s_xor_b32(s[53], s[55], s[53]))
k.emit(s_xor_b32(s[54], s[56], s[54]))
k.emit(s_cmp_eq_u32(s[8], 1))
k.emit(s_cbranch_scc1(), target='skipPGR2')
k.emit(s_mov_b32(M0, s[53]))
k.waitcnt(lgkm=0)
k.emit(s_barrier())
buffer_load_x4_m0_stride(k, 0, 68, 4096, 8, 0, 0, 1, 0, 0, 1, 0, 1)
k.emit(s_mov_b32(M0, 133120))
k.emit(s_mov_b32(M0, s[54]))
k.waitcnt(lgkm=0)
k.emit(s_barrier())
buffer_load_x4_m0_stride(k, 8, 72, 4224, 8, 0, 0, 1, 0, 0, 0, 1, 1, 1)
k.emit(s_mov_b32(M0, 133120))
k.emit(s_xor_b32(s[53], s[55], s[53]))
k.emit(s_xor_b32(s[54], s[56], s[54]))
k.label('skipPGR2')
k.emit(s_barrier())
ds_read_ab_16(k)
k.waitcnt(lgkm=0)
perm_b32_32(k, 18)
k.label('openLoopL')
k.emit(s_cmp_eq_u32(s[8], 1))
k.emit(s_cbranch_scc1(), target='toPGR1')
k.emit(s_cmp_le_u32(s[8], 2))
k.emit(s_cbranch_scc1(), target='LoopEndL')
k.label('LoopBeginL')
k.emit(s_getreg_b32(s[87], 260))
k.waitcnt(lgkm=0)
k.emit(s_cmp_eq_u32(s[87], 0))
k.emit(s_cbranch_scc1(), target='LoopBeginL_0')
k.emit(s_cmp_eq_u32(s[87], 1))
k.emit(s_cbranch_scc1(), target='LoopBeginL_1')
k.label('LoopBeginL_0')
k.emit(v_mfma_16x16x32(v[0:3], v[114:117], v[18:21], v[0:3], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_cmp_eq_u32(s[8], s[78]))
k.emit(ds_read_b128(v[82:85], v[16], v[0], v[0], 0, 0, 64))
k.emit(ds_read_b128(v[86:89], v[16], v[0], v[0], 0, 0, 66))
k.emit(v_mfma_16x16x32(v[4:7], v[114:117], v[22:25], v[4:7], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_cselect_b32(s[88], s[79], s[83]))
k.emit(v_mfma_16x16x32(v[8:11], v[114:117], v[26:29], v[8:11], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_cselect_b32(s[89], s[80], 0))
k.emit(ds_read_b128(v[90:93], v[16], v[0], v[0], 0, 0, 68))
k.emit(ds_read_b128(v[94:97], v[16], v[0], v[0], 0, 0, 70))
k.emit(v_mfma_16x16x32(v[12:15], v[114:117], v[30:33], v[12:15], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_add_u32(s[68], s[68], s[88]))
k.emit(v_mfma_16x16x32(v[16:19], v[114:117], v[34:37], v[16:19], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_addc_u32(s[69], s[69], s[89]))
k.emit(ds_read_b128(v[98:101], v[16], v[0], v[0], 0, 0, 72))
k.emit(ds_read_b128(v[102:105], v[16], v[0], v[0], 0, 0, 74))
k.emit(v_mfma_16x16x32(v[20:23], v[114:117], v[38:41], v[20:23], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_sub_u32(s[62], s[62], s[88]))
k.emit(v_mfma_16x16x32(v[24:27], v[114:117], v[42:45], v[24:27], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_subb_u32(s[63], s[63], s[89]))
k.emit(ds_read_b128(v[106:109], v[16], v[0], v[0], 0, 0, 76))
k.emit(ds_read_b128(v[110:113], v[16], v[0], v[0], 0, 0, 78))
k.emit(v_mfma_16x16x32(v[28:31], v[114:117], v[46:49], v[28:31], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_cmp_eq_u32(s[63], 0))
k.emit(v_mfma_16x16x32(v[32:35], v[118:121], v[18:21], v[32:35], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.waitcnt(lgkm=4)
k.emit(s_cselect_b32(s[70], s[62], -1))
k.emit(v_perm_b32_e64(v[50], v[86], v[82], s[85]))
k.emit(v_perm_b32_e64(v[51], v[94], v[90], s[85]))
k.emit(v_mfma_16x16x32(v[36:39], v[118:121], v[22:25], v[36:39], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[146:149], v[17], v[0], v[0], 0, 64))
k.emit(v_mfma_16x16x32(v[40:43], v[118:121], v[26:29], v[40:43], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[44:47], v[118:121], v[30:33], v[44:47], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[150:153], v[17], v[0], v[0], 0, 192))
k.emit(v_mfma_16x16x32(v[48:51], v[118:121], v[34:37], v[48:51], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.waitcnt(lgkm=1)
k.emit(v_mfma_16x16x32(v[52:55], v[118:121], v[38:41], v[52:55], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_barrier())
k.emit(v_mfma_16x16x32(v[56:59], v[118:121], v[42:45], v[56:59], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_mov_b32(M0, s[53]))
k.emit(buffer_load_dwordx4(v[0:3], v[0], s[68:71], 0, 0, 1, 0, 0, 1, 0, 1))
k.emit(v_mfma_16x16x32(v[60:63], v[118:121], v[46:49], v[60:63], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[154:157], v[17], v[0], v[0], 0, 64, 1))
k.emit(v_mfma_16x16x32(v[64:67], v[122:125], v[18:21], v[64:67], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[52], v[102], v[98], s[85]))
k.emit(v_perm_b32_e64(v[53], v[110], v[106], s[85]))
k.emit(v_mfma_16x16x32(v[68:71], v[122:125], v[22:25], v[68:71], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_add_u32(M0, M0, 4096))
k.emit(buffer_load_dwordx4(v[0:3], v[1], s[68:71], 0, 0, 1, 0, 0, 1, 0, 1))
k.emit(v_mfma_16x16x32(v[72:75], v[122:125], v[26:29], v[72:75], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[158:161], v[17], v[0], v[0], 0, 192, 1))
k.emit(v_mfma_16x16x32(v[76:79], v[122:125], v[30:33], v[76:79], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[54], v[86], v[82], s[86]))
k.emit(v_perm_b32_e64(v[55], v[94], v[90], s[86]))
k.emit(v_mfma_16x16x32(v[80:83], v[122:125], v[34:37], v[80:83], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_add_u32(M0, M0, 4096))
k.emit(buffer_load_dwordx4(v[0:3], v[2], s[68:71], 0, 0, 1, 0, 0, 1, 0, 1))
k.emit(v_mfma_16x16x32(v[84:87], v[122:125], v[38:41], v[84:87], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[162:165], v[17], v[0], v[0], 0, 64, 2))
k.emit(v_mfma_16x16x32(v[88:91], v[122:125], v[42:45], v[88:91], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[56], v[102], v[98], s[86]))
k.emit(v_perm_b32_e64(v[57], v[110], v[106], s[86]))
k.emit(v_mfma_16x16x32(v[92:95], v[122:125], v[46:49], v[92:95], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_add_u32(M0, M0, 4096))
k.emit(buffer_load_dwordx4(v[0:3], v[3], s[68:71], 0, 0, 1, 0, 0, 1, 0, 1))
k.emit(v_mfma_16x16x32(v[96:99], v[126:129], v[18:21], v[96:99], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[166:169], v[17], v[0], v[0], 0, 192, 2))
k.emit(v_mfma_16x16x32(v[100:103], v[126:129], v[22:25], v[100:103], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[58], v[87], v[83], s[85]))
k.emit(v_perm_b32_e64(v[59], v[95], v[91], s[85]))
k.emit(v_mfma_16x16x32(v[104:107], v[126:129], v[26:29], v[104:107], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_add_u32(M0, M0, 4096))
k.emit(buffer_load_dwordx4(v[0:3], v[4], s[68:71], 0, 0, 1, 0, 0, 1, 0, 1))
k.emit(v_mfma_16x16x32(v[108:111], v[126:129], v[30:33], v[108:111], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[170:173], v[17], v[0], v[0], 0, 64, 3))
k.emit(v_mfma_16x16x32(v[112:115], v[126:129], v[34:37], v[112:115], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_cmp_eq_u32(s[8], s[78]))
k.emit(v_perm_b32_e64(v[60], v[103], v[99], s[85]))
k.emit(v_perm_b32_e64(v[61], v[111], v[107], s[85]))
k.emit(v_mfma_16x16x32(v[116:119], v[126:129], v[38:41], v[116:119], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_cselect_b32(s[88], s[81], s[84]))
k.emit(v_perm_b32_e64(v[62], v[87], v[83], s[86]))
k.emit(v_perm_b32_e64(v[63], v[95], v[91], s[86]))
k.emit(v_mfma_16x16x32(v[120:123], v[126:129], v[42:45], v[120:123], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_cselect_b32(s[89], s[82], 0))
k.emit(ds_read_b128(v[174:177], v[17], v[0], v[0], 0, 192, 3))
k.emit(v_mfma_16x16x32(v[124:127], v[126:129], v[46:49], v[124:127], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_add_u32(s[72], s[72], s[88]))
k.emit(v_perm_b32_e64(v[64], v[103], v[99], s[86]))
k.emit(v_perm_b32_e64(v[65], v[111], v[107], s[86]))
k.emit(v_mfma_16x16x32(v[128:131], v[130:133], v[18:21], v[128:131], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_addc_u32(s[73], s[73], s[89]))
k.emit(v_perm_b32_e64(v[66], v[88], v[84], s[85]))
k.emit(v_perm_b32_e64(v[67], v[96], v[92], s[85]))
k.emit(v_mfma_16x16x32(v[132:135], v[130:133], v[22:25], v[132:135], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_sub_u32(s[76], s[76], s[88]))
k.emit(v_perm_b32_e64(v[68], v[104], v[100], s[85]))
k.emit(v_perm_b32_e64(v[69], v[112], v[108], s[85]))
k.emit(v_mfma_16x16x32(v[136:139], v[130:133], v[26:29], v[136:139], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_subb_u32(s[77], s[77], s[89]))
k.emit(v_perm_b32_e64(v[70], v[88], v[84], s[86]))
k.emit(v_perm_b32_e64(v[71], v[96], v[92], s[86]))
k.emit(v_mfma_16x16x32(v[140:143], v[130:133], v[30:33], v[140:143], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_cmp_eq_u32(s[77], 0))
k.emit(v_perm_b32_e64(v[72], v[104], v[100], s[86]))
k.emit(v_perm_b32_e64(v[73], v[112], v[108], s[86]))
k.emit(v_mfma_16x16x32(v[144:147], v[130:133], v[34:37], v[144:147], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.waitcnt(lgkm=0)
k.emit(s_cselect_b32(s[74], s[76], -1))
k.emit(v_perm_b32_e64(v[74], v[89], v[85], s[85]))
k.emit(v_perm_b32_e64(v[75], v[97], v[93], s[85]))
k.emit(v_mfma_16x16x32(v[148:151], v[130:133], v[38:41], v[148:151], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[76], v[105], v[101], s[85]))
k.emit(v_perm_b32_e64(v[77], v[113], v[109], s[85]))
k.emit(v_mfma_16x16x32(v[152:155], v[130:133], v[42:45], v[152:155], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[78], v[89], v[85], s[86]))
k.emit(v_perm_b32_e64(v[79], v[97], v[93], s[86]))
k.emit(v_mfma_16x16x32(v[156:159], v[130:133], v[46:49], v[156:159], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[80], v[105], v[101], s[86]))
k.emit(v_perm_b32_e64(v[81], v[113], v[109], s[86]))
k.emit(v_mfma_16x16x32(v[160:163], v[134:137], v[18:21], v[160:163], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[164:167], v[134:137], v[22:25], v[164:167], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[168:171], v[134:137], v[26:29], v[168:171], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[172:175], v[134:137], v[30:33], v[172:175], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[176:179], v[134:137], v[34:37], v[176:179], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_barrier())
k.emit(v_mfma_16x16x32(v[180:183], v[134:137], v[38:41], v[180:183], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_add_u32(M0, M0, 4096))
k.emit(buffer_load_dwordx4(v[0:3], v[5], s[68:71], 0, 0, 1, 0, 0, 1, 0, 1))
k.emit(v_mfma_16x16x32(v[184:187], v[134:137], v[42:45], v[184:187], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[188:191], v[134:137], v[46:49], v[188:191], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[192:195], v[138:141], v[18:21], v[192:195], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_add_u32(M0, M0, 4096))
k.emit(buffer_load_dwordx4(v[0:3], v[6], s[68:71], 0, 0, 1, 0, 0, 1, 0, 1))
k.emit(v_mfma_16x16x32(v[196:199], v[138:141], v[22:25], v[196:199], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[200:203], v[138:141], v[26:29], v[200:203], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[204:207], v[138:141], v[30:33], v[204:207], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_add_u32(M0, M0, 4096))
k.emit(buffer_load_dwordx4(v[0:3], v[7], s[68:71], 0, 0, 1, 0, 0, 1, 0, 1))
k.emit(v_mfma_16x16x32(v[208:211], v[138:141], v[34:37], v[208:211], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[212:215], v[138:141], v[38:41], v[212:215], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[216:219], v[138:141], v[42:45], v[216:219], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_mov_b32(M0, s[54]))
k.emit(buffer_load_dwordx4(v[0:3], v[8], s[72:75], 0, 0, 1, 0, 0, 0, 1, 1, 1))
k.emit(v_mfma_16x16x32(v[220:223], v[138:141], v[46:49], v[220:223], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[224:227], v[142:145], v[18:21], v[224:227], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.waitcnt(vm=17)
k.emit(v_mfma_16x16x32(v[228:231], v[142:145], v[22:25], v[228:231], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_add_u32(M0, M0, 4224))
k.emit(buffer_load_dwordx4(v[0:3], v[9], s[72:75], 0, 0, 1, 0, 0, 0, 1, 1, 1))
k.emit(v_mfma_16x16x32(v[232:235], v[142:145], v[26:29], v[232:235], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[236:239], v[142:145], v[30:33], v[236:239], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_barrier())
k.emit(v_xor_b32_e32(v[16], v[178], v[16]))
k.emit(v_xor_b32_e32(v[17], v[179], v[17]))
k.emit(v_mfma_16x16x32(v[240:243], v[142:145], v[34:37], v[240:243], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[82:85], v[16]))
k.emit(ds_read_b128(v[86:89], v[16], v[0], v[0], 0, 0, 2))
k.emit(v_mfma_16x16x32(v[244:247], v[142:145], v[38:41], v[244:247], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[248:251], v[142:145], v[42:45], v[248:251], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[90:93], v[16], v[0], v[0], 0, 0, 4))
k.emit(ds_read_b128(v[94:97], v[16], v[0], v[0], 0, 0, 6))
k.emit(v_mfma_16x16x32(v[252:255], v[142:145], v[46:49], v[252:255], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[0:3], v[146:149], v[50:53], v[0:3], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[98:101], v[16], v[0], v[0], 0, 0, 8))
k.emit(ds_read_b128(v[102:105], v[16], v[0], v[0], 0, 0, 10))
k.emit(v_mfma_16x16x32(v[4:7], v[146:149], v[54:57], v[4:7], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[8:11], v[146:149], v[58:61], v[8:11], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.waitcnt(vm=9)
k.emit(ds_read_b128(v[106:109], v[16], v[0], v[0], 0, 0, 12))
k.emit(ds_read_b128(v[110:113], v[16], v[0], v[0], 0, 0, 14))
k.emit(v_mfma_16x16x32(v[12:15], v[146:149], v[62:65], v[12:15], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[16:19], v[146:149], v[66:69], v[16:19], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_barrier())
k.emit(ds_read_b128(v[114:117], v[17]))
k.emit(v_mfma_16x16x32(v[20:23], v[146:149], v[70:73], v[20:23], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[24:27], v[146:149], v[74:77], v[24:27], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[118:121], v[17], v[0], v[0], 0, 128))
k.emit(v_mfma_16x16x32(v[28:31], v[146:149], v[78:81], v[28:31], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[32:35], v[150:153], v[50:53], v[32:35], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[122:125], v[17], v[0], v[0], 0, 0, 1))
k.emit(v_mfma_16x16x32(v[36:39], v[150:153], v[54:57], v[36:39], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.waitcnt(lgkm=4)
k.emit(v_mfma_16x16x32(v[40:43], v[150:153], v[58:61], v[40:43], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[126:129], v[17], v[0], v[0], 0, 128, 1))
k.emit(v_mfma_16x16x32(v[44:47], v[150:153], v[62:65], v[44:47], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[18], v[86], v[82], s[85]))
k.emit(v_perm_b32_e64(v[19], v[94], v[90], s[85]))
k.emit(v_mfma_16x16x32(v[48:51], v[150:153], v[66:69], v[48:51], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[130:133], v[17], v[0], v[0], 0, 0, 2))
k.emit(v_mfma_16x16x32(v[52:55], v[150:153], v[70:73], v[52:55], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[20], v[102], v[98], s[85]))
k.emit(v_perm_b32_e64(v[21], v[110], v[106], s[85]))
k.emit(v_mfma_16x16x32(v[56:59], v[150:153], v[74:77], v[56:59], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[134:137], v[17], v[0], v[0], 0, 128, 2))
k.emit(v_mfma_16x16x32(v[60:63], v[150:153], v[78:81], v[60:63], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[22], v[86], v[82], s[86]))
k.emit(v_perm_b32_e64(v[23], v[94], v[90], s[86]))
k.emit(v_mfma_16x16x32(v[64:67], v[154:157], v[50:53], v[64:67], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[138:141], v[17], v[0], v[0], 0, 0, 3))
k.emit(v_mfma_16x16x32(v[68:71], v[154:157], v[54:57], v[68:71], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[24], v[102], v[98], s[86]))
k.emit(v_perm_b32_e64(v[25], v[110], v[106], s[86]))
k.emit(v_mfma_16x16x32(v[72:75], v[154:157], v[58:61], v[72:75], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[142:145], v[17], v[0], v[0], 0, 128, 3))
k.emit(v_mfma_16x16x32(v[76:79], v[154:157], v[62:65], v[76:79], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[26], v[87], v[83], s[85]))
k.emit(v_perm_b32_e64(v[27], v[95], v[91], s[85]))
k.emit(v_mfma_16x16x32(v[80:83], v[154:157], v[66:69], v[80:83], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[28], v[103], v[99], s[85]))
k.emit(v_perm_b32_e64(v[29], v[111], v[107], s[85]))
k.emit(v_mfma_16x16x32(v[84:87], v[154:157], v[70:73], v[84:87], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[30], v[87], v[83], s[86]))
k.emit(v_perm_b32_e64(v[31], v[95], v[91], s[86]))
k.emit(v_mfma_16x16x32(v[88:91], v[154:157], v[74:77], v[88:91], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[32], v[103], v[99], s[86]))
k.emit(v_perm_b32_e64(v[33], v[111], v[107], s[86]))
k.emit(v_mfma_16x16x32(v[92:95], v[154:157], v[78:81], v[92:95], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_add_u32(M0, M0, 4224))
k.emit(buffer_load_dwordx4(v[0:3], v[10], s[72:75], 0, 0, 1, 0, 0, 0, 1, 1, 1))
k.emit(v_mfma_16x16x32(v[96:99], v[158:161], v[50:53], v[96:99], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[34], v[88], v[84], s[85]))
k.emit(v_perm_b32_e64(v[35], v[96], v[92], s[85]))
k.emit(v_mfma_16x16x32(v[100:103], v[158:161], v[54:57], v[100:103], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[36], v[104], v[100], s[85]))
k.emit(v_perm_b32_e64(v[37], v[112], v[108], s[85]))
k.emit(v_mfma_16x16x32(v[104:107], v[158:161], v[58:61], v[104:107], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_add_u32(M0, M0, 4224))
k.emit(buffer_load_dwordx4(v[0:3], v[11], s[72:75], 0, 0, 1, 0, 0, 0, 1, 1, 1))
k.emit(v_mfma_16x16x32(v[108:111], v[158:161], v[62:65], v[108:111], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[38], v[88], v[84], s[86]))
k.emit(v_perm_b32_e64(v[39], v[96], v[92], s[86]))
k.emit(v_mfma_16x16x32(v[112:115], v[158:161], v[66:69], v[112:115], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[40], v[104], v[100], s[86]))
k.emit(v_perm_b32_e64(v[41], v[112], v[108], s[86]))
k.emit(v_mfma_16x16x32(v[116:119], v[158:161], v[70:73], v[116:119], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_add_u32(M0, M0, 4224))
k.emit(buffer_load_dwordx4(v[0:3], v[12], s[72:75], 0, 0, 1, 0, 0, 0, 1, 1, 1))
k.emit(v_mfma_16x16x32(v[120:123], v[158:161], v[74:77], v[120:123], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[42], v[89], v[85], s[85]))
k.emit(v_perm_b32_e64(v[43], v[97], v[93], s[85]))
k.emit(v_mfma_16x16x32(v[124:127], v[158:161], v[78:81], v[124:127], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[44], v[105], v[101], s[85]))
k.emit(v_perm_b32_e64(v[45], v[113], v[109], s[85]))
k.emit(v_mfma_16x16x32(v[128:131], v[162:165], v[50:53], v[128:131], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_add_u32(M0, M0, 4224))
k.emit(buffer_load_dwordx4(v[0:3], v[13], s[72:75], 0, 0, 1, 0, 0, 0, 1, 1, 1))
k.emit(v_mfma_16x16x32(v[132:135], v[162:165], v[54:57], v[132:135], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[46], v[89], v[85], s[86]))
k.emit(v_perm_b32_e64(v[47], v[97], v[93], s[86]))
k.emit(v_mfma_16x16x32(v[136:139], v[162:165], v[58:61], v[136:139], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[48], v[105], v[101], s[86]))
k.emit(v_perm_b32_e64(v[49], v[113], v[109], s[86]))
k.emit(v_mfma_16x16x32(v[140:143], v[162:165], v[62:65], v[140:143], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_add_u32(M0, M0, 4224))
k.emit(buffer_load_dwordx4(v[0:3], v[14], s[72:75], 0, 0, 1, 0, 0, 0, 1, 1, 1))
k.emit(v_mfma_16x16x32(v[144:147], v[162:165], v[66:69], v[144:147], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[148:151], v[162:165], v[70:73], v[148:151], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[152:155], v[162:165], v[74:77], v[152:155], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[156:159], v[162:165], v[78:81], v[156:159], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[160:163], v[166:169], v[50:53], v[160:163], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[164:167], v[166:169], v[54:57], v[164:167], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[168:171], v[166:169], v[58:61], v[168:171], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[172:175], v[166:169], v[62:65], v[172:175], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[176:179], v[166:169], v[66:69], v[176:179], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[180:183], v[166:169], v[70:73], v[180:183], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[184:187], v[166:169], v[74:77], v[184:187], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[188:191], v[166:169], v[78:81], v[188:191], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[192:195], v[170:173], v[50:53], v[192:195], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[196:199], v[170:173], v[54:57], v[196:199], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[200:203], v[170:173], v[58:61], v[200:203], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[204:207], v[170:173], v[62:65], v[204:207], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[208:211], v[170:173], v[66:69], v[208:211], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[212:215], v[170:173], v[70:73], v[212:215], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[216:219], v[170:173], v[74:77], v[216:219], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[220:223], v[170:173], v[78:81], v[220:223], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[224:227], v[174:177], v[50:53], v[224:227], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[228:231], v[174:177], v[54:57], v[228:231], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[232:235], v[174:177], v[58:61], v[232:235], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[236:239], v[174:177], v[62:65], v[236:239], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_add_u32(M0, M0, 4224))
k.emit(buffer_load_dwordx4(v[0:3], v[15], s[72:75], 0, 0, 1, 0, 0, 0, 1, 1, 1))
k.emit(v_mfma_16x16x32(v[240:243], v[174:177], v[66:69], v[240:243], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[244:247], v[174:177], v[70:73], v[244:247], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_xor_b32(s[53], s[55], s[53]))
k.emit(s_xor_b32(s[54], s[56], s[54]))
k.emit(v_mfma_16x16x32(v[248:251], v[174:177], v[74:77], v[248:251], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_sub_u32(s[8], s[8], 1))
k.emit(s_cmp_eq_i32(s[8], 2))
k.emit(v_mfma_16x16x32(v[252:255], v[174:177], v[78:81], v[252:255], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_cbranch_scc0(), target='LoopBeginL_0')
k.emit(s_branch(), target='LoopEndL')
k.label('LoopBeginL_1')
k.emit(v_mfma_16x16x32(v[0:3], v[114:117], v[18:21], v[0:3], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_cmp_eq_u32(s[8], s[78]))
k.emit(v_mfma_16x16x32(v[4:7], v[114:117], v[22:25], v[4:7], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_cselect_b32(s[88], s[79], s[83]))
k.emit(ds_read_b128(v[82:85], v[16], v[0], v[0], 0, 0, 64))
k.emit(ds_read_b128(v[86:89], v[16], v[0], v[0], 0, 0, 66))
k.emit(v_mfma_16x16x32(v[8:11], v[114:117], v[26:29], v[8:11], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_cselect_b32(s[89], s[80], 0))
k.emit(v_mfma_16x16x32(v[12:15], v[114:117], v[30:33], v[12:15], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_add_u32(s[68], s[68], s[88]))
k.emit(ds_read_b128(v[90:93], v[16], v[0], v[0], 0, 0, 68))
k.emit(ds_read_b128(v[94:97], v[16], v[0], v[0], 0, 0, 70))
k.emit(v_mfma_16x16x32(v[16:19], v[114:117], v[34:37], v[16:19], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_addc_u32(s[69], s[69], s[89]))
k.emit(v_mfma_16x16x32(v[20:23], v[114:117], v[38:41], v[20:23], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_sub_u32(s[62], s[62], s[88]))
k.emit(ds_read_b128(v[98:101], v[16], v[0], v[0], 0, 0, 72))
k.emit(ds_read_b128(v[102:105], v[16], v[0], v[0], 0, 0, 74))
k.emit(v_mfma_16x16x32(v[24:27], v[114:117], v[42:45], v[24:27], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_subb_u32(s[63], s[63], s[89]))
k.emit(v_mfma_16x16x32(v[28:31], v[114:117], v[46:49], v[28:31], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_cmp_eq_u32(s[63], 0))
k.emit(ds_read_b128(v[106:109], v[16], v[0], v[0], 0, 0, 76))
k.emit(ds_read_b128(v[110:113], v[16], v[0], v[0], 0, 0, 78))
k.emit(v_mfma_16x16x32(v[32:35], v[118:121], v[18:21], v[32:35], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.waitcnt(lgkm=4)
k.emit(s_cselect_b32(s[70], s[62], -1))
k.emit(v_perm_b32_e64(v[50], v[86], v[82], s[85]))
k.emit(v_perm_b32_e64(v[51], v[94], v[90], s[85]))
k.emit(v_mfma_16x16x32(v[36:39], v[118:121], v[22:25], v[36:39], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[40:43], v[118:121], v[26:29], v[40:43], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[146:149], v[17], v[0], v[0], 0, 64))
k.emit(v_mfma_16x16x32(v[44:47], v[118:121], v[30:33], v[44:47], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[48:51], v[118:121], v[34:37], v[48:51], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.waitcnt(lgkm=1)
k.emit(ds_read_b128(v[150:153], v[17], v[0], v[0], 0, 192))
k.emit(v_mfma_16x16x32(v[52:55], v[118:121], v[38:41], v[52:55], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_barrier())
k.emit(v_mfma_16x16x32(v[56:59], v[118:121], v[42:45], v[56:59], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[154:157], v[17], v[0], v[0], 0, 64, 1))
k.emit(v_mfma_16x16x32(v[60:63], v[118:121], v[46:49], v[60:63], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_mov_b32(M0, s[53]))
k.emit(buffer_load_dwordx4(v[0:3], v[0], s[68:71], 0, 0, 1, 0, 0, 1, 0, 1))
k.emit(v_mfma_16x16x32(v[64:67], v[122:125], v[18:21], v[64:67], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[52], v[102], v[98], s[85]))
k.emit(v_perm_b32_e64(v[53], v[110], v[106], s[85]))
k.emit(v_mfma_16x16x32(v[68:71], v[122:125], v[22:25], v[68:71], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[158:161], v[17], v[0], v[0], 0, 192, 1))
k.emit(v_mfma_16x16x32(v[72:75], v[122:125], v[26:29], v[72:75], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_add_u32(M0, M0, 4096))
k.emit(buffer_load_dwordx4(v[0:3], v[1], s[68:71], 0, 0, 1, 0, 0, 1, 0, 1))
k.emit(v_mfma_16x16x32(v[76:79], v[122:125], v[30:33], v[76:79], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[54], v[86], v[82], s[86]))
k.emit(v_perm_b32_e64(v[55], v[94], v[90], s[86]))
k.emit(v_mfma_16x16x32(v[80:83], v[122:125], v[34:37], v[80:83], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[162:165], v[17], v[0], v[0], 0, 64, 2))
k.emit(v_mfma_16x16x32(v[84:87], v[122:125], v[38:41], v[84:87], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_add_u32(M0, M0, 4096))
k.emit(buffer_load_dwordx4(v[0:3], v[2], s[68:71], 0, 0, 1, 0, 0, 1, 0, 1))
k.emit(v_mfma_16x16x32(v[88:91], v[122:125], v[42:45], v[88:91], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[56], v[102], v[98], s[86]))
k.emit(v_perm_b32_e64(v[57], v[110], v[106], s[86]))
k.emit(v_mfma_16x16x32(v[92:95], v[122:125], v[46:49], v[92:95], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[166:169], v[17], v[0], v[0], 0, 192, 2))
k.emit(v_mfma_16x16x32(v[96:99], v[126:129], v[18:21], v[96:99], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_add_u32(M0, M0, 4096))
k.emit(buffer_load_dwordx4(v[0:3], v[3], s[68:71], 0, 0, 1, 0, 0, 1, 0, 1))
k.emit(v_mfma_16x16x32(v[100:103], v[126:129], v[22:25], v[100:103], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[58], v[87], v[83], s[85]))
k.emit(v_perm_b32_e64(v[59], v[95], v[91], s[85]))
k.emit(v_mfma_16x16x32(v[104:107], v[126:129], v[26:29], v[104:107], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[170:173], v[17], v[0], v[0], 0, 64, 3))
k.emit(v_mfma_16x16x32(v[108:111], v[126:129], v[30:33], v[108:111], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_add_u32(M0, M0, 4096))
k.emit(buffer_load_dwordx4(v[0:3], v[4], s[68:71], 0, 0, 1, 0, 0, 1, 0, 1))
k.emit(v_mfma_16x16x32(v[112:115], v[126:129], v[34:37], v[112:115], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_cmp_eq_u32(s[8], s[78]))
k.emit(v_perm_b32_e64(v[60], v[103], v[99], s[85]))
k.emit(v_perm_b32_e64(v[61], v[111], v[107], s[85]))
k.emit(v_mfma_16x16x32(v[116:119], v[126:129], v[38:41], v[116:119], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_cselect_b32(s[88], s[81], s[84]))
k.emit(v_perm_b32_e64(v[62], v[87], v[83], s[86]))
k.emit(v_perm_b32_e64(v[63], v[95], v[91], s[86]))
k.emit(ds_read_b128(v[174:177], v[17], v[0], v[0], 0, 192, 3))
k.emit(v_mfma_16x16x32(v[120:123], v[126:129], v[42:45], v[120:123], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_cselect_b32(s[89], s[82], 0))
k.emit(v_mfma_16x16x32(v[124:127], v[126:129], v[46:49], v[124:127], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_add_u32(s[72], s[72], s[88]))
k.emit(v_perm_b32_e64(v[64], v[103], v[99], s[86]))
k.emit(v_perm_b32_e64(v[65], v[111], v[107], s[86]))
k.emit(v_mfma_16x16x32(v[128:131], v[130:133], v[18:21], v[128:131], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_addc_u32(s[73], s[73], s[89]))
k.emit(v_perm_b32_e64(v[66], v[88], v[84], s[85]))
k.emit(v_perm_b32_e64(v[67], v[96], v[92], s[85]))
k.emit(v_mfma_16x16x32(v[132:135], v[130:133], v[22:25], v[132:135], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_sub_u32(s[76], s[76], s[88]))
k.emit(v_perm_b32_e64(v[68], v[104], v[100], s[85]))
k.emit(v_perm_b32_e64(v[69], v[112], v[108], s[85]))
k.emit(v_mfma_16x16x32(v[136:139], v[130:133], v[26:29], v[136:139], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_subb_u32(s[77], s[77], s[89]))
k.emit(v_perm_b32_e64(v[70], v[88], v[84], s[86]))
k.emit(v_perm_b32_e64(v[71], v[96], v[92], s[86]))
k.emit(v_mfma_16x16x32(v[140:143], v[130:133], v[30:33], v[140:143], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_cmp_eq_u32(s[77], 0))
k.emit(v_perm_b32_e64(v[72], v[104], v[100], s[86]))
k.emit(v_perm_b32_e64(v[73], v[112], v[108], s[86]))
k.emit(v_mfma_16x16x32(v[144:147], v[130:133], v[34:37], v[144:147], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.waitcnt(lgkm=0)
k.emit(s_cselect_b32(s[74], s[76], -1))
k.emit(v_perm_b32_e64(v[74], v[89], v[85], s[85]))
k.emit(v_perm_b32_e64(v[75], v[97], v[93], s[85]))
k.emit(v_mfma_16x16x32(v[148:151], v[130:133], v[38:41], v[148:151], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[76], v[105], v[101], s[85]))
k.emit(v_perm_b32_e64(v[77], v[113], v[109], s[85]))
k.emit(v_mfma_16x16x32(v[152:155], v[130:133], v[42:45], v[152:155], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[78], v[89], v[85], s[86]))
k.emit(v_perm_b32_e64(v[79], v[97], v[93], s[86]))
k.emit(v_mfma_16x16x32(v[156:159], v[130:133], v[46:49], v[156:159], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[80], v[105], v[101], s[86]))
k.emit(v_perm_b32_e64(v[81], v[113], v[109], s[86]))
k.emit(v_mfma_16x16x32(v[160:163], v[134:137], v[18:21], v[160:163], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[164:167], v[134:137], v[22:25], v[164:167], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[168:171], v[134:137], v[26:29], v[168:171], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[172:175], v[134:137], v[30:33], v[172:175], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[176:179], v[134:137], v[34:37], v[176:179], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_barrier())
k.emit(v_mfma_16x16x32(v[180:183], v[134:137], v[38:41], v[180:183], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[184:187], v[134:137], v[42:45], v[184:187], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_add_u32(M0, M0, 4096))
k.emit(buffer_load_dwordx4(v[0:3], v[5], s[68:71], 0, 0, 1, 0, 0, 1, 0, 1))
k.emit(v_mfma_16x16x32(v[188:191], v[134:137], v[46:49], v[188:191], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[192:195], v[138:141], v[18:21], v[192:195], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[196:199], v[138:141], v[22:25], v[196:199], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_add_u32(M0, M0, 4096))
k.emit(buffer_load_dwordx4(v[0:3], v[6], s[68:71], 0, 0, 1, 0, 0, 1, 0, 1))
k.emit(v_mfma_16x16x32(v[200:203], v[138:141], v[26:29], v[200:203], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[204:207], v[138:141], v[30:33], v[204:207], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[208:211], v[138:141], v[34:37], v[208:211], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_add_u32(M0, M0, 4096))
k.emit(buffer_load_dwordx4(v[0:3], v[7], s[68:71], 0, 0, 1, 0, 0, 1, 0, 1))
k.emit(v_mfma_16x16x32(v[212:215], v[138:141], v[38:41], v[212:215], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[216:219], v[138:141], v[42:45], v[216:219], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[220:223], v[138:141], v[46:49], v[220:223], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_mov_b32(M0, s[54]))
k.emit(buffer_load_dwordx4(v[0:3], v[8], s[72:75], 0, 0, 1, 0, 0, 0, 1, 1, 1))
k.emit(v_mfma_16x16x32(v[224:227], v[142:145], v[18:21], v[224:227], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.waitcnt(vm=17)
k.emit(v_mfma_16x16x32(v[228:231], v[142:145], v[22:25], v[228:231], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[232:235], v[142:145], v[26:29], v[232:235], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_add_u32(M0, M0, 4224))
k.emit(buffer_load_dwordx4(v[0:3], v[9], s[72:75], 0, 0, 1, 0, 0, 0, 1, 1, 1))
k.emit(v_mfma_16x16x32(v[236:239], v[142:145], v[30:33], v[236:239], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_barrier())
k.emit(v_xor_b32_e32(v[16], v[178], v[16]))
k.emit(v_xor_b32_e32(v[17], v[179], v[17]))
k.emit(v_mfma_16x16x32(v[240:243], v[142:145], v[34:37], v[240:243], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[244:247], v[142:145], v[38:41], v[244:247], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[82:85], v[16]))
k.emit(ds_read_b128(v[86:89], v[16], v[0], v[0], 0, 0, 2))
k.emit(v_mfma_16x16x32(v[248:251], v[142:145], v[42:45], v[248:251], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[252:255], v[142:145], v[46:49], v[252:255], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[90:93], v[16], v[0], v[0], 0, 0, 4))
k.emit(ds_read_b128(v[94:97], v[16], v[0], v[0], 0, 0, 6))
k.emit(v_mfma_16x16x32(v[0:3], v[146:149], v[50:53], v[0:3], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[4:7], v[146:149], v[54:57], v[4:7], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[98:101], v[16], v[0], v[0], 0, 0, 8))
k.emit(ds_read_b128(v[102:105], v[16], v[0], v[0], 0, 0, 10))
k.emit(v_mfma_16x16x32(v[8:11], v[146:149], v[58:61], v[8:11], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.waitcnt(vm=9)
k.emit(v_mfma_16x16x32(v[12:15], v[146:149], v[62:65], v[12:15], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[106:109], v[16], v[0], v[0], 0, 0, 12))
k.emit(ds_read_b128(v[110:113], v[16], v[0], v[0], 0, 0, 14))
k.emit(v_mfma_16x16x32(v[16:19], v[146:149], v[66:69], v[16:19], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_barrier())
k.emit(v_mfma_16x16x32(v[20:23], v[146:149], v[70:73], v[20:23], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[114:117], v[17]))
k.emit(v_mfma_16x16x32(v[24:27], v[146:149], v[74:77], v[24:27], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[28:31], v[146:149], v[78:81], v[28:31], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[118:121], v[17], v[0], v[0], 0, 128))
k.emit(v_mfma_16x16x32(v[32:35], v[150:153], v[50:53], v[32:35], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[36:39], v[150:153], v[54:57], v[36:39], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.waitcnt(lgkm=4)
k.emit(ds_read_b128(v[122:125], v[17], v[0], v[0], 0, 0, 1))
k.emit(v_mfma_16x16x32(v[40:43], v[150:153], v[58:61], v[40:43], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[18], v[86], v[82], s[85]))
k.emit(v_perm_b32_e64(v[19], v[94], v[90], s[85]))
k.emit(v_mfma_16x16x32(v[44:47], v[150:153], v[62:65], v[44:47], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[126:129], v[17], v[0], v[0], 0, 128, 1))
k.emit(v_mfma_16x16x32(v[48:51], v[150:153], v[66:69], v[48:51], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[20], v[102], v[98], s[85]))
k.emit(v_perm_b32_e64(v[21], v[110], v[106], s[85]))
k.emit(v_mfma_16x16x32(v[52:55], v[150:153], v[70:73], v[52:55], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[130:133], v[17], v[0], v[0], 0, 0, 2))
k.emit(v_mfma_16x16x32(v[56:59], v[150:153], v[74:77], v[56:59], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[22], v[86], v[82], s[86]))
k.emit(v_perm_b32_e64(v[23], v[94], v[90], s[86]))
k.emit(v_mfma_16x16x32(v[60:63], v[150:153], v[78:81], v[60:63], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[134:137], v[17], v[0], v[0], 0, 128, 2))
k.emit(v_mfma_16x16x32(v[64:67], v[154:157], v[50:53], v[64:67], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[24], v[102], v[98], s[86]))
k.emit(v_perm_b32_e64(v[25], v[110], v[106], s[86]))
k.emit(v_mfma_16x16x32(v[68:71], v[154:157], v[54:57], v[68:71], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[138:141], v[17], v[0], v[0], 0, 0, 3))
k.emit(v_mfma_16x16x32(v[72:75], v[154:157], v[58:61], v[72:75], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[26], v[87], v[83], s[85]))
k.emit(v_perm_b32_e64(v[27], v[95], v[91], s[85]))
k.emit(v_mfma_16x16x32(v[76:79], v[154:157], v[62:65], v[76:79], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[142:145], v[17], v[0], v[0], 0, 128, 3))
k.emit(v_mfma_16x16x32(v[80:83], v[154:157], v[66:69], v[80:83], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[28], v[103], v[99], s[85]))
k.emit(v_perm_b32_e64(v[29], v[111], v[107], s[85]))
k.emit(v_mfma_16x16x32(v[84:87], v[154:157], v[70:73], v[84:87], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[30], v[87], v[83], s[86]))
k.emit(v_perm_b32_e64(v[31], v[95], v[91], s[86]))
k.emit(v_mfma_16x16x32(v[88:91], v[154:157], v[74:77], v[88:91], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[32], v[103], v[99], s[86]))
k.emit(v_perm_b32_e64(v[33], v[111], v[107], s[86]))
k.emit(v_mfma_16x16x32(v[92:95], v[154:157], v[78:81], v[92:95], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[34], v[88], v[84], s[85]))
k.emit(v_perm_b32_e64(v[35], v[96], v[92], s[85]))
k.emit(v_mfma_16x16x32(v[96:99], v[158:161], v[50:53], v[96:99], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_add_u32(M0, M0, 4224))
k.emit(buffer_load_dwordx4(v[0:3], v[10], s[72:75], 0, 0, 1, 0, 0, 0, 1, 1, 1))
k.emit(v_mfma_16x16x32(v[100:103], v[158:161], v[54:57], v[100:103], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[36], v[104], v[100], s[85]))
k.emit(v_perm_b32_e64(v[37], v[112], v[108], s[85]))
k.emit(v_mfma_16x16x32(v[104:107], v[158:161], v[58:61], v[104:107], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[38], v[88], v[84], s[86]))
k.emit(v_perm_b32_e64(v[39], v[96], v[92], s[86]))
k.emit(v_mfma_16x16x32(v[108:111], v[158:161], v[62:65], v[108:111], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_add_u32(M0, M0, 4224))
k.emit(buffer_load_dwordx4(v[0:3], v[11], s[72:75], 0, 0, 1, 0, 0, 0, 1, 1, 1))
k.emit(v_mfma_16x16x32(v[112:115], v[158:161], v[66:69], v[112:115], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[40], v[104], v[100], s[86]))
k.emit(v_perm_b32_e64(v[41], v[112], v[108], s[86]))
k.emit(v_mfma_16x16x32(v[116:119], v[158:161], v[70:73], v[116:119], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[42], v[89], v[85], s[85]))
k.emit(v_perm_b32_e64(v[43], v[97], v[93], s[85]))
k.emit(v_mfma_16x16x32(v[120:123], v[158:161], v[74:77], v[120:123], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_add_u32(M0, M0, 4224))
k.emit(buffer_load_dwordx4(v[0:3], v[12], s[72:75], 0, 0, 1, 0, 0, 0, 1, 1, 1))
k.emit(v_mfma_16x16x32(v[124:127], v[158:161], v[78:81], v[124:127], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[44], v[105], v[101], s[85]))
k.emit(v_perm_b32_e64(v[45], v[113], v[109], s[85]))
k.emit(v_mfma_16x16x32(v[128:131], v[162:165], v[50:53], v[128:131], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[46], v[89], v[85], s[86]))
k.emit(v_perm_b32_e64(v[47], v[97], v[93], s[86]))
k.emit(v_mfma_16x16x32(v[132:135], v[162:165], v[54:57], v[132:135], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_add_u32(M0, M0, 4224))
k.emit(buffer_load_dwordx4(v[0:3], v[13], s[72:75], 0, 0, 1, 0, 0, 0, 1, 1, 1))
k.emit(v_mfma_16x16x32(v[136:139], v[162:165], v[58:61], v[136:139], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[48], v[105], v[101], s[86]))
k.emit(v_perm_b32_e64(v[49], v[113], v[109], s[86]))
k.emit(v_mfma_16x16x32(v[140:143], v[162:165], v[62:65], v[140:143], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[144:147], v[162:165], v[66:69], v[144:147], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_add_u32(M0, M0, 4224))
k.emit(buffer_load_dwordx4(v[0:3], v[14], s[72:75], 0, 0, 1, 0, 0, 0, 1, 1, 1))
k.emit(v_mfma_16x16x32(v[148:151], v[162:165], v[70:73], v[148:151], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[152:155], v[162:165], v[74:77], v[152:155], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[156:159], v[162:165], v[78:81], v[156:159], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[160:163], v[166:169], v[50:53], v[160:163], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[164:167], v[166:169], v[54:57], v[164:167], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[168:171], v[166:169], v[58:61], v[168:171], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[172:175], v[166:169], v[62:65], v[172:175], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[176:179], v[166:169], v[66:69], v[176:179], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[180:183], v[166:169], v[70:73], v[180:183], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[184:187], v[166:169], v[74:77], v[184:187], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[188:191], v[166:169], v[78:81], v[188:191], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[192:195], v[170:173], v[50:53], v[192:195], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[196:199], v[170:173], v[54:57], v[196:199], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[200:203], v[170:173], v[58:61], v[200:203], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[204:207], v[170:173], v[62:65], v[204:207], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[208:211], v[170:173], v[66:69], v[208:211], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[212:215], v[170:173], v[70:73], v[212:215], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[216:219], v[170:173], v[74:77], v[216:219], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[220:223], v[170:173], v[78:81], v[220:223], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[224:227], v[174:177], v[50:53], v[224:227], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[228:231], v[174:177], v[54:57], v[228:231], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[232:235], v[174:177], v[58:61], v[232:235], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[236:239], v[174:177], v[62:65], v[236:239], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[240:243], v[174:177], v[66:69], v[240:243], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_add_u32(M0, M0, 4224))
k.emit(buffer_load_dwordx4(v[0:3], v[15], s[72:75], 0, 0, 1, 0, 0, 0, 1, 1, 1))
k.emit(v_mfma_16x16x32(v[244:247], v[174:177], v[70:73], v[244:247], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_xor_b32(s[53], s[55], s[53]))
k.emit(s_xor_b32(s[54], s[56], s[54]))
k.emit(v_mfma_16x16x32(v[248:251], v[174:177], v[74:77], v[248:251], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_sub_u32(s[8], s[8], 1))
k.emit(s_cmp_eq_i32(s[8], 2))
k.emit(v_mfma_16x16x32(v[252:255], v[174:177], v[78:81], v[252:255], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_cbranch_scc0(), target='LoopBeginL_1')
k.emit(s_branch(), target='LoopEndL')
k.label('LoopEndL')
k.emit(s_waitcnt())
k.emit(s_barrier())
k.emit(v_mfma_16x16x32(v[0:3], v[114:117], v[18:21], v[0:3], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_cmp_eq_u32(s[8], s[78]))
k.emit(ds_read_b128(v[82:85], v[16], v[0], v[0], 0, 0, 64))
k.emit(ds_read_b128(v[86:89], v[16], v[0], v[0], 0, 0, 66))
k.emit(v_mfma_16x16x32(v[4:7], v[114:117], v[22:25], v[4:7], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_cselect_b32(s[88], s[79], s[83]))
k.emit(v_mfma_16x16x32(v[8:11], v[114:117], v[26:29], v[8:11], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_cselect_b32(s[89], s[80], 0))
k.emit(ds_read_b128(v[90:93], v[16], v[0], v[0], 0, 0, 68))
k.emit(ds_read_b128(v[94:97], v[16], v[0], v[0], 0, 0, 70))
k.emit(v_mfma_16x16x32(v[12:15], v[114:117], v[30:33], v[12:15], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_add_u32(s[68], s[68], s[88]))
k.emit(v_mfma_16x16x32(v[16:19], v[114:117], v[34:37], v[16:19], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_addc_u32(s[69], s[69], s[89]))
k.emit(ds_read_b128(v[98:101], v[16], v[0], v[0], 0, 0, 72))
k.emit(ds_read_b128(v[102:105], v[16], v[0], v[0], 0, 0, 74))
k.emit(v_mfma_16x16x32(v[20:23], v[114:117], v[38:41], v[20:23], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_sub_u32(s[62], s[62], s[88]))
k.emit(v_mfma_16x16x32(v[24:27], v[114:117], v[42:45], v[24:27], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_subb_u32(s[63], s[63], s[89]))
k.emit(ds_read_b128(v[106:109], v[16], v[0], v[0], 0, 0, 76))
k.emit(ds_read_b128(v[110:113], v[16], v[0], v[0], 0, 0, 78))
k.emit(v_mfma_16x16x32(v[28:31], v[114:117], v[46:49], v[28:31], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_cmp_eq_u32(s[63], 0))
k.emit(v_mfma_16x16x32(v[32:35], v[118:121], v[18:21], v[32:35], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.waitcnt(lgkm=4)
k.emit(s_cselect_b32(s[70], s[62], -1))
k.emit(v_perm_b32_e64(v[50], v[86], v[82], s[85]))
k.emit(v_perm_b32_e64(v[51], v[94], v[90], s[85]))
k.emit(v_mfma_16x16x32(v[36:39], v[118:121], v[22:25], v[36:39], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[146:149], v[17], v[0], v[0], 0, 64))
k.emit(v_mfma_16x16x32(v[40:43], v[118:121], v[26:29], v[40:43], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[44:47], v[118:121], v[30:33], v[44:47], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[150:153], v[17], v[0], v[0], 0, 192))
k.emit(v_mfma_16x16x32(v[48:51], v[118:121], v[34:37], v[48:51], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.waitcnt(lgkm=1)
k.emit(v_mfma_16x16x32(v[52:55], v[118:121], v[38:41], v[52:55], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_barrier())
k.emit(v_mfma_16x16x32(v[56:59], v[118:121], v[42:45], v[56:59], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[60:63], v[118:121], v[46:49], v[60:63], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[154:157], v[17], v[0], v[0], 0, 64, 1))
k.emit(v_mfma_16x16x32(v[64:67], v[122:125], v[18:21], v[64:67], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[52], v[102], v[98], s[85]))
k.emit(v_perm_b32_e64(v[53], v[110], v[106], s[85]))
k.emit(v_mfma_16x16x32(v[68:71], v[122:125], v[22:25], v[68:71], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[72:75], v[122:125], v[26:29], v[72:75], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[158:161], v[17], v[0], v[0], 0, 192, 1))
k.emit(v_mfma_16x16x32(v[76:79], v[122:125], v[30:33], v[76:79], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[54], v[86], v[82], s[86]))
k.emit(v_perm_b32_e64(v[55], v[94], v[90], s[86]))
k.emit(v_mfma_16x16x32(v[80:83], v[122:125], v[34:37], v[80:83], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[84:87], v[122:125], v[38:41], v[84:87], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[162:165], v[17], v[0], v[0], 0, 64, 2))
k.emit(v_mfma_16x16x32(v[88:91], v[122:125], v[42:45], v[88:91], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[56], v[102], v[98], s[86]))
k.emit(v_perm_b32_e64(v[57], v[110], v[106], s[86]))
k.emit(v_mfma_16x16x32(v[92:95], v[122:125], v[46:49], v[92:95], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[96:99], v[126:129], v[18:21], v[96:99], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[166:169], v[17], v[0], v[0], 0, 192, 2))
k.emit(v_mfma_16x16x32(v[100:103], v[126:129], v[22:25], v[100:103], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[58], v[87], v[83], s[85]))
k.emit(v_perm_b32_e64(v[59], v[95], v[91], s[85]))
k.emit(v_mfma_16x16x32(v[104:107], v[126:129], v[26:29], v[104:107], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[108:111], v[126:129], v[30:33], v[108:111], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[170:173], v[17], v[0], v[0], 0, 64, 3))
k.emit(v_mfma_16x16x32(v[112:115], v[126:129], v[34:37], v[112:115], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_cmp_eq_u32(s[8], s[78]))
k.emit(v_perm_b32_e64(v[60], v[103], v[99], s[85]))
k.emit(v_perm_b32_e64(v[61], v[111], v[107], s[85]))
k.emit(v_mfma_16x16x32(v[116:119], v[126:129], v[38:41], v[116:119], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_cselect_b32(s[88], s[81], s[84]))
k.emit(v_perm_b32_e64(v[62], v[87], v[83], s[86]))
k.emit(v_perm_b32_e64(v[63], v[95], v[91], s[86]))
k.emit(v_mfma_16x16x32(v[120:123], v[126:129], v[42:45], v[120:123], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_cselect_b32(s[89], s[82], 0))
k.emit(ds_read_b128(v[174:177], v[17], v[0], v[0], 0, 192, 3))
k.emit(v_mfma_16x16x32(v[124:127], v[126:129], v[46:49], v[124:127], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_add_u32(s[72], s[72], s[88]))
k.emit(v_perm_b32_e64(v[64], v[103], v[99], s[86]))
k.emit(v_perm_b32_e64(v[65], v[111], v[107], s[86]))
k.emit(v_mfma_16x16x32(v[128:131], v[130:133], v[18:21], v[128:131], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_addc_u32(s[73], s[73], s[89]))
k.emit(v_perm_b32_e64(v[66], v[88], v[84], s[85]))
k.emit(v_perm_b32_e64(v[67], v[96], v[92], s[85]))
k.emit(v_mfma_16x16x32(v[132:135], v[130:133], v[22:25], v[132:135], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_sub_u32(s[76], s[76], s[88]))
k.emit(v_perm_b32_e64(v[68], v[104], v[100], s[85]))
k.emit(v_perm_b32_e64(v[69], v[112], v[108], s[85]))
k.emit(v_mfma_16x16x32(v[136:139], v[130:133], v[26:29], v[136:139], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_subb_u32(s[77], s[77], s[89]))
k.emit(v_perm_b32_e64(v[70], v[88], v[84], s[86]))
k.emit(v_perm_b32_e64(v[71], v[96], v[92], s[86]))
k.emit(v_mfma_16x16x32(v[140:143], v[130:133], v[30:33], v[140:143], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_cmp_eq_u32(s[77], 0))
k.emit(v_perm_b32_e64(v[72], v[104], v[100], s[86]))
k.emit(v_perm_b32_e64(v[73], v[112], v[108], s[86]))
k.emit(v_mfma_16x16x32(v[144:147], v[130:133], v[34:37], v[144:147], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.waitcnt(lgkm=0)
k.emit(s_cselect_b32(s[74], s[76], -1))
k.emit(v_perm_b32_e64(v[74], v[89], v[85], s[85]))
k.emit(v_perm_b32_e64(v[75], v[97], v[93], s[85]))
k.emit(v_mfma_16x16x32(v[148:151], v[130:133], v[38:41], v[148:151], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[76], v[105], v[101], s[85]))
k.emit(v_perm_b32_e64(v[77], v[113], v[109], s[85]))
k.emit(v_mfma_16x16x32(v[152:155], v[130:133], v[42:45], v[152:155], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[78], v[89], v[85], s[86]))
k.emit(v_perm_b32_e64(v[79], v[97], v[93], s[86]))
k.emit(v_mfma_16x16x32(v[156:159], v[130:133], v[46:49], v[156:159], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[80], v[105], v[101], s[86]))
k.emit(v_perm_b32_e64(v[81], v[113], v[109], s[86]))
k.emit(v_mfma_16x16x32(v[160:163], v[134:137], v[18:21], v[160:163], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[164:167], v[134:137], v[22:25], v[164:167], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[168:171], v[134:137], v[26:29], v[168:171], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[172:175], v[134:137], v[30:33], v[172:175], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[176:179], v[134:137], v[34:37], v[176:179], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_barrier())
k.emit(v_mfma_16x16x32(v[180:183], v[134:137], v[38:41], v[180:183], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[184:187], v[134:137], v[42:45], v[184:187], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[188:191], v[134:137], v[46:49], v[188:191], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[192:195], v[138:141], v[18:21], v[192:195], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[196:199], v[138:141], v[22:25], v[196:199], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[200:203], v[138:141], v[26:29], v[200:203], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[204:207], v[138:141], v[30:33], v[204:207], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[208:211], v[138:141], v[34:37], v[208:211], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[212:215], v[138:141], v[38:41], v[212:215], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[216:219], v[138:141], v[42:45], v[216:219], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[220:223], v[138:141], v[46:49], v[220:223], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[224:227], v[142:145], v[18:21], v[224:227], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.waitcnt(vm=17)
k.emit(v_mfma_16x16x32(v[228:231], v[142:145], v[22:25], v[228:231], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[232:235], v[142:145], v[26:29], v[232:235], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[236:239], v[142:145], v[30:33], v[236:239], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_barrier())
k.emit(v_xor_b32_e32(v[16], v[178], v[16]))
k.emit(v_xor_b32_e32(v[17], v[179], v[17]))
k.emit(v_mfma_16x16x32(v[240:243], v[142:145], v[34:37], v[240:243], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[82:85], v[16]))
k.emit(ds_read_b128(v[86:89], v[16], v[0], v[0], 0, 0, 2))
k.emit(v_mfma_16x16x32(v[244:247], v[142:145], v[38:41], v[244:247], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[248:251], v[142:145], v[42:45], v[248:251], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[90:93], v[16], v[0], v[0], 0, 0, 4))
k.emit(ds_read_b128(v[94:97], v[16], v[0], v[0], 0, 0, 6))
k.emit(v_mfma_16x16x32(v[252:255], v[142:145], v[46:49], v[252:255], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[0:3], v[146:149], v[50:53], v[0:3], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[98:101], v[16], v[0], v[0], 0, 0, 8))
k.emit(ds_read_b128(v[102:105], v[16], v[0], v[0], 0, 0, 10))
k.emit(v_mfma_16x16x32(v[4:7], v[146:149], v[54:57], v[4:7], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[8:11], v[146:149], v[58:61], v[8:11], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.waitcnt(vm=9)
k.emit(ds_read_b128(v[106:109], v[16], v[0], v[0], 0, 0, 12))
k.emit(ds_read_b128(v[110:113], v[16], v[0], v[0], 0, 0, 14))
k.emit(v_mfma_16x16x32(v[12:15], v[146:149], v[62:65], v[12:15], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[16:19], v[146:149], v[66:69], v[16:19], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_barrier())
k.emit(ds_read_b128(v[114:117], v[17]))
k.emit(v_mfma_16x16x32(v[20:23], v[146:149], v[70:73], v[20:23], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[24:27], v[146:149], v[74:77], v[24:27], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[118:121], v[17], v[0], v[0], 0, 128))
k.emit(v_mfma_16x16x32(v[28:31], v[146:149], v[78:81], v[28:31], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[32:35], v[150:153], v[50:53], v[32:35], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[122:125], v[17], v[0], v[0], 0, 0, 1))
k.emit(v_mfma_16x16x32(v[36:39], v[150:153], v[54:57], v[36:39], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.waitcnt(lgkm=4)
k.emit(v_mfma_16x16x32(v[40:43], v[150:153], v[58:61], v[40:43], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[126:129], v[17], v[0], v[0], 0, 128, 1))
k.emit(v_mfma_16x16x32(v[44:47], v[150:153], v[62:65], v[44:47], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[18], v[86], v[82], s[85]))
k.emit(v_perm_b32_e64(v[19], v[94], v[90], s[85]))
k.emit(v_mfma_16x16x32(v[48:51], v[150:153], v[66:69], v[48:51], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[130:133], v[17], v[0], v[0], 0, 0, 2))
k.emit(v_mfma_16x16x32(v[52:55], v[150:153], v[70:73], v[52:55], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[20], v[102], v[98], s[85]))
k.emit(v_perm_b32_e64(v[21], v[110], v[106], s[85]))
k.emit(v_mfma_16x16x32(v[56:59], v[150:153], v[74:77], v[56:59], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[134:137], v[17], v[0], v[0], 0, 128, 2))
k.emit(v_mfma_16x16x32(v[60:63], v[150:153], v[78:81], v[60:63], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[22], v[86], v[82], s[86]))
k.emit(v_perm_b32_e64(v[23], v[94], v[90], s[86]))
k.emit(v_mfma_16x16x32(v[64:67], v[154:157], v[50:53], v[64:67], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[138:141], v[17], v[0], v[0], 0, 0, 3))
k.emit(v_mfma_16x16x32(v[68:71], v[154:157], v[54:57], v[68:71], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[24], v[102], v[98], s[86]))
k.emit(v_perm_b32_e64(v[25], v[110], v[106], s[86]))
k.emit(v_mfma_16x16x32(v[72:75], v[154:157], v[58:61], v[72:75], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[142:145], v[17], v[0], v[0], 0, 128, 3))
k.emit(v_mfma_16x16x32(v[76:79], v[154:157], v[62:65], v[76:79], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[26], v[87], v[83], s[85]))
k.emit(v_perm_b32_e64(v[27], v[95], v[91], s[85]))
k.emit(v_mfma_16x16x32(v[80:83], v[154:157], v[66:69], v[80:83], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[28], v[103], v[99], s[85]))
k.emit(v_perm_b32_e64(v[29], v[111], v[107], s[85]))
k.emit(v_mfma_16x16x32(v[84:87], v[154:157], v[70:73], v[84:87], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[30], v[87], v[83], s[86]))
k.emit(v_perm_b32_e64(v[31], v[95], v[91], s[86]))
k.emit(v_mfma_16x16x32(v[88:91], v[154:157], v[74:77], v[88:91], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[32], v[103], v[99], s[86]))
k.emit(v_perm_b32_e64(v[33], v[111], v[107], s[86]))
k.emit(v_mfma_16x16x32(v[92:95], v[154:157], v[78:81], v[92:95], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[96:99], v[158:161], v[50:53], v[96:99], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[34], v[88], v[84], s[85]))
k.emit(v_perm_b32_e64(v[35], v[96], v[92], s[85]))
k.emit(v_mfma_16x16x32(v[100:103], v[158:161], v[54:57], v[100:103], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[36], v[104], v[100], s[85]))
k.emit(v_perm_b32_e64(v[37], v[112], v[108], s[85]))
k.emit(v_mfma_16x16x32(v[104:107], v[158:161], v[58:61], v[104:107], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[108:111], v[158:161], v[62:65], v[108:111], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[38], v[88], v[84], s[86]))
k.emit(v_perm_b32_e64(v[39], v[96], v[92], s[86]))
k.emit(v_mfma_16x16x32(v[112:115], v[158:161], v[66:69], v[112:115], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[40], v[104], v[100], s[86]))
k.emit(v_perm_b32_e64(v[41], v[112], v[108], s[86]))
k.emit(v_mfma_16x16x32(v[116:119], v[158:161], v[70:73], v[116:119], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[120:123], v[158:161], v[74:77], v[120:123], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[42], v[89], v[85], s[85]))
k.emit(v_perm_b32_e64(v[43], v[97], v[93], s[85]))
k.emit(v_mfma_16x16x32(v[124:127], v[158:161], v[78:81], v[124:127], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[44], v[105], v[101], s[85]))
k.emit(v_perm_b32_e64(v[45], v[113], v[109], s[85]))
k.emit(v_mfma_16x16x32(v[128:131], v[162:165], v[50:53], v[128:131], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[132:135], v[162:165], v[54:57], v[132:135], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[46], v[89], v[85], s[86]))
k.emit(v_perm_b32_e64(v[47], v[97], v[93], s[86]))
k.emit(v_mfma_16x16x32(v[136:139], v[162:165], v[58:61], v[136:139], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[48], v[105], v[101], s[86]))
k.emit(v_perm_b32_e64(v[49], v[113], v[109], s[86]))
k.emit(v_mfma_16x16x32(v[140:143], v[162:165], v[62:65], v[140:143], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[144:147], v[162:165], v[66:69], v[144:147], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[148:151], v[162:165], v[70:73], v[148:151], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[152:155], v[162:165], v[74:77], v[152:155], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[156:159], v[162:165], v[78:81], v[156:159], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[160:163], v[166:169], v[50:53], v[160:163], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[164:167], v[166:169], v[54:57], v[164:167], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[168:171], v[166:169], v[58:61], v[168:171], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[172:175], v[166:169], v[62:65], v[172:175], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[176:179], v[166:169], v[66:69], v[176:179], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[180:183], v[166:169], v[70:73], v[180:183], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[184:187], v[166:169], v[74:77], v[184:187], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[188:191], v[166:169], v[78:81], v[188:191], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[192:195], v[170:173], v[50:53], v[192:195], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[196:199], v[170:173], v[54:57], v[196:199], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[200:203], v[170:173], v[58:61], v[200:203], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[204:207], v[170:173], v[62:65], v[204:207], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[208:211], v[170:173], v[66:69], v[208:211], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[212:215], v[170:173], v[70:73], v[212:215], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[216:219], v[170:173], v[74:77], v[216:219], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[220:223], v[170:173], v[78:81], v[220:223], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[224:227], v[174:177], v[50:53], v[224:227], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[228:231], v[174:177], v[54:57], v[228:231], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[232:235], v[174:177], v[58:61], v[232:235], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[236:239], v[174:177], v[62:65], v[236:239], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[240:243], v[174:177], v[66:69], v[240:243], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[244:247], v[174:177], v[70:73], v[244:247], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[248:251], v[174:177], v[74:77], v[248:251], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[252:255], v[174:177], v[78:81], v[252:255], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.label('toPGR1')
k.emit(s_waitcnt())
k.emit(s_barrier())
k.emit(v_mfma_16x16x32(v[0:3], v[114:117], v[18:21], v[0:3], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[82:85], v[16], v[0], v[0], 0, 0, 64))
k.emit(ds_read_b128(v[86:89], v[16], v[0], v[0], 0, 0, 66))
k.emit(v_mfma_16x16x32(v[4:7], v[114:117], v[22:25], v[4:7], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[8:11], v[114:117], v[26:29], v[8:11], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[90:93], v[16], v[0], v[0], 0, 0, 68))
k.emit(ds_read_b128(v[94:97], v[16], v[0], v[0], 0, 0, 70))
k.emit(v_mfma_16x16x32(v[12:15], v[114:117], v[30:33], v[12:15], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[16:19], v[114:117], v[34:37], v[16:19], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[98:101], v[16], v[0], v[0], 0, 0, 72))
k.emit(ds_read_b128(v[102:105], v[16], v[0], v[0], 0, 0, 74))
k.emit(v_mfma_16x16x32(v[20:23], v[114:117], v[38:41], v[20:23], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[24:27], v[114:117], v[42:45], v[24:27], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[106:109], v[16], v[0], v[0], 0, 0, 76))
k.emit(ds_read_b128(v[110:113], v[16], v[0], v[0], 0, 0, 78))
k.emit(v_mfma_16x16x32(v[28:31], v[114:117], v[46:49], v[28:31], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[32:35], v[118:121], v[18:21], v[32:35], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.waitcnt(lgkm=4)
k.emit(v_perm_b32_e64(v[50], v[86], v[82], s[85]))
k.emit(v_perm_b32_e64(v[51], v[94], v[90], s[85]))
k.emit(v_mfma_16x16x32(v[36:39], v[118:121], v[22:25], v[36:39], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[146:149], v[17], v[0], v[0], 0, 64))
k.emit(v_mfma_16x16x32(v[40:43], v[118:121], v[26:29], v[40:43], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[44:47], v[118:121], v[30:33], v[44:47], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[150:153], v[17], v[0], v[0], 0, 192))
k.emit(v_mfma_16x16x32(v[48:51], v[118:121], v[34:37], v[48:51], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.waitcnt(lgkm=1)
k.emit(v_mfma_16x16x32(v[52:55], v[118:121], v[38:41], v[52:55], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_barrier())
k.emit(v_mfma_16x16x32(v[56:59], v[118:121], v[42:45], v[56:59], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[60:63], v[118:121], v[46:49], v[60:63], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[154:157], v[17], v[0], v[0], 0, 64, 1))
k.emit(v_mfma_16x16x32(v[64:67], v[122:125], v[18:21], v[64:67], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[52], v[102], v[98], s[85]))
k.emit(v_perm_b32_e64(v[53], v[110], v[106], s[85]))
k.emit(v_mfma_16x16x32(v[68:71], v[122:125], v[22:25], v[68:71], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[72:75], v[122:125], v[26:29], v[72:75], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[158:161], v[17], v[0], v[0], 0, 192, 1))
k.emit(v_mfma_16x16x32(v[76:79], v[122:125], v[30:33], v[76:79], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[54], v[86], v[82], s[86]))
k.emit(v_perm_b32_e64(v[55], v[94], v[90], s[86]))
k.emit(v_mfma_16x16x32(v[80:83], v[122:125], v[34:37], v[80:83], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[84:87], v[122:125], v[38:41], v[84:87], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[162:165], v[17], v[0], v[0], 0, 64, 2))
k.emit(v_mfma_16x16x32(v[88:91], v[122:125], v[42:45], v[88:91], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[56], v[102], v[98], s[86]))
k.emit(v_perm_b32_e64(v[57], v[110], v[106], s[86]))
k.emit(v_mfma_16x16x32(v[92:95], v[122:125], v[46:49], v[92:95], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[96:99], v[126:129], v[18:21], v[96:99], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[166:169], v[17], v[0], v[0], 0, 192, 2))
k.emit(v_mfma_16x16x32(v[100:103], v[126:129], v[22:25], v[100:103], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[58], v[87], v[83], s[85]))
k.emit(v_perm_b32_e64(v[59], v[95], v[91], s[85]))
k.emit(v_mfma_16x16x32(v[104:107], v[126:129], v[26:29], v[104:107], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[108:111], v[126:129], v[30:33], v[108:111], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[170:173], v[17], v[0], v[0], 0, 64, 3))
k.emit(v_mfma_16x16x32(v[112:115], v[126:129], v[34:37], v[112:115], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[60], v[103], v[99], s[85]))
k.emit(v_perm_b32_e64(v[61], v[111], v[107], s[85]))
k.emit(v_mfma_16x16x32(v[116:119], v[126:129], v[38:41], v[116:119], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[62], v[87], v[83], s[86]))
k.emit(v_perm_b32_e64(v[63], v[95], v[91], s[86]))
k.emit(v_mfma_16x16x32(v[120:123], v[126:129], v[42:45], v[120:123], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(ds_read_b128(v[174:177], v[17], v[0], v[0], 0, 192, 3))
k.emit(v_mfma_16x16x32(v[124:127], v[126:129], v[46:49], v[124:127], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[64], v[103], v[99], s[86]))
k.emit(v_perm_b32_e64(v[65], v[111], v[107], s[86]))
k.emit(v_mfma_16x16x32(v[128:131], v[130:133], v[18:21], v[128:131], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[66], v[88], v[84], s[85]))
k.emit(v_perm_b32_e64(v[67], v[96], v[92], s[85]))
k.emit(v_mfma_16x16x32(v[132:135], v[130:133], v[22:25], v[132:135], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[68], v[104], v[100], s[85]))
k.emit(v_perm_b32_e64(v[69], v[112], v[108], s[85]))
k.emit(v_mfma_16x16x32(v[136:139], v[130:133], v[26:29], v[136:139], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[70], v[88], v[84], s[86]))
k.emit(v_perm_b32_e64(v[71], v[96], v[92], s[86]))
k.emit(v_mfma_16x16x32(v[140:143], v[130:133], v[30:33], v[140:143], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[72], v[104], v[100], s[86]))
k.emit(v_perm_b32_e64(v[73], v[112], v[108], s[86]))
k.emit(v_mfma_16x16x32(v[144:147], v[130:133], v[34:37], v[144:147], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.waitcnt(lgkm=0)
k.emit(v_perm_b32_e64(v[74], v[89], v[85], s[85]))
k.emit(v_perm_b32_e64(v[75], v[97], v[93], s[85]))
k.emit(v_mfma_16x16x32(v[148:151], v[130:133], v[38:41], v[148:151], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[76], v[105], v[101], s[85]))
k.emit(v_perm_b32_e64(v[77], v[113], v[109], s[85]))
k.emit(v_mfma_16x16x32(v[152:155], v[130:133], v[42:45], v[152:155], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[78], v[89], v[85], s[86]))
k.emit(v_perm_b32_e64(v[79], v[97], v[93], s[86]))
k.emit(v_mfma_16x16x32(v[156:159], v[130:133], v[46:49], v[156:159], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[80], v[105], v[101], s[86]))
k.emit(v_perm_b32_e64(v[81], v[113], v[109], s[86]))
k.emit(v_mfma_16x16x32(v[160:163], v[134:137], v[18:21], v[160:163], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[164:167], v[134:137], v[22:25], v[164:167], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[168:171], v[134:137], v[26:29], v[168:171], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[172:175], v[134:137], v[30:33], v[172:175], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[176:179], v[134:137], v[34:37], v[176:179], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_barrier())
k.emit(v_mfma_16x16x32(v[180:183], v[134:137], v[38:41], v[180:183], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[184:187], v[134:137], v[42:45], v[184:187], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[188:191], v[134:137], v[46:49], v[188:191], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[192:195], v[138:141], v[18:21], v[192:195], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[196:199], v[138:141], v[22:25], v[196:199], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[200:203], v[138:141], v[26:29], v[200:203], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[204:207], v[138:141], v[30:33], v[204:207], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[208:211], v[138:141], v[34:37], v[208:211], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[212:215], v[138:141], v[38:41], v[212:215], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[216:219], v[138:141], v[42:45], v[216:219], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[220:223], v[138:141], v[46:49], v[220:223], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[224:227], v[142:145], v[18:21], v[224:227], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.waitcnt(vm=17)
k.emit(v_mfma_16x16x32(v[228:231], v[142:145], v[22:25], v[228:231], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[232:235], v[142:145], v[26:29], v[232:235], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[236:239], v[142:145], v[30:33], v[236:239], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_barrier())
k.emit(v_mfma_16x16x32(v[240:243], v[142:145], v[34:37], v[240:243], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[244:247], v[142:145], v[38:41], v[244:247], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[248:251], v[142:145], v[42:45], v[248:251], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[252:255], v[142:145], v[46:49], v[252:255], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[0:3], v[146:149], v[50:53], v[0:3], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[4:7], v[146:149], v[54:57], v[4:7], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[8:11], v[146:149], v[58:61], v[8:11], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.waitcnt(vm=9)
k.emit(v_mfma_16x16x32(v[12:15], v[146:149], v[62:65], v[12:15], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[16:19], v[146:149], v[66:69], v[16:19], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(s_barrier())
k.emit(v_mfma_16x16x32(v[20:23], v[146:149], v[70:73], v[20:23], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[24:27], v[146:149], v[74:77], v[24:27], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[28:31], v[146:149], v[78:81], v[28:31], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[32:35], v[150:153], v[50:53], v[32:35], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[36:39], v[150:153], v[54:57], v[36:39], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.waitcnt(lgkm=4)
k.emit(v_mfma_16x16x32(v[40:43], v[150:153], v[58:61], v[40:43], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[44:47], v[150:153], v[62:65], v[44:47], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[18], v[86], v[82], s[85]))
k.emit(v_perm_b32_e64(v[19], v[94], v[90], s[85]))
k.emit(v_mfma_16x16x32(v[48:51], v[150:153], v[66:69], v[48:51], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[52:55], v[150:153], v[70:73], v[52:55], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[20], v[102], v[98], s[85]))
k.emit(v_perm_b32_e64(v[21], v[110], v[106], s[85]))
k.emit(v_mfma_16x16x32(v[56:59], v[150:153], v[74:77], v[56:59], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[60:63], v[150:153], v[78:81], v[60:63], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[22], v[86], v[82], s[86]))
k.emit(v_perm_b32_e64(v[23], v[94], v[90], s[86]))
k.emit(v_mfma_16x16x32(v[64:67], v[154:157], v[50:53], v[64:67], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[68:71], v[154:157], v[54:57], v[68:71], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[24], v[102], v[98], s[86]))
k.emit(v_perm_b32_e64(v[25], v[110], v[106], s[86]))
k.emit(v_mfma_16x16x32(v[72:75], v[154:157], v[58:61], v[72:75], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[76:79], v[154:157], v[62:65], v[76:79], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[26], v[87], v[83], s[85]))
k.emit(v_perm_b32_e64(v[27], v[95], v[91], s[85]))
k.emit(v_mfma_16x16x32(v[80:83], v[154:157], v[66:69], v[80:83], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[28], v[103], v[99], s[85]))
k.emit(v_perm_b32_e64(v[29], v[111], v[107], s[85]))
k.emit(v_mfma_16x16x32(v[84:87], v[154:157], v[70:73], v[84:87], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[30], v[87], v[83], s[86]))
k.emit(v_perm_b32_e64(v[31], v[95], v[91], s[86]))
k.emit(v_mfma_16x16x32(v[88:91], v[154:157], v[74:77], v[88:91], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[32], v[103], v[99], s[86]))
k.emit(v_perm_b32_e64(v[33], v[111], v[107], s[86]))
k.emit(v_mfma_16x16x32(v[92:95], v[154:157], v[78:81], v[92:95], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[96:99], v[158:161], v[50:53], v[96:99], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[34], v[88], v[84], s[85]))
k.emit(v_perm_b32_e64(v[35], v[96], v[92], s[85]))
k.emit(v_mfma_16x16x32(v[100:103], v[158:161], v[54:57], v[100:103], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[36], v[104], v[100], s[85]))
k.emit(v_perm_b32_e64(v[37], v[112], v[108], s[85]))
k.emit(v_mfma_16x16x32(v[104:107], v[158:161], v[58:61], v[104:107], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[108:111], v[158:161], v[62:65], v[108:111], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[38], v[88], v[84], s[86]))
k.emit(v_perm_b32_e64(v[39], v[96], v[92], s[86]))
k.emit(v_mfma_16x16x32(v[112:115], v[158:161], v[66:69], v[112:115], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[40], v[104], v[100], s[86]))
k.emit(v_perm_b32_e64(v[41], v[112], v[108], s[86]))
k.emit(v_mfma_16x16x32(v[116:119], v[158:161], v[70:73], v[116:119], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[120:123], v[158:161], v[74:77], v[120:123], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[42], v[89], v[85], s[85]))
k.emit(v_perm_b32_e64(v[43], v[97], v[93], s[85]))
k.emit(v_mfma_16x16x32(v[124:127], v[158:161], v[78:81], v[124:127], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[44], v[105], v[101], s[85]))
k.emit(v_perm_b32_e64(v[45], v[113], v[109], s[85]))
k.emit(v_mfma_16x16x32(v[128:131], v[162:165], v[50:53], v[128:131], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[132:135], v[162:165], v[54:57], v[132:135], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[46], v[89], v[85], s[86]))
k.emit(v_perm_b32_e64(v[47], v[97], v[93], s[86]))
k.emit(v_mfma_16x16x32(v[136:139], v[162:165], v[58:61], v[136:139], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_perm_b32_e64(v[48], v[105], v[101], s[86]))
k.emit(v_perm_b32_e64(v[49], v[113], v[109], s[86]))
k.emit(v_mfma_16x16x32(v[140:143], v[162:165], v[62:65], v[140:143], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[144:147], v[162:165], v[66:69], v[144:147], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[148:151], v[162:165], v[70:73], v[148:151], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[152:155], v[162:165], v[74:77], v[152:155], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[156:159], v[162:165], v[78:81], v[156:159], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[160:163], v[166:169], v[50:53], v[160:163], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[164:167], v[166:169], v[54:57], v[164:167], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[168:171], v[166:169], v[58:61], v[168:171], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[172:175], v[166:169], v[62:65], v[172:175], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[176:179], v[166:169], v[66:69], v[176:179], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[180:183], v[166:169], v[70:73], v[180:183], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[184:187], v[166:169], v[74:77], v[184:187], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[188:191], v[166:169], v[78:81], v[188:191], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[192:195], v[170:173], v[50:53], v[192:195], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[196:199], v[170:173], v[54:57], v[196:199], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[200:203], v[170:173], v[58:61], v[200:203], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[204:207], v[170:173], v[62:65], v[204:207], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[208:211], v[170:173], v[66:69], v[208:211], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[212:215], v[170:173], v[70:73], v[212:215], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[216:219], v[170:173], v[74:77], v[216:219], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[220:223], v[170:173], v[78:81], v[220:223], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[224:227], v[174:177], v[50:53], v[224:227], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[228:231], v[174:177], v[54:57], v[228:231], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[232:235], v[174:177], v[58:61], v[232:235], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[236:239], v[174:177], v[62:65], v[236:239], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[240:243], v[174:177], v[66:69], v[240:243], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[244:247], v[174:177], v[70:73], v[244:247], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[248:251], v[174:177], v[74:77], v[248:251], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.emit(v_mfma_16x16x32(v[252:255], v[174:177], v[78:81], v[252:255], 0, 0, 1, 0, 0, 0, 0, 0, 1))
k.label('toPGR1end_OrdNLL')
k.emit(s_xor_b32(s[87], s[55], s[53]))
k.emit(s_min_u32(s[53], s[53], s[87]))
k.emit(s_xor_b32(s[87], s[56], s[54]))
k.emit(s_min_u32(s[54], s[54], s[87]))
k.emit(s_and_b32(s[8], 63, s[23]))
k.emit(s_cmp_lt_u32(s[61], s[46]))
k.emit(s_cmov_b32(s[8], 0))
k.emit(s_cmp_eq_u32(s[8], 0))
k.emit(s_mov_b32(s[9], 0))
k.emit(s_cbranch_scc1(), target='SkipTailLoopL')
k.emit(s_sub_i32(s[88], 3, s[78]))
k.emit(s_cmp_ge_i32(s[88], 0))
k.emit(s_cbranch_scc0(), target='Negative_LHNOKZ26V2FLOONQ')
k.emit(s_mul_hi_u32(s[89], s[88], s[83]))
k.emit(s_mul_i32(s[88], s[88], s[83]))
k.emit(s_branch(), target='MultiplyDone_L9DK3KJL31S8WWGN')
k.label('Negative_LHNOKZ26V2FLOONQ')
k.emit(s_abs_i32(s[88], s[88]))
k.emit(s_mul_hi_u32(s[89], s[88], s[83]))
k.emit(s_mul_i32(s[88], s[88], s[83]))
k.emit(s_xor_b32(s[88], s[88], -1))
k.emit(s_xor_b32(s[89], s[89], -1))
k.emit(s_add_u32(s[88], s[88], 1))
k.emit(s_addc_u32(s[89], s[89], 0))
k.label('MultiplyDone_L9DK3KJL31S8WWGN')
k.emit(s_sub_u32(s[88], s[88], s[79]))
k.emit(s_subb_u32(s[89], s[89], s[80]))
k.emit(s_add_u32(s[68], s[68], s[88]))
k.emit(s_addc_u32(s[69], s[69], s[89]))
k.emit(s_sub_u32(s[62], s[62], s[88]))
k.emit(s_subb_u32(s[63], s[63], s[89]))
k.emit(s_cmp_eq_u32(s[63], 0))
k.emit(s_cselect_b32(s[70], s[62], -1))
k.emit(s_sub_i32(s[88], 3, s[78]))
k.emit(s_cmp_ge_i32(s[88], 0))
k.emit(s_cbranch_scc0(), target='Negative_3U2TZUPK3AVX5ODG')
k.emit(s_mul_hi_u32(s[89], s[88], s[84]))
k.emit(s_mul_i32(s[88], s[88], s[84]))
k.emit(s_branch(), target='MultiplyDone_NW6XNGOG77EAT0NM')
k.label('Negative_3U2TZUPK3AVX5ODG')
k.emit(s_abs_i32(s[88], s[88]))
k.emit(s_mul_hi_u32(s[89], s[88], s[84]))
k.emit(s_mul_i32(s[88], s[88], s[84]))
k.emit(s_xor_b32(s[88], s[88], -1))
k.emit(s_xor_b32(s[89], s[89], -1))
k.emit(s_add_u32(s[88], s[88], 1))
k.emit(s_addc_u32(s[89], s[89], 0))
k.label('MultiplyDone_NW6XNGOG77EAT0NM')
k.emit(s_sub_u32(s[88], s[88], s[81]))
k.emit(s_subb_u32(s[89], s[89], s[82]))
k.emit(s_add_u32(s[72], s[72], s[88]))
k.emit(s_addc_u32(s[73], s[73], s[89]))
k.emit(s_sub_u32(s[76], s[76], s[88]))
k.emit(s_subb_u32(s[77], s[77], s[89]))
k.emit(s_cmp_eq_u32(s[77], 0))
k.emit(s_cselect_b32(s[74], s[76], -1))
k.emit(s_mov_b32(M0, s[53]))
k.waitcnt(lgkm=0)
k.emit(s_barrier())
buffer_load_d16_32(k, 18, 84, 0, 68, 1)
k.emit(s_mov_b32(M0, 133120))
k.emit(s_mov_b32(M0, s[54]))
buffer_load_d16_32(k, 50, 84, 8, 72, 0, 1, 0, 1)
k.emit(s_mov_b32(M0, 133120))
k.waitcnt(vm=0)
k.emit(s_barrier())
k.emit(v_and_b32_e32(v[82], 63, v[180]))
k.emit(v_lshlrev_b32_e32(v[82], 4, v[82]))
k.emit(v_add_u32_e32(v[82], s[53], v[82]))
k.emit(v_and_b32_e32(v[83], 63, v[180]))
k.emit(v_lshlrev_b32_e32(v[83], 4, v[83]))
k.emit(v_add_u32_e32(v[83], s[54], v[83]))
# A-tile LDS writes: 8 blocks at stride 16
for i in range(8):
d = v[18+i*4:21+i*4]
k.emit(ds_write_b128(v[0], v[82], d) if i == 0 else ds_write_b128(v[0], v[82], d, v[0], 0, 0, i*16))
# B-tile LDS writes: 8 blocks with bank-striped offsets
b_offsets = [(), (128, 16), (0, 33), (128, 49), (0, 66), (128, 82), (0, 99), (128, 115)]
for i, offs in enumerate(b_offsets):
d = v[50+i*4:53+i*4]
k.emit(ds_write_b128(v[0], v[83], d) if i == 0 else ds_write_b128(v[0], v[83], d, v[0], 0, *offs))
k.waitcnt(lgkm=0)
k.emit(s_barrier())
k.emit(v_xor_b32_e32(v[181], v[178], v[16]))
k.emit(v_min_i32_e32(v[16], v[16], v[181]))
k.emit(v_xor_b32_e32(v[181], v[179], v[17]))
k.emit(v_min_i32_e32(v[17], v[17], v[181]))
k.label('TailLoopBeginL')
ds_read_ab_16(k)
k.emit(s_mov_b32(s[87], 16384))
k.emit(v_add_co_u32_e32(v[16], s[87], v[16]))
k.emit(s_mov_b32(s[87], 64))
k.emit(v_add_co_u32_e32(v[17], s[87], v[17]))
k.waitcnt(lgkm=0)
perm_b32_32(k, 18)
zero_out_mask_32(k, 18)
zero_out_mask_32(k, 114)
k.emit(s_and_b32(s[87], s[23], 7))
k.emit(s_cmp_eq_u32(s[87], 0))
k.emit(s_cbranch_scc1(), target='TailLoop_SkipZeroOutMask_0FMPG10PI1CDGWZ9')
k.emit(s_and_b32(s[87], s[8], 7))
k.emit(s_sub_u32(s[87], 8, s[87]))
k.emit(s_lshl_b32(s[87], s[87], 4))
shift_mask(k, list(range(18, 50, 4)) + list(range(114, 146, 4)))
k.label('TailLoop_SkipZeroOutMask_0FMPG10PI1CDGWZ9')
k.emit(s_nop(1))
mfma_64(k, v_mfma_16x16x32, 114, 18)
k.emit(s_sub_i32(s[8], s[8], 32))
k.emit(s_add_u32(s[9], s[9], 32))
k.emit(s_cmp_le_i32(s[8], 0))
k.emit(s_cbranch_scc1(), target='TailLoopEndL')
ds_read_ab_16(k, 146)
k.emit(s_mov_b32(s[87], 16384))
k.emit(v_add_co_u32_e32(v[16], s[87], v[16]))
k.emit(s_mov_b32(s[87], 64))
k.emit(v_add_co_u32_e32(v[17], s[87], v[17]))
k.waitcnt(lgkm=0)
perm_b32_32(k, 50)
zero_out_mask_32(k, 50)
zero_out_mask_32(k, 146)
k.emit(s_and_b32(s[87], s[23], 7))
k.emit(s_cmp_eq_u32(s[87], 0))
k.emit(s_cbranch_scc1(), target='TailLoop_SkipZeroOutMask_YVWB1RHZO1Z7SCZY')
k.emit(s_and_b32(s[87], s[8], 7))
k.emit(s_sub_u32(s[87], 8, s[87]))
k.emit(s_lshl_b32(s[87], s[87], 4))
shift_mask(k, list(range(50, 82, 4)) + list(range(146, 178, 4)))
k.label('TailLoop_SkipZeroOutMask_YVWB1RHZO1Z7SCZY')
k.emit(s_nop(1))
mfma_64(k, v_mfma_16x16x32, 146, 50)
k.emit(s_sub_i32(s[8], s[8], 32))
k.emit(s_add_u32(s[9], s[9], 32))
k.emit(s_cmp_le_i32(s[8], 0))
k.emit(s_cbranch_scc0(), target='TailLoopBeginL')
k.label('TailLoopEndL')
k.emit(s_mov_b32(s[87], 512))
k.emit(s_mul_i32(s[87], s[9], s[87]))
k.emit(v_sub_u32_e64(v[16], v[16], s[87]))
k.emit(s_mov_b32(s[87], 2))
k.emit(s_mul_i32(s[87], s[9], s[87]))
k.emit(v_sub_u32_e64(v[17], v[17], s[87]))
k.label('SkipTailLoopL')
k.emit(s_setprio())
k.emit(s_mov_b64(s[68:69], 0))
k.emit(s_mov_b32(s[72], 0))
k.emit(v_mov_b32_e32(v[21], s[2]))
k.emit(v_mul_i32_i24_e32(v[21], 4294967040, v[21]))
k.emit(v_add_co_u32_e32(v[21], s[20], v[21]))
k.emit(v_mov_b32_e32(v[22], 256))
k.emit(v_cmp_lt_u32_e64(s[8:9], v[21], v[22]))
k.emit(v_cndmask_b32_e64(v[21], v[22], v[21], s[8:9]))
k.emit(v_lshrrev_b32_e32(v[23], 6, v[180]))
k.emit(v_and_b32_e32(v[23], 1, v[23]))
k.emit(v_lshrrev_b32_e32(v[24], 7, v[21]))
k.emit(v_and_b32_e32(v[24], 1, v[24]))
k.emit(v_cmp_eq_u32_e64(s[8:9], v[24], v[23]))
k.emit(v_cndmask_b32_e64(v[21], v[22], v[21], s[8:9]))
k.emit(v_lshrrev_b32_e32(v[22], 7, v[21]))
k.emit(v_lshlrev_b32_e32(v[24], 0, v[23]))
k.emit(v_sub_u32_e32(v[22], v[22], v[24]))
k.emit(v_lshrrev_b32_e32(v[24], 3, v[21]))
k.emit(v_lshrrev_b32_e32(v[25], 0, v[180]))
k.emit(v_and_b32_e32(v[25], 15, v[25]))
k.emit(v_lshlrev_b32_e32(v[25], 3, v[25]))
k.emit(v_lshrrev_b32_e32(v[25], 3, v[25]))
k.emit(v_lshlrev_b32_e32(v[23], 4, v[23]))
k.emit(v_add_co_u32_e32(v[25], v[23], v[25]))
k.emit(v_sub_u32_e32(v[24], v[24], v[25]))
k.emit(v_and_b32_e32(v[23], 7, v[21]))
k.emit(v_lshrrev_b32_e32(v[23], 3, v[23]))
k.emit(v_and_b32_e32(v[25], 7, v[21]))
# GLVW dispatch table: branch to shift handler based on vector width
for glvw in range(1, 8):
k.emit(v_cmp_eq_u32_e64(VCC, v[25], glvw))
k.emit(s_cbranch_vccnz(), target=f'ShiftVectorComponents0_GLVW{glvw}')
k.emit(s_branch(), target='ShiftVectorComponents0_GLVW0')
# GLVW → BM0 fallthrough chain
for glvw in range(1, 8):
k.label(f'ShiftVectorComponents0_GLVW{glvw}')
k.emit(v_cmp_eq_u32_e64(VCC, v[22], 0))
k.emit(s_cbranch_vccnz(), target=f'ShiftVectorComponents0_GLVW{glvw}_BM0')
# BM0 → VW0 fallthrough chain
for glvw in range(1, 8):
k.label(f'ShiftVectorComponents0_GLVW{glvw}_BM0')
k.emit(v_cmp_eq_u32_e64(VCC, v[23], 0))
k.emit(s_cbranch_vccnz(), target=f'ShiftVectorComponents0_GLVW{glvw}_BM0_VW0')
for glvw in range(1, 8):
k.label(f'ShiftVectorComponents0_GLVW{glvw}_BM0_VW0')
k.emit(s_mov_b32(s[8], 0))
k.emit(v_cmpx_eq_u32_e64(s[8:9], v[24], s[8]))
k.emit(v_and_b32_e32(v[18], 63, v[180]))
k.emit(v_lshlrev_b32_e32(v[18], 2, v[18]))
shift_vector_components(k, glvw)
k.emit(s_mov_b64(s[8:9], -1))
k.emit(s_or_saveexec_b64(VCC, s[8:9]))
if glvw < 7:
k.emit(s_branch(), target='ShiftVectorComponents0_GLVW0')
k.label('ShiftVectorComponents0_GLVW0')
k.emit(v_lshrrev_b32_e32(v[22], 6, v[180]))
k.emit(v_lshrrev_b32_e32(v[23], 1, v[22]))
k.emit(v_mul_lo_u32(v[23], 16, v[23]))
k.emit(v_and_b32_e32(v[19], 63, v[180]))
k.emit(v_lshrrev_b32_e32(v[19], 4, v[19]))
k.emit(v_lshlrev_b32_e32(v[19], 2, v[19]))
k.emit(v_add_lshl_u32_e64(v[19], v[23], v[19], 3))
k.emit(v_mul_lo_u32(v[20], v[19], s[38]))
k.emit(v_mul_lo_u32(v[21], v[19], s[36]))
k.emit(v_and_b32_e32(v[18], 1, v[22]))
k.emit(v_mul_lo_u32(v[18], 16, v[18]))
k.emit(v_and_b32_e32(v[23], 15, v[180]))
k.emit(v_add_lshl_u32_e64(v[18], v[23], v[18], 3))
k.emit(s_mul_i32(s[8], 256, s[2]))
k.emit(v_add_u32_e32(v[18], s[8], v[18]))
k.emit(s_mul_i32(s[8], 256, s[3]))
k.emit(v_add_u32_e32(v[19], s[8], v[19]))
k.waitcnt(lgkm=0)
k.emit(s_add_u32(s[8], s[4], 1))
k.emit(s_mul_i32(s[8], s[73], s[8]))
k.emit(s_cmp_eq_u32(s[8], 0))
k.emit(s_cselect_b32(s[8], s[20], s[8]))
k.emit(s_mov_b32(s[91], 131072))
k.emit(s_mov_b32(s[90], 0))
k.emit(s_mul_i32(s[8], 256, s[2]))
k.emit(v_add_u32_e32(v[26], s[8], v[180]))
k.emit(s_mul_i32(s[90], 4, s[90]))
k.emit(s_mul_i32(s[8], s[73], s[4]))
k.emit(v_add_u32_e32(v[24], s[8], v[26]))
k.emit(v_lshlrev_b32_e32(v[24], 2, v[24]))
k.emit(s_mul_i32(s[8], 256, s[3]))
k.emit(v_add_u32_e32(v[26], s[8], v[180]))
k.emit(buffer_load_dword(v[22], v[24], s[88:91], 0, 0, 1))
k.emit(v_lshlrev_b32_e32(v[26], 2, v[180]))
k.emit(s_barrier())
k.waitcnt(vm=0)
k.emit(ds_write_b32(v[0], v[26], v[22]))
k.emit(v_mov_b32_e32(v[23], 1.0))
k.emit(ds_write_b32(v[0], v[26], v[23], v[0], 0, 0, 4))
k.emit(s_mul_i32(s[8], 256, s[2]))
k.emit(v_add_u32_e32(v[26], s[8], v[180]))
k.emit(s_mul_i32(s[90], 2, s[90]))
k.emit(s_mul_i32(s[8], s[73], s[4]))
k.emit(v_add_u32_e32(v[24], s[8], v[26]))
k.emit(v_lshlrev_b32_e32(v[24], 1, v[24]))
k.emit(s_mul_i32(s[8], 256, s[3]))
k.emit(v_add_u32_e32(v[26], s[8], v[180]))
k.emit(buffer_load_short_d16(v[22], v[24], s[88:91], 0, 0, 1))
k.emit(v_lshlrev_b32_e32(v[26], 2, v[180]))
k.emit(s_barrier())
k.waitcnt(vm=0)
k.emit(v_cvt(v[22], SDWA, v[22], 0, 0, 0, 0, 0, 0, 6, 2, 4))
k.emit(ds_write_b32(v[0], v[26], v[22]))
k.emit(v_mov_b32_e32(v[23], 1.0))
k.emit(ds_write_b32(v[0], v[26], v[23], v[0], 0, 0, 4))
k.emit(s_and_b32(s[78], 255, s[20]))
k.emit(s_add_u32(s[79], -1, s[10]))
k.emit(s_cmp_ge_u32(s[2], s[79]))
k.emit(s_cselect_b32(s[78], s[78], 0))
k.emit(s_cmpk_gt_u32(s[78]))
k.emit(s_cbranch_scc1(), target='GW_B0_E1_M_1')
k.emit(s_and_b32(s[78], 255, s[21]))
k.emit(s_add_u32(s[79], -1, s[11]))
k.emit(s_cmp_ge_u32(s[3], s[79]))
k.emit(s_cselect_b32(s[78], s[78], 0))
k.emit(s_cmpk_gt_u32(s[78]))
k.emit(s_cbranch_scc0(), target='GW_B0_E0_1')
k.emit(s_cbranch_scc1(), target='GW_B0_E1_N_1')
k.label('GW_B0_E0_1')
k.emit(s_mul_i32(s[68], 256, s[2]))
k.emit(v_sub_u32_e64(v[37], v[18], s[68]))
k.emit(v_lshlrev_b32_e32(v[37], 2, v[37]))
k.waitcnt(lgkm=0)
k.emit(s_barrier())
# accvgpr source order: byte_off + i*4 for each byte offset, reading all 64 elements per byte
accvgpr_srcs = [byte_off + i * 4 for byte_off in range(4) for i in range(64)]
first_store = True
for batch_start in range(0, 256, 48):
batch = accvgpr_srcs[batch_start:batch_start+48]
num_blocks = len(batch) // 8
# last batch (16 reads) uses v[56:71] for scale/bias since v[56:87] are free
sb = 56 if len(batch) <= 16 else 88
if batch_start == 0:
k.emit(ds_read_b128(v[sb:sb+3], v[37]))
k.emit(ds_read_b128(v[sb+4:sb+7], v[37], v[0], v[0], 0, 16))
k.emit(ds_read_b128(v[sb+8:sb+11], v[37], v[0], v[0], 0, 0, 4))
k.emit(ds_read_b128(v[sb+12:sb+15], v[37], v[0], v[0], 0, 16, 4))
k.emit(v_add_lshl_u32_e64(v[35], v[21], v[18], 1))
else:
k.emit(s_nop())
k.emit(ds_read_b128(v[sb:sb+3], v[37]))
k.emit(ds_read_b128(v[sb+4:sb+7], v[37], v[0], v[0], 0, 16))
k.emit(ds_read_b128(v[sb+8:sb+11], v[37], v[0], v[0], 0, 0, 4))
k.emit(ds_read_b128(v[sb+12:sb+15], v[37], v[0], v[0], 0, 16, 4))
for j, src in enumerate(batch):
k.emit(v_accvgpr_read(v[40 + j], v[src]))
k.emit(v_mov_b32_e32(v[32], 4294901760))
k.emit(v_mov_b32_e32(v[33], 2147418112))
k.emit(v_mov_b32_e32(v[34], 32767))
k.waitcnt(lgkm=0)
for b in range(num_blocks):
gw_convert_and_store(k, v_cvt_pk, 40 + b * 8, 35, sb=sb, stride=not first_store)
first_store = False
k.emit(s_nop())
k.emit(s_branch(), target='GW_End_1')
k.label('GW_B0_E1_N_1')
k.emit(v_mov_b32_e32(v[30], 2147483648))
n_lds_vs = [36, 38, 104, 106, 108, 110]
n_addr_vs = [35, 37, 39, 105, 107, 109]
accvgpr_srcs = [b + i*4 for b in range(4) for i in range(64)]
for batch in range(6):
count = 16 if batch == 5 else 48
rows = 2 if batch == 5 else 6
sb = 56 if batch == 5 else 88
if batch > 0:
k.emit(s_nop())
k.emit(v_mov_b32_e32(v[30], 2147483648))
gw_m_row_inc(k)
gw_n_addr_row(k, n_lds_vs[0], n_addr_vs[0], ds_base=sb, barrier=(batch == 0))
for r in range(1, rows):
gw_m_row_inc(k)
gw_n_addr_row(k, n_lds_vs[r], n_addr_vs[r])
srcs = accvgpr_srcs[batch*48 : batch*48 + count]
for j, src in enumerate(srcs):
k.emit(v_accvgpr_read(v[40 + j], v[src]))
k.waitcnt(lgkm=0)
k.emit(v_mov_b32_e32(v[32], 4294901760))
k.emit(v_mov_b32_e32(v[33], 2147418112))
k.emit(v_mov_b32_e32(v[34], 32767))
for r in range(rows):
gw_convert_and_store(k, v_cvt_pk, 40 + r*8, n_addr_vs[r], sb=sb, stride=False)
k.emit(s_nop())
k.emit(s_branch(), target='GW_End_1')
k.label('GW_B0_E1_M_1')
k.emit(v_mov_b32_e32(v[30], 2147483648))
gw_m_addr_elem(k, 0, 78, 77, 75, 76, barrier=True)
gw_m_addr_elem(k, 1, 82, 81, 79, 80)
gw_m_addr_elem(k, 2, 86, 85, 83, 84)
gw_m_addr_elem(k, 3, 90, 89, 87, 88)
gw_m_addr_elem(k, 4, 94, 93, 91, 92)
gw_m_addr_elem(k, 5, 98, 97, 95, 96)
gw_m_addr_elem(k, 6, 102, 101, 99, 100)
gw_m_addr_elem(k, 7, 106, 105, 103, 104)
gw_m_row_inc(k)
gw_m_addr_elem(k, 0, 108, 107)
gw_m_addr_elem(k, 1, 110, 109)
gw_m_addr_elem(k, 2, 112, 111)
gw_m_addr_elem(k, 3, 114, 113)
gw_m_addr_elem(k, 4, 116, 115)
gw_m_addr_elem(k, 5, 118, 117)
gw_m_addr_elem(k, 6, 120, 119)
gw_m_addr_elem(k, 7, 122, 121)
gw_m_row_inc(k)
gw_m_addr_elem(k, 0, 124, 123)
gw_m_addr_elem(k, 1, 126, 125)
gw_m_addr_elem(k, 2, 128, 127)
gw_m_addr_elem(k, 3, 130, 129)
gw_m_addr_elem(k, 4, 132, 131)
gw_m_addr_elem(k, 5, 134, 133)
gw_m_addr_elem(k, 6, 136, 135)
gw_m_addr_elem(k, 7, 138, 137)
gw_m_row_inc(k)
gw_m_addr_elem(k, 0, 140, 139)
gw_m_addr_elem(k, 1, 142, 141)
gw_m_addr_elem(k, 2, 144, 143)
gw_m_addr_elem(k, 3, 146, 145)
gw_m_addr_elem(k, 4, 148, 147)
gw_m_addr_elem(k, 5, 150, 149)
gw_m_addr_elem(k, 6, 152, 151)
gw_m_addr_elem(k, 7, 154, 153)
gw_m_row_inc(k)
gw_m_addr_elem(k, 0, 156, 155)
gw_m_addr_elem(k, 1, 158, 157)
gw_m_addr_elem(k, 2, 160, 159)
gw_m_addr_elem(k, 3, 162, 161)
gw_m_addr_elem(k, 4, 164, 163)
gw_m_addr_elem(k, 5, 166, 165)
gw_m_addr_elem(k, 6, 168, 167)
gw_m_addr_elem(k, 7, 170, 169)
k.emit(v_accvgpr_read(v[35], v[0]))
k.emit(v_accvgpr_read(v[36], v[4]))
k.emit(v_accvgpr_read(v[37], v[8]))
k.emit(v_accvgpr_read(v[38], v[12]))
k.emit(v_accvgpr_read(v[39], v[16]))
for _j, _src in enumerate([20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60, 64, 68, 72, 76, 80, 84, 88, 92, 96, 100, 104, 108, 112, 116, 120, 124, 128, 132, 136, 140, 144, 148, 152, 156]):
k.emit(v_accvgpr_read(v[40 + _j], v[_src]))
k.waitcnt(lgkm=0)
k.emit(v_mov_b32_e32(v[32], 4294901760))
k.emit(v_mov_b32_e32(v[33], 2147418112))
k.emit(v_mov_b32_e32(v[34], 32767))
for _i, _addr in enumerate([77, 81, 85, 89, 93, 97, 101, 105, 107, 109, 111, 113, 115, 117, 119, 121, 123, 125, 127, 129, 131, 133, 135, 137, 139, 141, 143, 145, 147, 149, 151, 153, 155, 157, 159, 161, 163, 165, 167, 169]):
gw_m_element(k, v_cvt_pk, 35+_i, 76+(_i%8)*4, 75+(_i%8)*4, _addr)
k.emit(s_nop())
k.emit(v_mov_b32_e32(v[30], 2147483648))
gw_m_row_inc(k)
gw_m_addr_elem(k, 0, 78, 77, 75, 76)
gw_m_addr_elem(k, 1, 82, 81, 79, 80)
gw_m_addr_elem(k, 2, 86, 85, 83, 84)
gw_m_addr_elem(k, 3, 90, 89, 87, 88)
gw_m_addr_elem(k, 4, 94, 93, 91, 92)
gw_m_addr_elem(k, 5, 98, 97, 95, 96)
gw_m_addr_elem(k, 6, 102, 101, 99, 100)
gw_m_addr_elem(k, 7, 106, 105, 103, 104)
gw_m_row_inc(k)
gw_m_addr_elem(k, 0, 108, 107)
gw_m_addr_elem(k, 1, 110, 109)
gw_m_addr_elem(k, 2, 112, 111)
gw_m_addr_elem(k, 3, 114, 113)
gw_m_addr_elem(k, 4, 116, 115)
gw_m_addr_elem(k, 5, 118, 117)
gw_m_addr_elem(k, 6, 120, 119)
gw_m_addr_elem(k, 7, 122, 121)
gw_m_row_inc(k)
gw_m_addr_elem(k, 0, 124, 123)
gw_m_addr_elem(k, 1, 126, 125)
gw_m_addr_elem(k, 2, 128, 127)
gw_m_addr_elem(k, 3, 130, 129)
gw_m_addr_elem(k, 4, 132, 131)
gw_m_addr_elem(k, 5, 134, 133)
gw_m_addr_elem(k, 6, 136, 135)
gw_m_addr_elem(k, 7, 138, 137)
gw_m_row_inc(k)
gw_m_addr_elem(k, 0, 140, 139)
gw_m_addr_elem(k, 1, 142, 141)
gw_m_addr_elem(k, 2, 144, 143)
gw_m_addr_elem(k, 3, 146, 145)
gw_m_addr_elem(k, 4, 148, 147)
gw_m_addr_elem(k, 5, 150, 149)
gw_m_addr_elem(k, 6, 152, 151)
gw_m_addr_elem(k, 7, 154, 153)
gw_m_row_inc(k)
gw_m_addr_elem(k, 0, 156, 155)
gw_m_addr_elem(k, 1, 158, 157)
gw_m_addr_elem(k, 2, 160, 159)
gw_m_addr_elem(k, 3, 162, 161)
gw_m_addr_elem(k, 4, 164, 163)
gw_m_addr_elem(k, 5, 166, 165)
gw_m_addr_elem(k, 6, 168, 167)
gw_m_addr_elem(k, 7, 170, 169)
k.emit(v_accvgpr_read(v[35], v[160]))
k.emit(v_accvgpr_read(v[36], v[164]))
k.emit(v_accvgpr_read(v[37], v[168]))
k.emit(v_accvgpr_read(v[38], v[172]))
k.emit(v_accvgpr_read(v[39], v[176]))
for _j, _src in enumerate([180, 184, 188, 192, 196, 200, 204, 208, 212, 216, 220, 224, 228, 232, 236, 240, 244, 248, 252, 1, 5, 9, 13, 17, 21, 25, 29, 33, 37, 41, 45, 49, 53, 57, 61]):
k.emit(v_accvgpr_read(v[40 + _j], v[_src]))
k.waitcnt(lgkm=0)
k.emit(v_mov_b32_e32(v[32], 4294901760))
k.emit(v_mov_b32_e32(v[33], 2147418112))
k.emit(v_mov_b32_e32(v[34], 32767))
for _i, _addr in enumerate([77, 81, 85, 89, 93, 97, 101, 105, 107, 109, 111, 113, 115, 117, 119, 121, 123, 125, 127, 129, 131, 133, 135, 137, 139, 141, 143, 145, 147, 149, 151, 153, 155, 157, 159, 161, 163, 165, 167, 169]):
gw_m_element(k, v_cvt_pk, 35+_i, 76+(_i%8)*4, 75+(_i%8)*4, _addr)
k.emit(s_nop())
k.emit(v_mov_b32_e32(v[30], 2147483648))
gw_m_row_inc(k)
gw_m_addr_elem(k, 0, 78, 77, 75, 76)
gw_m_addr_elem(k, 1, 82, 81, 79, 80)
gw_m_addr_elem(k, 2, 86, 85, 83, 84)
gw_m_addr_elem(k, 3, 90, 89, 87, 88)
gw_m_addr_elem(k, 4, 94, 93, 91, 92)
gw_m_addr_elem(k, 5, 98, 97, 95, 96)
gw_m_addr_elem(k, 6, 102, 101, 99, 100)
gw_m_addr_elem(k, 7, 106, 105, 103, 104)
gw_m_row_inc(k)
gw_m_addr_elem(k, 0, 108, 107)
gw_m_addr_elem(k, 1, 110, 109)
gw_m_addr_elem(k, 2, 112, 111)
gw_m_addr_elem(k, 3, 114, 113)
gw_m_addr_elem(k, 4, 116, 115)
gw_m_addr_elem(k, 5, 118, 117)
gw_m_addr_elem(k, 6, 120, 119)
gw_m_addr_elem(k, 7, 122, 121)
gw_m_row_inc(k)
gw_m_addr_elem(k, 0, 124, 123)
gw_m_addr_elem(k, 1, 126, 125)
gw_m_addr_elem(k, 2, 128, 127)
gw_m_addr_elem(k, 3, 130, 129)
gw_m_addr_elem(k, 4, 132, 131)
gw_m_addr_elem(k, 5, 134, 133)
gw_m_addr_elem(k, 6, 136, 135)
gw_m_addr_elem(k, 7, 138, 137)
gw_m_row_inc(k)
gw_m_addr_elem(k, 0, 140, 139)
gw_m_addr_elem(k, 1, 142, 141)
gw_m_addr_elem(k, 2, 144, 143)
gw_m_addr_elem(k, 3, 146, 145)
gw_m_addr_elem(k, 4, 148, 147)
gw_m_addr_elem(k, 5, 150, 149)
gw_m_addr_elem(k, 6, 152, 151)
gw_m_addr_elem(k, 7, 154, 153)
gw_m_row_inc(k)
gw_m_addr_elem(k, 0, 156, 155)
gw_m_addr_elem(k, 1, 158, 157)
gw_m_addr_elem(k, 2, 160, 159)
gw_m_addr_elem(k, 3, 162, 161)
gw_m_addr_elem(k, 4, 164, 163)
gw_m_addr_elem(k, 5, 166, 165)
gw_m_addr_elem(k, 6, 168, 167)
gw_m_addr_elem(k, 7, 170, 169)
k.emit(v_accvgpr_read(v[35], v[65]))
k.emit(v_accvgpr_read(v[36], v[69]))
k.emit(v_accvgpr_read(v[37], v[73]))
k.emit(v_accvgpr_read(v[38], v[77]))
k.emit(v_accvgpr_read(v[39], v[81]))
for _j, _src in enumerate([85, 89, 93, 97, 101, 105, 109, 113, 117, 121, 125, 129, 133, 137, 141, 145, 149, 153, 157, 161, 165, 169, 173, 177, 181, 185, 189, 193, 197, 201, 205, 209, 213, 217, 221]):
k.emit(v_accvgpr_read(v[40 + _j], v[_src]))
k.waitcnt(lgkm=0)
k.emit(v_mov_b32_e32(v[32], 4294901760))
k.emit(v_mov_b32_e32(v[33], 2147418112))
k.emit(v_mov_b32_e32(v[34], 32767))
for _i, _addr in enumerate([77, 81, 85, 89, 93, 97, 101, 105, 107, 109, 111, 113, 115, 117, 119, 121, 123, 125, 127, 129, 131, 133, 135, 137, 139, 141, 143, 145, 147, 149, 151, 153, 155, 157, 159, 161, 163, 165, 167, 169]):
gw_m_element(k, v_cvt_pk, 35+_i, 76+(_i%8)*4, 75+(_i%8)*4, _addr)
k.emit(s_nop())
k.emit(v_mov_b32_e32(v[30], 2147483648))
gw_m_row_inc(k)
gw_m_addr_elem(k, 0, 78, 77, 75, 76)
gw_m_addr_elem(k, 1, 82, 81, 79, 80)
gw_m_addr_elem(k, 2, 86, 85, 83, 84)
gw_m_addr_elem(k, 3, 90, 89, 87, 88)
gw_m_addr_elem(k, 4, 94, 93, 91, 92)
gw_m_addr_elem(k, 5, 98, 97, 95, 96)
gw_m_addr_elem(k, 6, 102, 101, 99, 100)
gw_m_addr_elem(k, 7, 106, 105, 103, 104)
gw_m_row_inc(k)
gw_m_addr_elem(k, 0, 108, 107)
gw_m_addr_elem(k, 1, 110, 109)
gw_m_addr_elem(k, 2, 112, 111)
gw_m_addr_elem(k, 3, 114, 113)
gw_m_addr_elem(k, 4, 116, 115)
gw_m_addr_elem(k, 5, 118, 117)
gw_m_addr_elem(k, 6, 120, 119)
gw_m_addr_elem(k, 7, 122, 121)
gw_m_row_inc(k)
gw_m_addr_elem(k, 0, 124, 123)
gw_m_addr_elem(k, 1, 126, 125)
gw_m_addr_elem(k, 2, 128, 127)
gw_m_addr_elem(k, 3, 130, 129)
gw_m_addr_elem(k, 4, 132, 131)
gw_m_addr_elem(k, 5, 134, 133)
gw_m_addr_elem(k, 6, 136, 135)
gw_m_addr_elem(k, 7, 138, 137)
gw_m_row_inc(k)
gw_m_addr_elem(k, 0, 140, 139)
gw_m_addr_elem(k, 1, 142, 141)
gw_m_addr_elem(k, 2, 144, 143)
gw_m_addr_elem(k, 3, 146, 145)
gw_m_addr_elem(k, 4, 148, 147)
gw_m_addr_elem(k, 5, 150, 149)
gw_m_addr_elem(k, 6, 152, 151)
gw_m_addr_elem(k, 7, 154, 153)
gw_m_row_inc(k)
gw_m_addr_elem(k, 0, 156, 155)
gw_m_addr_elem(k, 1, 158, 157)
gw_m_addr_elem(k, 2, 160, 159)
gw_m_addr_elem(k, 3, 162, 161)
gw_m_addr_elem(k, 4, 164, 163)
gw_m_addr_elem(k, 5, 166, 165)
gw_m_addr_elem(k, 6, 168, 167)
gw_m_addr_elem(k, 7, 170, 169)
k.emit(v_accvgpr_read(v[35], v[225]))
k.emit(v_accvgpr_read(v[36], v[229]))
k.emit(v_accvgpr_read(v[37], v[233]))
k.emit(v_accvgpr_read(v[38], v[237]))
k.emit(v_accvgpr_read(v[39], v[241]))
for _j, _src in enumerate([245, 249, 253, 2, 6, 10, 14, 18, 22, 26, 30, 34, 38, 42, 46, 50, 54, 58, 62, 66, 70, 74, 78, 82, 86, 90, 94, 98, 102, 106, 110, 114, 118, 122, 126]):
k.emit(v_accvgpr_read(v[40 + _j], v[_src]))
k.waitcnt(lgkm=0)
k.emit(v_mov_b32_e32(v[32], 4294901760))
k.emit(v_mov_b32_e32(v[33], 2147418112))
k.emit(v_mov_b32_e32(v[34], 32767))
for _i, _addr in enumerate([77, 81, 85, 89, 93, 97, 101, 105, 107, 109, 111, 113, 115, 117, 119, 121, 123, 125, 127, 129, 131, 133, 135, 137, 139, 141, 143, 145, 147, 149, 151, 153, 155, 157, 159, 161, 163, 165, 167, 169]):
gw_m_element(k, v_cvt_pk, 35+_i, 76+(_i%8)*4, 75+(_i%8)*4, _addr)
k.emit(s_nop())
k.emit(v_mov_b32_e32(v[30], 2147483648))
gw_m_row_inc(k)
gw_m_addr_elem(k, 0, 78, 77, 75, 76)
gw_m_addr_elem(k, 1, 82, 81, 79, 80)
gw_m_addr_elem(k, 2, 86, 85, 83, 84)
gw_m_addr_elem(k, 3, 90, 89, 87, 88)
gw_m_addr_elem(k, 4, 94, 93, 91, 92)
gw_m_addr_elem(k, 5, 98, 97, 95, 96)
gw_m_addr_elem(k, 6, 102, 101, 99, 100)
gw_m_addr_elem(k, 7, 106, 105, 103, 104)
gw_m_row_inc(k)
gw_m_addr_elem(k, 0, 108, 107)
gw_m_addr_elem(k, 1, 110, 109)
gw_m_addr_elem(k, 2, 112, 111)
gw_m_addr_elem(k, 3, 114, 113)
gw_m_addr_elem(k, 4, 116, 115)
gw_m_addr_elem(k, 5, 118, 117)
gw_m_addr_elem(k, 6, 120, 119)
gw_m_addr_elem(k, 7, 122, 121)
gw_m_row_inc(k)
gw_m_addr_elem(k, 0, 124, 123)
gw_m_addr_elem(k, 1, 126, 125)
gw_m_addr_elem(k, 2, 128, 127)
gw_m_addr_elem(k, 3, 130, 129)
gw_m_addr_elem(k, 4, 132, 131)
gw_m_addr_elem(k, 5, 134, 133)
gw_m_addr_elem(k, 6, 136, 135)
gw_m_addr_elem(k, 7, 138, 137)
gw_m_row_inc(k)
gw_m_addr_elem(k, 0, 140, 139)
gw_m_addr_elem(k, 1, 142, 141)
gw_m_addr_elem(k, 2, 144, 143)
gw_m_addr_elem(k, 3, 146, 145)
gw_m_addr_elem(k, 4, 148, 147)
gw_m_addr_elem(k, 5, 150, 149)
gw_m_addr_elem(k, 6, 152, 151)
gw_m_addr_elem(k, 7, 154, 153)
gw_m_row_inc(k)
gw_m_addr_elem(k, 0, 156, 155)
gw_m_addr_elem(k, 1, 158, 157)
gw_m_addr_elem(k, 2, 160, 159)
gw_m_addr_elem(k, 3, 162, 161)
gw_m_addr_elem(k, 4, 164, 163)
gw_m_addr_elem(k, 5, 166, 165)
gw_m_addr_elem(k, 6, 168, 167)
gw_m_addr_elem(k, 7, 170, 169)
k.emit(v_accvgpr_read(v[35], v[130]))
k.emit(v_accvgpr_read(v[36], v[134]))
k.emit(v_accvgpr_read(v[37], v[138]))
k.emit(v_accvgpr_read(v[38], v[142]))
k.emit(v_accvgpr_read(v[39], v[146]))
for _j, _src in enumerate([150, 154, 158, 162, 166, 170, 174, 178, 182, 186, 190, 194, 198, 202, 206, 210, 214, 218, 222, 226, 230, 234, 238, 242, 246, 250, 254, 3, 7, 11, 15, 19, 23, 27, 31]):
k.emit(v_accvgpr_read(v[40 + _j], v[_src]))
k.waitcnt(lgkm=0)
k.emit(v_mov_b32_e32(v[32], 4294901760))
k.emit(v_mov_b32_e32(v[33], 2147418112))
k.emit(v_mov_b32_e32(v[34], 32767))
for _i, _addr in enumerate([77, 81, 85, 89, 93, 97, 101, 105, 107, 109, 111, 113, 115, 117, 119, 121, 123, 125, 127, 129, 131, 133, 135, 137, 139, 141, 143, 145, 147, 149, 151, 153, 155, 157, 159, 161, 163, 165, 167, 169]):
gw_m_element(k, v_cvt_pk, 35+_i, 76+(_i%8)*4, 75+(_i%8)*4, _addr)
k.emit(s_nop())
k.emit(v_mov_b32_e32(v[30], 2147483648))
gw_m_row_inc(k)
gw_m_addr_elem(k, 0, 78, 77, 75, 76)
gw_m_addr_elem(k, 1, 82, 81, 79, 80)
gw_m_addr_elem(k, 2, 86, 85, 83, 84)
gw_m_addr_elem(k, 3, 90, 89, 87, 88)
gw_m_addr_elem(k, 4, 94, 93, 91, 92)
gw_m_addr_elem(k, 5, 98, 97, 95, 96)
gw_m_addr_elem(k, 6, 102, 101, 99, 100)
gw_m_addr_elem(k, 7, 106, 105, 103, 104)
gw_m_row_inc(k)
gw_m_addr_elem(k, 0, 108, 107)
gw_m_addr_elem(k, 1, 110, 109)
gw_m_addr_elem(k, 2, 112, 111)
gw_m_addr_elem(k, 3, 114, 113)
gw_m_addr_elem(k, 4, 116, 115)
gw_m_addr_elem(k, 5, 118, 117)
gw_m_addr_elem(k, 6, 120, 119)
gw_m_addr_elem(k, 7, 122, 121)
gw_m_row_inc(k)
gw_m_addr_elem(k, 0, 124, 123)
gw_m_addr_elem(k, 1, 126, 125)
gw_m_addr_elem(k, 2, 128, 127)
gw_m_addr_elem(k, 3, 130, 129)
gw_m_addr_elem(k, 4, 132, 131)
gw_m_addr_elem(k, 5, 134, 133)
gw_m_addr_elem(k, 6, 136, 135)
gw_m_addr_elem(k, 7, 138, 137)
gw_m_row_inc(k)
gw_m_addr_elem(k, 0, 140, 139)
gw_m_addr_elem(k, 1, 142, 141)
gw_m_addr_elem(k, 2, 144, 143)
gw_m_addr_elem(k, 3, 146, 145)
gw_m_addr_elem(k, 4, 148, 147)
gw_m_addr_elem(k, 5, 150, 149)
gw_m_addr_elem(k, 6, 152, 151)
gw_m_addr_elem(k, 7, 154, 153)
gw_m_row_inc(k)
gw_m_addr_elem(k, 0, 156, 155)
gw_m_addr_elem(k, 1, 158, 157)
gw_m_addr_elem(k, 2, 160, 159)
gw_m_addr_elem(k, 3, 162, 161)
gw_m_addr_elem(k, 4, 164, 163)
gw_m_addr_elem(k, 5, 166, 165)
gw_m_addr_elem(k, 6, 168, 167)
gw_m_addr_elem(k, 7, 170, 169)
k.emit(v_accvgpr_read(v[35], v[35]))
k.emit(v_accvgpr_read(v[36], v[39]))
k.emit(v_accvgpr_read(v[37], v[43]))
k.emit(v_accvgpr_read(v[38], v[47]))
k.emit(v_accvgpr_read(v[39], v[51]))
for _j, _src in enumerate([55, 59, 63, 67, 71, 75, 79, 83, 87, 91, 95, 99, 103, 107, 111, 115, 119, 123, 127, 131, 135, 139, 143, 147, 151, 155, 159, 163, 167, 171, 175, 179, 183, 187, 191]):
k.emit(v_accvgpr_read(v[40 + _j], v[_src]))
k.waitcnt(lgkm=0)
k.emit(v_mov_b32_e32(v[32], 4294901760))
k.emit(v_mov_b32_e32(v[33], 2147418112))
k.emit(v_mov_b32_e32(v[34], 32767))
for _i, _addr in enumerate([77, 81, 85, 89, 93, 97, 101, 105, 107, 109, 111, 113, 115, 117, 119, 121, 123, 125, 127, 129, 131, 133, 135, 137, 139, 141, 143, 145, 147, 149, 151, 153, 155, 157, 159, 161, 163, 165, 167, 169]):
gw_m_element(k, v_cvt_pk, 35+_i, 76+(_i%8)*4, 75+(_i%8)*4, _addr)
k.emit(s_nop())
k.emit(v_mov_b32_e32(v[30], 2147483648))
gw_m_row_inc(k)
gw_m_addr_elem(k, 0, 54, 53, 51, 52)
gw_m_addr_elem(k, 1, 58, 57, 55, 56)
gw_m_addr_elem(k, 2, 62, 61, 59, 60)
gw_m_addr_elem(k, 3, 66, 65, 63, 64)
gw_m_addr_elem(k, 4, 70, 69, 67, 68)
gw_m_addr_elem(k, 5, 74, 73, 71, 72)
gw_m_addr_elem(k, 6, 78, 77, 75, 76)
gw_m_addr_elem(k, 7, 82, 81, 79, 80)
gw_m_row_inc(k)
gw_m_addr_elem(k, 0, 84, 83)
gw_m_addr_elem(k, 1, 86, 85)
gw_m_addr_elem(k, 2, 88, 87)
gw_m_addr_elem(k, 3, 90, 89)
gw_m_addr_elem(k, 4, 92, 91)
gw_m_addr_elem(k, 5, 94, 93)
gw_m_addr_elem(k, 6, 96, 95)
gw_m_addr_elem(k, 7, 98, 97)
for i in range(16): k.emit(v_accvgpr_read(v[35+i], v[195+i*4]))
k.waitcnt(lgkm=0)
k.emit(v_mov_b32_e32(v[32], 4294901760))
k.emit(v_mov_b32_e32(v[33], 2147418112))
k.emit(v_mov_b32_e32(v[34], 32767))
gw_m_element(k, v_cvt_pk, 35, 52, 51, 53)
gw_m_element(k, v_cvt_pk, 36, 56, 55, 57)
gw_m_element(k, v_cvt_pk, 37, 60, 59, 61)
gw_m_element(k, v_cvt_pk, 38, 64, 63, 65)
gw_m_element(k, v_cvt_pk, 39, 68, 67, 69)
gw_m_element(k, v_cvt_pk, 40, 72, 71, 73)
gw_m_element(k, v_cvt_pk, 41, 76, 75, 77)
gw_m_element(k, v_cvt_pk, 42, 80, 79, 81)
gw_m_element(k, v_cvt_pk, 43, 52, 51, 83)
gw_m_element(k, v_cvt_pk, 44, 56, 55, 85)
gw_m_element(k, v_cvt_pk, 45, 60, 59, 87)
gw_m_element(k, v_cvt_pk, 46, 64, 63, 89)
gw_m_element(k, v_cvt_pk, 47, 68, 67, 91)
gw_m_element(k, v_cvt_pk, 48, 72, 71, 93)
gw_m_element(k, v_cvt_pk, 49, 76, 75, 95)
gw_m_element(k, v_cvt_pk, 50, 80, 79, 97)
k.emit(s_nop())
k.emit(s_branch(), target='GW_End_1')
k.label('GW_End_1')
k.emit(s_cmp_ge_u32(s[58], s[59]))
k.emit(s_cbranch_scc1(), target='KernelEnd')
k.emit(s_branch(), target='PersistentLoopStart')
k.label('KernelEnd')
k.emit(s_endpgm())
return k.finalize()
# ** ASM_GEMM custom kernel
@functools.cache
def custom_asm_gemm(C:UOp, A:UOp, B:UOp, dname:str) -> UOp:
batch, M, K = A.shape
K2, N = B.shape[(1 if B.ndim == 3 else 0):]
assert K == K2
lidx = UOp.special(WORKGROUP_SIZE, "lidx0")
gidx = UOp.special(NUM_WG, "gidx0")
insts = build_kernel(batch, M, N, K, A.dtype.base)
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=133_120, addrspace=AddrSpace.LOCAL), (), 'lds')
sink = UOp.sink(C.base, A.base, B.base, lds, lidx, gidx,
arg=KernelInfo(name=f"gemm_{batch}_{M}_{N}_{K}", estimates=Estimates(ops=2*batch*M*N*K, mem=(batch*M*K + K*N + batch*M*N)*2)))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname),
UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
# ** FP8 GEMM custom kernel
@functools.cache
def custom_hk_fp8_gemm(C:UOp, A:UOp, B:UOp, *args:UOp, dname:str, scale_mode:int=3) -> UOp:
# scale_mode: 0=no scale, 1=x only, 2=w only, 3=both
n_scales = (1 if scale_mode & 1 else 0) + (1 if scale_mode & 2 else 0)
scales, extra = args[:n_scales], args[n_scales:]
M, K = A.shape[0]*A.shape[1], A.shape[2]
N, K2 = B.shape[(1 if B.ndim == 3 else 0):]
assert K == K2, f"{A.shape} {B.shape}"
block_size = 256
threads = UOp.special(64 * 8, "lidx0")
workgroups = UOp.special((M // block_size) * (N // block_size), "gidx0")
sink_inputs = (C.base, A.base, B.base) + tuple(s.base for s in scales) + (threads, workgroups)
sink = UOp.sink(*sink_inputs,
arg=KernelInfo(f"hk_fp8_gemm_{M}_{N}_{K}", estimates=Estimates(ops=2*M*N*K, mem=(M*K+N*K)*A.dtype.itemsize+M*N*C.dtype.itemsize)))
kittens_path = pathlib.Path(__file__).parent.parent/"thunder"/"amd"
src = (kittens_path/"gemm_fp8.cpp").read_text()
lib = HIPCCCompiler("gfx950", [f"-I{(kittens_path/'include').as_posix()}", "-std=c++20", "-DKITTENS_CDNA4", "-ffast-math",
"-DHIP_ENABLE_WARP_SYNC_BUILTINS", f"-DGEMM_M={M}", f"-DGEMM_N={N}", f"-DGEMM_K={K}",
f"-DSCALE_MODE={scale_mode}"]).compile_cached(src)
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src),
UOp(Ops.BINARY, arg=lib)))
counters = {"used":0, "todos":[]}
def todo(msg:str) -> bool: counters["todos"].append(msg); return False
def _asm_gemm_report():
print(f'asm_gemm: {counters["used"]} used, {len(counters["todos"])} not used')
if DEBUG >= 2 and counters["todos"]:
from collections import Counter
for msg, cnt in Counter(counters["todos"]).most_common(): print(f' {cnt:3d}x {msg}')
atexit.register(_asm_gemm_report)
def can_use_asm_gemm(a:Tensor, b:Tensor) -> bool:
if a.dtype != b.dtype: return todo(f"dtypes must match {a.dtype} != {b.dtype}")
if a.dtype not in {dtypes.bfloat16, dtypes.float16, FP8_DTYPE}: return todo(f"only bfloat16/float16/fp8, got {a.dtype}")
batch, M, K = (1, *a.shape) if a.ndim == 2 else a.shape
N = b.shape[1]
if isinstance(a.device, tuple):
if a.ndim == 2 and a.uop.axis == 0 and b.uop.axis is None: M //= len(a.device)
elif a.ndim == 2 and a.uop.axis == 1 and b.uop.axis == 0: K //= len(a.device)
elif a.ndim == 2 and a.uop.axis is None and b.uop.axis == 1: N //= len(a.device)
elif a.ndim == 3 and a.uop.axis == 0 and b.uop.axis is None: batch //= len(a.device)
elif a.ndim == 3 and a.uop.axis is None and b.uop.axis == 1: N //= len(a.device)
elif a.ndim == 3 and a.uop.axis == 2 and b.uop.axis == 0: K //= len(a.device)
else: return todo(f"sharding mismatch a.ndim={a.ndim} a.uop.axis={a.uop.axis} b.uop.axis={b.uop.axis}")
dname = a.device[0]
else: dname = a.device
arch = Device[dname].renderer.target.arch
if batch not in {1, 2}: return todo(f"GEMM batch size {batch}")
# blacklist slow matmul
# TODO: why is this slow?
if (M,N,K) == (8192, 2304, 16384): return todo("blacklisted slow matmul")
if (M % TILE_M != 0 or N % TILE_N != 0 or K % TILE_K != 0) and arch == "gfx950":
return todo(f"GEMM shape ({M},{N},{K}) not a multiple of ({TILE_M},{TILE_N},{TILE_K})")
return True
# ** UOp gemm to test Tensor.custom_kernel multi and backward correctness on non cdna4
# note: this can be removed after we have GEMM on mixins
def custom_uop_gemm(C:UOp, A:UOp, B:UOp) -> UOp:
M, K = A.shape[0]*A.shape[1], A.shape[2]
K2, N = B.shape[(1 if B.ndim == 3 else 0):]
assert K == K2
m = UOp.range(M, 1, AxisType.LOOP)
n = UOp.range(N, 2, AxisType.LOOP)
k = UOp.range(K, 0, AxisType.REDUCE)
mul = (A.flatten().index((m*UOp.const(dtypes.weakint, K)+k))*
B.flatten().index((k*UOp.const(dtypes.weakint, N)+n))).cast(dtypes.float32)
red = mul.reduce(k, arg=Ops.ADD, dtype=dtypes.float32).cast(C.dtype.base)
store = C.flatten().index((m*UOp.const(dtypes.weakint, N)+n), ptr=True).store(red).end(m, n)
return store.sink(arg=KernelInfo(name=f'uop_gemm_{M}_{N}_{K}'))
# ** bf16 A @ B.T kernel in C
@functools.cache
def custom_hk_bf16_gemm(C:UOp, A:UOp, B:UOp, *args:UOp, dname:str) -> UOp:
M, K = A.shape[0]*A.shape[1], A.shape[2]
N, K2 = B.shape[(1 if B.ndim == 3 else 0):]
assert K == K2, f"{A.shape} {B.shape}"
block_m, block_n, block_k, num_warps = 256, 256, 64, 8
assert M % block_m == 0 and N % block_n == 0 and K % block_k == 0, f"invalid bf16 tile {(block_m, block_n, block_k)} for {(M, N, K)}"
threads = UOp.special(64 * num_warps, "lidx0")
workgroups = UOp.special((M // block_m) * (N // block_n), "gidx0")
sink = UOp.sink(C.base, A.base, B.base, threads, workgroups,
arg=KernelInfo(f"hk_bf16_gemm_{M}_{N}_{K}", estimates=Estimates(ops=2*M*N*K, mem=(M*K+N*K+M*N)*A.dtype.itemsize)))
kittens_path = pathlib.Path(__file__).parent.parent/"thunder"/"amd"
src = (kittens_path/"gemm_bf16.cpp").read_text()
lib = HIPCCCompiler("gfx950", [f"-I{(kittens_path/'include').as_posix()}", "-std=c++20", "-DKITTENS_CDNA4", "-ffast-math",
"-DHIP_ENABLE_WARP_SYNC_BUILTINS", f"-DGEMM_M={M}", f"-DGEMM_N={N}", f"-DGEMM_K={K}"]).compile_cached(src)
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src),
UOp(Ops.BINARY, arg=lib)))
@functools.cache
def custom_hk_bf16_atb_gemm(C:UOp, A:UOp, B:UOp, dname:str) -> UOp:
K, M = A.shape[0]*A.shape[1], A.shape[2]
K2, N = B.shape[0]*B.shape[1], B.shape[2]
assert K == K2, f"{A.shape} {B.shape}"
block_m, block_n, block_k, num_warps = 256, 256, 64, 8
assert M % block_m == 0 and N % block_n == 0 and K % block_k == 0, f"invalid bf16 atb tile {(block_m, block_n, block_k)} for {(M, N, K)}"
threads = UOp.special(64 * num_warps, "lidx0")
workgroups = UOp.special((M // block_m) * (N // block_n), "gidx0")
sink = UOp.sink(C.base, A.base, B.base, threads, workgroups,
arg=KernelInfo(f"hk_bf16_atb_gemm_{M}_{N}_{K}", estimates=Estimates(ops=2*M*N*K, mem=(M*K+N*K+M*N)*A.dtype.itemsize)))
kittens_path = pathlib.Path(__file__).parent.parent/"thunder"/"amd"
src = (kittens_path/"gemm_bf16_atb.cpp").read_text()
lib = HIPCCCompiler("gfx950", [f"-I{(kittens_path/'include').as_posix()}", "-std=c++20", "-DKITTENS_CDNA4", "-ffast-math",
"-DHIP_ENABLE_WARP_SYNC_BUILTINS", f"-DGEMM_M={M}", f"-DGEMM_N={N}", f"-DGEMM_K={K}"]).compile_cached(src)
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src),
UOp(Ops.BINARY, arg=lib)))
def hk_bf16_atb_gemm(a:Tensor, b:Tensor) -> Tensor:
assert a.dtype == b.dtype == dtypes.bfloat16, f"expected bf16, got {a.dtype} {b.dtype}"
assert a.ndim == b.ndim == 3 and a.shape[:2] == b.shape[:2], f"{a.shape} {b.shape}"
batch, rows, M = a.shape
N = b.shape[2]
assert M % TILE_M == 0 and N % TILE_N == 0 and (batch * rows) % TILE_K == 0, \
f"atb shape {a.shape} {b.shape} must produce (M,N,K) multiples of ({TILE_M},{TILE_N},{TILE_K})"
is_multi = isinstance(a.device, tuple)
if is_multi:
out = Tensor(Tensor.invalids(1, M, N, dtype=a.dtype, device=a.device).uop.multi(0), device=a.device)
dname = a.device[0]
else:
out = Tensor.invalids(1, M, N, dtype=a.dtype, device=a.device)
dname = a.device
dname = dname.split(":")[0]
out = Tensor.custom_kernel(out, a, b, fxn=functools.partial(custom_hk_bf16_atb_gemm, dname=dname))[0]
if is_multi: out = out.sum(0)
return out.squeeze(0) if out.ndim == 3 else out
# ** backward gemm, might use the asm gemm
def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool=False, has_w_post:bool=False):
inputs = kernel.src[1:]
if inputs[1].dtype == FP8_DTYPE:
out, a, b = inputs[:3]
i = 3
s_x = inputs[i]; i += 1
has_w = n_scales == 2
s_w = inputs[i] if has_w else None; i += has_w
grad_amax_state = inputs[i] if has_grad_amax else None; i += has_grad_amax
w_post = inputs[i] if has_w_post else None
a_t, b_t, g_t = Tensor(a, device=a.device), Tensor(b, device=a.device), Tensor(gradient, device=a.device)
s_x_t = Tensor(s_x, device=a.device)
s_w_t = Tensor(s_w, device=a.device) if has_w else None
w_post_t = Tensor(w_post, device=a.device) if has_w_post else None
g_t = g_t[:a.shape[0]]
from extra.llama_kernels.cast_amax import _grad_fp8_mailbox
from extra.llama_kernels.quantize_fp8_delayed import quantize_fp8_delayed
gbase = gradient.base if hasattr(gradient, "base") else gradient
mailbox_entry = _grad_fp8_mailbox.pop(gbase, None) or _grad_fp8_mailbox.pop(gradient, None)
if mailbox_entry is not None:
g_fp8_u, inv_scale_u = mailbox_entry
g_fp8 = Tensor(g_fp8_u, device=a.device)[:a.shape[0]]
g_scale = Tensor(inv_scale_u, device=a.device)
else:
assert grad_amax_state is not None, "fp8 matmul bwd needs either a mailbox entry or a grad_amax_state"
if getenv("CURRENT_GRAD_SCALE", 0):
g_fp8, g_scale, _ = quantize_fp8(g_t, amax_state=None)
elif getenv("FUSED_GRAD_QUANTIZE", 0):
g_fp8, g_scale, _, store_effect = quantize_fp8_delayed(g_t, Tensor(grad_amax_state, device=a.device))
assert g_fp8.uop.op is Ops.AFTER, f"expected AFTER, got {g_fp8.uop.op}"
g_fp8 = Tensor(g_fp8.uop.replace(src=g_fp8.uop.src + (store_effect,)), device=a.device)
else:
grad_amax_t = Tensor(grad_amax_state, device=a.device)
g_fp8, g_scale, new_grad_amax = quantize_fp8(g_t, amax_state=grad_amax_t)
store_effect = grad_amax_state.store(new_grad_amax.uop)
g_fp8 = Tensor(g_fp8.contiguous().uop.after(store_effect), device=a.device)
# dgrad: uses g_scale * x_scale * w_scale (only when scalar)
grad_a = asm_gemm(g_fp8, b_t, x_scale=g_scale * s_x_t, w_scale=s_w_t) if has_w else asm_gemm(g_fp8, b_t, x_scale=g_scale * s_x_t)
# wgrad: no w_scale
g_fp8_2d = g_fp8.reshape(-1, g_fp8.shape[-1])
if getenv("FAST_FP8_TRANSPOSE", 0) and g_fp8_2d.shape[0] % 64 == 0 and g_fp8_2d.shape[1] % 64 == 0:
from extra.llama_kernels.fp8_transpose import fast_fp8_transpose
g_fp8_T = fast_fp8_transpose(g_fp8_2d)
else:
g_fp8_T = g_fp8.permute(2, 0, 1).reshape(g_t.shape[-1], -1)
grad_b = asm_gemm(g_fp8_T, a_t.reshape(-1, a_t.shape[-1]), x_scale=g_scale * s_x_t)
# wgrad: rescale if not scalar
if w_post_t is not None:
grad_b = grad_b / w_post_t.reshape(*w_post_t.shape, *([1]*(grad_b.ndim - w_post_t.ndim)))
# one None per input: (out, a, b, x_scale[, w_scale][, grad_amax][, w_post_scale])
ret = (None, grad_a.uop, grad_b.uop) + tuple(None for _ in inputs[3:])
return ret
else:
hk_bf16 = len(inputs) == 4 and inputs[1].dtype == dtypes.bfloat16
if hk_bf16:
out, a, b_t, b = inputs
assert all_same([gradient.device, a.device, b_t.device, b.device, out.device])
else:
assert len(inputs) == 3, f"regular gemm must have exactly 3 sources, got: {len(inputs)}"
out, a, b = inputs
assert all_same([gradient.device, a.device, b.device, out.device])
a_t, b_t, g_t = Tensor(a, device=a.device), Tensor(b, device=a.device), Tensor(gradient, device=a.device)
g_t = g_t[:a.shape[0]]
if hk_bf16 and g_t.dtype != b_t.dtype: g_t = g_t.cast(b_t.dtype)
if can_use_asm_gemm(g_t, b_t.T): grad_a = asm_gemm(g_t, b_t.T).uop
else: grad_a = (g_t @ b_t.T).uop
if hk_bf16 and getenv("USE_HK_BF16_ATB", 1):
grad_b = hk_bf16_atb_gemm(a_t, g_t).uop
else:
a_t_flat, g_t_flat = a_t.permute(2, 0, 1).reshape(a_t.shape[2], -1), g_t.reshape(-1, g_t.shape[-1])
if can_use_asm_gemm(a_t_flat, g_t_flat): grad_b = asm_gemm(a_t_flat, g_t_flat).uop
else: grad_b = (a_t_flat @ g_t_flat).uop
# hk_bf16 uses b.T, writes gradients only for a and b
return (None, grad_a, None, grad_b) if hk_bf16 else (None, grad_a, grad_b)
# ** main gemm function
def asm_gemm(a:Tensor, b:Tensor, x_scale:Tensor|None=None, w_scale:Tensor|None=None, grad_amax_state:Tensor|None=None,
w_post_scale:Tensor|None=None) -> Tensor:
assert can_use_asm_gemm(a, b), f"{counters['todos'][-1]}"
counters["used"] += 1
unfold_batch = a.ndim == 3 and isinstance(a.device, tuple) and a.uop.axis == 2 and b.uop.axis == 0
if unfold_batch:
orig_batch = a.shape[0]
a = a.reshape(a.shape[0]*a.shape[1], a.shape[2])
squeeze = a.ndim == 2
if squeeze: a = a.unsqueeze(0)
out_dtype = dtypes.bfloat16 if a.dtype == FP8_DTYPE else a.dtype
batch, M, K = a.shape
N = b.shape[1]
is_multi = isinstance(a.device, tuple)
if (k_sharded:=is_multi and a.uop.axis == 2): K //= len(a.device)
if (m_sharded:=is_multi and a.uop.axis == 1): M //= len(a.device)
n_sharded = is_multi and b.uop.axis == 1
if is_multi:
if n_sharded:
out = Tensor(Tensor.invalids(batch, M, N//len(a.device), dtype=out_dtype, device=a.device).uop.multi(2), device=a.device)
elif m_sharded:
out = Tensor(Tensor.invalids(batch, M, N, dtype=out_dtype, device=a.device).uop.multi(1), device=a.device)
else:
out = Tensor(Tensor.invalids(batch//len(a.device) if a.uop.axis==0 else batch, M, N, dtype=out_dtype, device=a.device).uop.multi(0),
device=a.device)
else:
out = Tensor.invalids(batch, M, N, dtype=out_dtype, device=a.device)
renderer = Device[dname:=(a.device[0] if is_multi else a.device)].renderer
dname, arch = dname.split(":")[0], renderer.target.arch
if arch.startswith("gfx950") and getenv("USE_ASM", 1):
# fp8 gemm computes a@b.T, kernel multiplies output by x_scale * w_scale before bf16 store
if a.dtype == FP8_DTYPE:
scales = tuple(s for s in (x_scale, w_scale) if s is not None)
scale_mode = (1 if x_scale is not None else 0) | (2 if w_scale is not None else 0)
extra = ([grad_amax_state] if grad_amax_state is not None else []) + ([w_post_scale] if w_post_scale is not None else [])
fxn = functools.partial(custom_hk_fp8_gemm, dname=dname, scale_mode=scale_mode)
bw = functools.partial(custom_gemm_bw, n_scales=len(scales), has_grad_amax=grad_amax_state is not None, has_w_post=w_post_scale is not None)
out = Tensor.custom_kernel(out, a, b.T, *scales, *extra, fxn=fxn, grad_fxn=bw)[0]
elif a.dtype == dtypes.bfloat16 and getenv("USE_HK_BF16_GEMM"):
out = Tensor.custom_kernel(out, a, b.T, b, fxn=functools.partial(custom_hk_bf16_gemm, dname=dname), grad_fxn=custom_gemm_bw)[0]
else:
out = Tensor.custom_kernel(out, a, b, fxn=functools.partial(custom_asm_gemm, dname=dname), grad_fxn=custom_gemm_bw)[0]
else:
out = Tensor.custom_kernel(out, a, b, fxn=custom_uop_gemm, grad_fxn=custom_gemm_bw)[0]
if k_sharded: out = out.sum(0)
out = out.squeeze(0) if squeeze else out
if unfold_batch: out = out.reshape(orig_batch, -1, out.shape[-1])
if w_post_scale is not None: out = (out * w_post_scale.reshape(*([1]*(out.ndim-1)), -1)).cast(out.dtype)
return out