quantize_fp8 kernels in uops (#16288)

* add tests

* simple UOp kernel is n^2

* fast kernel matching c++, opts_to_apply=()

* remove cpp

* simple o(n) kernel, two passes

* fuse the loops

* works on DEV=CPU

* multi regression test

* fix multi, this can possibly be its own bugfix

* test cleanups

* minimal diff

* match C in UOps

* Revert "match C in UOps"

This reverts commit 0bef740c30.

* edit test

* match speed with C try 2

* needs_second_gpu

* cleanup
This commit is contained in:
qazal
2026-05-22 14:54:06 +03:00
committed by GitHub
parent 3115952266
commit bbfe4f80ec
6 changed files with 115 additions and 143 deletions

View File

@@ -1,35 +1,64 @@
from __future__ import annotations
import functools, pathlib
import functools
from tinygrad import Tensor, dtypes
from tinygrad.uop.ops import UOp, Ops, KernelInfo
from tinygrad.renderer import Estimates
from extra.llama_kernels import FP8_MAX, NUM_WG, THREADS_PER_WG, alloc_like, alloc_local, scalar_amax, dname_of, compile_hip
from tinygrad.dtype import AddrSpace
from tinygrad.helpers import prod
from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType
from extra.llama_kernels import FP8_MAX, NUM_WG, THREADS_PER_WG, alloc_like, alloc_local, scalar_amax
@functools.cache
def _custom_quantize_fp8_with_amax(fp8_out:UOp, amax_partial:UOp, x:UOp, amax_state:UOp, dname:str) -> UOp:
n_elems = 1
for d in x.shape: n_elems *= d
threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(NUM_WG, "gidx0")
mem = n_elems * 2 + n_elems + 4 + NUM_WG * 4
sink = UOp.sink(fp8_out.base, amax_partial.base, x.base, amax_state.base, threads, workgroups,
arg=KernelInfo(f"quantize_fp8_with_amax_{n_elems}", estimates=Estimates(ops=3*n_elems, mem=mem)))
src = (pathlib.Path(__file__).parent/"quantize_fp8_with_amax.cpp").read_text()
defines = [f"-DN_ELEMS={n_elems}", f"-DNUM_WG={NUM_WG}", f"-DTHREADS_PER_WG={THREADS_PER_WG}"]
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)),
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=compile_hip(src, defines))))
def _custom_quantize_fp8_with_amax(fp8_out:UOp, amax_partial:UOp, x:UOp, amax_state:UOp) -> UOp:
VEC = 8
n_elems = prod(x.shape)
assert n_elems % (NUM_WG * THREADS_PER_WG * VEC) == 0
assert amax_partial.shape[0] == NUM_WG
x = x.reshape(n_elems)
fp8_out = fp8_out.reshape(n_elems)
wg = UOp.range(NUM_WG, 0, AxisType.GLOBAL)
tid = UOp.range(THREADS_PER_WG, 1, AxisType.LOCAL)
it = UOp.range((n_elems // VEC) // (NUM_WG * THREADS_PER_WG), 2, AxisType.LOOP)
lane = UOp.range(VEC, 3, AxisType.UNROLL)
idx = (((it * NUM_WG + wg) * THREADS_PER_WG + tid) * VEC) + lane
scale = FP8_MAX / (amax_state[0].cast(dtypes.float) + 1e-8)
x_f = x[idx].cast(dtypes.float)
abs_x = (x_f < 0.0).where(-x_f, x_f)
scaled = (x_f * scale).maximum(-FP8_MAX).minimum(FP8_MAX)
fp8_store = fp8_out[idx].store(scaled.cast(fp8_out.dtype.base)).end(lane)
lane_max = abs_x.reduce(lane, arg=Ops.MAX)
lmax = UOp.placeholder((1,), dtypes.float, slot=1, addrspace=AddrSpace.REG)
lmax_init = lmax.after(wg, tid)[0].store(0.0)
lmax_prev = lmax.after(lmax_init, it)[0]
lmax_store = lmax.after(fp8_store)[0].store(lmax_prev.maximum(lane_max))
lmax_val = lmax.after(lmax_store.end(it))[0]
lds = UOp.placeholder((THREADS_PER_WG,), dtypes.float, slot=0, addrspace=AddrSpace.LOCAL)
lds = lds.after(lds[tid].store(lmax_val).barrier())
step = THREADS_PER_WG // 2
while step:
active = tid < step
other = lds[tid + step].load(UOp.const(dtypes.float, 0.0), active)
lds = lds.after(lds[tid].store(lds[tid].maximum(other), gate=active).barrier())
step //= 2
amax_store = amax_partial[tid.eq(0).where(wg, UOp.invalid())].store(lds[0])
return amax_store.end(tid, wg).sink(arg=KernelInfo(f"quantize_fp8_with_amax_{n_elems}", opts_to_apply=()))
@functools.cache
def _custom_quantize_fp8_scalar(fp8_out:UOp, x:UOp, amax_state:UOp, dname:str) -> UOp:
n_elems = 1
for d in x.shape: n_elems *= d
threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(NUM_WG, "gidx0")
mem = n_elems * 2 + n_elems
sink = UOp.sink(fp8_out.base, x.base, amax_state.base, threads, workgroups,
arg=KernelInfo(f"quantize_fp8_scalar_{n_elems}", estimates=Estimates(ops=2*n_elems, mem=mem)))
src = (pathlib.Path(__file__).parent/"quantize_fp8_scalar.cpp").read_text()
defines = [f"-DN_ELEMS={n_elems}", f"-DNUM_WG={NUM_WG}", f"-DTHREADS_PER_WG={THREADS_PER_WG}"]
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)),
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=compile_hip(src, defines))))
def _custom_quantize_fp8_scalar(fp8_out:UOp, x:UOp, amax_state:UOp) -> UOp:
n_elems = prod(x.shape)
i = UOp.range(n_elems, 0)
x_f = x.reshape(n_elems)[i].cast(dtypes.float)
scale = FP8_MAX / (amax_state[0].cast(dtypes.float) + 1e-8)
store = fp8_out.reshape(n_elems)[i].store((x_f * scale).cast(fp8_out.dtype.base))
return store.end(i).sink(arg=KernelInfo(f"quantize_fp8_scalar_{n_elems}"))
def _quantize_fp8_delayed_bwd(gradient:UOp, kernel:UOp):
# NOTE: STE-equivalent backward — grad_x = grad_fp8 * scale, scale = FP8_MAX / amax_state.
@@ -49,8 +78,10 @@ def quantize_fp8_delayed(x:Tensor, amax_state:Tensor, fp8_dtype=dtypes.fp8e4m3)
assert x.dtype == dtypes.bfloat16, f"expected bf16, got {x.dtype}"
axis = x.uop.axis if isinstance(x.device, tuple) else None
fp8_out = alloc_like(x.shape, fp8_dtype, x.device, axis)
n_elems = prod(x.uop.shard_shape)
assert n_elems % NUM_WG == 0, f"{n_elems=} must divide over {NUM_WG=}"
amax_partial = alloc_local((NUM_WG,), dtypes.float32, x.device, axis)
fxn = functools.partial(_custom_quantize_fp8_with_amax, dname=dname_of(x.device))
fxn = _custom_quantize_fp8_with_amax
fp8_out, amax_partial, *_ = Tensor.custom_kernel(fp8_out, amax_partial, x, amax_state,
fxn=fxn, grad_fxn=_quantize_fp8_delayed_bwd)
new_amax = scalar_amax(amax_partial)
@@ -62,6 +93,6 @@ def quantize_fp8_scalar(x:Tensor, amax_state:Tensor, fp8_dtype=dtypes.fp8e4m3) -
# NOTE: pure one-pass bf16 -> fp8 quantize with delayed scalar scale. No amax computation.
axis = x.uop.axis if isinstance(x.device, tuple) else None
fp8_out = alloc_like(x.shape, fp8_dtype, x.device, axis)
fxn = functools.partial(_custom_quantize_fp8_scalar, dname=dname_of(x.device))
fxn = _custom_quantize_fp8_scalar
fp8_out, *_ = Tensor.custom_kernel(fp8_out, x, amax_state, fxn=fxn)
return fp8_out

View File

@@ -1,48 +0,0 @@
#include <hip/hip_runtime.h>
#include <hip/hip_bf16.h>
#include <hip/hip_fp8.h>
// Pure one-pass bf16 -> fp8 quantize with delayed scalar scale. No amax computation.
#ifndef N_ELEMS
#define N_ELEMS 67108864
#endif
#ifndef NUM_WG
#define NUM_WG 1024
#endif
#ifndef THREADS_PER_WG
#define THREADS_PER_WG 256
#endif
constexpr int VEC = 8;
constexpr float FP8_MAX = 448.0f;
static_assert(N_ELEMS % VEC == 0, "N_ELEMS must be divisible by VEC");
extern "C" __global__ __launch_bounds__(THREADS_PER_WG) void
quantize_fp8_scalar(
__hip_fp8_storage_t* __restrict__ fp8_out, // fp8, N_ELEMS
const __hip_bfloat16* __restrict__ x, // bf16, N_ELEMS
const float* __restrict__ amax_state) // fp32 scalar (delayed)
{
const int tid = threadIdx.x;
const int wg = blockIdx.x;
const int gid = wg * THREADS_PER_WG + tid;
const int stride_elems = NUM_WG * THREADS_PER_WG * VEC;
const float scale = FP8_MAX / (static_cast<float>(*amax_state) + 1e-8f);
for (int base = gid * VEC; base < N_ELEMS; base += stride_elems) {
float4 x_raw = *reinterpret_cast<const float4*>(&x[base]);
const __hip_bfloat16 *xi = reinterpret_cast<const __hip_bfloat16*>(&x_raw);
__hip_fp8_storage_t out[VEC];
#pragma unroll
for (int i = 0; i < VEC; i++) {
const float v = static_cast<float>(xi[i]);
const float scaled = fmaxf(-FP8_MAX, fminf(FP8_MAX, v * scale));
out[i] = __hip_cvt_float_to_fp8(scaled, __HIP_SATFINITE, __HIP_E4M3);
}
*reinterpret_cast<uint64_t*>(&fp8_out[base]) = *reinterpret_cast<uint64_t*>(out);
}
}

View File

@@ -1,63 +0,0 @@
#include <hip/hip_runtime.h>
#include <hip/hip_bf16.h>
#include <hip/hip_fp8.h>
// One-pass bf16 -> fp8 quantize using a scalar delayed amax state,
// AND simultaneously computes per-WG |x| max partials for the next step's amax state.
// Saves one full HBM pass over the grad tensor vs. doing quantize + separate abs().max().
#ifndef N_ELEMS
#define N_ELEMS 67108864
#endif
#ifndef NUM_WG
#define NUM_WG 1024
#endif
#ifndef THREADS_PER_WG
#define THREADS_PER_WG 256
#endif
constexpr int VEC = 8;
constexpr float FP8_MAX = 448.0f;
static_assert(N_ELEMS % VEC == 0, "N_ELEMS must be divisible by VEC");
extern "C" __global__ __launch_bounds__(THREADS_PER_WG) void
quantize_fp8_with_amax(
__hip_fp8_storage_t* __restrict__ fp8_out, // out: fp8, N_ELEMS
float* __restrict__ amax_partial, // out: fp32, NUM_WG per-WG partials
const __hip_bfloat16* __restrict__ x, // in: bf16, N_ELEMS
const float* __restrict__ amax_state) // in: fp32 scalar (delayed)
{
__shared__ float sdata[THREADS_PER_WG];
const int tid = threadIdx.x;
const int wg = blockIdx.x;
const int gid = wg * THREADS_PER_WG + tid;
const int stride_elems = NUM_WG * THREADS_PER_WG * VEC;
const float scale = FP8_MAX / (static_cast<float>(*amax_state) + 1e-8f);
float local_max = 0.0f;
for (int base = gid * VEC; base < N_ELEMS; base += stride_elems) {
float4 x_raw = *reinterpret_cast<const float4*>(&x[base]);
const __hip_bfloat16 *xi = reinterpret_cast<const __hip_bfloat16*>(&x_raw);
__hip_fp8_storage_t out[VEC];
#pragma unroll
for (int i = 0; i < VEC; i++) {
const float v = static_cast<float>(xi[i]);
local_max = fmaxf(local_max, fabsf(v));
const float scaled = fmaxf(-FP8_MAX, fminf(FP8_MAX, v * scale));
out[i] = __hip_cvt_float_to_fp8(scaled, __HIP_SATFINITE, __HIP_E4M3);
}
*reinterpret_cast<uint64_t*>(&fp8_out[base]) = *reinterpret_cast<uint64_t*>(out);
}
sdata[tid] = local_max;
__syncthreads();
for (int s = THREADS_PER_WG / 2; s > 0; s >>= 1) {
if (tid < s) sdata[tid] = fmaxf(sdata[tid], sdata[tid + s]);
__syncthreads();
}
if (tid == 0) amax_partial[wg] = sdata[0];
}

View File

@@ -1,6 +1,10 @@
import unittest
from tinygrad import Tensor, Device, dtypes, Context
from tinygrad.helpers import getenv
from examples.mlperf.models.flat_llama import FP8_DTYPE, quantize_fp8
from extra.llama_kernels.fused_ce import fused_ce_loss
from extra.llama_kernels.quantize_fp8_delayed import quantize_fp8_delayed, quantize_fp8_scalar
from test.helpers import needs_second_gpu
def run_fused_ce(bs:int, seqlen:int, vocab:int, label_smoothing:float=0.0) -> None:
Tensor.manual_seed(0)
@@ -23,8 +27,10 @@ def run_fused_ce(bs:int, seqlen:int, vocab:int, label_smoothing:float=0.0) -> No
assert loss.allclose(ref, atol=2e-3, rtol=2e-3).item(), "forward mismatch"
assert logits.grad.allclose(logits_ref.grad, atol=2e-3, rtol=2e-3).item(), "grad mismatch"
@unittest.skipUnless(dtypes.bfloat16 in Device[Device.DEFAULT].renderer.supported_dtypes(), "need bfloat16")
class TestFusedCE(unittest.TestCase):
def setUp(self):
if dtypes.bfloat16 not in Device[Device.DEFAULT].renderer.supported_dtypes(): self.skipTest("need bfloat16")
def test_fused_ce_1_2_16(self): run_fused_ce(1, 2, 16, label_smoothing=0.2)
def test_fused_ce_2_16_128(self): run_fused_ce(2, 16, 128)
def test_fused_ce_4_128_1024(self): run_fused_ce(4, 128, 1024, label_smoothing=0.2)
@@ -32,5 +38,49 @@ class TestFusedCE(unittest.TestCase):
# note: this is the shape used in llama 8b
#def test_fused_ce_smoothing_16_1024_128256(self): run_fused_ce(16, 1024, 128256, label_smoothing=0.2)
def run_quantize_fp8(shape:tuple[int, ...], delayed:bool=True) -> None:
Tensor.manual_seed(0)
x = Tensor.randn(*shape).cast(dtypes.bfloat16).contiguous()
amax_state = Tensor.full((), 2.0, dtype=dtypes.float32).contiguous()
with Context(DEBUG=0): Tensor.realize(x, amax_state)
if delayed:
fp8, inv_scale, new_amax, _ = quantize_fp8_delayed(x, amax_state, FP8_DTYPE)
ref_fp8, ref_inv_scale, ref_new_amax = quantize_fp8(x, amax_state=amax_state)
Tensor.realize(fp8, inv_scale, new_amax)
Tensor.realize(ref_fp8, ref_inv_scale, ref_new_amax)
else:
fp8 = quantize_fp8_scalar(x, amax_state, FP8_DTYPE)
ref_fp8, _, _ = quantize_fp8(x, amax_state=amax_state)
Tensor.realize(fp8)
Tensor.realize(ref_fp8)
with Context(DEBUG=0):
assert fp8.cast(dtypes.float).allclose(ref_fp8.cast(dtypes.float), atol=0, rtol=0).item(), "fp8 mismatch"
if delayed:
assert inv_scale.allclose(ref_inv_scale, atol=0, rtol=0).item(), "inv_scale mismatch"
assert new_amax.allclose(ref_new_amax, atol=0, rtol=0).item(), \
f"amax mismatch: got={new_amax.item()} ref={ref_new_amax.item()} diff={abs(new_amax.item()-ref_new_amax.item())}"
class TestQuantizeFP8(unittest.TestCase):
def setUp(self):
ren = Device[Device.DEFAULT].renderer
if dtypes.bfloat16 not in ren.supported_dtypes(): self.skipTest("need bfloat16")
if not ren.has_local or not ren.has_shared: self.skipTest("need local/shared")
def test_scalar(self): run_quantize_fp8((getenv("N", 1024), 32), delayed=False)
def test_delayed(self): run_quantize_fp8((getenv("N", 2048), 1024))
@needs_second_gpu
def test_multi(self):
devs = tuple(f"{Device.DEFAULT}:{i}" for i in range(8))
x = Tensor.empty(2048*8, 1024, dtype=dtypes.bfloat16, device=devs).uop.multi(0)
x = Tensor(x, device=devs)
amax_state = Tensor.full((), 2.0, dtype=dtypes.float32, device=devs).contiguous()
fp8, _, new_amax, _ = quantize_fp8_delayed(x, amax_state, FP8_DTYPE)
Tensor.realize(fp8, new_amax)
assert fp8.uop.shape == x.uop.shape
assert new_amax.shape == ()
if __name__ == '__main__':
unittest.main()

View File

@@ -358,7 +358,8 @@ pm_add_loads = PatternMatcher([
# add loads to non ptr index
(UPat(Ops.INDEX, name="idx"), add_load),
# remove loads from stores
(UPat(Ops.STORE, src=(UPat(Ops.LOAD), UPat(name="val")), name="s"), lambda s,val: s.replace(src=(s.src[0].src[0], val))),
(UPat(Ops.STORE, src=(UPat(Ops.LOAD),), allow_any_len=True, name="s"), lambda s: s.replace(src=(s.src[0].src[0],)+s.src[1:])),
(UPat(Ops.LOAD, src=(UPat(Ops.LOAD),), allow_any_len=True, name="l"), lambda l: l.replace(src=(l.src[0].src[0],)+l.src[1:])),
])
# make images

View File

@@ -199,7 +199,8 @@ def main(args) -> None:
if DEBUG >= 3 and s["name"] == "View Base AST": print_step(s)
if DEBUG >= 4 and s["name"] == "View Source": print_step(s)
if DEBUG >= 5 or ls: print(emit(" "*s["depth"]+s["name"]+(f" - {s['match_count']}" if s.get('match_count', 0) else '')))
if DEBUG >= 6 or (DEBUG >= 5 and s["name"] == "View Kernel Graph") or (s["name"] in args.src): print_step(s, print_graph=True)
if DEBUG >= 6 or (DEBUG >= 5 and s["name"] == "View Kernel Graph") or (s["name"] in args.src):
print_step(s, print_graph=True, reconstruct_matches=s["name"] in args.src)
if DEBUG >= 7: print_step(s, reconstruct_matches=True)
elif DEBUG >= 3 and k.get("ext"): print(emit(k["ext"]))
for k in (produce_top_kernels if args.t else produce_all_kernels)(): render_event(k)