mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-10 23:17:20 +08:00
llama speed 6 (#16071)
This commit is contained in:
@@ -257,18 +257,19 @@ def _get_pads(uop:UOp) -> list[UOp]:
|
||||
|
||||
def apply_grad(grad_buf:Tensor, new_grad:UOp):
|
||||
pads = _get_pads(new_grad)
|
||||
new_grad = new_grad.cast(grad_buf.dtype)
|
||||
if len(pads) <= 1:
|
||||
new_grad = new_grad.cast(grad_buf.dtype)
|
||||
store = grad_buf.uop.store(grad_buf.uop + new_grad)
|
||||
grad_buf.uop = grad_buf.uop.after(store)
|
||||
return
|
||||
sorted_pads = sorted(pads, key=lambda p: p.marg[0][0] if p.op == Ops.PAD else 0)
|
||||
inners = [Tensor(p.src[0] if p.op == Ops.PAD else p, device=grad_buf.device).cast(grad_buf.dtype) for p in sorted_pads]
|
||||
inners_raw = [Tensor(p.src[0] if p.op == Ops.PAD else p, device=grad_buf.device) for p in sorted_pads]
|
||||
if getenv("FUSED_PAD_GRAD_ACCUM", 0):
|
||||
from extra.llama_kernels.fused_pad_grad_accum import fused_pad_grad_accum, can_fused_pad_grad_accum
|
||||
if can_fused_pad_grad_accum(grad_buf, inners):
|
||||
grad_buf.uop = fused_pad_grad_accum(grad_buf, inners).uop
|
||||
if can_fused_pad_grad_accum(grad_buf, inners_raw):
|
||||
grad_buf.uop = fused_pad_grad_accum(grad_buf, inners_raw).uop
|
||||
return
|
||||
inners = [t.cast(grad_buf.dtype) for t in inners_raw]
|
||||
grad_buf.assign(grad_buf + inners[0].cat(*inners[1:], dim=0))
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -9,6 +9,7 @@ export DEVICE_IN_FUNCTION_BUG=1
|
||||
export DEBUG=${DEBUG:-2}
|
||||
export HK_FLASH_ATTENTION=${HK_FLASH_ATTENTION:-1}
|
||||
export ALL2ALL=${ALL2ALL:-1}
|
||||
export LATE_ALLREDUCE=${LATE_ALLREDUCE:-0}
|
||||
export USE_ATOMICS=${USE_ATOMICS:-1}
|
||||
export ASM_GEMM=${ASM_GEMM:-1}
|
||||
export WQKV=${WQKV:-1}
|
||||
|
||||
@@ -9,6 +9,7 @@ export DEVICE_IN_FUNCTION_BUG=1
|
||||
export DEBUG=${DEBUG:-0}
|
||||
export HK_FLASH_ATTENTION=${HK_FLASH_ATTENTION:-1}
|
||||
export ALL2ALL=${ALL2ALL:-1}
|
||||
export LATE_ALLREDUCE=${LATE_ALLREDUCE:-0}
|
||||
export USE_ATOMICS=${USE_ATOMICS:-1}
|
||||
export ASM_GEMM=${ASM_GEMM:-1}
|
||||
export WQKV=${WQKV:-1}
|
||||
|
||||
@@ -10,6 +10,7 @@ export DEVICE_IN_FUNCTION_BUG=1
|
||||
|
||||
export HK_FLASH_ATTENTION=1
|
||||
export ALL2ALL=1
|
||||
export LATE_ALLREDUCE=0
|
||||
export USE_ATOMICS=1
|
||||
export ASM_GEMM=1
|
||||
export WQKV=1
|
||||
|
||||
@@ -2722,7 +2722,13 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp):
|
||||
# dgrad: uses g_scale * x_scale * w_scale
|
||||
grad_a = asm_gemm(g_fp8, b_t, x_scale=g_scale * s_x_t, w_scale=s_w_t)
|
||||
# wgrad: no w_scale
|
||||
grad_b = asm_gemm(g_fp8.permute(2, 0, 1).reshape(g_t.shape[-1], -1), a_t.reshape(-1, a_t.shape[-1]), x_scale=g_scale * s_x_t)
|
||||
g_fp8_2d = g_fp8.reshape(-1, g_fp8.shape[-1])
|
||||
if getenv("FAST_FP8_TRANSPOSE", 0) and g_fp8_2d.shape[0] % 64 == 0 and g_fp8_2d.shape[1] % 64 == 0:
|
||||
from extra.llama_kernels.fp8_transpose import fast_fp8_transpose
|
||||
g_fp8_T = fast_fp8_transpose(g_fp8_2d)
|
||||
else:
|
||||
g_fp8_T = g_fp8.permute(2, 0, 1).reshape(g_t.shape[-1], -1)
|
||||
grad_b = asm_gemm(g_fp8_T, a_t.reshape(-1, a_t.shape[-1]), x_scale=g_scale * s_x_t)
|
||||
# Attach the delayed-amax store effect (if any) to grad_a so realizing grads commits the amax update.
|
||||
ret = (None, grad_a.uop.after(store_effect), grad_b.uop, None, None)
|
||||
if len(inputs) == 6: ret = ret + (None,)
|
||||
|
||||
41
extra/llama_kernels/fp8_transpose/__init__.py
Normal file
41
extra/llama_kernels/fp8_transpose/__init__.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from __future__ import annotations
|
||||
import functools, pathlib
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo
|
||||
from tinygrad.renderer import Estimates
|
||||
from extra.llama_kernels import THREADS_PER_WG, alloc_like, dname_of, compile_hip
|
||||
|
||||
TILE = 64
|
||||
|
||||
@functools.cache
|
||||
def _custom_fp8_transpose(out:UOp, inp:UOp, dname:str) -> UOp:
|
||||
M, N = inp.shape
|
||||
num_wg = (M // TILE) * (N // TILE)
|
||||
threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(num_wg, "gidx0")
|
||||
mem = M * N * 2 # one byte read + one byte write per element
|
||||
sink = UOp.sink(out.base, inp.base, threads, workgroups,
|
||||
arg=KernelInfo(f"fp8_transpose_{M}_{N}",
|
||||
estimates=Estimates(ops=M*N, mem=mem)))
|
||||
src = (pathlib.Path(__file__).parent/"fp8_transpose.cpp").read_text()
|
||||
defines = [f"-DM_DIM={M}", f"-DN_DIM={N}", f"-DTHREADS_PER_WG={THREADS_PER_WG}"]
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)),
|
||||
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=compile_hip(src, defines))))
|
||||
|
||||
def fast_fp8_transpose(t:Tensor) -> Tensor:
|
||||
assert t.ndim == 2, f"fast_fp8_transpose needs 2D input, got shape {t.shape}"
|
||||
assert t.dtype in dtypes.fp8s, f"fast_fp8_transpose needs fp8 dtype, got {t.dtype}"
|
||||
M, N = t.shape
|
||||
assert M % TILE == 0 and N % TILE == 0, f"M={M}, N={N} must be multiples of {TILE}"
|
||||
|
||||
device = t.device
|
||||
axis = t.uop.axis if isinstance(device, tuple) else None
|
||||
out_axis = None
|
||||
if axis == 0: out_axis = 1
|
||||
elif axis == 1: out_axis = 0
|
||||
elif axis is not None:
|
||||
raise ValueError(f"fast_fp8_transpose: unsupported axis {axis}")
|
||||
|
||||
out = alloc_like((N, M), t.dtype, device, out_axis)
|
||||
fxn = functools.partial(_custom_fp8_transpose, dname=dname_of(device))
|
||||
out, _ = Tensor.custom_kernel(out, t, fxn=fxn)
|
||||
return out
|
||||
74
extra/llama_kernels/fp8_transpose/fp8_transpose.cpp
Normal file
74
extra/llama_kernels/fp8_transpose/fp8_transpose.cpp
Normal file
@@ -0,0 +1,74 @@
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
// LDS-staged 64x64 fp8 transpose.
|
||||
// in : (M_DIM, N_DIM) fp8 contiguous
|
||||
// out: (N_DIM, M_DIM) fp8 contiguous, out[c][r] = in[r][c]
|
||||
//
|
||||
// One WG processes one 64x64 output tile. Each thread reads one uint4 (16 fp8) coalesced
|
||||
// from input rows, stages into LDS, then writes one uint4 coalesced to the output (whose
|
||||
// 16 fp8 come from 16 different input rows via in-LDS gather).
|
||||
//
|
||||
// LDS layout: lds[64][LDS_STRIDE] with LDS_STRIDE=65 (1 byte pad) to mitigate bank conflicts
|
||||
// during the column-direction read of the write phase.
|
||||
|
||||
#ifndef M_DIM
|
||||
#define M_DIM 16384
|
||||
#endif
|
||||
#ifndef N_DIM
|
||||
#define N_DIM 28672
|
||||
#endif
|
||||
#ifndef THREADS_PER_WG
|
||||
#define THREADS_PER_WG 256
|
||||
#endif
|
||||
|
||||
constexpr int TILE = 64;
|
||||
constexpr int VEC = 16; // fp8 per uint4 (128-bit) load/store
|
||||
constexpr int LDS_PAD = 1;
|
||||
constexpr int LDS_STRIDE = TILE + LDS_PAD; // 65 fp8 per row
|
||||
|
||||
static_assert(THREADS_PER_WG * VEC == TILE * TILE, "256 threads * 16 fp8 = 64*64");
|
||||
static_assert(M_DIM % TILE == 0, "M_DIM must be a multiple of 64");
|
||||
static_assert(N_DIM % TILE == 0, "N_DIM must be a multiple of 64");
|
||||
|
||||
constexpr int N_TILES_N = N_DIM / TILE;
|
||||
|
||||
struct alignas(16) fp8x16 { uint8_t v[16]; };
|
||||
|
||||
extern "C" __global__ __launch_bounds__(THREADS_PER_WG) void
|
||||
fp8_transpose(uint8_t* __restrict__ out, // (N_DIM, M_DIM)
|
||||
const uint8_t* __restrict__ in) // (M_DIM, N_DIM)
|
||||
{
|
||||
__shared__ uint8_t lds[TILE * LDS_STRIDE];
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const int wg_id = blockIdx.x;
|
||||
const int tile_r = wg_id / N_TILES_N; // tile index along M dim of input
|
||||
const int tile_c = wg_id % N_TILES_N; // tile index along N dim of input
|
||||
|
||||
const int a = tid / (TILE / VEC); // 0..63 (row within tile during read; col within tile during write)
|
||||
const int b = tid % (TILE / VEC); // 0..3
|
||||
const int b16 = b * VEC; // 0,16,32,48
|
||||
|
||||
// ---- Read phase: input rows -> LDS rows
|
||||
{
|
||||
const long long src = (long long)(tile_r * TILE + a) * (long long)N_DIM
|
||||
+ (long long)(tile_c * TILE + b16);
|
||||
fp8x16 v = *reinterpret_cast<const fp8x16*>(&in[src]);
|
||||
*reinterpret_cast<fp8x16*>(&lds[a * LDS_STRIDE + b16]) = v;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// ---- Write phase: LDS columns (gathered) -> output rows
|
||||
// out[(tile_c*TILE + a)][(tile_r*TILE + b16 + i)] = in[(tile_r*TILE + b16 + i)][(tile_c*TILE + a)]
|
||||
// = lds[b16 + i][a]
|
||||
{
|
||||
fp8x16 v;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VEC; ++i) {
|
||||
v.v[i] = lds[(b16 + i) * LDS_STRIDE + a];
|
||||
}
|
||||
const long long dst = (long long)(tile_c * TILE + a) * (long long)M_DIM
|
||||
+ (long long)(tile_r * TILE + b16);
|
||||
*reinterpret_cast<fp8x16*>(&out[dst]) = v;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user