mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-13 16:37:04 +08:00
501 lines
24 KiB
Python
501 lines
24 KiB
Python
# RDNA3 128x128 tiled GEMM kernel - DSL version
|
|
# Computes C = A @ B for NxN float32 matrices using 128x128 tiles
|
|
#
|
|
# Architecture: RDNA3 (gfx1100)
|
|
# Tile size: 128x128 (each workgroup computes one tile of C)
|
|
# Workgroup: 128 threads (arranged as 32x4 for coalesced memory access)
|
|
# Inner loop: 8 iterations per K-block, processing 8 columns of A and 8 rows of B
|
|
#
|
|
# Accumulators: 128 vgprs (v[2-129])
|
|
|
|
import numpy as np
|
|
from tinygrad import Tensor, Device, Context, GlobalCounters
|
|
from tinygrad.uop.ops import UOp, Ops, KernelInfo
|
|
from tinygrad.helpers import getenv, colored
|
|
from tinygrad.dtype import dtypes, AddrSpace
|
|
from tinygrad.engine.realize import Estimates, run_linear
|
|
from tinygrad.renderer.amd.dsl import s, v, VCC_LO, NULL
|
|
from tinygrad.runtime.autogen.amd.rdna3.ins import *
|
|
|
|
# =============================================================================
|
|
# Kernel constants
|
|
# =============================================================================
|
|
LDS_SIZE = 8320 # Local data share size in bytes
|
|
LDS_A_STRIDE = 0x210 # LDS stride for A tile (528 bytes)
|
|
LDS_B_STRIDE = 0x200 # LDS stride for B tile (512 bytes)
|
|
LDS_BASE_OFFSET = 0x1080 # Base LDS offset for tiles
|
|
ADDR_MASK = 0x3fffff80 # Address alignment mask
|
|
|
|
# =============================================================================
|
|
# Named register assignments (VGPRs)
|
|
# =============================================================================
|
|
V_LANE_ID = 0 # lane_id set on startup
|
|
# Use tile gaps (v146-159) for named regs to minimize max VGPR
|
|
V_LANE_ID_MOD8 = 146 # lane_id & 7
|
|
V_LANE_MOD8_X4 = 147 # (lane_id & 7) << 2
|
|
V_LANE_DIV8_X4 = 150 # ((lane_id >> 3) & 3) << 2
|
|
V_LDS_B_BASE = 151 # LDS B-tile base address for inner loop
|
|
V_LDS_A_BASE = 154 # LDS A-tile base address for inner loop
|
|
V_GLOBAL_A_ADDR = 155 # global memory A prefetch address
|
|
V_GLOBAL_B_ADDR = 158 # global memory B prefetch address
|
|
V_LDS_A_ADDR = 159 # single base register for A stores
|
|
V_LDS_B_ADDR = 162 # single base register for B stores
|
|
|
|
# LDS tile register destinations - SEPARATE from DATA to avoid overlap
|
|
# A on banks 2-3, B on banks 0-1 to avoid bank conflicts in VOPD
|
|
V_A_TILE_REGS = [130, 134, 138, 142] # A tile: banks 2,2,2,2 (130%4=2, etc.)
|
|
V_B_TILE_REGS = [132, 136, 140, 144, 148, 152, 156, 160] # B tile: banks 0,0,0,0,0,0,0,0
|
|
|
|
# =============================================================================
|
|
# Named register assignments (SGPRs)
|
|
# =============================================================================
|
|
S_OUT_PTR = (0, 1) # output C matrix base pointer
|
|
S_WORKGROUP_X = 2 # workgroup_id_x (system SGPR, follows user SGPRs)
|
|
S_WORKGROUP_Y = 3 # workgroup_id_y (system SGPR)
|
|
S_DIM_N = 4 # matrix dimension N
|
|
S_LOOP_BOUND = 7 # K-8 (loop termination bound)
|
|
S_LOOP_CTR = 12 # loop counter (increments by 8)
|
|
S_PREFETCH_FLAG = 13 # prefetch condition flag / row stride in epilogue
|
|
S_TILE_X = 14 # workgroup_x << 7
|
|
S_TILE_Y = 15 # workgroup_y << 7
|
|
# Kernarg load destinations
|
|
S_KERNARG_A = (20, 21) # A pointer from kernarg
|
|
S_KERNARG_B = (22, 23) # B pointer from kernarg
|
|
# Prefetch base pointers (8 pairs each, B: N*4 bytes apart, A: N*64 bytes apart)
|
|
S_PREFETCH_B = 24 # s[24:39] - 8 B tile pointers
|
|
S_PREFETCH_A = 40 # s[40:55] - 8 A tile pointers
|
|
|
|
# =============================================================================
|
|
# Data tables
|
|
# =============================================================================
|
|
|
|
# Accumulator grid: ACC_GRID[a_idx][b_idx] = vgpr for C[a,b]
|
|
# a_idx: which A value (0-7), b_idx: which B value (0-15)
|
|
# Scattered due to VOPD bank constraints (vdst_x % 4 != vdst_y % 4)
|
|
# Range is from v2 - v129
|
|
ACC_GRID = [
|
|
[ 5, 3, 9, 8, 37, 35, 41, 40, 69, 67, 73, 72, 101, 99,105,104], # a0
|
|
[ 4, 2, 7, 6, 36, 34, 39, 38, 68, 66, 71, 70, 100, 98,103,102], # a1
|
|
[ 17, 16, 13, 11, 49, 48, 45, 43, 81, 80, 77, 75, 113,112,109,107], # a2
|
|
[ 15, 14, 12, 10, 47, 46, 44, 42, 79, 78, 76, 74, 111,110,108,106], # a3
|
|
[ 21, 19, 25, 24, 53, 51, 57, 56, 85, 83, 89, 88, 117,115,121,120], # a4
|
|
[ 20, 18, 23, 22, 52, 50, 55, 54, 84, 82, 87, 86, 116,114,123,122], # a5
|
|
[125,128, 29, 27, 33, 32, 61, 59, 65, 64, 93, 91, 97, 96,129,127], # a6
|
|
[119,118, 28, 26, 31, 30, 60, 58, 63, 62, 92, 90, 95, 94,124,126], # a7
|
|
]
|
|
|
|
# Optimized (a_pair, b_pair) iteration order for better GPU scheduling
|
|
# Interleaves A and B pairs to maximize instruction-level parallelism
|
|
FMAC_PAIR_ORDER = [
|
|
(0,0),(0,1),(1,1),(1,0), (2,0),(2,1),(3,1),(3,2), (0,2),(0,3),(1,3),(1,2), (2,2),(2,3),(3,3),(3,4),
|
|
(0,4),(0,5),(1,5),(1,4), (2,4),(2,5),(3,5),(3,6), (0,6),(0,7),(1,7),(1,6), (2,6),(2,7),(3,7),(3,0),
|
|
]
|
|
|
|
def derive_fmac_pattern(acc_grid, a_tile_regs=None, b_tile_regs=None):
|
|
"""Generate 64 dual FMAC ops from accumulator grid with optimized iteration order."""
|
|
pattern = []
|
|
for idx, (a_pair, b_pair) in enumerate(FMAC_PAIR_ORDER):
|
|
a_even, a_odd = a_pair * 2, a_pair * 2 + 1
|
|
b_even, b_odd = b_pair * 2, b_pair * 2 + 1
|
|
a_base, b_base = a_tile_regs[a_pair], b_tile_regs[b_pair]
|
|
# Op 1: normal order -> C[a_even, b_even] + C[a_odd, b_odd]
|
|
pattern.append((acc_grid[a_even][b_even], acc_grid[a_odd][b_odd],
|
|
a_base, b_base, a_base+1, b_base+1))
|
|
# Op 2: alternate swapping A vs B to vary register banks
|
|
if idx % 2 == 0: # swap B
|
|
pattern.append((acc_grid[a_even][b_odd], acc_grid[a_odd][b_even],
|
|
a_base, b_base+1, a_base+1, b_base))
|
|
else: # swap A
|
|
pattern.append((acc_grid[a_odd][b_even], acc_grid[a_even][b_odd],
|
|
a_base+1, b_base, a_base, b_base+1))
|
|
return pattern
|
|
|
|
# Derived: 64 dual FMAC operations
|
|
FMAC_PATTERN = derive_fmac_pattern(ACC_GRID, V_A_TILE_REGS, V_B_TILE_REGS)
|
|
|
|
def derive_permute_swaps(acc_grid, out_regs):
|
|
"""Derive swap sequence to permute accumulators from FMAC layout to output order.
|
|
|
|
After FMAC loop: acc_grid[a][b] holds C[a,b]
|
|
Output order: for row_half in 0,1; col_group in 0-3; row_in_group in 0-3; b_off in 0-3
|
|
-> need C[row_half*4 + row_in_group, col_group*4 + b_off] in specified reg order
|
|
"""
|
|
def target_ab(i):
|
|
row_half, col_group = i // 64, (i // 16) % 4
|
|
row_in_group, b_off = (i // 4) % 4, i % 4
|
|
return (row_half * 4 + row_in_group, col_group * 4 + b_off)
|
|
|
|
reg_contents = {acc_grid[a][b]: (a, b) for a in range(8) for b in range(16)}
|
|
ab_location = {ab: r for r, ab in reg_contents.items()}
|
|
|
|
swaps = []
|
|
for i in range(128):
|
|
target_reg, needed_ab = out_regs[i], target_ab(i)
|
|
current_reg = ab_location[needed_ab]
|
|
if current_reg != target_reg:
|
|
swaps.append((current_reg, target_reg))
|
|
ab_at_target = reg_contents.get(target_reg)
|
|
reg_contents[target_reg], ab_location[needed_ab] = needed_ab, target_reg
|
|
if ab_at_target is not None:
|
|
reg_contents[current_reg], ab_location[ab_at_target] = ab_at_target, current_reg
|
|
return swaps
|
|
|
|
# Derived: swap sequence to arrange accumulators for output
|
|
# Each group of 4 registers is ascending for direct global_store_b128
|
|
OUT_REGS = [r for i in range(32) for r in range(126 - i*4, 130 - i*4)]
|
|
PERMUTE_SWAPS = derive_permute_swaps(ACC_GRID, OUT_REGS)
|
|
|
|
# =============================================================================
|
|
# LDS tile staging registers
|
|
# =============================================================================
|
|
# DATA regs receive contiguous global prefetch, then write to LDS
|
|
# TILE regs receive scattered LDS loads (ds_load_b64 pairs), then feed FMACs
|
|
# Contiguous layout with mod4=[3,0,1,2,3,0,1,2] for bank conflict avoidance
|
|
V_LDS_A_DATA = [163, 164, 165, 166, 167, 168, 169, 170]
|
|
V_LDS_B_DATA = [171, 172, 173, 174, 175, 176, 177, 178]
|
|
|
|
# Initial tile prefetch: (vdst, saddr_lo) - load into A data regs using B prefetch pointers (s[24:31])
|
|
INIT_PREFETCH = [(V_LDS_A_DATA[i], S_PREFETCH_B+2*i) for i in range(4)]
|
|
|
|
# Global memory prefetch schedule: (vdst1, vdst2, addr_vreg, saddr_lo1, saddr_lo2)
|
|
# First 2 pairs from B prefetch pointers (s[32:39]), next 4 pairs from A prefetch pointers (s[40:55])
|
|
PREFETCH_LOADS = [(V_LDS_A_DATA[4+2*i], V_LDS_A_DATA[4+2*i+1], V_GLOBAL_B_ADDR, S_PREFETCH_B+8+4*i, S_PREFETCH_B+10+4*i) for i in range(2)] + \
|
|
[(V_LDS_B_DATA[2*(i-2)], V_LDS_B_DATA[2*(i-2)+1], V_GLOBAL_A_ADDR, S_PREFETCH_A+4*(i-2), S_PREFETCH_A+2+4*(i-2)) for i in range(2, 6)]
|
|
|
|
# =============================================================================
|
|
# Kernel class
|
|
# =============================================================================
|
|
|
|
class Kernel:
|
|
def __init__(self): self.instructions, self.labels, self.pos = [], {}, 0
|
|
def label(self, name): self.labels[name] = self.pos
|
|
|
|
def emit(self, inst, target=None):
|
|
self.instructions.append(inst)
|
|
inst._target, inst._pos = target, self.pos
|
|
self.pos += inst.size()
|
|
return inst
|
|
|
|
def waitcnt(self, lgkm=None, vm=None):
|
|
"""Wait for memory operations. lgkm=N waits until N lgkm ops remain, vm=N waits until N vmem ops remain."""
|
|
vmcnt, lgkmcnt, expcnt = vm if vm is not None else 63, lgkm if lgkm is not None else 63, 7
|
|
waitcnt = (expcnt & 0x7) | ((lgkmcnt & 0x3f) << 4) | ((vmcnt & 0x3f) << 10)
|
|
self.emit(s_waitcnt(simm16=waitcnt))
|
|
|
|
def finalize(self):
|
|
"""Patch branch offsets and return the finalized instruction list."""
|
|
for inst in self.instructions:
|
|
if inst._target is None: continue
|
|
offset_dwords = (self.labels[inst._target] - inst._pos - inst.size()) // 4
|
|
if not -32768 <= offset_dwords <= 32767: raise ValueError(f"branch to '{inst._target}' offset {offset_dwords} exceeds simm16 range")
|
|
inst.simm16 = offset_dwords
|
|
return self.instructions
|
|
|
|
|
|
# =============================================================================
|
|
# Kernel builder
|
|
# =============================================================================
|
|
|
|
def build_kernel(N):
|
|
assert N % 128 == 0, f"N must be a multiple of 128 (tile size), got {N}"
|
|
assert N >= 256, f"N must be >= 256 (prefetch pipeline requires at least 2 K-blocks), got {N}"
|
|
k = Kernel()
|
|
|
|
# ===========================================================================
|
|
# PROLOGUE: Load kernel arguments, compute tile coordinates and addresses
|
|
# ===========================================================================
|
|
k.emit(s_load_b128(sdata=s[S_KERNARG_A[0]:S_KERNARG_B[1]], sbase=s[0:1], offset=0x0, soffset=NULL))
|
|
k.emit(s_load_b64(sdata=s[S_OUT_PTR[0]:S_OUT_PTR[1]], sbase=s[0:1], offset=0x10, soffset=NULL))
|
|
k.emit(s_mov_b32(s[S_DIM_N], N))
|
|
k.emit(s_mov_b32(s[S_LOOP_CTR], 0)) # used by LDS swizzle, always 0 for valid workgroups
|
|
k.emit(s_lshl_b32(s[S_TILE_X], s[S_WORKGROUP_X], 7))
|
|
k.emit(s_lshl_b32(s[S_TILE_Y], s[S_WORKGROUP_Y], 7))
|
|
|
|
# Lane-derived values
|
|
k.emit(v_and_b32_e32(v[V_LANE_ID_MOD8], 7, v[V_LANE_ID]))
|
|
k.emit(v_lshrrev_b32_e32(v[4], 3, v[V_LANE_ID]))
|
|
k.emit(v_or_b32_e32(v[1], s[S_TILE_X], v[V_LANE_ID]))
|
|
k.emit(v_or_b32_e32(v[22], s[S_TILE_Y], v[4]))
|
|
k.emit(v_lshlrev_b32_e32(v[V_LANE_MOD8_X4], 2, v[V_LANE_ID_MOD8]))
|
|
k.waitcnt(lgkm=0)
|
|
|
|
# Compute 8 A and B matrix tile base pointers for prefetch
|
|
k.emit(s_mov_b64(s[S_PREFETCH_B:S_PREFETCH_B+1], s[S_KERNARG_B[0]:S_KERNARG_B[1]])) # B[0]: no offset
|
|
for i in range(1, 8): # B: each pointer 1 row of B apart (N*4 bytes)
|
|
k.emit(s_add_u32(s[S_PREFETCH_B+i*2], s[S_KERNARG_B[0]], i * N * 4))
|
|
k.emit(s_addc_u32(s[S_PREFETCH_B+i*2+1], s[S_KERNARG_B[1]], 0))
|
|
k.emit(s_mov_b64(s[S_PREFETCH_A:S_PREFETCH_A+1], s[S_KERNARG_A[0]:S_KERNARG_A[1]])) # A[0]: no offset
|
|
for i in range(1, 8): # A: each pointer 16 rows of A apart (16*N*4 bytes)
|
|
k.emit(s_add_u32(s[S_PREFETCH_A+i*2], s[S_KERNARG_A[0]], i * N * 64))
|
|
k.emit(s_addc_u32(s[S_PREFETCH_A+i*2+1], s[S_KERNARG_A[1]], 0))
|
|
|
|
# Global prefetch addresses: B = (tile_x + lane_id) * 4, A = (tile_y*N + (lane_id/8)*N + lane_id%8) * 4
|
|
k.emit(v_add_nc_u32_e32(v[V_GLOBAL_B_ADDR], s[S_TILE_X], v[V_LANE_ID]))
|
|
k.emit(v_lshlrev_b32_e32(v[V_GLOBAL_B_ADDR], 2, v[V_GLOBAL_B_ADDR]))
|
|
k.emit(s_mul_i32(s[19], s[S_TILE_Y], N))
|
|
k.emit(v_mul_lo_u32(v[V_GLOBAL_A_ADDR], v[4], N)) # (lane_id/8)*N
|
|
k.emit(v_add_nc_u32_e32(v[V_GLOBAL_A_ADDR], v[V_LANE_ID_MOD8], v[V_GLOBAL_A_ADDR])) # + lane_id%8
|
|
k.emit(v_add_nc_u32_e32(v[V_GLOBAL_A_ADDR], s[19], v[V_GLOBAL_A_ADDR]))
|
|
k.emit(v_lshlrev_b32_e32(v[V_GLOBAL_A_ADDR], 2, v[V_GLOBAL_A_ADDR]))
|
|
|
|
# Do initial loads
|
|
for vdst, saddr_lo in INIT_PREFETCH:
|
|
k.emit(global_load_b32(vdst=v[vdst], addr=v[V_GLOBAL_B_ADDR], saddr=s[saddr_lo:saddr_lo+1]))
|
|
for iter in range(6):
|
|
vdst1, vdst2, addr, slo1, slo2 = PREFETCH_LOADS[iter]
|
|
k.emit(global_load_b32(vdst=v[vdst1], addr=v[addr], saddr=s[slo1:slo1+1]))
|
|
k.emit(global_load_b32(vdst=v[vdst2], addr=v[addr], saddr=s[slo2:slo2+1]))
|
|
|
|
# ===========================================================================
|
|
# LDS store address computation (bank-conflict-avoiding swizzle)
|
|
# ===========================================================================
|
|
# This section computes LDS store addresses with a swizzle pattern to avoid bank conflicts.
|
|
# The swizzle ensures that threads in the same wavefront write to different LDS banks.
|
|
# Formula: swizzled_addr = base + (lane_id & 7) * LDS_A_STRIDE + swizzle_offset
|
|
# where swizzle_offset depends on (lane_id >> 3) to distribute across banks.
|
|
k.emit(v_add_nc_u32_e32(v[9], s[S_LOOP_CTR], v[22])) # row 0 base
|
|
k.emit(v_and_b32_e32(v[9], ADDR_MASK, v[9]))
|
|
k.emit(v_sub_nc_u32_e32(v[9], v[22], v[9])) # row 0 swizzle offset
|
|
k.emit(v_lshlrev_b32_e32(v[9], 2, v[9])) # * 4
|
|
k.emit(v_mad_u32_u24(v[V_LDS_B_ADDR], LDS_A_STRIDE, v[V_LANE_ID_MOD8], v[9]))
|
|
|
|
# For V_LDS_A_BASE and epilogue
|
|
k.emit(v_bfe_u32(v[2], v[V_LANE_ID], 3, 2)) # v[2] = (lane_id >> 3) & 3
|
|
k.emit(v_lshlrev_b32_e32(v[V_LANE_DIV8_X4], 2, v[2]))
|
|
|
|
# Compute LDS load/store base addresses for inner loop
|
|
k.emit(v_lshlrev_b32_e32(v[2], 4, v[2]))
|
|
k.emit(v_and_b32_e32(v[3], 0x7F, v[1])) # simplified from 3 lines
|
|
k.emit(v_lshl_or_b32(v[V_LDS_B_BASE], v[V_LANE_ID_MOD8], 4, LDS_BASE_OFFSET))
|
|
k.emit(v_lshl_add_u32(v[V_LDS_A_ADDR], v[3], 2, LDS_BASE_OFFSET))
|
|
k.emit(v_lshlrev_b32_e32(v[3], 2, v[V_LANE_ID]))
|
|
k.emit(v_and_or_b32(v[V_LDS_A_BASE], 0x180, v[3], v[2]))
|
|
|
|
# Do initial stores
|
|
k.waitcnt(vm=0)
|
|
for i in range(4): # A tile: 8 values via 4 stride64 stores
|
|
k.emit(ds_store_2addr_stride64_b32(addr=v[V_LDS_A_ADDR], data0=v[V_LDS_A_DATA[i*2]], data1=v[V_LDS_A_DATA[i*2+1]], offset0=i*4, offset1=i*4+2))
|
|
for i in range(8): # B tile: 8 values via 8 scalar stores with 64-byte spacing
|
|
offset = i * 64
|
|
k.emit(ds_store_b32(addr=v[V_LDS_B_ADDR], data0=v[V_LDS_B_DATA[i]], offset0=offset & 0xFF, offset1=offset >> 8))
|
|
|
|
# Zero all 128 accumulators using VOPD dual moves (64 instructions instead of 128)
|
|
for i in range(0, len(OUT_REGS), 2):
|
|
k.emit(VOPD(VOPDOp.V_DUAL_MOV_B32, VOPDOp.V_DUAL_MOV_B32, vdstx=v[OUT_REGS[i]], vdsty=v[OUT_REGS[i+1]], srcx0=0, srcy0=0))
|
|
k.emit(s_add_i32(s[S_LOOP_BOUND], s[S_DIM_N], -8))
|
|
|
|
# S_LOOP_CTR is already 0 from prologue initialization
|
|
k.emit(s_branch(), target='LOOP_ENTRY')
|
|
|
|
# ===========================================================================
|
|
# MAIN GEMM LOOP
|
|
# ===========================================================================
|
|
|
|
NO_ALU, NO_DS, NO_GLOBAL = getenv("NO_ALU", 0), getenv("NO_DS", 0), getenv("NO_GLOBAL", 0)
|
|
|
|
k.label('LOOP_INC')
|
|
k.emit(s_add_i32(s[S_LOOP_CTR], s[S_LOOP_CTR], 8))
|
|
k.emit(s_cmp_ge_i32(s[S_LOOP_CTR], s[S_DIM_N]))
|
|
k.emit(s_cbranch_scc1(), target='EPILOGUE')
|
|
|
|
k.label('LOOP_ENTRY')
|
|
k.emit(s_cmp_lt_i32(s[S_LOOP_CTR], s[S_LOOP_BOUND]))
|
|
k.emit(s_cselect_b32(s[S_PREFETCH_FLAG], -1, 0)) # s_cselect doesn't modify SCC
|
|
k.emit(s_cbranch_scc0(), target='SKIP_PREFETCH') # branch if loop_ctr >= loop_bound
|
|
|
|
if not NO_GLOBAL:
|
|
# Advance prefetch pointers (VGPR)
|
|
#k.emit(v_add_nc_u32_e32(v[V_GLOBAL_B_ADDR], N * 32, v[V_GLOBAL_B_ADDR]))
|
|
#k.emit(v_add_nc_u32_e32(v[V_GLOBAL_A_ADDR], 0x20, v[V_GLOBAL_A_ADDR]))
|
|
|
|
# Advance prefetch pointers (64-bit adds): B advances 8 rows (8*N*4 bytes), A advances 8 cols (8*4 bytes)
|
|
k.emit(s_clause(simm16=31))
|
|
for i in range(8):
|
|
k.emit(s_add_u32(s[S_PREFETCH_B+i*2], s[S_PREFETCH_B+i*2], N * 32))
|
|
k.emit(s_addc_u32(s[S_PREFETCH_B+i*2+1], s[S_PREFETCH_B+i*2+1], 0))
|
|
for i in range(8):
|
|
k.emit(s_add_u32(s[S_PREFETCH_A+i*2], s[S_PREFETCH_A+i*2], 0x20))
|
|
k.emit(s_addc_u32(s[S_PREFETCH_A+i*2+1], s[S_PREFETCH_A+i*2+1], 0))
|
|
|
|
# do the fetch
|
|
for vdst, saddr_lo in INIT_PREFETCH:
|
|
k.emit(global_load_b32(vdst=v[vdst], addr=v[V_GLOBAL_B_ADDR], saddr=s[saddr_lo:saddr_lo+1]))
|
|
|
|
k.label('SKIP_PREFETCH')
|
|
|
|
# wait for local stores to finish (either initial or loop)
|
|
# then sync the warp so it's safe to load local
|
|
k.waitcnt(lgkm=0)
|
|
k.emit(s_barrier())
|
|
|
|
# 8 inner loop iterations
|
|
for iter in range(8):
|
|
# Load A tile (4 pairs) and B tile (8 pairs) from LDS
|
|
if not NO_DS:
|
|
k.emit(s_clause(simm16=len(V_A_TILE_REGS) + len(V_B_TILE_REGS) - 1)) # 12 loads total: 4 A + 8 B
|
|
# A tile: 4 ds_load_b64
|
|
for i, vdst in enumerate(V_A_TILE_REGS):
|
|
a_off = (i & 1) * 8 + (i >> 1) * 64 + iter * LDS_A_STRIDE
|
|
k.emit(ds_load_b64(vdst=v[vdst:vdst+1], addr=v[V_LDS_A_BASE], offset0=a_off & 0xFF, offset1=a_off >> 8))
|
|
# B tile: 8 ds_load_b64
|
|
for i, vdst in enumerate(V_B_TILE_REGS):
|
|
b_off = (i & 1) * 8 + (i & 2) * 64 + (i >> 2) * 256 + iter * LDS_B_STRIDE
|
|
k.emit(ds_load_b64(vdst=v[vdst:vdst+1], addr=v[V_LDS_B_BASE], offset0=b_off & 0xFF, offset1=b_off >> 8))
|
|
|
|
# Issue global prefetch (first 6 iterations only)
|
|
if iter < 6 and not NO_GLOBAL:
|
|
vdst1, vdst2, addr, slo1, slo2 = PREFETCH_LOADS[iter]
|
|
k.emit(global_load_b32(vdst=v[vdst1], addr=v[addr], saddr=s[slo1:slo1+1]))
|
|
k.emit(global_load_b32(vdst=v[vdst2], addr=v[addr], saddr=s[slo2:slo2+1]))
|
|
|
|
# 64 dual FMACs
|
|
k.waitcnt(lgkm=0)
|
|
if not NO_ALU:
|
|
k.emit(s_clause(simm16=len(FMAC_PATTERN)-1))
|
|
for i, (vdst_x, vdst_y, ax, bx, ay, by) in enumerate(FMAC_PATTERN):
|
|
k.emit(VOPD(VOPDOp.V_DUAL_FMAC_F32, VOPDOp.V_DUAL_FMAC_F32,
|
|
vdstx=v[vdst_x], vdsty=v[vdst_y], srcx0=v[ax], vsrcx1=v[bx], srcy0=v[ay], vsrcy1=v[by]))
|
|
|
|
# wait for all global loads to finish
|
|
# then sync the warp so it's safe to store local
|
|
k.waitcnt(vm=0)
|
|
k.emit(s_barrier())
|
|
|
|
# Store prefetched data to LDS
|
|
# NOTE: Register naming reflects LDS tile organization, not source matrix:
|
|
# V_LDS_A_DATA (v155-162) holds data that goes to LDS A-tile region
|
|
# V_LDS_B_DATA (v163-170) holds data that goes to LDS B-tile region
|
|
# The data sources are swapped: A-tile receives B matrix rows, B-tile receives A matrix columns
|
|
if not NO_DS:
|
|
for i in range(4): # A tile: 8 values via 4 stride64 stores
|
|
k.emit(ds_store_2addr_stride64_b32(addr=v[V_LDS_A_ADDR], data0=v[V_LDS_A_DATA[i*2]], data1=v[V_LDS_A_DATA[i*2+1]], offset0=i*4, offset1=i*4+2))
|
|
for i in range(8): # B tile: 8 values via 8 scalar stores with 64-byte spacing
|
|
offset = i * 64
|
|
k.emit(ds_store_b32(addr=v[V_LDS_B_ADDR], data0=v[V_LDS_B_DATA[i]], offset0=offset & 0xFF, offset1=offset >> 8))
|
|
|
|
k.emit(s_branch(), target='LOOP_INC')
|
|
|
|
# ===========================================================================
|
|
# EPILOGUE: Permute and store results
|
|
# ===========================================================================
|
|
k.label('EPILOGUE')
|
|
|
|
# Rearrange accumulators from FMAC layout to contiguous output order
|
|
for a, b in PERMUTE_SWAPS:
|
|
k.emit(v_swap_b32_e32(v[a], v[b]))
|
|
|
|
# Compute output base coordinates
|
|
# v[130] = col_base = tile_x + (lane_id & 7) * 4
|
|
# v[131] = row_base = tile_y + (lane_id & 0x60) + ((lane_id >> 3) & 3) * 4
|
|
# v[132] = 0 (for 64-bit address high part)
|
|
k.emit(v_add_nc_u32_e32(v[130], s[S_TILE_X], v[V_LANE_MOD8_X4]))
|
|
k.emit(v_and_b32_e32(v[131], 0x60, v[V_LANE_ID]))
|
|
k.emit(v_add_nc_u32_e32(v[131], s[S_TILE_Y], v[131]))
|
|
k.emit(v_add_nc_u32_e32(v[131], v[V_LANE_DIV8_X4], v[131]))
|
|
k.emit(v_mov_b32_e32(v[132], 0))
|
|
|
|
# Precompute row offsets: v[133-136] for rows 0-3, v[137-140] for rows 16-19
|
|
for base, row_off in [(133, 0), (137, 16)]:
|
|
if row_off: k.emit(v_add_nc_u32_e32(v[141], row_off, v[131]))
|
|
k.emit(v_mul_lo_u32(v[base], v[141] if row_off else v[131], s[S_DIM_N]))
|
|
for j in range(3): k.emit(v_add_nc_u32_e32(v[base + 1 + j], s[S_DIM_N], v[base + j]))
|
|
|
|
# s[S_PREFETCH_FLAG] = row stride in bytes (N * 4)
|
|
k.emit(s_lshl_b32(s[S_PREFETCH_FLAG], s[S_DIM_N], 2))
|
|
|
|
# Store 128 output values as 32 groups of 4 (128-bit stores)
|
|
# Layout: 2 row halves (0-3, 16-19) x 4 col groups x 4 rows = 32 stores of 4 floats
|
|
for i, (row_half, col_off, row_in_group) in enumerate([(rh, co, ri)
|
|
for rh in range(2) for co in [0, 32, 64, 96] for ri in range(4)]):
|
|
row = row_half * 16 + row_in_group
|
|
src = OUT_REGS[i*4] # first reg of ascending group of 4
|
|
|
|
if row_in_group == 0:
|
|
# First row of group: compute full address
|
|
if col_off == 0: k.emit(v_mov_b32_e32(v[141], v[130]))
|
|
else: k.emit(v_add_nc_u32_e32(v[141], col_off, v[130]))
|
|
row_base = 133 + row if row < 4 else 137 + row - 16
|
|
k.emit(v_add_nc_u32_e32(v[141], v[row_base], v[141]))
|
|
k.emit(v_lshlrev_b32_e32(v[141], 2, v[141]))
|
|
k.emit(v_add_co_u32(v[141], VCC_LO, s[S_OUT_PTR[0]], v[141]))
|
|
k.emit(v_add_co_ci_u32_e32(v[142], s[S_OUT_PTR[1]], v[132]))
|
|
else:
|
|
# Subsequent rows: add stride
|
|
k.emit(v_add_co_u32(v[141], VCC_LO, s[S_PREFETCH_FLAG], v[141]))
|
|
k.emit(v_add_co_ci_u32_e32(v[142], v[142], v[132]))
|
|
|
|
k.emit(global_store_b128(addr=v[141:142], data=v[src:src+3], saddr=NULL))
|
|
|
|
k.emit(s_sendmsg(simm16=3)) # DEALLOC_VGPRS
|
|
k.emit(s_endpgm())
|
|
|
|
return k.finalize()
|
|
|
|
# =============================================================================
|
|
# Test harness
|
|
# =============================================================================
|
|
|
|
N = getenv("N", 4096)
|
|
BLOCK_M, BLOCK_N = 128, 128
|
|
THREADS = 128
|
|
|
|
def test_matmul():
|
|
dev = Device[Device.DEFAULT]
|
|
print(f"Device arch: {dev.renderer.target.arch}")
|
|
|
|
insts = build_kernel(N)
|
|
|
|
rng = np.random.default_rng(42)
|
|
a = Tensor(rng.random((N, N), dtype=np.float32) - 0.5)
|
|
b = Tensor(rng.random((N, N), dtype=np.float32) - 0.5)
|
|
c = Tensor.empty(N, N)
|
|
Tensor.realize(a, b, c)
|
|
|
|
grid, local = (N // BLOCK_N, N // BLOCK_M, 1), (THREADS, 1, 1)
|
|
print(f"Grid: {grid}, Local: {local}")
|
|
|
|
dname:str = Device.DEFAULT
|
|
def asm_kernel(A:UOp, B:UOp, C:UOp) -> UOp:
|
|
gidxs = [UOp.special(n, f"gidx{i}") for i,n in enumerate(grid)]
|
|
lidxs = [UOp.special(n, f"lidx{i}") for i,n in enumerate(local)]
|
|
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=max(LDS_SIZE, 65536//getenv("LIMIT_OCC", 65536)), addrspace=AddrSpace.LOCAL), (), 'lds')
|
|
sink = UOp.sink(A.base, B.base, C.base, lds, *gidxs, *lidxs, arg=KernelInfo(name=colored("kernel", "cyan"),
|
|
estimates=Estimates(ops=N*N*N*2, mem=N*N*4*3)))
|
|
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
|
|
c = Tensor.custom_kernel(a, b, c, fxn=asm_kernel)[2]
|
|
linear = c.schedule_linear()
|
|
|
|
ets = []
|
|
with Context(DEBUG=2):
|
|
for _ in range(getenv("CNT", 5)):
|
|
start = GlobalCounters.time_sum_s
|
|
run_linear(linear)
|
|
ets.append(GlobalCounters.time_sum_s - start)
|
|
print(f"REAL TFLOPS {N * N * N * 2 / min(ets) * 1e-12:.2f}")
|
|
|
|
if getenv("VERIFY", 1):
|
|
GlobalCounters.reset()
|
|
with Context(DEBUG=2): tc = (a @ b).realize()
|
|
with Context(DEBUG=0): err = (c - tc).square().mean().item()
|
|
print(f"mean squared error {err}")
|
|
if err != err or err > 1e-06:
|
|
c_np, tc_np = c.numpy(), tc.numpy()
|
|
for bi in range(N // 128):
|
|
for bj in range(N // 128):
|
|
blk_c = c_np[bi*128:(bi+1)*128, bj*128:(bj+1)*128]
|
|
blk_ref = tc_np[bi*128:(bi+1)*128, bj*128:(bj+1)*128]
|
|
blk_diff = blk_c - blk_ref
|
|
zero_rows = [i for i in range(128) if np.all(np.abs(blk_c[i,:]) < 1e-10)]
|
|
nz_rows = [i for i in range(128) if i not in zero_rows]
|
|
nz_mse = float(np.mean(blk_diff[nz_rows,:]**2)) if nz_rows else 0
|
|
print(f"Block ({bi},{bj}): zero_rows={zero_rows}, nz_rows_mse={nz_mse:.2e}")
|
|
# show first few non-zero row comparisons
|
|
if nz_rows and nz_mse > 1e-6:
|
|
for r in nz_rows[:3]:
|
|
print(f" row {r} asm[0:8]: {blk_c[r,:8]}")
|
|
print(f" row {r} ref[0:8]: {blk_ref[r,:8]}")
|
|
raise RuntimeError("matmul is wrong!")
|
|
|
|
if __name__ == "__main__":
|
|
test_matmul()
|