From bbfe4f80ec4e1bee31a4fd879b387a7810e0f7cf Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Fri, 22 May 2026 14:54:06 +0300 Subject: [PATCH] quantize_fp8 kernels in uops (#16288) * add tests * simple UOp kernel is n^2 * fast kernel matching c++, opts_to_apply=() * remove cpp * simple o(n) kernel, two passes * fuse the loops * works on DEV=CPU * multi regression test * fix multi, this can possibly be its own bugfix * test cleanups * minimal diff * match C in UOps * Revert "match C in UOps" This reverts commit 0bef740c302035a0213ea2e3c451a5714caa8116. * edit test * match speed with C try 2 * needs_second_gpu * cleanup --- .../quantize_fp8_delayed/__init__.py | 89 +++++++++++++------ .../quantize_fp8_scalar.cpp | 48 ---------- .../quantize_fp8_with_amax.cpp | 63 ------------- test/backend/test_llama_kernels.py | 52 ++++++++++- tinygrad/codegen/late/devectorizer.py | 3 +- tinygrad/viz/cli.py | 3 +- 6 files changed, 115 insertions(+), 143 deletions(-) delete mode 100644 extra/llama_kernels/quantize_fp8_delayed/quantize_fp8_scalar.cpp delete mode 100644 extra/llama_kernels/quantize_fp8_delayed/quantize_fp8_with_amax.cpp diff --git a/extra/llama_kernels/quantize_fp8_delayed/__init__.py b/extra/llama_kernels/quantize_fp8_delayed/__init__.py index 60d7060347..f51f2b49b5 100644 --- a/extra/llama_kernels/quantize_fp8_delayed/__init__.py +++ b/extra/llama_kernels/quantize_fp8_delayed/__init__.py @@ -1,35 +1,64 @@ -from __future__ import annotations -import functools, pathlib +import functools from tinygrad import Tensor, dtypes -from tinygrad.uop.ops import UOp, Ops, KernelInfo -from tinygrad.renderer import Estimates -from extra.llama_kernels import FP8_MAX, NUM_WG, THREADS_PER_WG, alloc_like, alloc_local, scalar_amax, dname_of, compile_hip +from tinygrad.dtype import AddrSpace +from tinygrad.helpers import prod +from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType +from extra.llama_kernels import FP8_MAX, NUM_WG, THREADS_PER_WG, alloc_like, alloc_local, scalar_amax @functools.cache -def _custom_quantize_fp8_with_amax(fp8_out:UOp, amax_partial:UOp, x:UOp, amax_state:UOp, dname:str) -> UOp: - n_elems = 1 - for d in x.shape: n_elems *= d - threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(NUM_WG, "gidx0") - mem = n_elems * 2 + n_elems + 4 + NUM_WG * 4 - sink = UOp.sink(fp8_out.base, amax_partial.base, x.base, amax_state.base, threads, workgroups, - arg=KernelInfo(f"quantize_fp8_with_amax_{n_elems}", estimates=Estimates(ops=3*n_elems, mem=mem))) - src = (pathlib.Path(__file__).parent/"quantize_fp8_with_amax.cpp").read_text() - defines = [f"-DN_ELEMS={n_elems}", f"-DNUM_WG={NUM_WG}", f"-DTHREADS_PER_WG={THREADS_PER_WG}"] - return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)), - UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=compile_hip(src, defines)))) +def _custom_quantize_fp8_with_amax(fp8_out:UOp, amax_partial:UOp, x:UOp, amax_state:UOp) -> UOp: + VEC = 8 + n_elems = prod(x.shape) + assert n_elems % (NUM_WG * THREADS_PER_WG * VEC) == 0 + assert amax_partial.shape[0] == NUM_WG + + x = x.reshape(n_elems) + fp8_out = fp8_out.reshape(n_elems) + + wg = UOp.range(NUM_WG, 0, AxisType.GLOBAL) + tid = UOp.range(THREADS_PER_WG, 1, AxisType.LOCAL) + it = UOp.range((n_elems // VEC) // (NUM_WG * THREADS_PER_WG), 2, AxisType.LOOP) + lane = UOp.range(VEC, 3, AxisType.UNROLL) + + idx = (((it * NUM_WG + wg) * THREADS_PER_WG + tid) * VEC) + lane + + scale = FP8_MAX / (amax_state[0].cast(dtypes.float) + 1e-8) + x_f = x[idx].cast(dtypes.float) + abs_x = (x_f < 0.0).where(-x_f, x_f) + scaled = (x_f * scale).maximum(-FP8_MAX).minimum(FP8_MAX) + + fp8_store = fp8_out[idx].store(scaled.cast(fp8_out.dtype.base)).end(lane) + lane_max = abs_x.reduce(lane, arg=Ops.MAX) + + lmax = UOp.placeholder((1,), dtypes.float, slot=1, addrspace=AddrSpace.REG) + lmax_init = lmax.after(wg, tid)[0].store(0.0) + lmax_prev = lmax.after(lmax_init, it)[0] + lmax_store = lmax.after(fp8_store)[0].store(lmax_prev.maximum(lane_max)) + lmax_val = lmax.after(lmax_store.end(it))[0] + + lds = UOp.placeholder((THREADS_PER_WG,), dtypes.float, slot=0, addrspace=AddrSpace.LOCAL) + lds = lds.after(lds[tid].store(lmax_val).barrier()) + + step = THREADS_PER_WG // 2 + while step: + active = tid < step + other = lds[tid + step].load(UOp.const(dtypes.float, 0.0), active) + lds = lds.after(lds[tid].store(lds[tid].maximum(other), gate=active).barrier()) + step //= 2 + + amax_store = amax_partial[tid.eq(0).where(wg, UOp.invalid())].store(lds[0]) + return amax_store.end(tid, wg).sink(arg=KernelInfo(f"quantize_fp8_with_amax_{n_elems}", opts_to_apply=())) @functools.cache -def _custom_quantize_fp8_scalar(fp8_out:UOp, x:UOp, amax_state:UOp, dname:str) -> UOp: - n_elems = 1 - for d in x.shape: n_elems *= d - threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(NUM_WG, "gidx0") - mem = n_elems * 2 + n_elems - sink = UOp.sink(fp8_out.base, x.base, amax_state.base, threads, workgroups, - arg=KernelInfo(f"quantize_fp8_scalar_{n_elems}", estimates=Estimates(ops=2*n_elems, mem=mem))) - src = (pathlib.Path(__file__).parent/"quantize_fp8_scalar.cpp").read_text() - defines = [f"-DN_ELEMS={n_elems}", f"-DNUM_WG={NUM_WG}", f"-DTHREADS_PER_WG={THREADS_PER_WG}"] - return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)), - UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=compile_hip(src, defines)))) +def _custom_quantize_fp8_scalar(fp8_out:UOp, x:UOp, amax_state:UOp) -> UOp: + n_elems = prod(x.shape) + i = UOp.range(n_elems, 0) + + x_f = x.reshape(n_elems)[i].cast(dtypes.float) + scale = FP8_MAX / (amax_state[0].cast(dtypes.float) + 1e-8) + store = fp8_out.reshape(n_elems)[i].store((x_f * scale).cast(fp8_out.dtype.base)) + + return store.end(i).sink(arg=KernelInfo(f"quantize_fp8_scalar_{n_elems}")) def _quantize_fp8_delayed_bwd(gradient:UOp, kernel:UOp): # NOTE: STE-equivalent backward — grad_x = grad_fp8 * scale, scale = FP8_MAX / amax_state. @@ -49,8 +78,10 @@ def quantize_fp8_delayed(x:Tensor, amax_state:Tensor, fp8_dtype=dtypes.fp8e4m3) assert x.dtype == dtypes.bfloat16, f"expected bf16, got {x.dtype}" axis = x.uop.axis if isinstance(x.device, tuple) else None fp8_out = alloc_like(x.shape, fp8_dtype, x.device, axis) + n_elems = prod(x.uop.shard_shape) + assert n_elems % NUM_WG == 0, f"{n_elems=} must divide over {NUM_WG=}" amax_partial = alloc_local((NUM_WG,), dtypes.float32, x.device, axis) - fxn = functools.partial(_custom_quantize_fp8_with_amax, dname=dname_of(x.device)) + fxn = _custom_quantize_fp8_with_amax fp8_out, amax_partial, *_ = Tensor.custom_kernel(fp8_out, amax_partial, x, amax_state, fxn=fxn, grad_fxn=_quantize_fp8_delayed_bwd) new_amax = scalar_amax(amax_partial) @@ -62,6 +93,6 @@ def quantize_fp8_scalar(x:Tensor, amax_state:Tensor, fp8_dtype=dtypes.fp8e4m3) - # NOTE: pure one-pass bf16 -> fp8 quantize with delayed scalar scale. No amax computation. axis = x.uop.axis if isinstance(x.device, tuple) else None fp8_out = alloc_like(x.shape, fp8_dtype, x.device, axis) - fxn = functools.partial(_custom_quantize_fp8_scalar, dname=dname_of(x.device)) + fxn = _custom_quantize_fp8_scalar fp8_out, *_ = Tensor.custom_kernel(fp8_out, x, amax_state, fxn=fxn) return fp8_out diff --git a/extra/llama_kernels/quantize_fp8_delayed/quantize_fp8_scalar.cpp b/extra/llama_kernels/quantize_fp8_delayed/quantize_fp8_scalar.cpp deleted file mode 100644 index 33c1636f6e..0000000000 --- a/extra/llama_kernels/quantize_fp8_delayed/quantize_fp8_scalar.cpp +++ /dev/null @@ -1,48 +0,0 @@ -#include -#include -#include - -// Pure one-pass bf16 -> fp8 quantize with delayed scalar scale. No amax computation. - -#ifndef N_ELEMS -#define N_ELEMS 67108864 -#endif -#ifndef NUM_WG -#define NUM_WG 1024 -#endif -#ifndef THREADS_PER_WG -#define THREADS_PER_WG 256 -#endif - -constexpr int VEC = 8; -constexpr float FP8_MAX = 448.0f; - -static_assert(N_ELEMS % VEC == 0, "N_ELEMS must be divisible by VEC"); - -extern "C" __global__ __launch_bounds__(THREADS_PER_WG) void -quantize_fp8_scalar( - __hip_fp8_storage_t* __restrict__ fp8_out, // fp8, N_ELEMS - const __hip_bfloat16* __restrict__ x, // bf16, N_ELEMS - const float* __restrict__ amax_state) // fp32 scalar (delayed) -{ - const int tid = threadIdx.x; - const int wg = blockIdx.x; - const int gid = wg * THREADS_PER_WG + tid; - const int stride_elems = NUM_WG * THREADS_PER_WG * VEC; - - const float scale = FP8_MAX / (static_cast(*amax_state) + 1e-8f); - - for (int base = gid * VEC; base < N_ELEMS; base += stride_elems) { - float4 x_raw = *reinterpret_cast(&x[base]); - const __hip_bfloat16 *xi = reinterpret_cast(&x_raw); - - __hip_fp8_storage_t out[VEC]; - #pragma unroll - for (int i = 0; i < VEC; i++) { - const float v = static_cast(xi[i]); - const float scaled = fmaxf(-FP8_MAX, fminf(FP8_MAX, v * scale)); - out[i] = __hip_cvt_float_to_fp8(scaled, __HIP_SATFINITE, __HIP_E4M3); - } - *reinterpret_cast(&fp8_out[base]) = *reinterpret_cast(out); - } -} diff --git a/extra/llama_kernels/quantize_fp8_delayed/quantize_fp8_with_amax.cpp b/extra/llama_kernels/quantize_fp8_delayed/quantize_fp8_with_amax.cpp deleted file mode 100644 index 360e4c1af4..0000000000 --- a/extra/llama_kernels/quantize_fp8_delayed/quantize_fp8_with_amax.cpp +++ /dev/null @@ -1,63 +0,0 @@ -#include -#include -#include - -// One-pass bf16 -> fp8 quantize using a scalar delayed amax state, -// AND simultaneously computes per-WG |x| max partials for the next step's amax state. -// Saves one full HBM pass over the grad tensor vs. doing quantize + separate abs().max(). - -#ifndef N_ELEMS -#define N_ELEMS 67108864 -#endif -#ifndef NUM_WG -#define NUM_WG 1024 -#endif -#ifndef THREADS_PER_WG -#define THREADS_PER_WG 256 -#endif - -constexpr int VEC = 8; -constexpr float FP8_MAX = 448.0f; - -static_assert(N_ELEMS % VEC == 0, "N_ELEMS must be divisible by VEC"); - -extern "C" __global__ __launch_bounds__(THREADS_PER_WG) void -quantize_fp8_with_amax( - __hip_fp8_storage_t* __restrict__ fp8_out, // out: fp8, N_ELEMS - float* __restrict__ amax_partial, // out: fp32, NUM_WG per-WG partials - const __hip_bfloat16* __restrict__ x, // in: bf16, N_ELEMS - const float* __restrict__ amax_state) // in: fp32 scalar (delayed) -{ - __shared__ float sdata[THREADS_PER_WG]; - - const int tid = threadIdx.x; - const int wg = blockIdx.x; - const int gid = wg * THREADS_PER_WG + tid; - const int stride_elems = NUM_WG * THREADS_PER_WG * VEC; - - const float scale = FP8_MAX / (static_cast(*amax_state) + 1e-8f); - float local_max = 0.0f; - - for (int base = gid * VEC; base < N_ELEMS; base += stride_elems) { - float4 x_raw = *reinterpret_cast(&x[base]); - const __hip_bfloat16 *xi = reinterpret_cast(&x_raw); - - __hip_fp8_storage_t out[VEC]; - #pragma unroll - for (int i = 0; i < VEC; i++) { - const float v = static_cast(xi[i]); - local_max = fmaxf(local_max, fabsf(v)); - const float scaled = fmaxf(-FP8_MAX, fminf(FP8_MAX, v * scale)); - out[i] = __hip_cvt_float_to_fp8(scaled, __HIP_SATFINITE, __HIP_E4M3); - } - *reinterpret_cast(&fp8_out[base]) = *reinterpret_cast(out); - } - - sdata[tid] = local_max; - __syncthreads(); - for (int s = THREADS_PER_WG / 2; s > 0; s >>= 1) { - if (tid < s) sdata[tid] = fmaxf(sdata[tid], sdata[tid + s]); - __syncthreads(); - } - if (tid == 0) amax_partial[wg] = sdata[0]; -} diff --git a/test/backend/test_llama_kernels.py b/test/backend/test_llama_kernels.py index ab3868fe7c..fe77c49662 100644 --- a/test/backend/test_llama_kernels.py +++ b/test/backend/test_llama_kernels.py @@ -1,6 +1,10 @@ import unittest from tinygrad import Tensor, Device, dtypes, Context +from tinygrad.helpers import getenv +from examples.mlperf.models.flat_llama import FP8_DTYPE, quantize_fp8 from extra.llama_kernels.fused_ce import fused_ce_loss +from extra.llama_kernels.quantize_fp8_delayed import quantize_fp8_delayed, quantize_fp8_scalar +from test.helpers import needs_second_gpu def run_fused_ce(bs:int, seqlen:int, vocab:int, label_smoothing:float=0.0) -> None: Tensor.manual_seed(0) @@ -23,8 +27,10 @@ def run_fused_ce(bs:int, seqlen:int, vocab:int, label_smoothing:float=0.0) -> No assert loss.allclose(ref, atol=2e-3, rtol=2e-3).item(), "forward mismatch" assert logits.grad.allclose(logits_ref.grad, atol=2e-3, rtol=2e-3).item(), "grad mismatch" -@unittest.skipUnless(dtypes.bfloat16 in Device[Device.DEFAULT].renderer.supported_dtypes(), "need bfloat16") class TestFusedCE(unittest.TestCase): + def setUp(self): + if dtypes.bfloat16 not in Device[Device.DEFAULT].renderer.supported_dtypes(): self.skipTest("need bfloat16") + def test_fused_ce_1_2_16(self): run_fused_ce(1, 2, 16, label_smoothing=0.2) def test_fused_ce_2_16_128(self): run_fused_ce(2, 16, 128) def test_fused_ce_4_128_1024(self): run_fused_ce(4, 128, 1024, label_smoothing=0.2) @@ -32,5 +38,49 @@ class TestFusedCE(unittest.TestCase): # note: this is the shape used in llama 8b #def test_fused_ce_smoothing_16_1024_128256(self): run_fused_ce(16, 1024, 128256, label_smoothing=0.2) +def run_quantize_fp8(shape:tuple[int, ...], delayed:bool=True) -> None: + Tensor.manual_seed(0) + x = Tensor.randn(*shape).cast(dtypes.bfloat16).contiguous() + amax_state = Tensor.full((), 2.0, dtype=dtypes.float32).contiguous() + with Context(DEBUG=0): Tensor.realize(x, amax_state) + + if delayed: + fp8, inv_scale, new_amax, _ = quantize_fp8_delayed(x, amax_state, FP8_DTYPE) + ref_fp8, ref_inv_scale, ref_new_amax = quantize_fp8(x, amax_state=amax_state) + Tensor.realize(fp8, inv_scale, new_amax) + Tensor.realize(ref_fp8, ref_inv_scale, ref_new_amax) + else: + fp8 = quantize_fp8_scalar(x, amax_state, FP8_DTYPE) + ref_fp8, _, _ = quantize_fp8(x, amax_state=amax_state) + Tensor.realize(fp8) + Tensor.realize(ref_fp8) + + with Context(DEBUG=0): + assert fp8.cast(dtypes.float).allclose(ref_fp8.cast(dtypes.float), atol=0, rtol=0).item(), "fp8 mismatch" + if delayed: + assert inv_scale.allclose(ref_inv_scale, atol=0, rtol=0).item(), "inv_scale mismatch" + assert new_amax.allclose(ref_new_amax, atol=0, rtol=0).item(), \ + f"amax mismatch: got={new_amax.item()} ref={ref_new_amax.item()} diff={abs(new_amax.item()-ref_new_amax.item())}" + +class TestQuantizeFP8(unittest.TestCase): + def setUp(self): + ren = Device[Device.DEFAULT].renderer + if dtypes.bfloat16 not in ren.supported_dtypes(): self.skipTest("need bfloat16") + if not ren.has_local or not ren.has_shared: self.skipTest("need local/shared") + + def test_scalar(self): run_quantize_fp8((getenv("N", 1024), 32), delayed=False) + def test_delayed(self): run_quantize_fp8((getenv("N", 2048), 1024)) + + @needs_second_gpu + def test_multi(self): + devs = tuple(f"{Device.DEFAULT}:{i}" for i in range(8)) + x = Tensor.empty(2048*8, 1024, dtype=dtypes.bfloat16, device=devs).uop.multi(0) + x = Tensor(x, device=devs) + amax_state = Tensor.full((), 2.0, dtype=dtypes.float32, device=devs).contiguous() + fp8, _, new_amax, _ = quantize_fp8_delayed(x, amax_state, FP8_DTYPE) + Tensor.realize(fp8, new_amax) + assert fp8.uop.shape == x.uop.shape + assert new_amax.shape == () + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/codegen/late/devectorizer.py b/tinygrad/codegen/late/devectorizer.py index dcf6284c84..cde5fda449 100644 --- a/tinygrad/codegen/late/devectorizer.py +++ b/tinygrad/codegen/late/devectorizer.py @@ -358,7 +358,8 @@ pm_add_loads = PatternMatcher([ # add loads to non ptr index (UPat(Ops.INDEX, name="idx"), add_load), # remove loads from stores - (UPat(Ops.STORE, src=(UPat(Ops.LOAD), UPat(name="val")), name="s"), lambda s,val: s.replace(src=(s.src[0].src[0], val))), + (UPat(Ops.STORE, src=(UPat(Ops.LOAD),), allow_any_len=True, name="s"), lambda s: s.replace(src=(s.src[0].src[0],)+s.src[1:])), + (UPat(Ops.LOAD, src=(UPat(Ops.LOAD),), allow_any_len=True, name="l"), lambda l: l.replace(src=(l.src[0].src[0],)+l.src[1:])), ]) # make images diff --git a/tinygrad/viz/cli.py b/tinygrad/viz/cli.py index 5ac88b918a..c1b6aeaa03 100755 --- a/tinygrad/viz/cli.py +++ b/tinygrad/viz/cli.py @@ -199,7 +199,8 @@ def main(args) -> None: if DEBUG >= 3 and s["name"] == "View Base AST": print_step(s) if DEBUG >= 4 and s["name"] == "View Source": print_step(s) if DEBUG >= 5 or ls: print(emit(" "*s["depth"]+s["name"]+(f" - {s['match_count']}" if s.get('match_count', 0) else ''))) - if DEBUG >= 6 or (DEBUG >= 5 and s["name"] == "View Kernel Graph") or (s["name"] in args.src): print_step(s, print_graph=True) + if DEBUG >= 6 or (DEBUG >= 5 and s["name"] == "View Kernel Graph") or (s["name"] in args.src): + print_step(s, print_graph=True, reconstruct_matches=s["name"] in args.src) if DEBUG >= 7: print_step(s, reconstruct_matches=True) elif DEBUG >= 3 and k.get("ext"): print(emit(k["ext"])) for k in (produce_top_kernels if args.t else produce_all_kernels)(): render_event(k)