Files
tinygrad/extra/llama_kernels/fused_ce/__init__.py
qazal 1e0fffe256 fused ce llama kernel in UOps (#16263)
* work

* using uops

* delete things

* work

* work

* higher level uops

* cleanups
2026-05-20 19:45:28 +09:00

90 lines
4.8 KiB
Python

import functools
from tinygrad import Tensor, dtypes
from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType
@functools.cache
def _custom_fused_ce_loss_fwd(loss_out:UOp, max_out:UOp, lse_out:UOp, logits:UOp, targets:UOp,
vocab:int, rows:int, label_smoothing:float) -> UOp:
row = UOp.range(rows, 0)
v_max = UOp.range(vocab, 1, axis_type=AxisType.REDUCE)
row_max = logits[row, v_max].cast(dtypes.float).reduce(v_max, arg=Ops.MAX)
v_lse = UOp.range(vocab, 2, axis_type=AxisType.REDUCE)
row_lse = (logits[row, v_lse].cast(dtypes.float) - row_max).exp().reduce(v_lse, arg=Ops.ADD).log() + row_max
v_smooth = UOp.range(vocab, 3, axis_type=AxisType.REDUCE)
target = logits[row, targets[row].cast(dtypes.weakint)].cast(dtypes.float)
mean_logits = logits[row, v_smooth].cast(dtypes.float).reduce(v_smooth, arg=Ops.ADD) / vocab
loss = row_lse - (1.0 - label_smoothing) * target - label_smoothing * mean_logits
stores = UOp.group(loss_out[row].store(loss), max_out[row].store(row_max), lse_out[row].store(row_lse))
return stores.end(row).sink(arg=KernelInfo(f"fused_ce_loss_fwd_{rows}_{vocab}"))
@functools.cache
def _custom_fused_ce_loss_bwd(d_logits:UOp, logits:UOp, lse:UOp, targets:UOp, scale:UOp,
vocab:int, rows:int, label_smoothing:float) -> UOp:
row = UOp.range(rows, 0)
v = UOp.range(vocab, 1)
prob = (logits[row, v].cast(dtypes.float) - lse[row]).exp()
target = v.eq(targets[row].cast(dtypes.weakint)).where(1.0 - label_smoothing, 0.0)
smooth = label_smoothing / vocab
grad = (prob - target - smooth) * scale[0]
return d_logits[row, v].store(grad.cast(d_logits.dtype.base)).end(v, row).sink(arg=KernelInfo(f"fused_ce_loss_bwd_{rows}_{vocab}"))
def _fused_ce_loss_bwd(gradient:UOp, kernel:UOp, label_smoothing:float):
# NOTE: forward inputs are (loss_out, max_out, lse_out, logits, targets)
# gradient is the upstream grad w.r.t. per-row loss (shape: (rows,) fp32)
_, _, lse_u, logits_u, targets_u = kernel.src[1:]
device = logits_u.device
rows, VOCAB = logits_u.shape # (rows, VOCAB) after reshape
if isinstance(device, tuple):
axis = logits_u.axis
ndev = len(device)
d_logits = Tensor(Tensor.invalids(rows // ndev, VOCAB, dtype=dtypes.bfloat16, device=device).uop.multi(axis), device=device)
rows_per_dev = rows // ndev
else:
d_logits = Tensor.invalids(rows, VOCAB, dtype=dtypes.bfloat16, device=device)
rows_per_dev = rows
# NOTE: .mean() backward gives same grad per row (1/N), so broadcast is safe; take scalar
scale = Tensor(gradient, device=device).float().reshape(-1)[0:1].contiguous()
logits_t = Tensor(logits_u.after(kernel), device=device)
lse_t = Tensor(lse_u.after(kernel), device=device)
targets_t = Tensor(targets_u, device=device)
fxn = functools.partial(_custom_fused_ce_loss_bwd, vocab=VOCAB, rows=rows_per_dev, label_smoothing=label_smoothing)
d_logits, *_ = Tensor.custom_kernel(d_logits, logits_t, lse_t, targets_t, scale, fxn=fxn)
return (None, None, None, d_logits.uop, None)
def fused_ce_loss(logits:Tensor, targets:Tensor, label_smoothing:float=0.1) -> Tensor:
# NOTE: fused sparse_categorical_crossentropy with label smoothing, returns mean loss scalar
assert logits.dtype == dtypes.bfloat16, f"expected bf16, got {logits.dtype}"
assert logits.ndim == 3, f"expected (MBS, SEQ, VOCAB), got {logits.shape}"
MBS, SEQ, VOCAB = logits.shape
rows = MBS * SEQ
if isinstance(logits.device, tuple):
axis = logits.uop.axis
assert axis in (0, 1), f"unsupported sharding axis={axis} for CE loss"
ndev = len(logits.device)
loss_out = Tensor(Tensor.invalids(rows // ndev, dtype=dtypes.float32, device=logits.device).uop.multi(0),
device=logits.device)
max_out = Tensor(Tensor.invalids(rows // ndev, dtype=dtypes.float32, device=logits.device).uop.multi(0),
device=logits.device)
lse_out = Tensor(Tensor.invalids(rows // ndev, dtype=dtypes.float32, device=logits.device).uop.multi(0),
device=logits.device)
rows_per_dev = rows // ndev
else:
loss_out = Tensor.invalids(rows, dtype=dtypes.float32, device=logits.device)
max_out = Tensor.invalids(rows, dtype=dtypes.float32, device=logits.device)
lse_out = Tensor.invalids(rows, dtype=dtypes.float32, device=logits.device)
rows_per_dev = rows
logits_flat = logits.reshape(rows, VOCAB)
targets_flat = targets.reshape(-1).cast(dtypes.int32)
fxn = functools.partial(_custom_fused_ce_loss_fwd, vocab=VOCAB, rows=rows_per_dev,
label_smoothing=label_smoothing)
loss_out, max_out, lse_out, *_ = Tensor.custom_kernel(
loss_out, max_out, lse_out, logits_flat, targets_flat,
fxn=fxn, grad_fxn=functools.partial(_fused_ce_loss_bwd, label_smoothing=label_smoothing))
return loss_out.mean()