mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-11 15:35:51 +08:00
fa: faster (#14453)
This commit is contained in:
@@ -8,9 +8,11 @@ export REWRITE_STACK_LIMIT=5000000 HCQDEV_WAIT_TIMEOUT_MS=240000
|
||||
export DEBUG=${DEBUG:-2}
|
||||
export FLASH_ATTENTION=${FLASH_ATTENTION:-1}
|
||||
export ALL2ALL=${ALL2ALL:-1}
|
||||
export USE_ATOMICS=${USE_ATOMICS:-1}
|
||||
export ASM_GEMM=${ASM_GEMM:-1}
|
||||
|
||||
export DEFAULT_FLOAT="bfloat16" OPTIM_DTYPE="bfloat16"
|
||||
export DP=8 BS=8 EVAL_BS=8 GRADIENT_ACC_STEPS=2
|
||||
export DP=8 BS=16 EVAL_BS=8 GRADIENT_ACC_STEPS=1
|
||||
export GBS=$((BS * GRADIENT_ACC_STEPS))
|
||||
|
||||
export MODEL="llama3"
|
||||
|
||||
@@ -13,7 +13,7 @@ export USE_ATOMICS=${USE_ATOMICS:-1}
|
||||
export ASM_GEMM=${ASM_GEMM:-1}
|
||||
|
||||
export DEFAULT_FLOAT="bfloat16" OPTIM_DTYPE="bfloat16"
|
||||
export DP=${DP:-8} BS=${BS:-8} EVAL_BS=${EVAL_BS:-8} GRADIENT_ACC_STEPS=${GRADIENT_ACC_STEPS:-2}
|
||||
export DP=${DP:-8} BS=${BS:-16} EVAL_BS=${EVAL_BS:-8} GRADIENT_ACC_STEPS=${GRADIENT_ACC_STEPS:-1}
|
||||
export GBS=$((BS * GRADIENT_ACC_STEPS))
|
||||
|
||||
export MODEL="llama3"
|
||||
|
||||
@@ -2,7 +2,7 @@ import math
|
||||
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.uop.ops import UOp
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
|
||||
from extra.thunder.tiny.tk import WARP_THREADS
|
||||
from extra.thunder.tiny.tk.kernel import Kernel
|
||||
@@ -43,11 +43,12 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
|
||||
B_local = B // num_devices
|
||||
if DEBUG >= 2: print(f"Flash Attention {B=} {B_local=} {N=} {H=} {D=} {H_KV=} {GROUP_SIZE=}")
|
||||
|
||||
def custom_forward(ou:UOp, l_vecu:UOp, qu:UOp, ku:UOp, vu:UOp, masku:UOp) -> UOp:
|
||||
def _custom_forward_impl(ou:UOp, l_vecu:UOp, qu:UOp, ku:UOp, vu:UOp, masku:UOp|None) -> UOp:
|
||||
with Kernel("fa_custom_forward", (H, N // (Q_BLOCK_SIZE*NUM_WORKERS), B_local), NUM_WORKERS * WARP_THREADS) as ker:
|
||||
warp = ker.warp
|
||||
|
||||
o, q, k, v, mask, l_vec = GL(ou, ker), GL(qu, ker), GL(ku, ker), GL(vu, ker), GL(masku, ker), GL(l_vecu, ker)
|
||||
o, q, k, v, l_vec = GL(ou, ker), GL(qu, ker), GL(ku, ker), GL(vu, ker), GL(l_vecu, ker)
|
||||
mask = GL(masku, ker) if masku is not None else None
|
||||
|
||||
head = ker.blockIdx_x
|
||||
head_kv = head // GROUP_SIZE
|
||||
@@ -86,7 +87,8 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
|
||||
q_reg = warp.copy(q_reg, q_reg_fl)
|
||||
q_reg_transposed = warp.transpose(q_reg_transposed, q_reg)
|
||||
|
||||
for kv_idx in ker.range(N // KV_BLOCK_SIZE):
|
||||
num_kv_blocks = (q_seq + 1) if is_causal else (N // KV_BLOCK_SIZE)
|
||||
for kv_idx in ker.range(num_kv_blocks):
|
||||
k_smem = warp.load(k_smem, k, (), (batch, kv_idx, head_kv, 0), axis=1)
|
||||
v_smem = warp.load(v_smem, v, (), (batch, kv_idx, head_kv, 0), axis=1)
|
||||
|
||||
@@ -99,9 +101,16 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
|
||||
att_block = warp.mma_AtB(att_block, k_reg_transposed, q_reg_transposed)
|
||||
|
||||
# apply attention mask
|
||||
mask_reg = warp.load(mask_reg, mask, (), (batch, 0, q_seq, kv_idx), axis=2)
|
||||
mask_reg_transposed = warp.transpose(mask_reg_transposed, mask_reg)
|
||||
att_block += mask_reg_transposed
|
||||
if is_causal:
|
||||
bs_rows, bs_cols, bs_stride = att_block.base_shape.rows, att_block.base_shape.cols, att_block.base_shape.stride
|
||||
q_base = q_seq * Q_BLOCK_SIZE + (warp.laneid % bs_cols)
|
||||
kv_base = kv_idx * KV_BLOCK_SIZE + (warp.laneid // bs_cols) * bs_stride
|
||||
att_block = warp.map(att_block,
|
||||
lambda x, idx: ((kv_base + idx[0]*bs_rows + idx[2]) > (q_base + idx[1]*bs_cols)).alu(Ops.WHERE, UOp.ufix(x._uop, -math.inf), x))
|
||||
elif mask is not None:
|
||||
mask_reg = warp.load(mask_reg, mask, (), (batch, 0, q_seq, kv_idx), axis=2)
|
||||
mask_reg_transposed = warp.transpose(mask_reg_transposed, mask_reg)
|
||||
att_block += mask_reg_transposed
|
||||
|
||||
# softmax
|
||||
max_vec_last = warp.copy(max_vec_last.after(kv_idx), max_vec)
|
||||
@@ -141,11 +150,18 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
|
||||
|
||||
return ker.finish()
|
||||
|
||||
def custom_backward_q(dqu:UOp, dou:UOp, qu:UOp, ku:UOp, vu:UOp, masku:UOp, l_vecu:UOp, delta_vecu:UOp) -> UOp:
|
||||
def custom_forward_causal(ou:UOp, l_vecu:UOp, qu:UOp, ku:UOp, vu:UOp) -> UOp:
|
||||
return _custom_forward_impl(ou, l_vecu, qu, ku, vu, None)
|
||||
|
||||
def custom_forward_masked(ou:UOp, l_vecu:UOp, qu:UOp, ku:UOp, vu:UOp, masku:UOp) -> UOp:
|
||||
return _custom_forward_impl(ou, l_vecu, qu, ku, vu, masku)
|
||||
|
||||
def _custom_backward_q_impl(dqu:UOp, dou:UOp, qu:UOp, ku:UOp, vu:UOp, masku:UOp|None, l_vecu:UOp, delta_vecu:UOp) -> UOp:
|
||||
with Kernel("fa_custom_backward_q", (H, N // (Q_BLOCK_SIZE*NUM_WORKERS), B_local), NUM_WORKERS * WARP_THREADS) as ker:
|
||||
warp = ker.warp
|
||||
|
||||
dq, do, q, k, v, mask = GL(dqu, ker), GL(dou, ker), GL(qu, ker), GL(ku, ker), GL(vu, ker), GL(masku, ker)
|
||||
dq, do, q, k, v = GL(dqu, ker), GL(dou, ker), GL(qu, ker), GL(ku, ker), GL(vu, ker)
|
||||
mask = GL(masku, ker) if masku is not None else None
|
||||
l_vec, delta_vec = GL(l_vecu, ker), GL(delta_vecu, ker)
|
||||
|
||||
head = ker.blockIdx_x
|
||||
@@ -194,7 +210,8 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
|
||||
l_vec_reg *= 1.0 / math.log(2)
|
||||
delta_vec_reg = warp.load(delta_vec_reg, delta_vec, (), (batch, head, 0, q_seq), axis=2)
|
||||
|
||||
for kv_idx in ker.range(N // KV_BLOCK_SIZE):
|
||||
num_kv_blocks = (q_seq + 1) if is_causal else (N // KV_BLOCK_SIZE)
|
||||
for kv_idx in ker.range(num_kv_blocks):
|
||||
k_smem = warp.load(k_smem, k, (), (batch, kv_idx, head_kv, 0), axis=1)
|
||||
v_smem = warp.load(v_smem, v, (), (batch, kv_idx, head_kv, 0), axis=1)
|
||||
|
||||
@@ -209,9 +226,16 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
|
||||
att_block = warp.mma_AtB(att_block, k_reg_t, q_reg_t)
|
||||
|
||||
# apply attention mask
|
||||
mask_reg = warp.load(mask_reg, mask, (), (batch, 0, q_seq, kv_idx), axis=2)
|
||||
mask_reg_transposed = warp.transpose(mask_reg_transposed, mask_reg)
|
||||
att_block += mask_reg_transposed
|
||||
if is_causal:
|
||||
bs_rows, bs_cols, bs_stride = att_block.base_shape.rows, att_block.base_shape.cols, att_block.base_shape.stride
|
||||
q_base = q_seq * Q_BLOCK_SIZE + (warp.laneid % bs_cols)
|
||||
kv_base = kv_idx * KV_BLOCK_SIZE + (warp.laneid // bs_cols) * bs_stride
|
||||
att_block = warp.map(att_block,
|
||||
lambda x, idx: ((kv_base + idx[0]*bs_rows + idx[2]) > (q_base + idx[1]*bs_cols)).alu(Ops.WHERE, UOp.ufix(x._uop, -math.inf), x))
|
||||
elif mask is not None:
|
||||
mask_reg = warp.load(mask_reg, mask, (), (batch, 0, q_seq, kv_idx), axis=2)
|
||||
mask_reg_transposed = warp.transpose(mask_reg_transposed, mask_reg)
|
||||
att_block += mask_reg_transposed
|
||||
|
||||
att_block -= l_vec_reg
|
||||
att_block = att_block.exp2()
|
||||
@@ -231,11 +255,18 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
|
||||
|
||||
return ker.finish()
|
||||
|
||||
def custom_backward_kv(dku:UOp, dvu:UOp, dou:UOp, qu:UOp, ku:UOp, vu:UOp, masku:UOp, l_vecu:UOp, delta_vecu:UOp):
|
||||
def custom_backward_q_causal(dqu:UOp, dou:UOp, qu:UOp, ku:UOp, vu:UOp, l_vecu:UOp, delta_vecu:UOp) -> UOp:
|
||||
return _custom_backward_q_impl(dqu, dou, qu, ku, vu, None, l_vecu, delta_vecu)
|
||||
|
||||
def custom_backward_q_masked(dqu:UOp, dou:UOp, qu:UOp, ku:UOp, vu:UOp, masku:UOp, l_vecu:UOp, delta_vecu:UOp) -> UOp:
|
||||
return _custom_backward_q_impl(dqu, dou, qu, ku, vu, masku, l_vecu, delta_vecu)
|
||||
|
||||
def _custom_backward_kv_impl(dku:UOp, dvu:UOp, dou:UOp, qu:UOp, ku:UOp, vu:UOp, masku:UOp|None, l_vecu:UOp, delta_vecu:UOp):
|
||||
with Kernel("fa_custom_backward_kv", (H_KV, N // (KV_BLOCK_SIZE*NUM_WORKERS), B_local), NUM_WORKERS * WARP_THREADS) as ker:
|
||||
warp = ker.warp
|
||||
|
||||
dk, dv, do, q, k, v, mask = GL(dku, ker), GL(dvu, ker), GL(dou, ker), GL(qu, ker), GL(ku, ker), GL(vu, ker), GL(masku, ker)
|
||||
dk, dv, do, q, k, v = GL(dku, ker), GL(dvu, ker), GL(dou, ker), GL(qu, ker), GL(ku, ker), GL(vu, ker)
|
||||
mask = GL(masku, ker) if masku is not None else None
|
||||
l_vec, delta_vec = GL(l_vecu, ker), GL(delta_vecu, ker)
|
||||
|
||||
head_kv = ker.blockIdx_x
|
||||
@@ -302,9 +333,16 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
|
||||
att_block *= (1.0 / math.sqrt(D)) * (1.0 / math.log(2))
|
||||
|
||||
# apply attention mask
|
||||
mask_reg = warp.load(mask_reg, mask, (), (batch, 0, q_idx, kv_seq), axis=2)
|
||||
mask_reg_transposed = warp.transpose(mask_reg_transposed, mask_reg)
|
||||
att_block += mask_reg_transposed
|
||||
if is_causal:
|
||||
bs_rows, bs_cols, bs_stride = att_block.base_shape.rows, att_block.base_shape.cols, att_block.base_shape.stride
|
||||
q_base = q_idx * Q_BLOCK_SIZE + (warp.laneid % bs_cols)
|
||||
kv_base = kv_seq * KV_BLOCK_SIZE + (warp.laneid // bs_cols) * bs_stride
|
||||
att_block = warp.map(att_block,
|
||||
lambda x, idx: ((kv_base + idx[0]*bs_rows + idx[2]) > (q_base + idx[1]*bs_cols)).alu(Ops.WHERE, UOp.ufix(x._uop, -math.inf), x))
|
||||
elif mask is not None:
|
||||
mask_reg = warp.load(mask_reg, mask, (), (batch, 0, q_idx, kv_seq), axis=2)
|
||||
mask_reg_transposed = warp.transpose(mask_reg_transposed, mask_reg)
|
||||
att_block += mask_reg_transposed
|
||||
|
||||
att_block -= l_vec_reg
|
||||
att_block = att_block.exp2()
|
||||
@@ -336,24 +374,31 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
|
||||
|
||||
return ker.finish(2)
|
||||
|
||||
def custom_backward_kv_causal(dku:UOp, dvu:UOp, dou:UOp, qu:UOp, ku:UOp, vu:UOp, l_vecu:UOp, delta_vecu:UOp):
|
||||
return _custom_backward_kv_impl(dku, dvu, dou, qu, ku, vu, None, l_vecu, delta_vecu)
|
||||
|
||||
def custom_backward_kv_masked(dku:UOp, dvu:UOp, dou:UOp, qu:UOp, ku:UOp, vu:UOp, masku:UOp, l_vecu:UOp, delta_vecu:UOp):
|
||||
return _custom_backward_kv_impl(dku, dvu, dou, qu, ku, vu, masku, l_vecu, delta_vecu)
|
||||
|
||||
single_device = xq.device[0] if isinstance(xq.device, tuple) else xq.device
|
||||
|
||||
if is_causal:
|
||||
if attn_mask is not None: raise RuntimeError("cannot set attn_mask when is_causal=True")
|
||||
attn_mask = Tensor.ones((B, 1, N, N), requires_grad=False, device=single_device, dtype=dtypes.bool).tril()
|
||||
if attn_mask is not None:
|
||||
elif attn_mask is not None:
|
||||
if attn_mask.dtype == dtypes.bool: attn_mask = attn_mask.where(0, -float("inf"))
|
||||
if attn_mask.shape != (B, 1, N, N):
|
||||
attn_mask = attn_mask.expand(B, 1, N, N)
|
||||
if isinstance(xq.device, tuple) and not isinstance(attn_mask.device, tuple):
|
||||
attn_mask = attn_mask.shard(xq.device, axis=0)
|
||||
else:
|
||||
attn_mask = Tensor.zeros((B, 1, N, N), requires_grad=False, device=single_device, dtype=dtypes.float32)
|
||||
if attn_mask.shape != (B, 1, N, N):
|
||||
attn_mask = attn_mask.expand(B, 1, N, N)
|
||||
if isinstance(xq.device, tuple) and not isinstance(attn_mask.device, tuple):
|
||||
attn_mask = attn_mask.shard(xq.device, axis=0)
|
||||
if isinstance(xq.device, tuple):
|
||||
attn_mask = attn_mask.shard(xq.device, axis=0)
|
||||
|
||||
attn = _sharded_empty_like(xq, axis=0)
|
||||
l_vec = _sharded_empty((B, H, 1, N), xq, axis=0)
|
||||
|
||||
def grad(gradu:UOp, _) -> tuple[None, None, UOp, UOp, UOp, None]:
|
||||
def grad_causal(gradu:UOp, _) -> tuple[None, None, UOp, UOp, UOp]:
|
||||
grad = Tensor(gradu, device=gradu.device)
|
||||
grad_q = _sharded_empty_like(xq, axis=0)
|
||||
grad_k = _sharded_empty_like(xk, axis=0)
|
||||
@@ -361,11 +406,26 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
|
||||
|
||||
delta_vec = (grad * attn).sum(-1, dtype=dtypes.float32).transpose(1, 2).unsqueeze(-2).detach()
|
||||
|
||||
grad_q = Tensor.custom_kernel(grad_q, grad, xq, xk, xv, attn_mask, l_vec, delta_vec, fxn=custom_backward_q)[0]
|
||||
grad_k, grad_v = Tensor.custom_kernel(grad_k, grad_v, grad, xq, xk, xv, attn_mask, l_vec, delta_vec, fxn=custom_backward_kv)[:2]
|
||||
grad_q = Tensor.custom_kernel(grad_q, grad, xq, xk, xv, l_vec, delta_vec, fxn=custom_backward_q_causal)[0]
|
||||
grad_k, grad_v = Tensor.custom_kernel(grad_k, grad_v, grad, xq, xk, xv, l_vec, delta_vec, fxn=custom_backward_kv_causal)[:2]
|
||||
return (None, None, grad_q.uop, grad_k.uop, grad_v.uop)
|
||||
|
||||
def grad_masked(gradu:UOp, _) -> tuple[None, None, UOp, UOp, UOp, None]:
|
||||
grad = Tensor(gradu, device=gradu.device)
|
||||
grad_q = _sharded_empty_like(xq, axis=0)
|
||||
grad_k = _sharded_empty_like(xk, axis=0)
|
||||
grad_v = _sharded_empty_like(xv, axis=0)
|
||||
|
||||
delta_vec = (grad * attn).sum(-1, dtype=dtypes.float32).transpose(1, 2).unsqueeze(-2).detach()
|
||||
|
||||
grad_q = Tensor.custom_kernel(grad_q, grad, xq, xk, xv, attn_mask, l_vec, delta_vec, fxn=custom_backward_q_masked)[0]
|
||||
grad_k, grad_v = Tensor.custom_kernel(grad_k, grad_v, grad, xq, xk, xv, attn_mask, l_vec, delta_vec, fxn=custom_backward_kv_masked)[:2]
|
||||
return (None, None, grad_q.uop, grad_k.uop, grad_v.uop, None)
|
||||
|
||||
attn, l_vec = Tensor.custom_kernel(attn, l_vec, xq, xk, xv, attn_mask, fxn=custom_forward, grad_fxn=grad)[:2]
|
||||
if is_causal:
|
||||
attn, l_vec = Tensor.custom_kernel(attn, l_vec, xq, xk, xv, fxn=custom_forward_causal, grad_fxn=grad_causal)[:2]
|
||||
else:
|
||||
attn, l_vec = Tensor.custom_kernel(attn, l_vec, xq, xk, xv, attn_mask, fxn=custom_forward_masked, grad_fxn=grad_masked)[:2]
|
||||
attn_ = attn[:, :N_, :, :D_]
|
||||
|
||||
return attn_.transpose(1, 2).cast(odtype)
|
||||
|
||||
@@ -750,6 +750,35 @@ class TestTK(unittest.TestCase):
|
||||
|
||||
fa_jitted = TinyJit(flash_attention)
|
||||
|
||||
for _ in range(10):
|
||||
st = time.perf_counter()
|
||||
out = fa_jitted(q, k, v, is_causal=False)
|
||||
et = time.perf_counter() - st
|
||||
attn_flops = 2 * B * H * N * N * D + \
|
||||
4 * B * H * N * N + \
|
||||
2 * B * H * N * N * D
|
||||
print(f"{attn_flops/(et*1e9):2f} GFLOPS")
|
||||
out = out.float().transpose(1, 2)
|
||||
|
||||
ref = q.scaled_dot_product_attention(k, v, is_causal=False, enable_gqa=True).float().transpose(1, 2)
|
||||
|
||||
np.testing.assert_allclose(out.numpy(), ref.numpy(), atol=2e-2, rtol=2e-2)
|
||||
|
||||
def test_fast_fa_causal(self):
|
||||
from extra.thunder.tiny.fa import flash_attention
|
||||
|
||||
B, N, H, H_KV, D = 2, 8192, 32, 8, 128
|
||||
|
||||
with Context(DEBUG=0):
|
||||
q = Tensor.randn(B, N, H, D, dtype=dtypes.bfloat16).contiguous()
|
||||
k = Tensor.randn(B, N, H_KV, D, dtype=dtypes.bfloat16).contiguous()
|
||||
v = Tensor.randn(B, N, H_KV, D, dtype=dtypes.bfloat16).contiguous()
|
||||
Tensor.realize(q, k, v)
|
||||
|
||||
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
|
||||
|
||||
fa_jitted = TinyJit(flash_attention)
|
||||
|
||||
for _ in range(10):
|
||||
st = time.perf_counter()
|
||||
out = fa_jitted(q, k, v, is_causal=True)
|
||||
@@ -838,7 +867,7 @@ class TestTK(unittest.TestCase):
|
||||
|
||||
np.testing.assert_allclose(q.grad.numpy(), q_ref.grad.numpy(), atol=2e-2, rtol=2e-2)
|
||||
np.testing.assert_allclose(v.grad.numpy(), v_ref.grad.numpy(), atol=2e-2, rtol=2e-2)
|
||||
np.testing.assert_allclose(k.grad.numpy(), k_ref.grad.numpy(), atol=5e-2, rtol=2e-2)
|
||||
np.testing.assert_allclose(k.grad.numpy(), k_ref.grad.numpy(), atol=6e-2, rtol=2e-2)
|
||||
|
||||
def test_fast_fa_bwd_causal_jitted(self):
|
||||
from extra.thunder.tiny.fa import flash_attention
|
||||
|
||||
Reference in New Issue
Block a user