From 7dcfd144b6f88fd0f8105fb3ecbfa73a1c930f40 Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Tue, 2 Jun 2026 21:55:45 -0400 Subject: [PATCH] llama: columnwise fp8 scaling (#16480) --- examples/mlperf/model_train.py | 2 +- examples/mlperf/models/flat_llama.py | 17 ++++++++++---- examples/mlperf/optim.py | 4 ++-- extra/gemm/cdna_asm_gemm.py | 35 +++++++++++++++++++--------- 4 files changed, 39 insertions(+), 19 deletions(-) diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index d2dab87814..3da2a7792d 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -1447,7 +1447,7 @@ def train_llama3(): idx = next(j for j, p in enumerate(optim.params) if p is w) master = optim.master_params[idx] inv = w._inv_scale if w._inv_scale.device == master.device else w._inv_scale.to(master.device) - master.assign((master * inv.reshape(-1, *([1]*(w.ndim-1)))).contiguous()) + master.assign((master * inv.reshape(*inv.shape, *([1]*(w.ndim-inv.ndim)))).contiguous()) # realize everything here if optim.master_params: Tensor.realize(*optim.master_params) diff --git a/examples/mlperf/models/flat_llama.py b/examples/mlperf/models/flat_llama.py index 12240d068c..62c33beba1 100644 --- a/examples/mlperf/models/flat_llama.py +++ b/examples/mlperf/models/flat_llama.py @@ -23,6 +23,7 @@ FUSED_INPUT_QUANTIZE = getenv("FUSED_INPUT_QUANTIZE", 0) FUSED_ADD_NORM_MUL_QUANTIZE = getenv("FUSED_ADD_NORM_MUL_QUANTIZE", 0) FUSED_SILU_W13 = getenv("FUSED_SILU_W13", 0) SPLIT_W13 = getenv("SPLIT_W13", 0) +COLUMNWISE_WEIGHT_SCALE = getenv("COLUMNWISE_WEIGHT_SCALE", 0) FP8_DTYPE = dtypes.fp8e4m3 FP8_GRAD_DTYPE = dtypes.fp8e5m2 @@ -52,7 +53,11 @@ def matmul(x:Tensor, w:Tensor, fp8:bool=True, amax_x:Tensor|None=None, w_inv_sca if ASM_GEMM: from extra.gemm.cdna_asm_gemm import can_use_asm_gemm, asm_gemm if can_use_asm_gemm(x_fp8, w.T): - return asm_gemm(x_fp8, w.T, x_scale=x_scale, w_scale=w_inv_scale, grad_amax_state=grad_amax_state), x_new_amax, x_fp8 + if COLUMNWISE_WEIGHT_SCALE: + out = asm_gemm(x_fp8, w.T, x_scale=x_scale, grad_amax_state=grad_amax_state, w_post_scale=w_inv_scale) + else: + out = asm_gemm(x_fp8, w.T, x_scale=x_scale, w_scale=w_inv_scale, grad_amax_state=grad_amax_state) + return out, x_new_amax, x_fp8 return (x_fp8.dot(w.T, dtype=dtypes.float) * x_scale * w_inv_scale).cast(dtypes.bfloat16), x_new_amax, x_fp8 def norm_quantize_matmul(x:Tensor, norm:Tensor, w:Tensor, w_inv_scale:Tensor, eps:float, amax_x:Tensor, grad_amax_state:Tensor): @@ -141,10 +146,11 @@ class FlatTransformer: def lin_per_layer(self, in_features:int, out_features:int, std:float=0.02): if getenv("ZEROS"): w = Tensor.zeros(self.n_layers, out_features, in_features) else: w = Tensor.normal(self.n_layers, out_features, in_features, mean=0.0, std=std) - amax = w.abs().flatten(1).max(1).detach() + amax = (w.abs().max(axis=2) if COLUMNWISE_WEIGHT_SCALE else w.abs().flatten(1).max(1)).detach() scale = FP8_MAX / (amax + 1e-8) inv_scale = (amax + 1e-8) / FP8_MAX - return (w * scale.reshape(-1, 1, 1)).clamp(-FP8_MAX, FP8_MAX).cast(FP8_DTYPE), inv_scale + scale_b = scale.reshape(self.n_layers, out_features, 1) if COLUMNWISE_WEIGHT_SCALE else scale.reshape(-1, 1, 1) + return (w * scale_b).clamp(-FP8_MAX, FP8_MAX).cast(FP8_DTYPE), inv_scale def attention(self, x:Tensor, freqs_cis:Tensor, *, attention_norm:Tensor, wqkv:Tensor, wo:Tensor, amax_xqkv:Tensor, amax_xo:Tensor, s_qkv:Tensor, s_o:Tensor, @@ -224,8 +230,9 @@ class FlatTransformer: # flat per-layer weights: axis 0 is n_layers, so shard axes are +1 vs per-layer Transformer def _shard_fp8(name:str, axis:int): getattr(self, name).shard_(device, axis=axis) - self._fp8_inv_scale[name] = self._fp8_inv_scale[name].to(device).contiguous().is_param_(False) - self._fp8_next_inv_scale[name] = self._fp8_next_inv_scale[name].to(device).contiguous().is_param_(False) + scale_axis = (1 if axis == 1 else None) if COLUMNWISE_WEIGHT_SCALE else None + self._fp8_inv_scale[name] = self._fp8_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False) + self._fp8_next_inv_scale[name] = self._fp8_next_inv_scale[name].shard(device, axis=scale_axis).contiguous().is_param_(False) Tensor.realize(getattr(self, name), self._fp8_inv_scale[name], self._fp8_next_inv_scale[name]) _shard_fp8("wqkv", 1) # (n_layers, out, dim) shard out _shard_fp8("wo", 2) # (n_layers, dim, in) shard in diff --git a/examples/mlperf/optim.py b/examples/mlperf/optim.py index d2fe057167..053bc5d3e6 100644 --- a/examples/mlperf/optim.py +++ b/examples/mlperf/optim.py @@ -93,11 +93,11 @@ class GradAccClipAdamW(Optimizer): # delayed scaling: reuse previous step's inv_scale t._inv_scale.assign(t._next_inv_scale) inv_scale = t._inv_scale.to(new_w.device) if offloaded else t._inv_scale - scale = inv_scale.reciprocal().reshape(-1, *([1]*(new_w.ndim-1))) + scale = inv_scale.reciprocal().reshape(*inv_scale.shape, *([1]*(new_w.ndim-inv_scale.ndim))) scaled = (new_w * scale).clamp(-FP8_MAX, FP8_MAX) ret = scaled.cast(t.dtype) # update inv_scale for next step from quantized result - new_amax = (ret.float().abs().max(axis=tuple(range(1, ret.ndim))) * inv_scale * FP8_AMAX_MARGIN).detach() + new_amax = (ret.float().abs().max(axis=tuple(range(inv_scale.ndim, ret.ndim))) * inv_scale * FP8_AMAX_MARGIN).detach() new_inv = ((new_amax + 1e-8) / FP8_MAX).cast(t._inv_scale.dtype) t._next_inv_scale.assign(new_inv.shard_like(t._next_inv_scale) if offloaded else new_inv) return ret.shard_like(t) if offloaded else ret diff --git a/extra/gemm/cdna_asm_gemm.py b/extra/gemm/cdna_asm_gemm.py index 5486c7598d..c92b909161 100644 --- a/extra/gemm/cdna_asm_gemm.py +++ b/extra/gemm/cdna_asm_gemm.py @@ -2700,13 +2700,20 @@ def custom_uop_gemm(C:UOp, A:UOp, B:UOp) -> UOp: # ** backward gemm, might use the asm gemm -def custom_gemm_bw(gradient:UOp, kernel:UOp): +def custom_gemm_bw(gradient:UOp, kernel:UOp, n_scales:int=2, has_grad_amax:bool=False, has_w_post:bool=False): inputs = kernel.src[1:] if inputs[1].dtype == FP8_DTYPE: - grad_amax_state = inputs[5] if len(inputs) == 6 else None - out, a, b, s_x, s_w = inputs[:5] + out, a, b = inputs[:3] + i = 3 + s_x = inputs[i]; i += 1 + has_w = n_scales == 2 + s_w = inputs[i] if has_w else None; i += has_w + grad_amax_state = inputs[i] if has_grad_amax else None; i += has_grad_amax + w_post = inputs[i] if has_w_post else None a_t, b_t, g_t = Tensor(a, device=a.device), Tensor(b, device=a.device), Tensor(gradient, device=a.device) - s_x_t, s_w_t = Tensor(s_x, device=a.device), Tensor(s_w, device=a.device) + s_x_t = Tensor(s_x, device=a.device) + s_w_t = Tensor(s_w, device=a.device) if has_w else None + w_post_t = Tensor(w_post, device=a.device) if has_w_post else None g_t = g_t[:a.shape[0]] from extra.llama_kernels.cast_amax import _grad_fp8_mailbox from extra.llama_kernels.quantize_fp8_delayed import quantize_fp8_delayed @@ -2727,8 +2734,8 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp): g_fp8, g_scale, new_grad_amax = quantize_fp8(g_t, amax_state=grad_amax_t) store_effect = grad_amax_state.store(new_grad_amax.uop) g_fp8 = Tensor(g_fp8.contiguous().uop.after(store_effect), device=a.device) - # dgrad: uses g_scale * x_scale * w_scale - grad_a = asm_gemm(g_fp8, b_t, x_scale=g_scale * s_x_t, w_scale=s_w_t) + # dgrad: uses g_scale * x_scale * w_scale (only when scalar) + grad_a = asm_gemm(g_fp8, b_t, x_scale=g_scale * s_x_t, w_scale=s_w_t) if has_w else asm_gemm(g_fp8, b_t, x_scale=g_scale * s_x_t) # wgrad: no w_scale g_fp8_2d = g_fp8.reshape(-1, g_fp8.shape[-1]) if getenv("FAST_FP8_TRANSPOSE", 0) and g_fp8_2d.shape[0] % 64 == 0 and g_fp8_2d.shape[1] % 64 == 0: @@ -2737,8 +2744,11 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp): else: g_fp8_T = g_fp8.permute(2, 0, 1).reshape(g_t.shape[-1], -1) grad_b = asm_gemm(g_fp8_T, a_t.reshape(-1, a_t.shape[-1]), x_scale=g_scale * s_x_t) - ret = (None, grad_a.uop, grad_b.uop, None, None) - if len(inputs) == 6: ret = ret + (None,) + # wgrad: rescale if not scalar + if w_post_t is not None: + grad_b = grad_b / w_post_t.reshape(*w_post_t.shape, *([1]*(grad_b.ndim - w_post_t.ndim))) + # one None per input: (out, a, b, x_scale[, w_scale][, grad_amax][, w_post_scale]) + ret = (None, grad_a.uop, grad_b.uop) + tuple(None for _ in inputs[3:]) return ret else: out, a, b = inputs @@ -2754,7 +2764,8 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp): # ** main gemm function -def asm_gemm(a:Tensor, b:Tensor, x_scale:Tensor|None=None, w_scale:Tensor|None=None, grad_amax_state:Tensor|None=None) -> Tensor: +def asm_gemm(a:Tensor, b:Tensor, x_scale:Tensor|None=None, w_scale:Tensor|None=None, grad_amax_state:Tensor|None=None, + w_post_scale:Tensor|None=None) -> Tensor: assert can_use_asm_gemm(a, b), f"{counters['todos'][-1]}" counters["used"] += 1 unfold_batch = a.ndim == 3 and isinstance(a.device, tuple) and a.uop.axis == 2 and b.uop.axis == 0 @@ -2790,9 +2801,10 @@ def asm_gemm(a:Tensor, b:Tensor, x_scale:Tensor|None=None, w_scale:Tensor|None=N if a.dtype == FP8_DTYPE: scales = tuple(s for s in (x_scale, w_scale) if s is not None) scale_mode = (1 if x_scale is not None else 0) | (2 if w_scale is not None else 0) - extra = [grad_amax_state] if grad_amax_state is not None else [] + extra = ([grad_amax_state] if grad_amax_state is not None else []) + ([w_post_scale] if w_post_scale is not None else []) fxn = functools.partial(custom_hk_fp8_gemm, dname=dname, scale_mode=scale_mode) - out = Tensor.custom_kernel(out, a, b.T, *scales, *extra, fxn=fxn, grad_fxn=custom_gemm_bw)[0] + bw = functools.partial(custom_gemm_bw, n_scales=len(scales), has_grad_amax=grad_amax_state is not None, has_w_post=w_post_scale is not None) + out = Tensor.custom_kernel(out, a, b.T, *scales, *extra, fxn=fxn, grad_fxn=bw)[0] else: out = Tensor.custom_kernel(out, a, b, fxn=functools.partial(custom_asm_gemm, dname=dname), grad_fxn=custom_gemm_bw)[0] else: @@ -2800,4 +2812,5 @@ def asm_gemm(a:Tensor, b:Tensor, x_scale:Tensor|None=None, w_scale:Tensor|None=N if k_sharded: out = out.sum(0) out = out.squeeze(0) if squeeze else out if unfold_batch: out = out.reshape(orig_batch, -1, out.shape[-1]) + if w_post_scale is not None: out = (out * w_post_scale.reshape(*([1]*(out.ndim-1)), -1)).cast(out.dtype) return out