llama speed 6 (#16071)

This commit is contained in:
wozeparrot
2026-05-06 23:51:03 -04:00
committed by GitHub
parent 7b91f7c90c
commit 730fa66bf3
7 changed files with 130 additions and 5 deletions

View File

@@ -257,18 +257,19 @@ def _get_pads(uop:UOp) -> list[UOp]:
def apply_grad(grad_buf:Tensor, new_grad:UOp):
pads = _get_pads(new_grad)
new_grad = new_grad.cast(grad_buf.dtype)
if len(pads) <= 1:
new_grad = new_grad.cast(grad_buf.dtype)
store = grad_buf.uop.store(grad_buf.uop + new_grad)
grad_buf.uop = grad_buf.uop.after(store)
return
sorted_pads = sorted(pads, key=lambda p: p.marg[0][0] if p.op == Ops.PAD else 0)
inners = [Tensor(p.src[0] if p.op == Ops.PAD else p, device=grad_buf.device).cast(grad_buf.dtype) for p in sorted_pads]
inners_raw = [Tensor(p.src[0] if p.op == Ops.PAD else p, device=grad_buf.device) for p in sorted_pads]
if getenv("FUSED_PAD_GRAD_ACCUM", 0):
from extra.llama_kernels.fused_pad_grad_accum import fused_pad_grad_accum, can_fused_pad_grad_accum
if can_fused_pad_grad_accum(grad_buf, inners):
grad_buf.uop = fused_pad_grad_accum(grad_buf, inners).uop
if can_fused_pad_grad_accum(grad_buf, inners_raw):
grad_buf.uop = fused_pad_grad_accum(grad_buf, inners_raw).uop
return
inners = [t.cast(grad_buf.dtype) for t in inners_raw]
grad_buf.assign(grad_buf + inners[0].cat(*inners[1:], dim=0))
if __name__ == "__main__":

View File

@@ -9,6 +9,7 @@ export DEVICE_IN_FUNCTION_BUG=1
export DEBUG=${DEBUG:-2}
export HK_FLASH_ATTENTION=${HK_FLASH_ATTENTION:-1}
export ALL2ALL=${ALL2ALL:-1}
export LATE_ALLREDUCE=${LATE_ALLREDUCE:-0}
export USE_ATOMICS=${USE_ATOMICS:-1}
export ASM_GEMM=${ASM_GEMM:-1}
export WQKV=${WQKV:-1}

View File

@@ -9,6 +9,7 @@ export DEVICE_IN_FUNCTION_BUG=1
export DEBUG=${DEBUG:-0}
export HK_FLASH_ATTENTION=${HK_FLASH_ATTENTION:-1}
export ALL2ALL=${ALL2ALL:-1}
export LATE_ALLREDUCE=${LATE_ALLREDUCE:-0}
export USE_ATOMICS=${USE_ATOMICS:-1}
export ASM_GEMM=${ASM_GEMM:-1}
export WQKV=${WQKV:-1}

View File

@@ -10,6 +10,7 @@ export DEVICE_IN_FUNCTION_BUG=1
export HK_FLASH_ATTENTION=1
export ALL2ALL=1
export LATE_ALLREDUCE=0
export USE_ATOMICS=1
export ASM_GEMM=1
export WQKV=1

View File

@@ -2722,7 +2722,13 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp):
# dgrad: uses g_scale * x_scale * w_scale
grad_a = asm_gemm(g_fp8, b_t, x_scale=g_scale * s_x_t, w_scale=s_w_t)
# wgrad: no w_scale
grad_b = asm_gemm(g_fp8.permute(2, 0, 1).reshape(g_t.shape[-1], -1), a_t.reshape(-1, a_t.shape[-1]), x_scale=g_scale * s_x_t)
g_fp8_2d = g_fp8.reshape(-1, g_fp8.shape[-1])
if getenv("FAST_FP8_TRANSPOSE", 0) and g_fp8_2d.shape[0] % 64 == 0 and g_fp8_2d.shape[1] % 64 == 0:
from extra.llama_kernels.fp8_transpose import fast_fp8_transpose
g_fp8_T = fast_fp8_transpose(g_fp8_2d)
else:
g_fp8_T = g_fp8.permute(2, 0, 1).reshape(g_t.shape[-1], -1)
grad_b = asm_gemm(g_fp8_T, a_t.reshape(-1, a_t.shape[-1]), x_scale=g_scale * s_x_t)
# Attach the delayed-amax store effect (if any) to grad_a so realizing grads commits the amax update.
ret = (None, grad_a.uop.after(store_effect), grad_b.uop, None, None)
if len(inputs) == 6: ret = ret + (None,)

View File

@@ -0,0 +1,41 @@
from __future__ import annotations
import functools, pathlib
from tinygrad import Tensor, dtypes
from tinygrad.uop.ops import UOp, Ops, KernelInfo
from tinygrad.renderer import Estimates
from extra.llama_kernels import THREADS_PER_WG, alloc_like, dname_of, compile_hip
TILE = 64
@functools.cache
def _custom_fp8_transpose(out:UOp, inp:UOp, dname:str) -> UOp:
M, N = inp.shape
num_wg = (M // TILE) * (N // TILE)
threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(num_wg, "gidx0")
mem = M * N * 2 # one byte read + one byte write per element
sink = UOp.sink(out.base, inp.base, threads, workgroups,
arg=KernelInfo(f"fp8_transpose_{M}_{N}",
estimates=Estimates(ops=M*N, mem=mem)))
src = (pathlib.Path(__file__).parent/"fp8_transpose.cpp").read_text()
defines = [f"-DM_DIM={M}", f"-DN_DIM={N}", f"-DTHREADS_PER_WG={THREADS_PER_WG}"]
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)),
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=compile_hip(src, defines))))
def fast_fp8_transpose(t:Tensor) -> Tensor:
assert t.ndim == 2, f"fast_fp8_transpose needs 2D input, got shape {t.shape}"
assert t.dtype in dtypes.fp8s, f"fast_fp8_transpose needs fp8 dtype, got {t.dtype}"
M, N = t.shape
assert M % TILE == 0 and N % TILE == 0, f"M={M}, N={N} must be multiples of {TILE}"
device = t.device
axis = t.uop.axis if isinstance(device, tuple) else None
out_axis = None
if axis == 0: out_axis = 1
elif axis == 1: out_axis = 0
elif axis is not None:
raise ValueError(f"fast_fp8_transpose: unsupported axis {axis}")
out = alloc_like((N, M), t.dtype, device, out_axis)
fxn = functools.partial(_custom_fp8_transpose, dname=dname_of(device))
out, _ = Tensor.custom_kernel(out, t, fxn=fxn)
return out

View File

@@ -0,0 +1,74 @@
#include <hip/hip_runtime.h>
// LDS-staged 64x64 fp8 transpose.
// in : (M_DIM, N_DIM) fp8 contiguous
// out: (N_DIM, M_DIM) fp8 contiguous, out[c][r] = in[r][c]
//
// One WG processes one 64x64 output tile. Each thread reads one uint4 (16 fp8) coalesced
// from input rows, stages into LDS, then writes one uint4 coalesced to the output (whose
// 16 fp8 come from 16 different input rows via in-LDS gather).
//
// LDS layout: lds[64][LDS_STRIDE] with LDS_STRIDE=65 (1 byte pad) to mitigate bank conflicts
// during the column-direction read of the write phase.
#ifndef M_DIM
#define M_DIM 16384
#endif
#ifndef N_DIM
#define N_DIM 28672
#endif
#ifndef THREADS_PER_WG
#define THREADS_PER_WG 256
#endif
constexpr int TILE = 64;
constexpr int VEC = 16; // fp8 per uint4 (128-bit) load/store
constexpr int LDS_PAD = 1;
constexpr int LDS_STRIDE = TILE + LDS_PAD; // 65 fp8 per row
static_assert(THREADS_PER_WG * VEC == TILE * TILE, "256 threads * 16 fp8 = 64*64");
static_assert(M_DIM % TILE == 0, "M_DIM must be a multiple of 64");
static_assert(N_DIM % TILE == 0, "N_DIM must be a multiple of 64");
constexpr int N_TILES_N = N_DIM / TILE;
struct alignas(16) fp8x16 { uint8_t v[16]; };
extern "C" __global__ __launch_bounds__(THREADS_PER_WG) void
fp8_transpose(uint8_t* __restrict__ out, // (N_DIM, M_DIM)
const uint8_t* __restrict__ in) // (M_DIM, N_DIM)
{
__shared__ uint8_t lds[TILE * LDS_STRIDE];
const int tid = threadIdx.x;
const int wg_id = blockIdx.x;
const int tile_r = wg_id / N_TILES_N; // tile index along M dim of input
const int tile_c = wg_id % N_TILES_N; // tile index along N dim of input
const int a = tid / (TILE / VEC); // 0..63 (row within tile during read; col within tile during write)
const int b = tid % (TILE / VEC); // 0..3
const int b16 = b * VEC; // 0,16,32,48
// ---- Read phase: input rows -> LDS rows
{
const long long src = (long long)(tile_r * TILE + a) * (long long)N_DIM
+ (long long)(tile_c * TILE + b16);
fp8x16 v = *reinterpret_cast<const fp8x16*>(&in[src]);
*reinterpret_cast<fp8x16*>(&lds[a * LDS_STRIDE + b16]) = v;
}
__syncthreads();
// ---- Write phase: LDS columns (gathered) -> output rows
// out[(tile_c*TILE + a)][(tile_r*TILE + b16 + i)] = in[(tile_r*TILE + b16 + i)][(tile_c*TILE + a)]
// = lds[b16 + i][a]
{
fp8x16 v;
#pragma unroll
for (int i = 0; i < VEC; ++i) {
v.v[i] = lds[(b16 + i) * LDS_STRIDE + a];
}
const long long dst = (long long)(tile_c * TILE + a) * (long long)M_DIM
+ (long long)(tile_r * TILE + b16);
*reinterpret_cast<fp8x16*>(&out[dst]) = v;
}
}