mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-14 00:45:16 +08:00
25 lines
968 B
Python
25 lines
968 B
Python
from __future__ import annotations
|
|
import functools
|
|
from tinygrad import Tensor
|
|
from tinygrad.uop.ops import UOp
|
|
|
|
def rmsnorm_fwd(x_in:Tensor, eps:float) -> tuple[Tensor, Tensor]:
|
|
x = x_in.float()
|
|
rrms = (x.square().mean(-1, keepdim=True) + eps).rsqrt()
|
|
return (x * rrms).cast(x_in.dtype), rrms
|
|
|
|
@functools.cache
|
|
def _rmsnorm_fwd_fxn(x_in_p, eps, device):
|
|
return rmsnorm_fwd(Tensor(x_in_p, device=device), eps)
|
|
|
|
def _rmsnorm_bwd(grad:UOp, call:UOp) -> tuple:
|
|
x_normed = Tensor(call.gettuple(0)).float()
|
|
do_float = Tensor(grad).float()
|
|
d_x = Tensor(call.gettuple(1)) * (do_float - x_normed * (do_float * x_normed).mean(-1, keepdim=True))
|
|
return (d_x.cast(call.src[1].dtype).uop,)
|
|
|
|
def rmsnorm(x_in:Tensor, eps:float) -> tuple[Tensor, Tensor]:
|
|
fxn = _rmsnorm_fwd_fxn(x_in.as_param(0).uop, eps, x_in.device)
|
|
call = UOp.maketuple(fxn[0].uop, fxn[1].uop).call(x_in.uop, grad_fxn=_rmsnorm_bwd)
|
|
return Tensor(call.gettuple(0)), Tensor(call.gettuple(1))
|