mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-11 23:46:02 +08:00
quantize_fp8 kernels in uops (#16288)
* add tests
* simple UOp kernel is n^2
* fast kernel matching c++, opts_to_apply=()
* remove cpp
* simple o(n) kernel, two passes
* fuse the loops
* works on DEV=CPU
* multi regression test
* fix multi, this can possibly be its own bugfix
* test cleanups
* minimal diff
* match C in UOps
* Revert "match C in UOps"
This reverts commit 0bef740c30.
* edit test
* match speed with C try 2
* needs_second_gpu
* cleanup
This commit is contained in:
@@ -1,35 +1,64 @@
|
||||
from __future__ import annotations
|
||||
import functools, pathlib
|
||||
import functools
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo
|
||||
from tinygrad.renderer import Estimates
|
||||
from extra.llama_kernels import FP8_MAX, NUM_WG, THREADS_PER_WG, alloc_like, alloc_local, scalar_amax, dname_of, compile_hip
|
||||
from tinygrad.dtype import AddrSpace
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType
|
||||
from extra.llama_kernels import FP8_MAX, NUM_WG, THREADS_PER_WG, alloc_like, alloc_local, scalar_amax
|
||||
|
||||
@functools.cache
|
||||
def _custom_quantize_fp8_with_amax(fp8_out:UOp, amax_partial:UOp, x:UOp, amax_state:UOp, dname:str) -> UOp:
|
||||
n_elems = 1
|
||||
for d in x.shape: n_elems *= d
|
||||
threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(NUM_WG, "gidx0")
|
||||
mem = n_elems * 2 + n_elems + 4 + NUM_WG * 4
|
||||
sink = UOp.sink(fp8_out.base, amax_partial.base, x.base, amax_state.base, threads, workgroups,
|
||||
arg=KernelInfo(f"quantize_fp8_with_amax_{n_elems}", estimates=Estimates(ops=3*n_elems, mem=mem)))
|
||||
src = (pathlib.Path(__file__).parent/"quantize_fp8_with_amax.cpp").read_text()
|
||||
defines = [f"-DN_ELEMS={n_elems}", f"-DNUM_WG={NUM_WG}", f"-DTHREADS_PER_WG={THREADS_PER_WG}"]
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)),
|
||||
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=compile_hip(src, defines))))
|
||||
def _custom_quantize_fp8_with_amax(fp8_out:UOp, amax_partial:UOp, x:UOp, amax_state:UOp) -> UOp:
|
||||
VEC = 8
|
||||
n_elems = prod(x.shape)
|
||||
assert n_elems % (NUM_WG * THREADS_PER_WG * VEC) == 0
|
||||
assert amax_partial.shape[0] == NUM_WG
|
||||
|
||||
x = x.reshape(n_elems)
|
||||
fp8_out = fp8_out.reshape(n_elems)
|
||||
|
||||
wg = UOp.range(NUM_WG, 0, AxisType.GLOBAL)
|
||||
tid = UOp.range(THREADS_PER_WG, 1, AxisType.LOCAL)
|
||||
it = UOp.range((n_elems // VEC) // (NUM_WG * THREADS_PER_WG), 2, AxisType.LOOP)
|
||||
lane = UOp.range(VEC, 3, AxisType.UNROLL)
|
||||
|
||||
idx = (((it * NUM_WG + wg) * THREADS_PER_WG + tid) * VEC) + lane
|
||||
|
||||
scale = FP8_MAX / (amax_state[0].cast(dtypes.float) + 1e-8)
|
||||
x_f = x[idx].cast(dtypes.float)
|
||||
abs_x = (x_f < 0.0).where(-x_f, x_f)
|
||||
scaled = (x_f * scale).maximum(-FP8_MAX).minimum(FP8_MAX)
|
||||
|
||||
fp8_store = fp8_out[idx].store(scaled.cast(fp8_out.dtype.base)).end(lane)
|
||||
lane_max = abs_x.reduce(lane, arg=Ops.MAX)
|
||||
|
||||
lmax = UOp.placeholder((1,), dtypes.float, slot=1, addrspace=AddrSpace.REG)
|
||||
lmax_init = lmax.after(wg, tid)[0].store(0.0)
|
||||
lmax_prev = lmax.after(lmax_init, it)[0]
|
||||
lmax_store = lmax.after(fp8_store)[0].store(lmax_prev.maximum(lane_max))
|
||||
lmax_val = lmax.after(lmax_store.end(it))[0]
|
||||
|
||||
lds = UOp.placeholder((THREADS_PER_WG,), dtypes.float, slot=0, addrspace=AddrSpace.LOCAL)
|
||||
lds = lds.after(lds[tid].store(lmax_val).barrier())
|
||||
|
||||
step = THREADS_PER_WG // 2
|
||||
while step:
|
||||
active = tid < step
|
||||
other = lds[tid + step].load(UOp.const(dtypes.float, 0.0), active)
|
||||
lds = lds.after(lds[tid].store(lds[tid].maximum(other), gate=active).barrier())
|
||||
step //= 2
|
||||
|
||||
amax_store = amax_partial[tid.eq(0).where(wg, UOp.invalid())].store(lds[0])
|
||||
return amax_store.end(tid, wg).sink(arg=KernelInfo(f"quantize_fp8_with_amax_{n_elems}", opts_to_apply=()))
|
||||
|
||||
@functools.cache
|
||||
def _custom_quantize_fp8_scalar(fp8_out:UOp, x:UOp, amax_state:UOp, dname:str) -> UOp:
|
||||
n_elems = 1
|
||||
for d in x.shape: n_elems *= d
|
||||
threads, workgroups = UOp.special(THREADS_PER_WG, "lidx0"), UOp.special(NUM_WG, "gidx0")
|
||||
mem = n_elems * 2 + n_elems
|
||||
sink = UOp.sink(fp8_out.base, x.base, amax_state.base, threads, workgroups,
|
||||
arg=KernelInfo(f"quantize_fp8_scalar_{n_elems}", estimates=Estimates(ops=2*n_elems, mem=mem)))
|
||||
src = (pathlib.Path(__file__).parent/"quantize_fp8_scalar.cpp").read_text()
|
||||
defines = [f"-DN_ELEMS={n_elems}", f"-DNUM_WG={NUM_WG}", f"-DTHREADS_PER_WG={THREADS_PER_WG}"]
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)),
|
||||
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=compile_hip(src, defines))))
|
||||
def _custom_quantize_fp8_scalar(fp8_out:UOp, x:UOp, amax_state:UOp) -> UOp:
|
||||
n_elems = prod(x.shape)
|
||||
i = UOp.range(n_elems, 0)
|
||||
|
||||
x_f = x.reshape(n_elems)[i].cast(dtypes.float)
|
||||
scale = FP8_MAX / (amax_state[0].cast(dtypes.float) + 1e-8)
|
||||
store = fp8_out.reshape(n_elems)[i].store((x_f * scale).cast(fp8_out.dtype.base))
|
||||
|
||||
return store.end(i).sink(arg=KernelInfo(f"quantize_fp8_scalar_{n_elems}"))
|
||||
|
||||
def _quantize_fp8_delayed_bwd(gradient:UOp, kernel:UOp):
|
||||
# NOTE: STE-equivalent backward — grad_x = grad_fp8 * scale, scale = FP8_MAX / amax_state.
|
||||
@@ -49,8 +78,10 @@ def quantize_fp8_delayed(x:Tensor, amax_state:Tensor, fp8_dtype=dtypes.fp8e4m3)
|
||||
assert x.dtype == dtypes.bfloat16, f"expected bf16, got {x.dtype}"
|
||||
axis = x.uop.axis if isinstance(x.device, tuple) else None
|
||||
fp8_out = alloc_like(x.shape, fp8_dtype, x.device, axis)
|
||||
n_elems = prod(x.uop.shard_shape)
|
||||
assert n_elems % NUM_WG == 0, f"{n_elems=} must divide over {NUM_WG=}"
|
||||
amax_partial = alloc_local((NUM_WG,), dtypes.float32, x.device, axis)
|
||||
fxn = functools.partial(_custom_quantize_fp8_with_amax, dname=dname_of(x.device))
|
||||
fxn = _custom_quantize_fp8_with_amax
|
||||
fp8_out, amax_partial, *_ = Tensor.custom_kernel(fp8_out, amax_partial, x, amax_state,
|
||||
fxn=fxn, grad_fxn=_quantize_fp8_delayed_bwd)
|
||||
new_amax = scalar_amax(amax_partial)
|
||||
@@ -62,6 +93,6 @@ def quantize_fp8_scalar(x:Tensor, amax_state:Tensor, fp8_dtype=dtypes.fp8e4m3) -
|
||||
# NOTE: pure one-pass bf16 -> fp8 quantize with delayed scalar scale. No amax computation.
|
||||
axis = x.uop.axis if isinstance(x.device, tuple) else None
|
||||
fp8_out = alloc_like(x.shape, fp8_dtype, x.device, axis)
|
||||
fxn = functools.partial(_custom_quantize_fp8_scalar, dname=dname_of(x.device))
|
||||
fxn = _custom_quantize_fp8_scalar
|
||||
fp8_out, *_ = Tensor.custom_kernel(fp8_out, x, amax_state, fxn=fxn)
|
||||
return fp8_out
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_fp8.h>
|
||||
|
||||
// Pure one-pass bf16 -> fp8 quantize with delayed scalar scale. No amax computation.
|
||||
|
||||
#ifndef N_ELEMS
|
||||
#define N_ELEMS 67108864
|
||||
#endif
|
||||
#ifndef NUM_WG
|
||||
#define NUM_WG 1024
|
||||
#endif
|
||||
#ifndef THREADS_PER_WG
|
||||
#define THREADS_PER_WG 256
|
||||
#endif
|
||||
|
||||
constexpr int VEC = 8;
|
||||
constexpr float FP8_MAX = 448.0f;
|
||||
|
||||
static_assert(N_ELEMS % VEC == 0, "N_ELEMS must be divisible by VEC");
|
||||
|
||||
extern "C" __global__ __launch_bounds__(THREADS_PER_WG) void
|
||||
quantize_fp8_scalar(
|
||||
__hip_fp8_storage_t* __restrict__ fp8_out, // fp8, N_ELEMS
|
||||
const __hip_bfloat16* __restrict__ x, // bf16, N_ELEMS
|
||||
const float* __restrict__ amax_state) // fp32 scalar (delayed)
|
||||
{
|
||||
const int tid = threadIdx.x;
|
||||
const int wg = blockIdx.x;
|
||||
const int gid = wg * THREADS_PER_WG + tid;
|
||||
const int stride_elems = NUM_WG * THREADS_PER_WG * VEC;
|
||||
|
||||
const float scale = FP8_MAX / (static_cast<float>(*amax_state) + 1e-8f);
|
||||
|
||||
for (int base = gid * VEC; base < N_ELEMS; base += stride_elems) {
|
||||
float4 x_raw = *reinterpret_cast<const float4*>(&x[base]);
|
||||
const __hip_bfloat16 *xi = reinterpret_cast<const __hip_bfloat16*>(&x_raw);
|
||||
|
||||
__hip_fp8_storage_t out[VEC];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VEC; i++) {
|
||||
const float v = static_cast<float>(xi[i]);
|
||||
const float scaled = fmaxf(-FP8_MAX, fminf(FP8_MAX, v * scale));
|
||||
out[i] = __hip_cvt_float_to_fp8(scaled, __HIP_SATFINITE, __HIP_E4M3);
|
||||
}
|
||||
*reinterpret_cast<uint64_t*>(&fp8_out[base]) = *reinterpret_cast<uint64_t*>(out);
|
||||
}
|
||||
}
|
||||
@@ -1,63 +0,0 @@
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_fp8.h>
|
||||
|
||||
// One-pass bf16 -> fp8 quantize using a scalar delayed amax state,
|
||||
// AND simultaneously computes per-WG |x| max partials for the next step's amax state.
|
||||
// Saves one full HBM pass over the grad tensor vs. doing quantize + separate abs().max().
|
||||
|
||||
#ifndef N_ELEMS
|
||||
#define N_ELEMS 67108864
|
||||
#endif
|
||||
#ifndef NUM_WG
|
||||
#define NUM_WG 1024
|
||||
#endif
|
||||
#ifndef THREADS_PER_WG
|
||||
#define THREADS_PER_WG 256
|
||||
#endif
|
||||
|
||||
constexpr int VEC = 8;
|
||||
constexpr float FP8_MAX = 448.0f;
|
||||
|
||||
static_assert(N_ELEMS % VEC == 0, "N_ELEMS must be divisible by VEC");
|
||||
|
||||
extern "C" __global__ __launch_bounds__(THREADS_PER_WG) void
|
||||
quantize_fp8_with_amax(
|
||||
__hip_fp8_storage_t* __restrict__ fp8_out, // out: fp8, N_ELEMS
|
||||
float* __restrict__ amax_partial, // out: fp32, NUM_WG per-WG partials
|
||||
const __hip_bfloat16* __restrict__ x, // in: bf16, N_ELEMS
|
||||
const float* __restrict__ amax_state) // in: fp32 scalar (delayed)
|
||||
{
|
||||
__shared__ float sdata[THREADS_PER_WG];
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const int wg = blockIdx.x;
|
||||
const int gid = wg * THREADS_PER_WG + tid;
|
||||
const int stride_elems = NUM_WG * THREADS_PER_WG * VEC;
|
||||
|
||||
const float scale = FP8_MAX / (static_cast<float>(*amax_state) + 1e-8f);
|
||||
float local_max = 0.0f;
|
||||
|
||||
for (int base = gid * VEC; base < N_ELEMS; base += stride_elems) {
|
||||
float4 x_raw = *reinterpret_cast<const float4*>(&x[base]);
|
||||
const __hip_bfloat16 *xi = reinterpret_cast<const __hip_bfloat16*>(&x_raw);
|
||||
|
||||
__hip_fp8_storage_t out[VEC];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VEC; i++) {
|
||||
const float v = static_cast<float>(xi[i]);
|
||||
local_max = fmaxf(local_max, fabsf(v));
|
||||
const float scaled = fmaxf(-FP8_MAX, fminf(FP8_MAX, v * scale));
|
||||
out[i] = __hip_cvt_float_to_fp8(scaled, __HIP_SATFINITE, __HIP_E4M3);
|
||||
}
|
||||
*reinterpret_cast<uint64_t*>(&fp8_out[base]) = *reinterpret_cast<uint64_t*>(out);
|
||||
}
|
||||
|
||||
sdata[tid] = local_max;
|
||||
__syncthreads();
|
||||
for (int s = THREADS_PER_WG / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) sdata[tid] = fmaxf(sdata[tid], sdata[tid + s]);
|
||||
__syncthreads();
|
||||
}
|
||||
if (tid == 0) amax_partial[wg] = sdata[0];
|
||||
}
|
||||
@@ -1,6 +1,10 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor, Device, dtypes, Context
|
||||
from tinygrad.helpers import getenv
|
||||
from examples.mlperf.models.flat_llama import FP8_DTYPE, quantize_fp8
|
||||
from extra.llama_kernels.fused_ce import fused_ce_loss
|
||||
from extra.llama_kernels.quantize_fp8_delayed import quantize_fp8_delayed, quantize_fp8_scalar
|
||||
from test.helpers import needs_second_gpu
|
||||
|
||||
def run_fused_ce(bs:int, seqlen:int, vocab:int, label_smoothing:float=0.0) -> None:
|
||||
Tensor.manual_seed(0)
|
||||
@@ -23,8 +27,10 @@ def run_fused_ce(bs:int, seqlen:int, vocab:int, label_smoothing:float=0.0) -> No
|
||||
assert loss.allclose(ref, atol=2e-3, rtol=2e-3).item(), "forward mismatch"
|
||||
assert logits.grad.allclose(logits_ref.grad, atol=2e-3, rtol=2e-3).item(), "grad mismatch"
|
||||
|
||||
@unittest.skipUnless(dtypes.bfloat16 in Device[Device.DEFAULT].renderer.supported_dtypes(), "need bfloat16")
|
||||
class TestFusedCE(unittest.TestCase):
|
||||
def setUp(self):
|
||||
if dtypes.bfloat16 not in Device[Device.DEFAULT].renderer.supported_dtypes(): self.skipTest("need bfloat16")
|
||||
|
||||
def test_fused_ce_1_2_16(self): run_fused_ce(1, 2, 16, label_smoothing=0.2)
|
||||
def test_fused_ce_2_16_128(self): run_fused_ce(2, 16, 128)
|
||||
def test_fused_ce_4_128_1024(self): run_fused_ce(4, 128, 1024, label_smoothing=0.2)
|
||||
@@ -32,5 +38,49 @@ class TestFusedCE(unittest.TestCase):
|
||||
# note: this is the shape used in llama 8b
|
||||
#def test_fused_ce_smoothing_16_1024_128256(self): run_fused_ce(16, 1024, 128256, label_smoothing=0.2)
|
||||
|
||||
def run_quantize_fp8(shape:tuple[int, ...], delayed:bool=True) -> None:
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(*shape).cast(dtypes.bfloat16).contiguous()
|
||||
amax_state = Tensor.full((), 2.0, dtype=dtypes.float32).contiguous()
|
||||
with Context(DEBUG=0): Tensor.realize(x, amax_state)
|
||||
|
||||
if delayed:
|
||||
fp8, inv_scale, new_amax, _ = quantize_fp8_delayed(x, amax_state, FP8_DTYPE)
|
||||
ref_fp8, ref_inv_scale, ref_new_amax = quantize_fp8(x, amax_state=amax_state)
|
||||
Tensor.realize(fp8, inv_scale, new_amax)
|
||||
Tensor.realize(ref_fp8, ref_inv_scale, ref_new_amax)
|
||||
else:
|
||||
fp8 = quantize_fp8_scalar(x, amax_state, FP8_DTYPE)
|
||||
ref_fp8, _, _ = quantize_fp8(x, amax_state=amax_state)
|
||||
Tensor.realize(fp8)
|
||||
Tensor.realize(ref_fp8)
|
||||
|
||||
with Context(DEBUG=0):
|
||||
assert fp8.cast(dtypes.float).allclose(ref_fp8.cast(dtypes.float), atol=0, rtol=0).item(), "fp8 mismatch"
|
||||
if delayed:
|
||||
assert inv_scale.allclose(ref_inv_scale, atol=0, rtol=0).item(), "inv_scale mismatch"
|
||||
assert new_amax.allclose(ref_new_amax, atol=0, rtol=0).item(), \
|
||||
f"amax mismatch: got={new_amax.item()} ref={ref_new_amax.item()} diff={abs(new_amax.item()-ref_new_amax.item())}"
|
||||
|
||||
class TestQuantizeFP8(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ren = Device[Device.DEFAULT].renderer
|
||||
if dtypes.bfloat16 not in ren.supported_dtypes(): self.skipTest("need bfloat16")
|
||||
if not ren.has_local or not ren.has_shared: self.skipTest("need local/shared")
|
||||
|
||||
def test_scalar(self): run_quantize_fp8((getenv("N", 1024), 32), delayed=False)
|
||||
def test_delayed(self): run_quantize_fp8((getenv("N", 2048), 1024))
|
||||
|
||||
@needs_second_gpu
|
||||
def test_multi(self):
|
||||
devs = tuple(f"{Device.DEFAULT}:{i}" for i in range(8))
|
||||
x = Tensor.empty(2048*8, 1024, dtype=dtypes.bfloat16, device=devs).uop.multi(0)
|
||||
x = Tensor(x, device=devs)
|
||||
amax_state = Tensor.full((), 2.0, dtype=dtypes.float32, device=devs).contiguous()
|
||||
fp8, _, new_amax, _ = quantize_fp8_delayed(x, amax_state, FP8_DTYPE)
|
||||
Tensor.realize(fp8, new_amax)
|
||||
assert fp8.uop.shape == x.uop.shape
|
||||
assert new_amax.shape == ()
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -358,7 +358,8 @@ pm_add_loads = PatternMatcher([
|
||||
# add loads to non ptr index
|
||||
(UPat(Ops.INDEX, name="idx"), add_load),
|
||||
# remove loads from stores
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.LOAD), UPat(name="val")), name="s"), lambda s,val: s.replace(src=(s.src[0].src[0], val))),
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.LOAD),), allow_any_len=True, name="s"), lambda s: s.replace(src=(s.src[0].src[0],)+s.src[1:])),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.LOAD),), allow_any_len=True, name="l"), lambda l: l.replace(src=(l.src[0].src[0],)+l.src[1:])),
|
||||
])
|
||||
|
||||
# make images
|
||||
|
||||
@@ -199,7 +199,8 @@ def main(args) -> None:
|
||||
if DEBUG >= 3 and s["name"] == "View Base AST": print_step(s)
|
||||
if DEBUG >= 4 and s["name"] == "View Source": print_step(s)
|
||||
if DEBUG >= 5 or ls: print(emit(" "*s["depth"]+s["name"]+(f" - {s['match_count']}" if s.get('match_count', 0) else '')))
|
||||
if DEBUG >= 6 or (DEBUG >= 5 and s["name"] == "View Kernel Graph") or (s["name"] in args.src): print_step(s, print_graph=True)
|
||||
if DEBUG >= 6 or (DEBUG >= 5 and s["name"] == "View Kernel Graph") or (s["name"] in args.src):
|
||||
print_step(s, print_graph=True, reconstruct_matches=s["name"] in args.src)
|
||||
if DEBUG >= 7: print_step(s, reconstruct_matches=True)
|
||||
elif DEBUG >= 3 and k.get("ext"): print(emit(k["ext"]))
|
||||
for k in (produce_top_kernels if args.t else produce_all_kernels)(): render_event(k)
|
||||
|
||||
Reference in New Issue
Block a user