Files
tinygrad/extra/gemm/amd_copy_matmul.py
George Hotz 48a7627b04 add RDNA4 support to copy WMMA (#15663)
* add RDNA4 supportt to copy WMMA

* simpler

* simpler

* comment

* assert
2026-04-09 22:48:20 +08:00

119 lines
5.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from tinygrad import Device, UOp, getenv
from tinygrad.uop.ops import AxisType, KernelInfo, Ops
from tinygrad.dtype import AddrSpace, dtypes
N = getenv("N", 4096)
M = getenv("M", N)
K = getenv("K", N)
WARP_SIZE = 32
BLOCK_M, BLOCK_N = 128, 128
BLOCK_K = getenv("BK", 16)
assert N % BLOCK_N == 0 and M % BLOCK_M == 0 and K % BLOCK_K == 0
use_wmma = getenv("WMMA")
if use_wmma:
is_rdna4 = Device[Device.DEFAULT].renderer.target.arch.startswith("gfx12")
WAVES_M, WAVES_N = 2, 2
LANES_PER_WAVE_M, LANES_PER_WAVE_N = 2, 16
# wmma params
WMMA_M, WMMA_N, WMMA_K = 16, 16, 16
WMMA_ACC = WMMA_M // LANES_PER_WAVE_M
UNROLL_M, UNROLL_N = (WMMA_ACC, 1) if is_rdna4 else (1, 1)
else:
WAVES_M, WAVES_N = 4, 1
LANES_PER_WAVE_M, LANES_PER_WAVE_N = 4, 8
UNROLL_M, UNROLL_N = 4, 4
# total lanes must be the warp size
assert LANES_PER_WAVE_M*LANES_PER_WAVE_N == WARP_SIZE
# WARP_SIZE * total waves
THREADS_PER_BLOCK = WARP_SIZE * WAVES_M * WAVES_N
# accumulator size
TM = BLOCK_M // (WAVES_M * LANES_PER_WAVE_M)
TN = BLOCK_N // (WAVES_N * LANES_PER_WAVE_N)
def block_128x128_gemm(c:UOp, a:UOp, b:UOp) -> UOp:
wave_m = UOp.range(WAVES_M, 2, AxisType.LOCAL)
wave_n = UOp.range(WAVES_N, 3, AxisType.LOCAL)
lane = UOp.range(WARP_SIZE, -1, AxisType.WARP)
tid = (wave_m * WAVES_N + wave_n) * WARP_SIZE + lane
# -- GLOBAL -> LOCAL --
# wmma: spatial outer, k inner (k contiguous for vectorized WMMA tile loads)
# gemm: k outer, spatial inner
A_local = UOp.placeholder((BLOCK_M, BLOCK_K) if use_wmma else (BLOCK_K, BLOCK_M), a.dtype.base, slot=0, addrspace=AddrSpace.LOCAL)
B_local = UOp.placeholder((BLOCK_N, BLOCK_K) if use_wmma else (BLOCK_K, BLOCK_N), b.dtype.base, slot=1, addrspace=AddrSpace.LOCAL)
a = a.reshape(K // BLOCK_K, BLOCK_K, BLOCK_M)
b = b.reshape(K // BLOCK_K, BLOCK_K, BLOCK_N)
k_tile = UOp.range(K // BLOCK_K, 100, AxisType.REDUCE)
# copy with transpose for wmma (input is k×spatial, LDS is spatial×k)
A_copy = A_local.permute((1,0)) if use_wmma else A_local
B_copy = B_local.permute((1,0)) if use_wmma else B_local
A_store = A_copy.reshape(-1, THREADS_PER_BLOCK)[:, tid].store(a[k_tile].reshape(-1, THREADS_PER_BLOCK)[:, tid])
B_store = B_copy.reshape(-1, THREADS_PER_BLOCK)[:, tid].store(b[k_tile].reshape(-1, THREADS_PER_BLOCK)[:, tid])
barrier = UOp.barrier(A_store, B_store)
A_local, B_local = A_local.after(barrier), B_local.after(barrier)
# -- COMPUTE --
lane_m, lane_n = lane // LANES_PER_WAVE_N, lane % LANES_PER_WAVE_N
# accumulator (unified: both paths use (TM, TN) with scalar dtypes.float)
acc = UOp.placeholder((TM, TN), dtypes.float, slot=2, addrspace=AddrSpace.REG)
acc = acc.after(acc.store(acc.zeros_like()))
if use_wmma:
k = UOp.range(BLOCK_K // WMMA_K, 101, AxisType.REDUCE)
tile_m = UOp.range(TM // WMMA_ACC, 200, AxisType.LOOP)
tile_n = UOp.range(TN, 201, AxisType.LOOP)
acc_frag = acc.reshape(TM // WMMA_ACC, WMMA_ACC, TN).permute(0,2,1)[tile_m, tile_n]
a_frag = A_local.reshape(WAVES_M, TM // WMMA_ACC, WMMA_M, BLOCK_K // WMMA_K, WMMA_K)[wave_m, tile_m, lane_n, k]
b_frag = B_local.reshape(WAVES_N, TN, WMMA_N, BLOCK_K // WMMA_K, WMMA_K)[wave_n, tile_n, lane_n, k]
if is_rdna4:
# NOTE: since this is part of K, these 2 can be anywhere in the frags and long as a and b match
a_frag = a_frag.reshape(2, 8)[lane_m, :]
b_frag = b_frag.reshape(2, 8)[lane_m, :]
wmma = UOp(Ops.SHAPED_WMMA, dtypes.float, (a_frag, b_frag, acc_frag.after(k)), arg=((16, 16, 16), 'AMD', 32))
acc_store = acc_frag.store(wmma).end(tile_m, tile_n)
else:
# registers for LOCAL -> REG
a_frag = UOp.placeholder((TM//UNROLL_M, UNROLL_M), dtypes.float, slot=0, addrspace=AddrSpace.REG)
b_frag = UOp.placeholder((TN//UNROLL_N, UNROLL_N), dtypes.float, slot=1, addrspace=AddrSpace.REG)
k = UOp.range(BLOCK_K, 101, AxisType.REDUCE)
a_frag = a_frag.after(a_frag.store(A_local[k].reshape(WAVES_M, TM//UNROLL_M, LANES_PER_WAVE_M, UNROLL_M)[wave_m, :, lane_m, :]))
b_frag = b_frag.after(b_frag.store(B_local[k].reshape(WAVES_N, TN//UNROLL_N, LANES_PER_WAVE_N, UNROLL_N)[wave_n, :, lane_n, :]))
# FMA
a_frag = a_frag.reshape(TM, 1).expand(TM, TN)
b_frag = b_frag.reshape(1, TN).expand(TM, TN)
acc_store = acc.store(acc.after(k) + (a_frag * b_frag))
# store accumulator and loop
acc = acc.after(acc_store.end(k).barrier().end(k_tile))
# store accumulator to output (unified)
c = c.reshape(WAVES_M, TM//UNROLL_M, LANES_PER_WAVE_M, UNROLL_M,
WAVES_N, TN//UNROLL_N, LANES_PER_WAVE_N, UNROLL_N)
c = c.permute((0,4,2,6, 1,3,5,7)).reshape(THREADS_PER_BLOCK, TM, TN)
return c[tid].store(acc).end(wave_m, wave_n, lane)
def amd_copy_matmul(c:UOp, a:UOp, b:UOp) -> UOp:
block_id_m = UOp.range(M // BLOCK_M, 0, AxisType.GLOBAL)
block_id_n = UOp.range(N // BLOCK_N, 1, AxisType.GLOBAL)
c = c.reshape(M // BLOCK_M, BLOCK_M, N // BLOCK_N, BLOCK_N)[block_id_m, :, block_id_n, :]
a = a.T.reshape(K, M // BLOCK_M, BLOCK_M)[:, block_id_m, :]
b = b.reshape(K, N // BLOCK_N, BLOCK_N)[:, block_id_n, :]
return block_128x128_gemm(c, a, b).end(block_id_n, block_id_m).sink(arg=KernelInfo(opts_to_apply=()))
if __name__ == "__main__":
from amd_uop_matmul import eval_custom_matmul
eval_custom_matmul(amd_copy_matmul, dtypes.half if use_wmma else dtypes.float)