mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-14 00:45:16 +08:00
90 lines
4.8 KiB
Python
90 lines
4.8 KiB
Python
import functools
|
|
from tinygrad import Tensor, dtypes
|
|
from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType
|
|
|
|
@functools.cache
|
|
def _custom_fused_ce_loss_fwd(loss_out:UOp, max_out:UOp, lse_out:UOp, logits:UOp, targets:UOp,
|
|
vocab:int, rows:int, label_smoothing:float) -> UOp:
|
|
row = UOp.range(rows, 0)
|
|
|
|
v_max = UOp.range(vocab, 1, axis_type=AxisType.REDUCE)
|
|
row_max = logits[row, v_max].cast(dtypes.float).reduce(v_max, arg=Ops.MAX)
|
|
|
|
v_lse = UOp.range(vocab, 2, axis_type=AxisType.REDUCE)
|
|
row_lse = (logits[row, v_lse].cast(dtypes.float) - row_max).exp().reduce(v_lse, arg=Ops.ADD).log() + row_max
|
|
|
|
v_smooth = UOp.range(vocab, 3, axis_type=AxisType.REDUCE)
|
|
target = logits[row, targets[row].cast(dtypes.weakint)].cast(dtypes.float)
|
|
mean_logits = logits[row, v_smooth].cast(dtypes.float).reduce(v_smooth, arg=Ops.ADD) / vocab
|
|
loss = row_lse - (1.0 - label_smoothing) * target - label_smoothing * mean_logits
|
|
stores = UOp.group(loss_out[row].store(loss), max_out[row].store(row_max), lse_out[row].store(row_lse))
|
|
|
|
return stores.end(row).sink(arg=KernelInfo(f"fused_ce_loss_fwd_{rows}_{vocab}"))
|
|
|
|
@functools.cache
|
|
def _custom_fused_ce_loss_bwd(d_logits:UOp, logits:UOp, lse:UOp, targets:UOp, scale:UOp,
|
|
vocab:int, rows:int, label_smoothing:float) -> UOp:
|
|
row = UOp.range(rows, 0)
|
|
v = UOp.range(vocab, 1)
|
|
|
|
prob = (logits[row, v].cast(dtypes.float) - lse[row]).exp()
|
|
target = v.eq(targets[row].cast(dtypes.weakint)).where(1.0 - label_smoothing, 0.0)
|
|
smooth = label_smoothing / vocab
|
|
grad = (prob - target - smooth) * scale[0]
|
|
|
|
return d_logits[row, v].store(grad.cast(d_logits.dtype.base)).end(v, row).sink(arg=KernelInfo(f"fused_ce_loss_bwd_{rows}_{vocab}"))
|
|
|
|
def _fused_ce_loss_bwd(gradient:UOp, kernel:UOp, label_smoothing:float):
|
|
# NOTE: forward inputs are (loss_out, max_out, lse_out, logits, targets)
|
|
# gradient is the upstream grad w.r.t. per-row loss (shape: (rows,) fp32)
|
|
_, _, lse_u, logits_u, targets_u = kernel.src[1:]
|
|
device = logits_u.device
|
|
rows, VOCAB = logits_u.shape # (rows, VOCAB) after reshape
|
|
if isinstance(device, tuple):
|
|
axis = logits_u.axis
|
|
ndev = len(device)
|
|
d_logits = Tensor(Tensor.invalids(rows // ndev, VOCAB, dtype=dtypes.bfloat16, device=device).uop.multi(axis), device=device)
|
|
rows_per_dev = rows // ndev
|
|
else:
|
|
d_logits = Tensor.invalids(rows, VOCAB, dtype=dtypes.bfloat16, device=device)
|
|
rows_per_dev = rows
|
|
# NOTE: .mean() backward gives same grad per row (1/N), so broadcast is safe; take scalar
|
|
scale = Tensor(gradient, device=device).float().reshape(-1)[0:1].contiguous()
|
|
logits_t = Tensor(logits_u.after(kernel), device=device)
|
|
lse_t = Tensor(lse_u.after(kernel), device=device)
|
|
targets_t = Tensor(targets_u, device=device)
|
|
fxn = functools.partial(_custom_fused_ce_loss_bwd, vocab=VOCAB, rows=rows_per_dev, label_smoothing=label_smoothing)
|
|
d_logits, *_ = Tensor.custom_kernel(d_logits, logits_t, lse_t, targets_t, scale, fxn=fxn)
|
|
return (None, None, None, d_logits.uop, None)
|
|
|
|
def fused_ce_loss(logits:Tensor, targets:Tensor, label_smoothing:float=0.1) -> Tensor:
|
|
# NOTE: fused sparse_categorical_crossentropy with label smoothing, returns mean loss scalar
|
|
assert logits.dtype == dtypes.bfloat16, f"expected bf16, got {logits.dtype}"
|
|
assert logits.ndim == 3, f"expected (MBS, SEQ, VOCAB), got {logits.shape}"
|
|
MBS, SEQ, VOCAB = logits.shape
|
|
rows = MBS * SEQ
|
|
if isinstance(logits.device, tuple):
|
|
axis = logits.uop.axis
|
|
assert axis in (0, 1), f"unsupported sharding axis={axis} for CE loss"
|
|
ndev = len(logits.device)
|
|
loss_out = Tensor(Tensor.invalids(rows // ndev, dtype=dtypes.float32, device=logits.device).uop.multi(0),
|
|
device=logits.device)
|
|
max_out = Tensor(Tensor.invalids(rows // ndev, dtype=dtypes.float32, device=logits.device).uop.multi(0),
|
|
device=logits.device)
|
|
lse_out = Tensor(Tensor.invalids(rows // ndev, dtype=dtypes.float32, device=logits.device).uop.multi(0),
|
|
device=logits.device)
|
|
rows_per_dev = rows // ndev
|
|
else:
|
|
loss_out = Tensor.invalids(rows, dtype=dtypes.float32, device=logits.device)
|
|
max_out = Tensor.invalids(rows, dtype=dtypes.float32, device=logits.device)
|
|
lse_out = Tensor.invalids(rows, dtype=dtypes.float32, device=logits.device)
|
|
rows_per_dev = rows
|
|
logits_flat = logits.reshape(rows, VOCAB)
|
|
targets_flat = targets.reshape(-1).cast(dtypes.int32)
|
|
fxn = functools.partial(_custom_fused_ce_loss_fwd, vocab=VOCAB, rows=rows_per_dev,
|
|
label_smoothing=label_smoothing)
|
|
loss_out, max_out, lse_out, *_ = Tensor.custom_kernel(
|
|
loss_out, max_out, lse_out, logits_flat, targets_flat,
|
|
fxn=fxn, grad_fxn=functools.partial(_fused_ce_loss_bwd, label_smoothing=label_smoothing))
|
|
return loss_out.mean()
|