diff --git a/examples/mlperf/models/flat_llama.py b/examples/mlperf/models/flat_llama.py index de0643c1b9..63997e0553 100644 --- a/examples/mlperf/models/flat_llama.py +++ b/examples/mlperf/models/flat_llama.py @@ -257,18 +257,19 @@ def _get_pads(uop:UOp) -> list[UOp]: def apply_grad(grad_buf:Tensor, new_grad:UOp): pads = _get_pads(new_grad) - new_grad = new_grad.cast(grad_buf.dtype) if len(pads) <= 1: + new_grad = new_grad.cast(grad_buf.dtype) store = grad_buf.uop.store(grad_buf.uop + new_grad) grad_buf.uop = grad_buf.uop.after(store) return sorted_pads = sorted(pads, key=lambda p: p.marg[0][0] if p.op == Ops.PAD else 0) - inners = [Tensor(p.src[0] if p.op == Ops.PAD else p, device=grad_buf.device).cast(grad_buf.dtype) for p in sorted_pads] + inners_raw = [Tensor(p.src[0] if p.op == Ops.PAD else p, device=grad_buf.device) for p in sorted_pads] if getenv("FUSED_PAD_GRAD_ACCUM", 0): from extra.llama_kernels.fused_pad_grad_accum import fused_pad_grad_accum, can_fused_pad_grad_accum - if can_fused_pad_grad_accum(grad_buf, inners): - grad_buf.uop = fused_pad_grad_accum(grad_buf, inners).uop + if can_fused_pad_grad_accum(grad_buf, inners_raw): + grad_buf.uop = fused_pad_grad_accum(grad_buf, inners_raw).uop return + inners = [t.cast(grad_buf.dtype) for t in inners_raw] grad_buf.assign(grad_buf + inners[0].cat(*inners[1:], dim=0)) if __name__ == "__main__": diff --git a/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_beam.sh b/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_beam.sh index 9e4c6e5cba..4b8493c0ba 100755 --- a/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_beam.sh +++ b/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_beam.sh @@ -9,6 +9,7 @@ export DEVICE_IN_FUNCTION_BUG=1 export DEBUG=${DEBUG:-2} export HK_FLASH_ATTENTION=${HK_FLASH_ATTENTION:-1} export ALL2ALL=${ALL2ALL:-1} +export LATE_ALLREDUCE=${LATE_ALLREDUCE:-0} export USE_ATOMICS=${USE_ATOMICS:-1} export ASM_GEMM=${ASM_GEMM:-1} export WQKV=${WQKV:-1} diff --git a/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_run.sh b/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_run.sh index f759e58e2e..f38282a354 100755 --- a/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_run.sh +++ b/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/dev_run.sh @@ -9,6 +9,7 @@ export DEVICE_IN_FUNCTION_BUG=1 export DEBUG=${DEBUG:-0} export HK_FLASH_ATTENTION=${HK_FLASH_ATTENTION:-1} export ALL2ALL=${ALL2ALL:-1} +export LATE_ALLREDUCE=${LATE_ALLREDUCE:-0} export USE_ATOMICS=${USE_ATOMICS:-1} export ASM_GEMM=${ASM_GEMM:-1} export WQKV=${WQKV:-1} diff --git a/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/run_and_time.sh b/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/run_and_time.sh index fbff01a222..21c1ace10c 100755 --- a/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/run_and_time.sh +++ b/examples/mlperf/training_submission_v6.0/tinycorp/benchmarks/llama8b/implementations/tinybox_8xMI350X/run_and_time.sh @@ -10,6 +10,7 @@ export DEVICE_IN_FUNCTION_BUG=1 export HK_FLASH_ATTENTION=1 export ALL2ALL=1 +export LATE_ALLREDUCE=0 export USE_ATOMICS=1 export ASM_GEMM=1 export WQKV=1 diff --git a/extra/gemm/cdna_asm_gemm.py b/extra/gemm/cdna_asm_gemm.py index 5dd3a54ad7..2220f76a12 100644 --- a/extra/gemm/cdna_asm_gemm.py +++ b/extra/gemm/cdna_asm_gemm.py @@ -2722,7 +2722,13 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp): # dgrad: uses g_scale * x_scale * w_scale grad_a = asm_gemm(g_fp8, b_t, x_scale=g_scale * s_x_t, w_scale=s_w_t) # wgrad: no w_scale - grad_b = asm_gemm(g_fp8.permute(2, 0, 1).reshape(g_t.shape[-1], -1), a_t.reshape(-1, a_t.shape[-1]), x_scale=g_scale * s_x_t) + g_fp8_2d = g_fp8.reshape(-1, g_fp8.shape[-1]) + if getenv("FAST_FP8_TRANSPOSE", 0) and g_fp8_2d.shape[0] % 64 == 0 and g_fp8_2d.shape[1] % 64 == 0: + from extra.llama_kernels.fp8_transpose import fast_fp8_transpose + g_fp8_T = fast_fp8_transpose(g_fp8_2d) + else: + g_fp8_T = g_fp8.permute(2, 0, 1).reshape(g_t.shape[-1], -1) + grad_b = asm_gemm(g_fp8_T, a_t.reshape(-1, a_t.shape[-1]), x_scale=g_scale * s_x_t) # Attach the delayed-amax store effect (if any) to grad_a so realizing grads commits the amax update. ret = (None, grad_a.uop.after(store_effect), grad_b.uop, None, None) if len(inputs) == 6: ret = ret + (None,) diff --git a/extra/llama_kernels/fp8_transpose/__init__.py b/extra/llama_kernels/fp8_transpose/__init__.py new file mode 100644 index 0000000000..cba18a2741 --- /dev/null +++ b/extra/llama_kernels/fp8_transpose/__init__.py @@ -0,0 +1,41 @@ +from __future__ import annotations +import functools, pathlib +from tinygrad import Tensor, dtypes +from tinygrad.uop.ops import UOp, Ops, KernelInfo +from tinygrad.renderer import Estimates +from extra.llama_kernels import THREADS_PER_WG, alloc_like, dname_of, compile_hip + +TILE = 64 + +@functools.cache +def _custom_fp8_transpose(out:UOp, inp:UOp, dname:str) -> UOp: + M, N = inp.shape + num_wg = (M // TILE) * (N // TILE) + threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(num_wg, "gidx0") + mem = M * N * 2 # one byte read + one byte write per element + sink = UOp.sink(out.base, inp.base, threads, workgroups, + arg=KernelInfo(f"fp8_transpose_{M}_{N}", + estimates=Estimates(ops=M*N, mem=mem))) + src = (pathlib.Path(__file__).parent/"fp8_transpose.cpp").read_text() + defines = [f"-DM_DIM={M}", f"-DN_DIM={N}", f"-DTHREADS_PER_WG={THREADS_PER_WG}"] + return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)), + UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=compile_hip(src, defines)))) + +def fast_fp8_transpose(t:Tensor) -> Tensor: + assert t.ndim == 2, f"fast_fp8_transpose needs 2D input, got shape {t.shape}" + assert t.dtype in dtypes.fp8s, f"fast_fp8_transpose needs fp8 dtype, got {t.dtype}" + M, N = t.shape + assert M % TILE == 0 and N % TILE == 0, f"M={M}, N={N} must be multiples of {TILE}" + + device = t.device + axis = t.uop.axis if isinstance(device, tuple) else None + out_axis = None + if axis == 0: out_axis = 1 + elif axis == 1: out_axis = 0 + elif axis is not None: + raise ValueError(f"fast_fp8_transpose: unsupported axis {axis}") + + out = alloc_like((N, M), t.dtype, device, out_axis) + fxn = functools.partial(_custom_fp8_transpose, dname=dname_of(device)) + out, _ = Tensor.custom_kernel(out, t, fxn=fxn) + return out diff --git a/extra/llama_kernels/fp8_transpose/fp8_transpose.cpp b/extra/llama_kernels/fp8_transpose/fp8_transpose.cpp new file mode 100644 index 0000000000..02c7e18d69 --- /dev/null +++ b/extra/llama_kernels/fp8_transpose/fp8_transpose.cpp @@ -0,0 +1,74 @@ +#include + +// LDS-staged 64x64 fp8 transpose. +// in : (M_DIM, N_DIM) fp8 contiguous +// out: (N_DIM, M_DIM) fp8 contiguous, out[c][r] = in[r][c] +// +// One WG processes one 64x64 output tile. Each thread reads one uint4 (16 fp8) coalesced +// from input rows, stages into LDS, then writes one uint4 coalesced to the output (whose +// 16 fp8 come from 16 different input rows via in-LDS gather). +// +// LDS layout: lds[64][LDS_STRIDE] with LDS_STRIDE=65 (1 byte pad) to mitigate bank conflicts +// during the column-direction read of the write phase. + +#ifndef M_DIM +#define M_DIM 16384 +#endif +#ifndef N_DIM +#define N_DIM 28672 +#endif +#ifndef THREADS_PER_WG +#define THREADS_PER_WG 256 +#endif + +constexpr int TILE = 64; +constexpr int VEC = 16; // fp8 per uint4 (128-bit) load/store +constexpr int LDS_PAD = 1; +constexpr int LDS_STRIDE = TILE + LDS_PAD; // 65 fp8 per row + +static_assert(THREADS_PER_WG * VEC == TILE * TILE, "256 threads * 16 fp8 = 64*64"); +static_assert(M_DIM % TILE == 0, "M_DIM must be a multiple of 64"); +static_assert(N_DIM % TILE == 0, "N_DIM must be a multiple of 64"); + +constexpr int N_TILES_N = N_DIM / TILE; + +struct alignas(16) fp8x16 { uint8_t v[16]; }; + +extern "C" __global__ __launch_bounds__(THREADS_PER_WG) void +fp8_transpose(uint8_t* __restrict__ out, // (N_DIM, M_DIM) + const uint8_t* __restrict__ in) // (M_DIM, N_DIM) +{ + __shared__ uint8_t lds[TILE * LDS_STRIDE]; + + const int tid = threadIdx.x; + const int wg_id = blockIdx.x; + const int tile_r = wg_id / N_TILES_N; // tile index along M dim of input + const int tile_c = wg_id % N_TILES_N; // tile index along N dim of input + + const int a = tid / (TILE / VEC); // 0..63 (row within tile during read; col within tile during write) + const int b = tid % (TILE / VEC); // 0..3 + const int b16 = b * VEC; // 0,16,32,48 + + // ---- Read phase: input rows -> LDS rows + { + const long long src = (long long)(tile_r * TILE + a) * (long long)N_DIM + + (long long)(tile_c * TILE + b16); + fp8x16 v = *reinterpret_cast(&in[src]); + *reinterpret_cast(&lds[a * LDS_STRIDE + b16]) = v; + } + __syncthreads(); + + // ---- Write phase: LDS columns (gathered) -> output rows + // out[(tile_c*TILE + a)][(tile_r*TILE + b16 + i)] = in[(tile_r*TILE + b16 + i)][(tile_c*TILE + a)] + // = lds[b16 + i][a] + { + fp8x16 v; + #pragma unroll + for (int i = 0; i < VEC; ++i) { + v.v[i] = lds[(b16 + i) * LDS_STRIDE + a]; + } + const long long dst = (long long)(tile_c * TILE + a) * (long long)M_DIM + + (long long)(tile_r * TILE + b16); + *reinterpret_cast(&out[dst]) = v; + } +}