Files
tinygrad/extra/gemm/amd_asm_matmul.py
qazal a37b605523 remove arch from asm kernel class (#15977)
* rm arch from kernel

* update other tests

* update abstractions4.py
2026-04-30 03:39:52 +09:00

501 lines
24 KiB
Python

# RDNA3 128x128 tiled GEMM kernel - DSL version
# Computes C = A @ B for NxN float32 matrices using 128x128 tiles
#
# Architecture: RDNA3 (gfx1100)
# Tile size: 128x128 (each workgroup computes one tile of C)
# Workgroup: 128 threads (arranged as 32x4 for coalesced memory access)
# Inner loop: 8 iterations per K-block, processing 8 columns of A and 8 rows of B
#
# Accumulators: 128 vgprs (v[2-129])
import numpy as np
from tinygrad import Tensor, Device, Context, GlobalCounters
from tinygrad.uop.ops import UOp, Ops, KernelInfo
from tinygrad.helpers import getenv, colored
from tinygrad.dtype import dtypes, AddrSpace
from tinygrad.engine.realize import Estimates, run_linear
from tinygrad.renderer.amd.dsl import s, v, VCC_LO, NULL
from tinygrad.runtime.autogen.amd.rdna3.ins import *
# =============================================================================
# Kernel constants
# =============================================================================
LDS_SIZE = 8320 # Local data share size in bytes
LDS_A_STRIDE = 0x210 # LDS stride for A tile (528 bytes)
LDS_B_STRIDE = 0x200 # LDS stride for B tile (512 bytes)
LDS_BASE_OFFSET = 0x1080 # Base LDS offset for tiles
ADDR_MASK = 0x3fffff80 # Address alignment mask
# =============================================================================
# Named register assignments (VGPRs)
# =============================================================================
V_LANE_ID = 0 # lane_id set on startup
# Use tile gaps (v146-159) for named regs to minimize max VGPR
V_LANE_ID_MOD8 = 146 # lane_id & 7
V_LANE_MOD8_X4 = 147 # (lane_id & 7) << 2
V_LANE_DIV8_X4 = 150 # ((lane_id >> 3) & 3) << 2
V_LDS_B_BASE = 151 # LDS B-tile base address for inner loop
V_LDS_A_BASE = 154 # LDS A-tile base address for inner loop
V_GLOBAL_A_ADDR = 155 # global memory A prefetch address
V_GLOBAL_B_ADDR = 158 # global memory B prefetch address
V_LDS_A_ADDR = 159 # single base register for A stores
V_LDS_B_ADDR = 162 # single base register for B stores
# LDS tile register destinations - SEPARATE from DATA to avoid overlap
# A on banks 2-3, B on banks 0-1 to avoid bank conflicts in VOPD
V_A_TILE_REGS = [130, 134, 138, 142] # A tile: banks 2,2,2,2 (130%4=2, etc.)
V_B_TILE_REGS = [132, 136, 140, 144, 148, 152, 156, 160] # B tile: banks 0,0,0,0,0,0,0,0
# =============================================================================
# Named register assignments (SGPRs)
# =============================================================================
S_OUT_PTR = (0, 1) # output C matrix base pointer
S_WORKGROUP_X = 2 # workgroup_id_x (system SGPR, follows user SGPRs)
S_WORKGROUP_Y = 3 # workgroup_id_y (system SGPR)
S_DIM_N = 4 # matrix dimension N
S_LOOP_BOUND = 7 # K-8 (loop termination bound)
S_LOOP_CTR = 12 # loop counter (increments by 8)
S_PREFETCH_FLAG = 13 # prefetch condition flag / row stride in epilogue
S_TILE_X = 14 # workgroup_x << 7
S_TILE_Y = 15 # workgroup_y << 7
# Kernarg load destinations
S_KERNARG_A = (20, 21) # A pointer from kernarg
S_KERNARG_B = (22, 23) # B pointer from kernarg
# Prefetch base pointers (8 pairs each, B: N*4 bytes apart, A: N*64 bytes apart)
S_PREFETCH_B = 24 # s[24:39] - 8 B tile pointers
S_PREFETCH_A = 40 # s[40:55] - 8 A tile pointers
# =============================================================================
# Data tables
# =============================================================================
# Accumulator grid: ACC_GRID[a_idx][b_idx] = vgpr for C[a,b]
# a_idx: which A value (0-7), b_idx: which B value (0-15)
# Scattered due to VOPD bank constraints (vdst_x % 4 != vdst_y % 4)
# Range is from v2 - v129
ACC_GRID = [
[ 5, 3, 9, 8, 37, 35, 41, 40, 69, 67, 73, 72, 101, 99,105,104], # a0
[ 4, 2, 7, 6, 36, 34, 39, 38, 68, 66, 71, 70, 100, 98,103,102], # a1
[ 17, 16, 13, 11, 49, 48, 45, 43, 81, 80, 77, 75, 113,112,109,107], # a2
[ 15, 14, 12, 10, 47, 46, 44, 42, 79, 78, 76, 74, 111,110,108,106], # a3
[ 21, 19, 25, 24, 53, 51, 57, 56, 85, 83, 89, 88, 117,115,121,120], # a4
[ 20, 18, 23, 22, 52, 50, 55, 54, 84, 82, 87, 86, 116,114,123,122], # a5
[125,128, 29, 27, 33, 32, 61, 59, 65, 64, 93, 91, 97, 96,129,127], # a6
[119,118, 28, 26, 31, 30, 60, 58, 63, 62, 92, 90, 95, 94,124,126], # a7
]
# Optimized (a_pair, b_pair) iteration order for better GPU scheduling
# Interleaves A and B pairs to maximize instruction-level parallelism
FMAC_PAIR_ORDER = [
(0,0),(0,1),(1,1),(1,0), (2,0),(2,1),(3,1),(3,2), (0,2),(0,3),(1,3),(1,2), (2,2),(2,3),(3,3),(3,4),
(0,4),(0,5),(1,5),(1,4), (2,4),(2,5),(3,5),(3,6), (0,6),(0,7),(1,7),(1,6), (2,6),(2,7),(3,7),(3,0),
]
def derive_fmac_pattern(acc_grid, a_tile_regs=None, b_tile_regs=None):
"""Generate 64 dual FMAC ops from accumulator grid with optimized iteration order."""
pattern = []
for idx, (a_pair, b_pair) in enumerate(FMAC_PAIR_ORDER):
a_even, a_odd = a_pair * 2, a_pair * 2 + 1
b_even, b_odd = b_pair * 2, b_pair * 2 + 1
a_base, b_base = a_tile_regs[a_pair], b_tile_regs[b_pair]
# Op 1: normal order -> C[a_even, b_even] + C[a_odd, b_odd]
pattern.append((acc_grid[a_even][b_even], acc_grid[a_odd][b_odd],
a_base, b_base, a_base+1, b_base+1))
# Op 2: alternate swapping A vs B to vary register banks
if idx % 2 == 0: # swap B
pattern.append((acc_grid[a_even][b_odd], acc_grid[a_odd][b_even],
a_base, b_base+1, a_base+1, b_base))
else: # swap A
pattern.append((acc_grid[a_odd][b_even], acc_grid[a_even][b_odd],
a_base+1, b_base, a_base, b_base+1))
return pattern
# Derived: 64 dual FMAC operations
FMAC_PATTERN = derive_fmac_pattern(ACC_GRID, V_A_TILE_REGS, V_B_TILE_REGS)
def derive_permute_swaps(acc_grid, out_regs):
"""Derive swap sequence to permute accumulators from FMAC layout to output order.
After FMAC loop: acc_grid[a][b] holds C[a,b]
Output order: for row_half in 0,1; col_group in 0-3; row_in_group in 0-3; b_off in 0-3
-> need C[row_half*4 + row_in_group, col_group*4 + b_off] in specified reg order
"""
def target_ab(i):
row_half, col_group = i // 64, (i // 16) % 4
row_in_group, b_off = (i // 4) % 4, i % 4
return (row_half * 4 + row_in_group, col_group * 4 + b_off)
reg_contents = {acc_grid[a][b]: (a, b) for a in range(8) for b in range(16)}
ab_location = {ab: r for r, ab in reg_contents.items()}
swaps = []
for i in range(128):
target_reg, needed_ab = out_regs[i], target_ab(i)
current_reg = ab_location[needed_ab]
if current_reg != target_reg:
swaps.append((current_reg, target_reg))
ab_at_target = reg_contents.get(target_reg)
reg_contents[target_reg], ab_location[needed_ab] = needed_ab, target_reg
if ab_at_target is not None:
reg_contents[current_reg], ab_location[ab_at_target] = ab_at_target, current_reg
return swaps
# Derived: swap sequence to arrange accumulators for output
# Each group of 4 registers is ascending for direct global_store_b128
OUT_REGS = [r for i in range(32) for r in range(126 - i*4, 130 - i*4)]
PERMUTE_SWAPS = derive_permute_swaps(ACC_GRID, OUT_REGS)
# =============================================================================
# LDS tile staging registers
# =============================================================================
# DATA regs receive contiguous global prefetch, then write to LDS
# TILE regs receive scattered LDS loads (ds_load_b64 pairs), then feed FMACs
# Contiguous layout with mod4=[3,0,1,2,3,0,1,2] for bank conflict avoidance
V_LDS_A_DATA = [163, 164, 165, 166, 167, 168, 169, 170]
V_LDS_B_DATA = [171, 172, 173, 174, 175, 176, 177, 178]
# Initial tile prefetch: (vdst, saddr_lo) - load into A data regs using B prefetch pointers (s[24:31])
INIT_PREFETCH = [(V_LDS_A_DATA[i], S_PREFETCH_B+2*i) for i in range(4)]
# Global memory prefetch schedule: (vdst1, vdst2, addr_vreg, saddr_lo1, saddr_lo2)
# First 2 pairs from B prefetch pointers (s[32:39]), next 4 pairs from A prefetch pointers (s[40:55])
PREFETCH_LOADS = [(V_LDS_A_DATA[4+2*i], V_LDS_A_DATA[4+2*i+1], V_GLOBAL_B_ADDR, S_PREFETCH_B+8+4*i, S_PREFETCH_B+10+4*i) for i in range(2)] + \
[(V_LDS_B_DATA[2*(i-2)], V_LDS_B_DATA[2*(i-2)+1], V_GLOBAL_A_ADDR, S_PREFETCH_A+4*(i-2), S_PREFETCH_A+2+4*(i-2)) for i in range(2, 6)]
# =============================================================================
# Kernel class
# =============================================================================
class Kernel:
def __init__(self): self.instructions, self.labels, self.pos = [], {}, 0
def label(self, name): self.labels[name] = self.pos
def emit(self, inst, target=None):
self.instructions.append(inst)
inst._target, inst._pos = target, self.pos
self.pos += inst.size()
return inst
def waitcnt(self, lgkm=None, vm=None):
"""Wait for memory operations. lgkm=N waits until N lgkm ops remain, vm=N waits until N vmem ops remain."""
vmcnt, lgkmcnt, expcnt = vm if vm is not None else 63, lgkm if lgkm is not None else 63, 7
waitcnt = (expcnt & 0x7) | ((lgkmcnt & 0x3f) << 4) | ((vmcnt & 0x3f) << 10)
self.emit(s_waitcnt(simm16=waitcnt))
def finalize(self):
"""Patch branch offsets and return the finalized instruction list."""
for inst in self.instructions:
if inst._target is None: continue
offset_dwords = (self.labels[inst._target] - inst._pos - inst.size()) // 4
if not -32768 <= offset_dwords <= 32767: raise ValueError(f"branch to '{inst._target}' offset {offset_dwords} exceeds simm16 range")
inst.simm16 = offset_dwords
return self.instructions
# =============================================================================
# Kernel builder
# =============================================================================
def build_kernel(N):
assert N % 128 == 0, f"N must be a multiple of 128 (tile size), got {N}"
assert N >= 256, f"N must be >= 256 (prefetch pipeline requires at least 2 K-blocks), got {N}"
k = Kernel()
# ===========================================================================
# PROLOGUE: Load kernel arguments, compute tile coordinates and addresses
# ===========================================================================
k.emit(s_load_b128(sdata=s[S_KERNARG_A[0]:S_KERNARG_B[1]], sbase=s[0:1], offset=0x0, soffset=NULL))
k.emit(s_load_b64(sdata=s[S_OUT_PTR[0]:S_OUT_PTR[1]], sbase=s[0:1], offset=0x10, soffset=NULL))
k.emit(s_mov_b32(s[S_DIM_N], N))
k.emit(s_mov_b32(s[S_LOOP_CTR], 0)) # used by LDS swizzle, always 0 for valid workgroups
k.emit(s_lshl_b32(s[S_TILE_X], s[S_WORKGROUP_X], 7))
k.emit(s_lshl_b32(s[S_TILE_Y], s[S_WORKGROUP_Y], 7))
# Lane-derived values
k.emit(v_and_b32_e32(v[V_LANE_ID_MOD8], 7, v[V_LANE_ID]))
k.emit(v_lshrrev_b32_e32(v[4], 3, v[V_LANE_ID]))
k.emit(v_or_b32_e32(v[1], s[S_TILE_X], v[V_LANE_ID]))
k.emit(v_or_b32_e32(v[22], s[S_TILE_Y], v[4]))
k.emit(v_lshlrev_b32_e32(v[V_LANE_MOD8_X4], 2, v[V_LANE_ID_MOD8]))
k.waitcnt(lgkm=0)
# Compute 8 A and B matrix tile base pointers for prefetch
k.emit(s_mov_b64(s[S_PREFETCH_B:S_PREFETCH_B+1], s[S_KERNARG_B[0]:S_KERNARG_B[1]])) # B[0]: no offset
for i in range(1, 8): # B: each pointer 1 row of B apart (N*4 bytes)
k.emit(s_add_u32(s[S_PREFETCH_B+i*2], s[S_KERNARG_B[0]], i * N * 4))
k.emit(s_addc_u32(s[S_PREFETCH_B+i*2+1], s[S_KERNARG_B[1]], 0))
k.emit(s_mov_b64(s[S_PREFETCH_A:S_PREFETCH_A+1], s[S_KERNARG_A[0]:S_KERNARG_A[1]])) # A[0]: no offset
for i in range(1, 8): # A: each pointer 16 rows of A apart (16*N*4 bytes)
k.emit(s_add_u32(s[S_PREFETCH_A+i*2], s[S_KERNARG_A[0]], i * N * 64))
k.emit(s_addc_u32(s[S_PREFETCH_A+i*2+1], s[S_KERNARG_A[1]], 0))
# Global prefetch addresses: B = (tile_x + lane_id) * 4, A = (tile_y*N + (lane_id/8)*N + lane_id%8) * 4
k.emit(v_add_nc_u32_e32(v[V_GLOBAL_B_ADDR], s[S_TILE_X], v[V_LANE_ID]))
k.emit(v_lshlrev_b32_e32(v[V_GLOBAL_B_ADDR], 2, v[V_GLOBAL_B_ADDR]))
k.emit(s_mul_i32(s[19], s[S_TILE_Y], N))
k.emit(v_mul_lo_u32(v[V_GLOBAL_A_ADDR], v[4], N)) # (lane_id/8)*N
k.emit(v_add_nc_u32_e32(v[V_GLOBAL_A_ADDR], v[V_LANE_ID_MOD8], v[V_GLOBAL_A_ADDR])) # + lane_id%8
k.emit(v_add_nc_u32_e32(v[V_GLOBAL_A_ADDR], s[19], v[V_GLOBAL_A_ADDR]))
k.emit(v_lshlrev_b32_e32(v[V_GLOBAL_A_ADDR], 2, v[V_GLOBAL_A_ADDR]))
# Do initial loads
for vdst, saddr_lo in INIT_PREFETCH:
k.emit(global_load_b32(vdst=v[vdst], addr=v[V_GLOBAL_B_ADDR], saddr=s[saddr_lo:saddr_lo+1]))
for iter in range(6):
vdst1, vdst2, addr, slo1, slo2 = PREFETCH_LOADS[iter]
k.emit(global_load_b32(vdst=v[vdst1], addr=v[addr], saddr=s[slo1:slo1+1]))
k.emit(global_load_b32(vdst=v[vdst2], addr=v[addr], saddr=s[slo2:slo2+1]))
# ===========================================================================
# LDS store address computation (bank-conflict-avoiding swizzle)
# ===========================================================================
# This section computes LDS store addresses with a swizzle pattern to avoid bank conflicts.
# The swizzle ensures that threads in the same wavefront write to different LDS banks.
# Formula: swizzled_addr = base + (lane_id & 7) * LDS_A_STRIDE + swizzle_offset
# where swizzle_offset depends on (lane_id >> 3) to distribute across banks.
k.emit(v_add_nc_u32_e32(v[9], s[S_LOOP_CTR], v[22])) # row 0 base
k.emit(v_and_b32_e32(v[9], ADDR_MASK, v[9]))
k.emit(v_sub_nc_u32_e32(v[9], v[22], v[9])) # row 0 swizzle offset
k.emit(v_lshlrev_b32_e32(v[9], 2, v[9])) # * 4
k.emit(v_mad_u32_u24(v[V_LDS_B_ADDR], LDS_A_STRIDE, v[V_LANE_ID_MOD8], v[9]))
# For V_LDS_A_BASE and epilogue
k.emit(v_bfe_u32(v[2], v[V_LANE_ID], 3, 2)) # v[2] = (lane_id >> 3) & 3
k.emit(v_lshlrev_b32_e32(v[V_LANE_DIV8_X4], 2, v[2]))
# Compute LDS load/store base addresses for inner loop
k.emit(v_lshlrev_b32_e32(v[2], 4, v[2]))
k.emit(v_and_b32_e32(v[3], 0x7F, v[1])) # simplified from 3 lines
k.emit(v_lshl_or_b32(v[V_LDS_B_BASE], v[V_LANE_ID_MOD8], 4, LDS_BASE_OFFSET))
k.emit(v_lshl_add_u32(v[V_LDS_A_ADDR], v[3], 2, LDS_BASE_OFFSET))
k.emit(v_lshlrev_b32_e32(v[3], 2, v[V_LANE_ID]))
k.emit(v_and_or_b32(v[V_LDS_A_BASE], 0x180, v[3], v[2]))
# Do initial stores
k.waitcnt(vm=0)
for i in range(4): # A tile: 8 values via 4 stride64 stores
k.emit(ds_store_2addr_stride64_b32(addr=v[V_LDS_A_ADDR], data0=v[V_LDS_A_DATA[i*2]], data1=v[V_LDS_A_DATA[i*2+1]], offset0=i*4, offset1=i*4+2))
for i in range(8): # B tile: 8 values via 8 scalar stores with 64-byte spacing
offset = i * 64
k.emit(ds_store_b32(addr=v[V_LDS_B_ADDR], data0=v[V_LDS_B_DATA[i]], offset0=offset & 0xFF, offset1=offset >> 8))
# Zero all 128 accumulators using VOPD dual moves (64 instructions instead of 128)
for i in range(0, len(OUT_REGS), 2):
k.emit(VOPD(VOPDOp.V_DUAL_MOV_B32, VOPDOp.V_DUAL_MOV_B32, vdstx=v[OUT_REGS[i]], vdsty=v[OUT_REGS[i+1]], srcx0=0, srcy0=0))
k.emit(s_add_i32(s[S_LOOP_BOUND], s[S_DIM_N], -8))
# S_LOOP_CTR is already 0 from prologue initialization
k.emit(s_branch(), target='LOOP_ENTRY')
# ===========================================================================
# MAIN GEMM LOOP
# ===========================================================================
NO_ALU, NO_DS, NO_GLOBAL = getenv("NO_ALU", 0), getenv("NO_DS", 0), getenv("NO_GLOBAL", 0)
k.label('LOOP_INC')
k.emit(s_add_i32(s[S_LOOP_CTR], s[S_LOOP_CTR], 8))
k.emit(s_cmp_ge_i32(s[S_LOOP_CTR], s[S_DIM_N]))
k.emit(s_cbranch_scc1(), target='EPILOGUE')
k.label('LOOP_ENTRY')
k.emit(s_cmp_lt_i32(s[S_LOOP_CTR], s[S_LOOP_BOUND]))
k.emit(s_cselect_b32(s[S_PREFETCH_FLAG], -1, 0)) # s_cselect doesn't modify SCC
k.emit(s_cbranch_scc0(), target='SKIP_PREFETCH') # branch if loop_ctr >= loop_bound
if not NO_GLOBAL:
# Advance prefetch pointers (VGPR)
#k.emit(v_add_nc_u32_e32(v[V_GLOBAL_B_ADDR], N * 32, v[V_GLOBAL_B_ADDR]))
#k.emit(v_add_nc_u32_e32(v[V_GLOBAL_A_ADDR], 0x20, v[V_GLOBAL_A_ADDR]))
# Advance prefetch pointers (64-bit adds): B advances 8 rows (8*N*4 bytes), A advances 8 cols (8*4 bytes)
k.emit(s_clause(simm16=31))
for i in range(8):
k.emit(s_add_u32(s[S_PREFETCH_B+i*2], s[S_PREFETCH_B+i*2], N * 32))
k.emit(s_addc_u32(s[S_PREFETCH_B+i*2+1], s[S_PREFETCH_B+i*2+1], 0))
for i in range(8):
k.emit(s_add_u32(s[S_PREFETCH_A+i*2], s[S_PREFETCH_A+i*2], 0x20))
k.emit(s_addc_u32(s[S_PREFETCH_A+i*2+1], s[S_PREFETCH_A+i*2+1], 0))
# do the fetch
for vdst, saddr_lo in INIT_PREFETCH:
k.emit(global_load_b32(vdst=v[vdst], addr=v[V_GLOBAL_B_ADDR], saddr=s[saddr_lo:saddr_lo+1]))
k.label('SKIP_PREFETCH')
# wait for local stores to finish (either initial or loop)
# then sync the warp so it's safe to load local
k.waitcnt(lgkm=0)
k.emit(s_barrier())
# 8 inner loop iterations
for iter in range(8):
# Load A tile (4 pairs) and B tile (8 pairs) from LDS
if not NO_DS:
k.emit(s_clause(simm16=len(V_A_TILE_REGS) + len(V_B_TILE_REGS) - 1)) # 12 loads total: 4 A + 8 B
# A tile: 4 ds_load_b64
for i, vdst in enumerate(V_A_TILE_REGS):
a_off = (i & 1) * 8 + (i >> 1) * 64 + iter * LDS_A_STRIDE
k.emit(ds_load_b64(vdst=v[vdst:vdst+1], addr=v[V_LDS_A_BASE], offset0=a_off & 0xFF, offset1=a_off >> 8))
# B tile: 8 ds_load_b64
for i, vdst in enumerate(V_B_TILE_REGS):
b_off = (i & 1) * 8 + (i & 2) * 64 + (i >> 2) * 256 + iter * LDS_B_STRIDE
k.emit(ds_load_b64(vdst=v[vdst:vdst+1], addr=v[V_LDS_B_BASE], offset0=b_off & 0xFF, offset1=b_off >> 8))
# Issue global prefetch (first 6 iterations only)
if iter < 6 and not NO_GLOBAL:
vdst1, vdst2, addr, slo1, slo2 = PREFETCH_LOADS[iter]
k.emit(global_load_b32(vdst=v[vdst1], addr=v[addr], saddr=s[slo1:slo1+1]))
k.emit(global_load_b32(vdst=v[vdst2], addr=v[addr], saddr=s[slo2:slo2+1]))
# 64 dual FMACs
k.waitcnt(lgkm=0)
if not NO_ALU:
k.emit(s_clause(simm16=len(FMAC_PATTERN)-1))
for i, (vdst_x, vdst_y, ax, bx, ay, by) in enumerate(FMAC_PATTERN):
k.emit(VOPD(VOPDOp.V_DUAL_FMAC_F32, VOPDOp.V_DUAL_FMAC_F32,
vdstx=v[vdst_x], vdsty=v[vdst_y], srcx0=v[ax], vsrcx1=v[bx], srcy0=v[ay], vsrcy1=v[by]))
# wait for all global loads to finish
# then sync the warp so it's safe to store local
k.waitcnt(vm=0)
k.emit(s_barrier())
# Store prefetched data to LDS
# NOTE: Register naming reflects LDS tile organization, not source matrix:
# V_LDS_A_DATA (v155-162) holds data that goes to LDS A-tile region
# V_LDS_B_DATA (v163-170) holds data that goes to LDS B-tile region
# The data sources are swapped: A-tile receives B matrix rows, B-tile receives A matrix columns
if not NO_DS:
for i in range(4): # A tile: 8 values via 4 stride64 stores
k.emit(ds_store_2addr_stride64_b32(addr=v[V_LDS_A_ADDR], data0=v[V_LDS_A_DATA[i*2]], data1=v[V_LDS_A_DATA[i*2+1]], offset0=i*4, offset1=i*4+2))
for i in range(8): # B tile: 8 values via 8 scalar stores with 64-byte spacing
offset = i * 64
k.emit(ds_store_b32(addr=v[V_LDS_B_ADDR], data0=v[V_LDS_B_DATA[i]], offset0=offset & 0xFF, offset1=offset >> 8))
k.emit(s_branch(), target='LOOP_INC')
# ===========================================================================
# EPILOGUE: Permute and store results
# ===========================================================================
k.label('EPILOGUE')
# Rearrange accumulators from FMAC layout to contiguous output order
for a, b in PERMUTE_SWAPS:
k.emit(v_swap_b32_e32(v[a], v[b]))
# Compute output base coordinates
# v[130] = col_base = tile_x + (lane_id & 7) * 4
# v[131] = row_base = tile_y + (lane_id & 0x60) + ((lane_id >> 3) & 3) * 4
# v[132] = 0 (for 64-bit address high part)
k.emit(v_add_nc_u32_e32(v[130], s[S_TILE_X], v[V_LANE_MOD8_X4]))
k.emit(v_and_b32_e32(v[131], 0x60, v[V_LANE_ID]))
k.emit(v_add_nc_u32_e32(v[131], s[S_TILE_Y], v[131]))
k.emit(v_add_nc_u32_e32(v[131], v[V_LANE_DIV8_X4], v[131]))
k.emit(v_mov_b32_e32(v[132], 0))
# Precompute row offsets: v[133-136] for rows 0-3, v[137-140] for rows 16-19
for base, row_off in [(133, 0), (137, 16)]:
if row_off: k.emit(v_add_nc_u32_e32(v[141], row_off, v[131]))
k.emit(v_mul_lo_u32(v[base], v[141] if row_off else v[131], s[S_DIM_N]))
for j in range(3): k.emit(v_add_nc_u32_e32(v[base + 1 + j], s[S_DIM_N], v[base + j]))
# s[S_PREFETCH_FLAG] = row stride in bytes (N * 4)
k.emit(s_lshl_b32(s[S_PREFETCH_FLAG], s[S_DIM_N], 2))
# Store 128 output values as 32 groups of 4 (128-bit stores)
# Layout: 2 row halves (0-3, 16-19) x 4 col groups x 4 rows = 32 stores of 4 floats
for i, (row_half, col_off, row_in_group) in enumerate([(rh, co, ri)
for rh in range(2) for co in [0, 32, 64, 96] for ri in range(4)]):
row = row_half * 16 + row_in_group
src = OUT_REGS[i*4] # first reg of ascending group of 4
if row_in_group == 0:
# First row of group: compute full address
if col_off == 0: k.emit(v_mov_b32_e32(v[141], v[130]))
else: k.emit(v_add_nc_u32_e32(v[141], col_off, v[130]))
row_base = 133 + row if row < 4 else 137 + row - 16
k.emit(v_add_nc_u32_e32(v[141], v[row_base], v[141]))
k.emit(v_lshlrev_b32_e32(v[141], 2, v[141]))
k.emit(v_add_co_u32(v[141], VCC_LO, s[S_OUT_PTR[0]], v[141]))
k.emit(v_add_co_ci_u32_e32(v[142], s[S_OUT_PTR[1]], v[132]))
else:
# Subsequent rows: add stride
k.emit(v_add_co_u32(v[141], VCC_LO, s[S_PREFETCH_FLAG], v[141]))
k.emit(v_add_co_ci_u32_e32(v[142], v[142], v[132]))
k.emit(global_store_b128(addr=v[141:142], data=v[src:src+3], saddr=NULL))
k.emit(s_sendmsg(simm16=3)) # DEALLOC_VGPRS
k.emit(s_endpgm())
return k.finalize()
# =============================================================================
# Test harness
# =============================================================================
N = getenv("N", 4096)
BLOCK_M, BLOCK_N = 128, 128
THREADS = 128
def test_matmul():
dev = Device[Device.DEFAULT]
print(f"Device arch: {dev.renderer.target.arch}")
insts = build_kernel(N)
rng = np.random.default_rng(42)
a = Tensor(rng.random((N, N), dtype=np.float32) - 0.5)
b = Tensor(rng.random((N, N), dtype=np.float32) - 0.5)
c = Tensor.empty(N, N)
Tensor.realize(a, b, c)
grid, local = (N // BLOCK_N, N // BLOCK_M, 1), (THREADS, 1, 1)
print(f"Grid: {grid}, Local: {local}")
dname:str = Device.DEFAULT
def asm_kernel(A:UOp, B:UOp, C:UOp) -> UOp:
gidxs = [UOp.special(n, f"gidx{i}") for i,n in enumerate(grid)]
lidxs = [UOp.special(n, f"lidx{i}") for i,n in enumerate(local)]
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=max(LDS_SIZE, 65536//getenv("LIMIT_OCC", 65536)), addrspace=AddrSpace.LOCAL), (), 'lds')
sink = UOp.sink(A.base, B.base, C.base, lds, *gidxs, *lidxs, arg=KernelInfo(name=colored("kernel", "cyan"),
estimates=Estimates(ops=N*N*N*2, mem=N*N*4*3)))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
c = Tensor.custom_kernel(a, b, c, fxn=asm_kernel)[2]
linear = c.schedule_linear()
ets = []
with Context(DEBUG=2):
for _ in range(getenv("CNT", 5)):
start = GlobalCounters.time_sum_s
run_linear(linear)
ets.append(GlobalCounters.time_sum_s - start)
print(f"REAL TFLOPS {N * N * N * 2 / min(ets) * 1e-12:.2f}")
if getenv("VERIFY", 1):
GlobalCounters.reset()
with Context(DEBUG=2): tc = (a @ b).realize()
with Context(DEBUG=0): err = (c - tc).square().mean().item()
print(f"mean squared error {err}")
if err != err or err > 1e-06:
c_np, tc_np = c.numpy(), tc.numpy()
for bi in range(N // 128):
for bj in range(N // 128):
blk_c = c_np[bi*128:(bi+1)*128, bj*128:(bj+1)*128]
blk_ref = tc_np[bi*128:(bi+1)*128, bj*128:(bj+1)*128]
blk_diff = blk_c - blk_ref
zero_rows = [i for i in range(128) if np.all(np.abs(blk_c[i,:]) < 1e-10)]
nz_rows = [i for i in range(128) if i not in zero_rows]
nz_mse = float(np.mean(blk_diff[nz_rows,:]**2)) if nz_rows else 0
print(f"Block ({bi},{bj}): zero_rows={zero_rows}, nz_rows_mse={nz_mse:.2e}")
# show first few non-zero row comparisons
if nz_rows and nz_mse > 1e-6:
for r in nz_rows[:3]:
print(f" row {r} asm[0:8]: {blk_c[r,:8]}")
print(f" row {r} ref[0:8]: {blk_ref[r,:8]}")
raise RuntimeError("matmul is wrong!")
if __name__ == "__main__":
test_matmul()