2022-05-21 05:21:58 +08:00
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
2023-07-17 20:26:11 +08:00
|
|
|
from einops import rearrange
|
|
|
|
|
|
2023-08-20 12:07:33 +08:00
|
|
|
# isort: off
|
|
|
|
|
# We need to import the CUDA kernels after importing torch
|
|
|
|
|
import flash_attn_2_cuda as flash_attn_cuda
|
|
|
|
|
# isort: on
|
|
|
|
|
|
2023-07-17 20:26:11 +08:00
|
|
|
|
|
|
|
|
def _get_block_size(device, head_dim, is_dropout, is_causal):
|
|
|
|
|
# This should match the block sizes in the CUDA kernel
|
|
|
|
|
assert head_dim <= 256
|
|
|
|
|
major, minor = torch.cuda.get_device_capability(device)
|
|
|
|
|
is_sm8x = major == 8 and minor > 0 # Only include sm86 and sm89, exclude sm80 (A100)
|
|
|
|
|
is_sm80 = major == 8 and minor == 0
|
|
|
|
|
is_sm90 = major == 9 and minor == 0
|
|
|
|
|
if head_dim <= 32:
|
|
|
|
|
return 128, 128
|
|
|
|
|
if head_dim <= 64:
|
|
|
|
|
return (128, 128) if not is_dropout else (128, 64)
|
|
|
|
|
elif head_dim <= 96:
|
|
|
|
|
return (64, 64) if (is_sm8x and is_causal) else (128, 64)
|
|
|
|
|
elif head_dim <= 128:
|
|
|
|
|
if is_sm8x:
|
|
|
|
|
return (64, 64) if (not is_dropout and is_causal) else (128, 32)
|
|
|
|
|
else:
|
|
|
|
|
return 128, (64 if not is_dropout else 32)
|
|
|
|
|
elif head_dim <= 160:
|
|
|
|
|
if is_sm8x:
|
|
|
|
|
return (128, 64) if not is_causal else (64, 64)
|
|
|
|
|
else:
|
|
|
|
|
return 128, 32
|
|
|
|
|
elif head_dim <= 192:
|
|
|
|
|
return (128, 64) if not is_dropout else (64, 64)
|
|
|
|
|
elif head_dim <= 224:
|
|
|
|
|
return (128, 64) if (is_sm80 or is_sm90) else (64, 64)
|
|
|
|
|
elif head_dim <= 256:
|
|
|
|
|
return (128, 64) if is_sm80 else (64, 64)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, return_softmax):
|
2023-07-18 12:54:44 +08:00
|
|
|
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
|
|
|
|
|
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
2023-07-28 07:11:34 +08:00
|
|
|
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
|
2023-07-17 20:26:11 +08:00
|
|
|
q, k, v, None, dropout_p, softmax_scale, causal, return_softmax, None
|
|
|
|
|
)
|
2023-07-28 07:11:34 +08:00
|
|
|
return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
|
2023-07-17 20:26:11 +08:00
|
|
|
|
|
|
|
|
|
2023-08-19 05:22:11 +08:00
|
|
|
def _flash_attn_varlen_forward(
|
|
|
|
|
q,
|
|
|
|
|
k,
|
|
|
|
|
v,
|
|
|
|
|
cu_seqlens_q,
|
|
|
|
|
cu_seqlens_k,
|
|
|
|
|
max_seqlen_q,
|
|
|
|
|
max_seqlen_k,
|
|
|
|
|
dropout_p,
|
|
|
|
|
softmax_scale,
|
|
|
|
|
causal,
|
|
|
|
|
return_softmax,
|
|
|
|
|
):
|
2023-07-18 12:54:44 +08:00
|
|
|
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
|
|
|
|
|
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
2023-07-28 07:11:34 +08:00
|
|
|
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
|
2023-08-19 05:22:11 +08:00
|
|
|
q,
|
|
|
|
|
k,
|
|
|
|
|
v,
|
|
|
|
|
None,
|
|
|
|
|
cu_seqlens_q,
|
|
|
|
|
cu_seqlens_k,
|
|
|
|
|
max_seqlen_q,
|
|
|
|
|
max_seqlen_k,
|
|
|
|
|
dropout_p,
|
|
|
|
|
softmax_scale,
|
|
|
|
|
False,
|
|
|
|
|
causal,
|
|
|
|
|
return_softmax,
|
|
|
|
|
None,
|
2022-07-01 11:26:04 +08:00
|
|
|
)
|
|
|
|
|
# if out.isnan().any() or softmax_lse.isnan().any():
|
2022-05-21 05:21:58 +08:00
|
|
|
# breakpoint()
|
2023-07-28 07:11:34 +08:00
|
|
|
return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
|
2022-05-21 05:21:58 +08:00
|
|
|
|
|
|
|
|
|
2023-08-19 05:22:11 +08:00
|
|
|
def _flash_attn_backward(
|
|
|
|
|
dout, q, k, v, out, softmax_lse, dq, dk, dv, dropout_p, softmax_scale, causal, rng_state=None
|
|
|
|
|
):
|
2023-07-18 12:54:44 +08:00
|
|
|
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
|
|
|
|
|
# dq, dk, dv are allocated by us so they should already be contiguous
|
|
|
|
|
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
|
2023-07-17 20:26:11 +08:00
|
|
|
dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
|
2023-08-19 05:22:11 +08:00
|
|
|
dout,
|
|
|
|
|
q,
|
|
|
|
|
k,
|
|
|
|
|
v,
|
|
|
|
|
out,
|
|
|
|
|
softmax_lse,
|
|
|
|
|
dq,
|
|
|
|
|
dk,
|
|
|
|
|
dv,
|
|
|
|
|
dropout_p,
|
|
|
|
|
softmax_scale,
|
|
|
|
|
causal,
|
|
|
|
|
None,
|
|
|
|
|
rng_state,
|
2023-07-17 20:26:11 +08:00
|
|
|
)
|
|
|
|
|
return dq, dk, dv, softmax_d
|
|
|
|
|
|
|
|
|
|
|
2023-08-19 05:22:11 +08:00
|
|
|
def _flash_attn_varlen_backward(
|
|
|
|
|
dout,
|
|
|
|
|
q,
|
|
|
|
|
k,
|
|
|
|
|
v,
|
|
|
|
|
out,
|
|
|
|
|
softmax_lse,
|
|
|
|
|
dq,
|
|
|
|
|
dk,
|
|
|
|
|
dv,
|
|
|
|
|
cu_seqlens_q,
|
|
|
|
|
cu_seqlens_k,
|
|
|
|
|
max_seqlen_q,
|
|
|
|
|
max_seqlen_k,
|
|
|
|
|
dropout_p,
|
|
|
|
|
softmax_scale,
|
|
|
|
|
causal,
|
|
|
|
|
rng_state=None,
|
|
|
|
|
):
|
2023-07-18 12:54:44 +08:00
|
|
|
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
|
|
|
|
|
# dq, dk, dv are allocated by us so they should already be contiguous
|
|
|
|
|
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
|
2023-07-17 20:26:11 +08:00
|
|
|
dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd(
|
2023-08-19 05:22:11 +08:00
|
|
|
dout,
|
|
|
|
|
q,
|
|
|
|
|
k,
|
|
|
|
|
v,
|
|
|
|
|
out,
|
|
|
|
|
softmax_lse,
|
|
|
|
|
dq,
|
|
|
|
|
dk,
|
|
|
|
|
dv,
|
|
|
|
|
cu_seqlens_q,
|
|
|
|
|
cu_seqlens_k,
|
|
|
|
|
max_seqlen_q,
|
|
|
|
|
max_seqlen_k,
|
|
|
|
|
dropout_p,
|
|
|
|
|
softmax_scale,
|
|
|
|
|
False,
|
|
|
|
|
causal,
|
|
|
|
|
None,
|
|
|
|
|
rng_state,
|
2023-07-17 20:26:11 +08:00
|
|
|
)
|
2022-07-01 11:26:04 +08:00
|
|
|
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
|
2022-05-21 05:21:58 +08:00
|
|
|
# breakpoint()
|
2022-07-01 11:26:04 +08:00
|
|
|
return dq, dk, dv, softmax_d
|
2022-05-21 05:21:58 +08:00
|
|
|
|
|
|
|
|
|
2022-07-01 11:26:04 +08:00
|
|
|
class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
2022-05-21 05:21:58 +08:00
|
|
|
@staticmethod
|
2023-07-17 20:26:11 +08:00
|
|
|
def forward(ctx, qkv, dropout_p, softmax_scale, causal, return_softmax):
|
2022-05-21 05:21:58 +08:00
|
|
|
if softmax_scale is None:
|
|
|
|
|
softmax_scale = qkv.shape[-1] ** (-0.5)
|
2023-07-28 07:11:34 +08:00
|
|
|
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
|
2023-08-19 05:22:11 +08:00
|
|
|
qkv[:, :, 0],
|
|
|
|
|
qkv[:, :, 1],
|
|
|
|
|
qkv[:, :, 2],
|
|
|
|
|
dropout_p,
|
|
|
|
|
softmax_scale,
|
|
|
|
|
causal=causal,
|
|
|
|
|
return_softmax=return_softmax and dropout_p > 0,
|
2022-05-21 05:21:58 +08:00
|
|
|
)
|
2023-07-17 20:26:11 +08:00
|
|
|
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
|
2022-05-21 05:21:58 +08:00
|
|
|
ctx.dropout_p = dropout_p
|
|
|
|
|
ctx.softmax_scale = softmax_scale
|
|
|
|
|
ctx.causal = causal
|
2022-07-01 11:26:04 +08:00
|
|
|
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
2022-05-21 05:21:58 +08:00
|
|
|
|
|
|
|
|
@staticmethod
|
2022-07-01 11:26:04 +08:00
|
|
|
def backward(ctx, dout, *args):
|
2023-07-17 20:26:11 +08:00
|
|
|
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
|
|
|
|
|
qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
|
|
|
|
|
dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
|
2022-07-01 11:26:04 +08:00
|
|
|
_flash_attn_backward(
|
2023-08-19 05:22:11 +08:00
|
|
|
dout,
|
|
|
|
|
q,
|
|
|
|
|
k,
|
|
|
|
|
v,
|
|
|
|
|
out,
|
|
|
|
|
softmax_lse,
|
|
|
|
|
dqkv[:, :, 0],
|
|
|
|
|
dqkv[:, :, 1],
|
|
|
|
|
dqkv[:, :, 2],
|
|
|
|
|
ctx.dropout_p,
|
|
|
|
|
ctx.softmax_scale,
|
|
|
|
|
ctx.causal,
|
|
|
|
|
rng_state=rng_state,
|
2023-07-17 20:26:11 +08:00
|
|
|
)
|
2023-08-19 05:22:11 +08:00
|
|
|
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
|
2023-07-17 20:26:11 +08:00
|
|
|
return dqkv, None, None, None, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_softmax):
|
|
|
|
|
if softmax_scale is None:
|
|
|
|
|
softmax_scale = qkv.shape[-1] ** (-0.5)
|
2023-07-28 07:11:34 +08:00
|
|
|
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
|
2023-08-19 05:22:11 +08:00
|
|
|
qkv[:, 0],
|
|
|
|
|
qkv[:, 1],
|
|
|
|
|
qkv[:, 2],
|
|
|
|
|
cu_seqlens,
|
|
|
|
|
cu_seqlens,
|
|
|
|
|
max_seqlen,
|
|
|
|
|
max_seqlen,
|
|
|
|
|
dropout_p,
|
|
|
|
|
softmax_scale,
|
|
|
|
|
causal=causal,
|
|
|
|
|
return_softmax=return_softmax and dropout_p > 0,
|
2022-05-21 05:21:58 +08:00
|
|
|
)
|
2023-07-17 20:26:11 +08:00
|
|
|
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
|
|
|
|
|
ctx.dropout_p = dropout_p
|
|
|
|
|
ctx.max_seqlen = max_seqlen
|
|
|
|
|
ctx.softmax_scale = softmax_scale
|
|
|
|
|
ctx.causal = causal
|
|
|
|
|
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def backward(ctx, dout, *args):
|
|
|
|
|
q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
|
|
|
|
|
qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
|
|
|
|
|
dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
|
|
|
|
|
_flash_attn_varlen_backward(
|
2023-08-19 05:22:11 +08:00
|
|
|
dout,
|
|
|
|
|
q,
|
|
|
|
|
k,
|
|
|
|
|
v,
|
|
|
|
|
out,
|
|
|
|
|
softmax_lse,
|
|
|
|
|
dqkv[:, 0],
|
|
|
|
|
dqkv[:, 1],
|
|
|
|
|
dqkv[:, 2],
|
|
|
|
|
cu_seqlens,
|
|
|
|
|
cu_seqlens,
|
|
|
|
|
ctx.max_seqlen,
|
|
|
|
|
ctx.max_seqlen,
|
|
|
|
|
ctx.dropout_p,
|
|
|
|
|
ctx.softmax_scale,
|
|
|
|
|
ctx.causal,
|
|
|
|
|
rng_state=rng_state,
|
2023-07-17 20:26:11 +08:00
|
|
|
)
|
2023-08-19 05:22:11 +08:00
|
|
|
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
|
2023-07-17 20:26:11 +08:00
|
|
|
return dqkv, None, None, None, None, None, None
|
2022-05-21 05:21:58 +08:00
|
|
|
|
|
|
|
|
|
2022-07-01 11:26:04 +08:00
|
|
|
class FlashAttnKVPackedFunc(torch.autograd.Function):
|
2022-05-21 05:21:58 +08:00
|
|
|
@staticmethod
|
2023-07-17 20:26:11 +08:00
|
|
|
def forward(ctx, q, kv, dropout_p, softmax_scale, causal, return_softmax):
|
2022-05-21 05:21:58 +08:00
|
|
|
if softmax_scale is None:
|
2022-07-01 11:26:04 +08:00
|
|
|
softmax_scale = q.shape[-1] ** (-0.5)
|
2023-07-28 07:11:34 +08:00
|
|
|
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
|
2023-08-19 05:22:11 +08:00
|
|
|
q,
|
|
|
|
|
kv[:, :, 0],
|
|
|
|
|
kv[:, :, 1],
|
|
|
|
|
dropout_p,
|
|
|
|
|
softmax_scale,
|
|
|
|
|
causal=causal,
|
|
|
|
|
return_softmax=return_softmax and dropout_p > 0,
|
2022-07-01 11:26:04 +08:00
|
|
|
)
|
2023-07-17 20:26:11 +08:00
|
|
|
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
|
2022-07-01 11:26:04 +08:00
|
|
|
ctx.dropout_p = dropout_p
|
|
|
|
|
ctx.softmax_scale = softmax_scale
|
|
|
|
|
ctx.causal = causal
|
|
|
|
|
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def backward(ctx, dout, *args):
|
2023-07-17 20:26:11 +08:00
|
|
|
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
|
2022-07-01 11:26:04 +08:00
|
|
|
dq = torch.empty_like(q)
|
2023-07-17 20:26:11 +08:00
|
|
|
kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
|
|
|
|
|
dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
|
2022-07-01 11:26:04 +08:00
|
|
|
_flash_attn_backward(
|
2023-08-19 05:22:11 +08:00
|
|
|
dout,
|
|
|
|
|
q,
|
|
|
|
|
k,
|
|
|
|
|
v,
|
|
|
|
|
out,
|
|
|
|
|
softmax_lse,
|
|
|
|
|
dq,
|
|
|
|
|
dkv[:, :, 0],
|
|
|
|
|
dkv[:, :, 1],
|
|
|
|
|
ctx.dropout_p,
|
|
|
|
|
ctx.softmax_scale,
|
|
|
|
|
ctx.causal,
|
|
|
|
|
rng_state=rng_state,
|
2022-07-01 11:26:04 +08:00
|
|
|
)
|
2023-08-19 05:22:11 +08:00
|
|
|
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
|
|
|
|
|
dkv = dkv[..., : dout.shape[-1]]
|
2023-07-17 20:26:11 +08:00
|
|
|
return dq, dkv, None, None, None, None
|
2022-07-01 11:26:04 +08:00
|
|
|
|
|
|
|
|
|
2023-07-17 20:26:11 +08:00
|
|
|
class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
|
2022-07-01 11:26:04 +08:00
|
|
|
@staticmethod
|
2023-08-19 05:22:11 +08:00
|
|
|
def forward(
|
|
|
|
|
ctx,
|
|
|
|
|
q,
|
|
|
|
|
kv,
|
|
|
|
|
cu_seqlens_q,
|
|
|
|
|
cu_seqlens_k,
|
|
|
|
|
max_seqlen_q,
|
|
|
|
|
max_seqlen_k,
|
|
|
|
|
dropout_p,
|
|
|
|
|
softmax_scale,
|
|
|
|
|
causal,
|
|
|
|
|
return_softmax,
|
|
|
|
|
):
|
2022-07-01 11:26:04 +08:00
|
|
|
if softmax_scale is None:
|
|
|
|
|
softmax_scale = q.shape[-1] ** (-0.5)
|
2023-07-28 07:11:34 +08:00
|
|
|
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
|
2023-08-19 05:22:11 +08:00
|
|
|
q,
|
|
|
|
|
kv[:, 0],
|
|
|
|
|
kv[:, 1],
|
|
|
|
|
cu_seqlens_q,
|
|
|
|
|
cu_seqlens_k,
|
|
|
|
|
max_seqlen_q,
|
|
|
|
|
max_seqlen_k,
|
|
|
|
|
dropout_p,
|
|
|
|
|
softmax_scale,
|
|
|
|
|
causal=causal,
|
|
|
|
|
return_softmax=return_softmax and dropout_p > 0,
|
|
|
|
|
)
|
|
|
|
|
ctx.save_for_backward(
|
|
|
|
|
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
|
2022-05-21 05:21:58 +08:00
|
|
|
)
|
|
|
|
|
ctx.dropout_p = dropout_p
|
2022-07-01 11:26:04 +08:00
|
|
|
ctx.max_seqlen_q = max_seqlen_q
|
|
|
|
|
ctx.max_seqlen_k = max_seqlen_k
|
2022-05-21 05:21:58 +08:00
|
|
|
ctx.softmax_scale = softmax_scale
|
|
|
|
|
ctx.causal = causal
|
2022-07-01 11:26:04 +08:00
|
|
|
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
2022-05-21 05:21:58 +08:00
|
|
|
|
|
|
|
|
@staticmethod
|
2022-07-01 11:26:04 +08:00
|
|
|
def backward(ctx, dout, *args):
|
|
|
|
|
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
|
2023-07-17 20:26:11 +08:00
|
|
|
dq = torch.empty_like(q)
|
|
|
|
|
kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
|
|
|
|
|
dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
|
|
|
|
|
_flash_attn_varlen_backward(
|
2023-08-19 05:22:11 +08:00
|
|
|
dout,
|
|
|
|
|
q,
|
|
|
|
|
k,
|
|
|
|
|
v,
|
|
|
|
|
out,
|
|
|
|
|
softmax_lse,
|
|
|
|
|
dq,
|
|
|
|
|
dkv[:, 0],
|
|
|
|
|
dkv[:, 1],
|
|
|
|
|
cu_seqlens_q,
|
|
|
|
|
cu_seqlens_k,
|
|
|
|
|
ctx.max_seqlen_q,
|
|
|
|
|
ctx.max_seqlen_k,
|
|
|
|
|
ctx.dropout_p,
|
|
|
|
|
ctx.softmax_scale,
|
|
|
|
|
ctx.causal,
|
|
|
|
|
rng_state=rng_state,
|
2023-07-17 20:26:11 +08:00
|
|
|
)
|
2023-08-19 05:22:11 +08:00
|
|
|
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
|
|
|
|
|
dkv = dkv[..., : dout.shape[-1]]
|
2023-07-17 20:26:11 +08:00
|
|
|
return dq, dkv, None, None, None, None, None, None, None, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FlashAttnFunc(torch.autograd.Function):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, return_softmax):
|
|
|
|
|
if softmax_scale is None:
|
|
|
|
|
softmax_scale = q.shape[-1] ** (-0.5)
|
2023-07-28 07:11:34 +08:00
|
|
|
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
|
2023-08-19 05:22:11 +08:00
|
|
|
q,
|
|
|
|
|
k,
|
|
|
|
|
v,
|
|
|
|
|
dropout_p,
|
|
|
|
|
softmax_scale,
|
|
|
|
|
causal=causal,
|
|
|
|
|
return_softmax=return_softmax and dropout_p > 0,
|
2023-07-17 20:26:11 +08:00
|
|
|
)
|
|
|
|
|
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
|
|
|
|
|
ctx.dropout_p = dropout_p
|
|
|
|
|
ctx.softmax_scale = softmax_scale
|
|
|
|
|
ctx.causal = causal
|
|
|
|
|
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def backward(ctx, dout, *args):
|
|
|
|
|
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
|
2022-07-01 11:26:04 +08:00
|
|
|
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
|
|
|
|
|
_flash_attn_backward(
|
2023-08-19 05:22:11 +08:00
|
|
|
dout,
|
|
|
|
|
q,
|
|
|
|
|
k,
|
|
|
|
|
v,
|
|
|
|
|
out,
|
|
|
|
|
softmax_lse,
|
|
|
|
|
dq,
|
|
|
|
|
dk,
|
|
|
|
|
dv,
|
|
|
|
|
ctx.dropout_p,
|
|
|
|
|
ctx.softmax_scale,
|
|
|
|
|
ctx.causal,
|
|
|
|
|
rng_state=rng_state,
|
2022-05-21 05:21:58 +08:00
|
|
|
)
|
2023-08-19 05:22:11 +08:00
|
|
|
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
|
|
|
|
|
dk = dk[..., : dout.shape[-1]]
|
|
|
|
|
dv = dv[..., : dout.shape[-1]]
|
2023-07-17 20:26:11 +08:00
|
|
|
return dq, dk, dv, None, None, None, None, None, None, None, None
|
2022-07-01 11:26:04 +08:00
|
|
|
|
|
|
|
|
|
2023-07-17 20:26:11 +08:00
|
|
|
class FlashAttnVarlenFunc(torch.autograd.Function):
|
2022-10-14 11:47:54 +08:00
|
|
|
@staticmethod
|
2023-08-19 05:22:11 +08:00
|
|
|
def forward(
|
|
|
|
|
ctx,
|
|
|
|
|
q,
|
|
|
|
|
k,
|
|
|
|
|
v,
|
|
|
|
|
cu_seqlens_q,
|
|
|
|
|
cu_seqlens_k,
|
|
|
|
|
max_seqlen_q,
|
|
|
|
|
max_seqlen_k,
|
|
|
|
|
dropout_p,
|
|
|
|
|
softmax_scale,
|
|
|
|
|
causal,
|
|
|
|
|
return_softmax,
|
|
|
|
|
):
|
2022-10-14 11:47:54 +08:00
|
|
|
if softmax_scale is None:
|
2023-07-17 20:26:11 +08:00
|
|
|
softmax_scale = q.shape[-1] ** (-0.5)
|
2023-07-28 07:11:34 +08:00
|
|
|
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
|
2023-08-19 05:22:11 +08:00
|
|
|
q,
|
|
|
|
|
k,
|
|
|
|
|
v,
|
|
|
|
|
cu_seqlens_q,
|
|
|
|
|
cu_seqlens_k,
|
|
|
|
|
max_seqlen_q,
|
|
|
|
|
max_seqlen_k,
|
|
|
|
|
dropout_p,
|
|
|
|
|
softmax_scale,
|
|
|
|
|
causal=causal,
|
|
|
|
|
return_softmax=return_softmax and dropout_p > 0,
|
|
|
|
|
)
|
|
|
|
|
ctx.save_for_backward(
|
|
|
|
|
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
|
2022-10-14 11:47:54 +08:00
|
|
|
)
|
|
|
|
|
ctx.dropout_p = dropout_p
|
2023-07-17 20:26:11 +08:00
|
|
|
ctx.max_seqlen_q = max_seqlen_q
|
|
|
|
|
ctx.max_seqlen_k = max_seqlen_k
|
2022-10-14 11:47:54 +08:00
|
|
|
ctx.softmax_scale = softmax_scale
|
|
|
|
|
ctx.causal = causal
|
2023-07-17 20:26:11 +08:00
|
|
|
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
2022-10-14 11:47:54 +08:00
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def backward(ctx, dout, *args):
|
2023-07-17 20:26:11 +08:00
|
|
|
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
|
|
|
|
|
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
|
|
|
|
|
_flash_attn_varlen_backward(
|
2023-08-19 05:22:11 +08:00
|
|
|
dout,
|
|
|
|
|
q,
|
|
|
|
|
k,
|
|
|
|
|
v,
|
|
|
|
|
out,
|
|
|
|
|
softmax_lse,
|
|
|
|
|
dq,
|
|
|
|
|
dk,
|
|
|
|
|
dv,
|
|
|
|
|
cu_seqlens_q,
|
|
|
|
|
cu_seqlens_k,
|
|
|
|
|
ctx.max_seqlen_q,
|
|
|
|
|
ctx.max_seqlen_k,
|
|
|
|
|
ctx.dropout_p,
|
|
|
|
|
ctx.softmax_scale,
|
|
|
|
|
ctx.causal,
|
|
|
|
|
rng_state=rng_state,
|
2022-10-14 11:47:54 +08:00
|
|
|
)
|
2023-08-19 05:22:11 +08:00
|
|
|
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
|
|
|
|
|
dk = dk[..., : dout.shape[-1]]
|
|
|
|
|
dv = dv[..., : dout.shape[-1]]
|
2023-07-17 20:26:11 +08:00
|
|
|
return dq, dk, dv, None, None, None, None, None, None, None, None
|
2022-10-14 11:47:54 +08:00
|
|
|
|
|
|
|
|
|
2023-08-19 05:22:11 +08:00
|
|
|
def flash_attn_qkvpacked_func(
|
|
|
|
|
qkv, dropout_p=0.0, softmax_scale=None, causal=False, return_attn_probs=False
|
|
|
|
|
):
|
2022-07-01 11:26:04 +08:00
|
|
|
"""dropout_p should be set to 0.0 during evaluation
|
2023-07-17 20:26:11 +08:00
|
|
|
If Q, K, V are already stacked into 1 tensor, this function will be faster than
|
|
|
|
|
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
|
|
|
|
|
of the gradients of Q, K, V.
|
2023-07-29 06:26:29 +08:00
|
|
|
For multi-query and grouped-query attention (MQA/GQA), please see
|
|
|
|
|
flash_attn_kvpacked_func and flash_attn_func.
|
2023-07-17 20:26:11 +08:00
|
|
|
|
2022-07-01 11:26:04 +08:00
|
|
|
Arguments:
|
2023-07-17 20:26:11 +08:00
|
|
|
qkv: (batch_size, seqlen, 3, nheads, headdim)
|
2022-07-01 11:26:04 +08:00
|
|
|
dropout_p: float. Dropout probability.
|
|
|
|
|
softmax_scale: float. The scaling of QK^T before applying softmax.
|
|
|
|
|
Default to 1 / sqrt(headdim).
|
|
|
|
|
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
|
|
|
|
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
|
|
|
|
testing only. The returned probabilities are not guaranteed to be correct
|
|
|
|
|
(they might not have the right scaling).
|
|
|
|
|
Return:
|
2023-07-17 20:26:11 +08:00
|
|
|
out: (batch_size, seqlen, nheads, headdim).
|
2022-07-01 11:26:04 +08:00
|
|
|
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
|
|
|
|
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
|
|
|
|
normalization factor).
|
|
|
|
|
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
|
|
|
|
|
The output of softmax (possibly with different scaling). It also encodes the dropout
|
|
|
|
|
pattern (negative means that location was dropped, nonnegative means it was kept).
|
|
|
|
|
"""
|
2023-07-17 20:26:11 +08:00
|
|
|
return FlashAttnQKVPackedFunc.apply(qkv, dropout_p, softmax_scale, causal, return_attn_probs)
|
2022-07-01 11:26:04 +08:00
|
|
|
|
|
|
|
|
|
2023-08-19 05:22:11 +08:00
|
|
|
def flash_attn_kvpacked_func(
|
|
|
|
|
q, kv, dropout_p=0.0, softmax_scale=None, causal=False, return_attn_probs=False
|
|
|
|
|
):
|
2022-07-01 11:26:04 +08:00
|
|
|
"""dropout_p should be set to 0.0 during evaluation
|
2023-07-17 20:26:11 +08:00
|
|
|
If K, V are already stacked into 1 tensor, this function will be faster than
|
|
|
|
|
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
|
|
|
|
|
of the gradients of K, V.
|
|
|
|
|
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
|
2023-08-01 08:47:03 +08:00
|
|
|
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
|
2023-07-17 20:26:11 +08:00
|
|
|
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
|
|
|
|
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
|
|
|
|
|
2023-08-21 15:07:33 +08:00
|
|
|
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
|
|
|
|
|
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
|
|
|
|
|
1 1 1 1 0
|
|
|
|
|
1 1 1 1 1
|
|
|
|
|
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
|
|
|
|
|
0 0
|
|
|
|
|
0 0
|
|
|
|
|
0 0
|
|
|
|
|
1 0
|
|
|
|
|
1 1
|
|
|
|
|
If the row of the mask is all zero, the output will be zero.
|
|
|
|
|
|
2022-07-01 11:26:04 +08:00
|
|
|
Arguments:
|
2023-07-17 20:26:11 +08:00
|
|
|
q: (batch_size, seqlen, nheads, headdim)
|
|
|
|
|
kv: (batch_size, seqlen, 2, nheads_k, headdim)
|
|
|
|
|
dropout_p: float. Dropout probability.
|
|
|
|
|
softmax_scale: float. The scaling of QK^T before applying softmax.
|
|
|
|
|
Default to 1 / sqrt(headdim).
|
|
|
|
|
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
|
|
|
|
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
|
|
|
|
testing only. The returned probabilities are not guaranteed to be correct
|
|
|
|
|
(they might not have the right scaling).
|
|
|
|
|
Return:
|
|
|
|
|
out: (batch_size, seqlen, nheads, headdim).
|
|
|
|
|
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
|
|
|
|
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
|
|
|
|
normalization factor).
|
|
|
|
|
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
|
|
|
|
|
The output of softmax (possibly with different scaling). It also encodes the dropout
|
|
|
|
|
pattern (negative means that location was dropped, nonnegative means it was kept).
|
|
|
|
|
"""
|
|
|
|
|
return FlashAttnKVPackedFunc.apply(q, kv, dropout_p, softmax_scale, causal, return_attn_probs)
|
|
|
|
|
|
|
|
|
|
|
2023-08-19 05:22:11 +08:00
|
|
|
def flash_attn_func(
|
|
|
|
|
q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, return_attn_probs=False
|
|
|
|
|
):
|
2023-07-17 20:26:11 +08:00
|
|
|
"""dropout_p should be set to 0.0 during evaluation
|
|
|
|
|
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
|
2023-08-01 08:47:03 +08:00
|
|
|
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
|
2023-07-17 20:26:11 +08:00
|
|
|
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
|
|
|
|
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
|
|
|
|
|
2023-08-21 15:07:33 +08:00
|
|
|
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
|
|
|
|
|
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
|
|
|
|
|
1 1 1 1 0
|
|
|
|
|
1 1 1 1 1
|
|
|
|
|
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
|
|
|
|
|
0 0
|
|
|
|
|
0 0
|
|
|
|
|
0 0
|
|
|
|
|
1 0
|
|
|
|
|
1 1
|
|
|
|
|
If the row of the mask is all zero, the output will be zero.
|
|
|
|
|
|
2023-07-17 20:26:11 +08:00
|
|
|
Arguments:
|
|
|
|
|
q: (batch_size, seqlen, nheads, headdim)
|
|
|
|
|
k: (batch_size, seqlen, nheads_k, headdim)
|
|
|
|
|
v: (batch_size, seqlen, nheads_k, headdim)
|
|
|
|
|
dropout_p: float. Dropout probability.
|
|
|
|
|
softmax_scale: float. The scaling of QK^T before applying softmax.
|
|
|
|
|
Default to 1 / sqrt(headdim).
|
|
|
|
|
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
|
|
|
|
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
|
|
|
|
testing only. The returned probabilities are not guaranteed to be correct
|
|
|
|
|
(they might not have the right scaling).
|
|
|
|
|
Return:
|
|
|
|
|
out: (batch_size, seqlen, nheads, headdim).
|
|
|
|
|
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
|
|
|
|
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
|
|
|
|
normalization factor).
|
|
|
|
|
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
|
|
|
|
|
The output of softmax (possibly with different scaling). It also encodes the dropout
|
|
|
|
|
pattern (negative means that location was dropped, nonnegative means it was kept).
|
|
|
|
|
"""
|
|
|
|
|
return FlashAttnFunc.apply(q, k, v, dropout_p, softmax_scale, causal, return_attn_probs)
|
|
|
|
|
|
|
|
|
|
|
2023-08-19 05:22:11 +08:00
|
|
|
def flash_attn_varlen_qkvpacked_func(
|
|
|
|
|
qkv,
|
|
|
|
|
cu_seqlens,
|
|
|
|
|
max_seqlen,
|
|
|
|
|
dropout_p=0.0,
|
|
|
|
|
softmax_scale=None,
|
|
|
|
|
causal=False,
|
|
|
|
|
return_attn_probs=False,
|
|
|
|
|
):
|
2023-07-17 20:26:11 +08:00
|
|
|
"""dropout_p should be set to 0.0 during evaluation
|
|
|
|
|
If Q, K, V are already stacked into 1 tensor, this function will be faster than
|
|
|
|
|
calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
|
|
|
|
|
of the gradients of Q, K, V.
|
2023-07-29 06:26:29 +08:00
|
|
|
For multi-query and grouped-query attention (MQA/GQA), please see
|
|
|
|
|
flash_attn_varlen_kvpacked_func and flash_attn_varlen_func.
|
2023-07-17 20:26:11 +08:00
|
|
|
|
|
|
|
|
Arguments:
|
|
|
|
|
qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
|
|
|
|
|
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
|
|
|
|
of the sequences in the batch, used to index into qkv.
|
|
|
|
|
max_seqlen: int. Maximum sequence length in the batch.
|
2022-07-01 11:26:04 +08:00
|
|
|
dropout_p: float. Dropout probability.
|
|
|
|
|
softmax_scale: float. The scaling of QK^T before applying softmax.
|
|
|
|
|
Default to 1 / sqrt(headdim).
|
|
|
|
|
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
|
|
|
|
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
|
|
|
|
testing only. The returned probabilities are not guaranteed to be correct
|
|
|
|
|
(they might not have the right scaling).
|
|
|
|
|
Return:
|
2023-07-17 20:26:11 +08:00
|
|
|
out: (total, nheads, headdim).
|
2022-07-01 11:26:04 +08:00
|
|
|
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
|
|
|
|
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
|
|
|
|
normalization factor).
|
|
|
|
|
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
|
|
|
|
|
The output of softmax (possibly with different scaling). It also encodes the dropout
|
|
|
|
|
pattern (negative means that location was dropped, nonnegative means it was kept).
|
|
|
|
|
"""
|
2023-07-17 20:26:11 +08:00
|
|
|
return FlashAttnVarlenQKVPackedFunc.apply(
|
|
|
|
|
qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_attn_probs
|
|
|
|
|
)
|
2022-07-01 11:26:04 +08:00
|
|
|
|
|
|
|
|
|
2023-08-19 05:22:11 +08:00
|
|
|
def flash_attn_varlen_kvpacked_func(
|
|
|
|
|
q,
|
|
|
|
|
kv,
|
|
|
|
|
cu_seqlens_q,
|
|
|
|
|
cu_seqlens_k,
|
|
|
|
|
max_seqlen_q,
|
|
|
|
|
max_seqlen_k,
|
|
|
|
|
dropout_p=0.0,
|
|
|
|
|
softmax_scale=None,
|
|
|
|
|
causal=False,
|
|
|
|
|
return_attn_probs=False,
|
|
|
|
|
):
|
2022-07-01 11:26:04 +08:00
|
|
|
"""dropout_p should be set to 0.0 during evaluation
|
2023-07-17 20:26:11 +08:00
|
|
|
If K, V are already stacked into 1 tensor, this function will be faster than
|
|
|
|
|
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
|
|
|
|
|
of the gradients of K, V.
|
|
|
|
|
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
|
2023-08-01 08:47:03 +08:00
|
|
|
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
|
2023-07-17 20:26:11 +08:00
|
|
|
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
|
|
|
|
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
|
|
|
|
|
2023-08-21 15:07:33 +08:00
|
|
|
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
|
|
|
|
|
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
|
|
|
|
|
1 1 1 1 0
|
|
|
|
|
1 1 1 1 1
|
|
|
|
|
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
|
|
|
|
|
0 0
|
|
|
|
|
0 0
|
|
|
|
|
0 0
|
|
|
|
|
1 0
|
|
|
|
|
1 1
|
|
|
|
|
If the row of the mask is all zero, the output will be zero.
|
|
|
|
|
|
2022-07-01 11:26:04 +08:00
|
|
|
Arguments:
|
|
|
|
|
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
|
2023-07-17 20:26:11 +08:00
|
|
|
kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
|
2022-07-01 11:26:04 +08:00
|
|
|
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
|
|
|
|
of the sequences in the batch, used to index into q.
|
|
|
|
|
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
|
|
|
|
of the sequences in the batch, used to index into kv.
|
|
|
|
|
max_seqlen_q: int. Maximum query sequence length in the batch.
|
|
|
|
|
max_seqlen_k: int. Maximum key sequence length in the batch.
|
|
|
|
|
dropout_p: float. Dropout probability.
|
|
|
|
|
softmax_scale: float. The scaling of QK^T before applying softmax.
|
|
|
|
|
Default to 1 / sqrt(headdim).
|
|
|
|
|
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
|
|
|
|
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
|
|
|
|
testing only. The returned probabilities are not guaranteed to be correct
|
|
|
|
|
(they might not have the right scaling).
|
|
|
|
|
Return:
|
2023-07-17 20:26:11 +08:00
|
|
|
out: (total, nheads, headdim).
|
2022-07-01 11:26:04 +08:00
|
|
|
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
|
|
|
|
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
|
|
|
|
normalization factor).
|
|
|
|
|
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
|
|
|
|
|
The output of softmax (possibly with different scaling). It also encodes the dropout
|
|
|
|
|
pattern (negative means that location was dropped, nonnegative means it was kept).
|
|
|
|
|
"""
|
2023-07-17 20:26:11 +08:00
|
|
|
return FlashAttnVarlenKVPackedFunc.apply(
|
2023-08-19 05:22:11 +08:00
|
|
|
q,
|
|
|
|
|
kv,
|
|
|
|
|
cu_seqlens_q,
|
|
|
|
|
cu_seqlens_k,
|
|
|
|
|
max_seqlen_q,
|
|
|
|
|
max_seqlen_k,
|
|
|
|
|
dropout_p,
|
|
|
|
|
softmax_scale,
|
|
|
|
|
causal,
|
|
|
|
|
return_attn_probs,
|
2023-07-17 20:26:11 +08:00
|
|
|
)
|
2022-05-21 05:21:58 +08:00
|
|
|
|
2022-10-14 11:47:54 +08:00
|
|
|
|
2023-08-19 05:22:11 +08:00
|
|
|
def flash_attn_varlen_func(
|
|
|
|
|
q,
|
|
|
|
|
k,
|
|
|
|
|
v,
|
|
|
|
|
cu_seqlens_q,
|
|
|
|
|
cu_seqlens_k,
|
|
|
|
|
max_seqlen_q,
|
|
|
|
|
max_seqlen_k,
|
|
|
|
|
dropout_p=0.0,
|
|
|
|
|
softmax_scale=None,
|
|
|
|
|
causal=False,
|
|
|
|
|
return_attn_probs=False,
|
|
|
|
|
):
|
2023-07-17 20:26:11 +08:00
|
|
|
"""dropout_p should be set to 0.0 during evaluation
|
|
|
|
|
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
|
2023-08-01 08:47:03 +08:00
|
|
|
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
|
2023-07-17 20:26:11 +08:00
|
|
|
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
|
|
|
|
|
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
|
2022-10-14 11:47:54 +08:00
|
|
|
|
2023-08-21 15:07:33 +08:00
|
|
|
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
|
|
|
|
|
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
|
|
|
|
|
1 1 1 1 0
|
|
|
|
|
1 1 1 1 1
|
|
|
|
|
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
|
|
|
|
|
0 0
|
|
|
|
|
0 0
|
|
|
|
|
0 0
|
|
|
|
|
1 0
|
|
|
|
|
1 1
|
|
|
|
|
If the row of the mask is all zero, the output will be zero.
|
|
|
|
|
|
2022-10-14 11:47:54 +08:00
|
|
|
Arguments:
|
2023-07-17 20:26:11 +08:00
|
|
|
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
|
|
|
|
|
k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
|
|
|
|
|
v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch.
|
|
|
|
|
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
|
|
|
|
of the sequences in the batch, used to index into q.
|
|
|
|
|
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
|
|
|
|
of the sequences in the batch, used to index into kv.
|
|
|
|
|
max_seqlen_q: int. Maximum query sequence length in the batch.
|
|
|
|
|
max_seqlen_k: int. Maximum key sequence length in the batch.
|
2022-10-14 11:47:54 +08:00
|
|
|
dropout_p: float. Dropout probability.
|
|
|
|
|
softmax_scale: float. The scaling of QK^T before applying softmax.
|
|
|
|
|
Default to 1 / sqrt(headdim).
|
|
|
|
|
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
|
|
|
|
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
|
|
|
|
testing only. The returned probabilities are not guaranteed to be correct
|
|
|
|
|
(they might not have the right scaling).
|
|
|
|
|
Return:
|
|
|
|
|
out: (total, nheads, headdim).
|
|
|
|
|
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
|
|
|
|
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
|
|
|
|
normalization factor).
|
|
|
|
|
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
|
|
|
|
|
The output of softmax (possibly with different scaling). It also encodes the dropout
|
|
|
|
|
pattern (negative means that location was dropped, nonnegative means it was kept).
|
|
|
|
|
"""
|
2023-07-17 20:26:11 +08:00
|
|
|
return FlashAttnVarlenFunc.apply(
|
2023-08-19 05:22:11 +08:00
|
|
|
q,
|
|
|
|
|
k,
|
|
|
|
|
v,
|
|
|
|
|
cu_seqlens_q,
|
|
|
|
|
cu_seqlens_k,
|
|
|
|
|
max_seqlen_q,
|
|
|
|
|
max_seqlen_k,
|
|
|
|
|
dropout_p,
|
|
|
|
|
softmax_scale,
|
|
|
|
|
causal,
|
|
|
|
|
return_attn_probs,
|
2023-07-17 20:26:11 +08:00
|
|
|
)
|