flash-attention/tests/test_alibi.py
Sanghun Cho e4f726fc44
Support alibi, by Sanghun Cho from Kakao Brain
* hard-code alibi in fwd

* use params.h as hun_heads

* hard-code alibi in bwd

* add alibi on/off option

* compute alibi_start, ratio outside of kernels

* fix minor merge conflict

* add test_alibi.py

* change apply_alibi() location before masking

* add alibi in splitkv kernel

* fix backward func # of returns

* add out-of-bound check in apply_alibi()

* update test_alibi.py

* update test_alibi.py for kvcache

* simplify alibi parameter interface

* fix performance issue
by computing alibi outside of branch

* update test_flash_attn_varlen_func() for left padding

* implement alibi_slopes (b, nh) loading

* optimize apply_alibi() a bit

* update test cases for alibi_slopes loading

* reflect stylistic comments

* disable "seqlenq_ngroups_swapped" when using alibi

---------

Co-authored-by: monk.detective <monk.detective@kakaobrain.com>
2023-12-19 22:56:06 -08:00

1007 lines
36 KiB
Python

import math
import pytest
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from flash_attn import (flash_attn_func, flash_attn_kvpacked_func,
flash_attn_qkvpacked_func, flash_attn_varlen_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func,
flash_attn_with_kvcache)
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import _get_block_size
from flash_attn.flash_attn_triton import \
flash_attn_func as flash_attn_func_triton
from flash_attn.layers.rotary import apply_rotary_emb
MAX_HEADDIM_SM8x = 192
is_sm75 = torch.cuda.get_device_capability("cuda") == (7, 5)
is_sm8x = torch.cuda.get_device_capability("cuda")[0] == 8
is_sm80 = torch.cuda.get_device_capability("cuda") == (8, 0)
is_sm90 = torch.cuda.get_device_capability("cuda") == (9, 0)
def generate_alibi(max_seq_len, num_attention_heads, tp_world_size, tp_index, key_padding_mask=None, device="cuda"):
def get_slopes(n):
def get_slopes_power_of_2(n):
start = (2 ** (-2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio ** i for i in range(n)]
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n))
return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2 * closest_power_of_2)[0::2][
:n - closest_power_of_2]
slopes = torch.tensor(get_slopes(num_attention_heads)).to(device=device)
# Select the part of the tensor that corresponds to our tensor parallel index.
assert (num_attention_heads/tp_world_size).is_integer(
), "it works only when (num_attention_heads/tp_world_size) is integer"
nh_tp = num_attention_heads // tp_world_size
slopes = slopes[nh_tp * tp_index:nh_tp * (tp_index + 1)]
if (key_padding_mask is None):
arange_tensor = rearrange(torch.arange(max_seq_len), "sqk -> 1 sqk").to(device=device)
else:
arange_tensor = (key_padding_mask.cumsum(dim=-1, dtype=slopes.dtype) - 1) \
.masked_fill_(~key_padding_mask, torch.finfo(torch.float).min).to(device=device)
arange_tensor = rearrange(arange_tensor, 'b sqk -> b 1 1 sqk')
# (1, nheads, 1, seqlen_k) or (batch, nheads, 1, seqlen_k)
alibi_tensor = rearrange(slopes, 'nh -> 1 nh 1 1') * arange_tensor
return alibi_tensor, slopes
def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", right_padding=True):
assert mode in ["full", "random", "third"]
if mode == "full":
lengths = torch.full((batch_size, 1), max_seqlen,
device=device, dtype=torch.int32)
elif mode == "random":
lengths = torch.randint(
max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device
)
elif mode == "third":
lengths = torch.randint(
max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)
if right_padding:
padding_mask = (
repeat(torch.arange(max_seqlen, device=device),
"s -> b s", b=batch_size) < lengths
)
else:
padding_mask = (
repeat(torch.arange(start=max_seqlen-1, end=-1, step=-1, device=device),
"s -> b s", b=batch_size) < lengths
)
return padding_mask
def generate_qkv(
q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, d)
k: (batch_size, seqlen_k, nheads_k, d)
v: (batch_size, seqlen_k, nheads_k, d)
query_padding_mask: (batch_size, seqlen), bool
key_padding_mask: (batch_size, seqlen), bool
"""
assert not (kvpacked and qkvpacked)
batch_size, seqlen_q, nheads, d = q.shape
_, seqlen_k, nheads_k, _ = k.shape
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
if query_padding_mask is not None:
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
q, query_padding_mask)
def output_pad_fn(output_unpad): return pad_input(
output_unpad, indices_q, batch_size, seqlen_q
)
else:
q_unpad = rearrange(q, "b s h d -> (b s) h d")
cu_seqlens_q = torch.arange(
0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device
)
max_seqlen_q = seqlen_q
def output_pad_fn(output_unpad): return rearrange(
output_unpad, "(b s) h d -> b s h d", b=batch_size
)
if key_padding_mask is not None:
k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(
k, key_padding_mask)
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
else:
k_unpad = rearrange(k, "b s h d -> (b s) h d")
v_unpad = rearrange(v, "b s h d -> (b s) h d")
cu_seqlens_k = torch.arange(
0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device
)
max_seqlen_k = seqlen_k
if qkvpacked:
assert (query_padding_mask == key_padding_mask).all()
assert nheads == nheads_k
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
qkv = torch.stack([q, k, v], dim=2)
if query_padding_mask is not None:
def dqkv_pad_fn(dqkv_unpad): return pad_input(
dqkv_unpad, indices_q, batch_size, seqlen_q)
else:
def dqkv_pad_fn(dqkv_unpad): return rearrange(
dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
)
return (
qkv_unpad.detach().requires_grad_(),
cu_seqlens_q,
max_seqlen_q,
qkv.detach().requires_grad_(),
output_pad_fn,
dqkv_pad_fn,
)
elif kvpacked:
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
kv = torch.stack([k, v], dim=2)
dq_pad_fn = output_pad_fn
if key_padding_mask is not None:
def dkv_pad_fn(dkv_unpad): return pad_input(
dkv_unpad, indices_k, batch_size, seqlen_k)
else:
def dkv_pad_fn(dkv_unpad): return rearrange(
dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
)
return (
q_unpad.detach().requires_grad_(),
kv_unpad.detach().requires_grad_(),
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q.detach().requires_grad_(),
kv.detach().requires_grad_(),
output_pad_fn,
dq_pad_fn,
dkv_pad_fn,
)
else:
dq_pad_fn = output_pad_fn
if key_padding_mask is not None:
def dk_pad_fn(dk_unpad): return pad_input(
dk_unpad, indices_k, batch_size, seqlen_k)
else:
def dk_pad_fn(dk_unpad): return rearrange(
dk_unpad, "(b s) h d -> b s h d", b=batch_size)
return (
q_unpad.detach().requires_grad_(),
k_unpad.detach().requires_grad_(),
v_unpad.detach().requires_grad_(),
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q.detach().requires_grad_(),
k.detach().requires_grad_(),
v.detach().requires_grad_(),
output_pad_fn,
dq_pad_fn,
dk_pad_fn,
)
def construct_local_mask(
seqlen_q,
seqlen_k,
window_size=(-1, -1), # -1 means infinite window size
query_padding_mask=None,
key_padding_mask=None,
device=None,
):
row_idx = rearrange(torch.arange(
seqlen_q, device=device, dtype=torch.long), "s -> s 1")
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
sk = (
seqlen_k
if key_padding_mask is None
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
)
sq = (
seqlen_q
if query_padding_mask is None
else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
)
if window_size[0] < 0:
return col_idx > row_idx + sk - sq + window_size[1]
else:
sk = torch.full_like(
col_idx, seqlen_k) if key_padding_mask is None else sk
return torch.logical_or(
col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),
col_idx < row_idx + sk - sq - window_size[0],
)
def attention_ref(
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
dropout_p=0.0,
dropout_mask=None,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
upcast=True,
reorder_ops=False,
bias=None
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, head_dim)
k: (batch_size, seqlen_k, nheads_k, head_dim)
v: (batch_size, seqlen_k, nheads_k, head_dim)
query_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_k)
dropout_p: float
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k)
causal: whether to apply causal masking
window_size: (int, int), left and right window size
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast
output back to fp16/bf16.
reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.)
without changing the math. This is to estimate the numerical error from operation
reordering.
Output:
output: (batch_size, seqlen_q, nheads, head_dim)
attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout
"""
if causal:
window_size = (window_size[0], 0)
dtype_og = q.dtype
if upcast:
q, k, v = q.float(), k.float(), v.float()
seqlen_q, seqlen_k = q.shape[1], k.shape[1]
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
d = q.shape[-1]
if not reorder_ops:
scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
else:
scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
if bias is not None:
bias = bias.to(scores.dtype)
scores += bias
if key_padding_mask is not None:
scores.masked_fill_(rearrange(~key_padding_mask,
"b s -> b 1 1 s"), float("-inf"))
if window_size[0] >= 0 or window_size[1] >= 0:
local_mask = construct_local_mask(
seqlen_q,
seqlen_k,
window_size,
query_padding_mask,
key_padding_mask,
q.device,
)
scores.masked_fill_(local_mask, float("-inf"))
attention = torch.softmax(scores, dim=-1)
# Some rows might be completely masked out so we fill them with zero instead of NaN
if window_size[0] >= 0 or window_size[1] >= 0:
attention = attention.masked_fill(
torch.all(local_mask, dim=-1, keepdim=True), 0.0)
# We want to mask here so that the attention matrix doesn't have any NaNs
# Otherwise we'll get NaN in dV
if query_padding_mask is not None:
attention = attention.masked_fill(
rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
dropout_scaling = 1.0 / (1 - dropout_p)
# attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling
# output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
if dropout_mask is not None:
attention_drop = attention.masked_fill(~dropout_mask, 0.0)
else:
attention_drop = attention
output = torch.einsum(
"bhts,bshd->bthd", attention_drop, v * dropout_scaling)
if query_padding_mask is not None:
output.masked_fill_(
rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
def attention_kvpacked_ref(
q,
kv,
query_padding_mask=None,
key_padding_mask=None,
dropout_p=0.0,
dropout_mask=None,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
upcast=True,
reorder_ops=False,
):
return attention_ref(
q,
kv[:, :, 0],
kv[:, :, 1],
query_padding_mask,
key_padding_mask,
dropout_p,
dropout_mask,
upcast=upcast,
causal=causal,
window_size=window_size,
reorder_ops=reorder_ops,
)
def attention_qkvpacked_ref(
qkv,
key_padding_mask=None,
dropout_p=0.0,
dropout_mask=None,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
upcast=True,
reorder_ops=False,
):
return attention_ref(
qkv[:, :, 0],
qkv[:, :, 1],
qkv[:, :, 2],
key_padding_mask,
key_padding_mask,
dropout_p,
dropout_mask,
upcast=upcast,
causal=causal,
window_size=window_size,
reorder_ops=reorder_ops,
)
def generate_sparsity_mask(seqlen, sparsity=0.3):
repeats = seqlen // 16 // 2
# mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda'),
# torch.tensor([0, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
# mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda'),
# torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
# mask = torch.stack([torch.tensor([1, 1] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
# mask = torch.stack([torch.tensor([1, 0] * repeats, dtype=torch.bool, device='cuda')], dim=-1)
nrow, ncol = seqlen // 16, seqlen // 256
mask = torch.rand(nrow, ncol, device="cuda") < sparsity
return mask
def attention_blocksparse_ref(qkv, blockmask, attn_mask, dropout_p, dropout_mask):
"""
Arguments:
qkv: (batch_size, seqlen, 3, nheads, head_dim)
blockmask: (seqlen / 16, seqlen / 256)
attn_mask: (batch_size, seqlen)
dropout_p: float
dropout_mask: (batch_size, nheads, seqlen, seqlen)
Output:
output: (batch_size, seqlen, nheads, head_dim)
attention: softmax after dropout
"""
q, k, v = qkv.float().unbind(dim=2)
d = qkv.shape[-1]
seqlen = qkv.shape[1]
scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k)
scores.masked_fill_(rearrange(~attn_mask, "b s -> b 1 1 s"), float("-inf"))
blockmask = repeat(blockmask, "s_16 s_256 -> (s_16 16) (s_256 256)")
blockmask = blockmask[:seqlen, :seqlen]
scores.masked_fill_(rearrange(~blockmask, "t s -> 1 1 t s"), float("-inf"))
attention = torch.softmax(scores, dim=-1)
attention = attention.masked_fill(
rearrange(~attn_mask, "b s -> b 1 s 1"), 0.0)
attention = attention.masked_fill_(
rearrange(~blockmask, "t s -> 1 1 t s"), 0.0)
attention_drop = attention.masked_fill(
~dropout_mask, 0.0) / (1 - dropout_p)
output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
output.masked_fill_(rearrange(~attn_mask, "b s -> b s 1 1"), 0)
return output.to(dtype=qkv.dtype), attention.to(dtype=qkv.dtype)
def convert_flash_attn_S_to_softmax(
S,
seqlen_q,
seqlen_k,
query_padding_mask,
key_padding_mask,
head_dim,
is_dropout,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
):
"""FlashAttention stores the S matrix in a different way.
Arguments:
S: (batch_size, nheads, seqlen_q_rounded, seqlen_k_rounded)
query_padding_mask: (batch_size, seqlen_q_rounded)
key_padding_mask: (batch_size, seqlen_k_rounded)
"""
if causal:
window_size = (window_size[0], 0)
seqlen_q_rounded, seqlen_k_rounded = S.shape[-2:]
warps_n = 4
blocksize_m, blocksize_n = _get_block_size(
S.device, head_dim, is_dropout, causal)
nblocks_n = (seqlen_k_rounded + blocksize_n - 1) // blocksize_n
nblocks_m = (seqlen_q_rounded + blocksize_m - 1) // blocksize_m
mmas_n = (blocksize_n + 16 - 1) // 16
S_flat = rearrange(
S,
"b h (nblocks_m blocksize_m) (nblocks_n blocksize_n) -> b h nblocks_m nblocks_n (blocksize_m blocksize_n)",
blocksize_m=blocksize_m,
blocksize_n=blocksize_n,
)
S_converted = rearrange(
S_flat,
"b h nblocks_m nblocks_n (mmas_n mmas_m warps_n eight four c2 c1 c0) -> b h (nblocks_m mmas_m warps_n c1 eight) (nblocks_n mmas_n c2 four c0)",
mmas_n=mmas_n,
warps_n=warps_n,
eight=8,
c0=2,
c1=2,
c2=2,
four=4,
)
if window_size[0] >= 0 or window_size[1] >= 0:
local_mask = construct_local_mask(
seqlen_q,
seqlen_k,
window_size,
query_padding_mask,
key_padding_mask,
S.device,
)
local_mask = F.pad(
local_mask,
(0, seqlen_k_rounded - seqlen_k, 0, seqlen_q_rounded - seqlen_q),
value=True,
)
S_converted.masked_fill_(local_mask, 0.0)
# Need to zero out things not in attention_mask in case S was initialized with random values
# and some of those values aren't overwritten.
seqlen_q_og = (
query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q_rounded
)
if query_padding_mask is not None:
query_padding_mask = F.pad(
query_padding_mask, (0, seqlen_q_rounded - seqlen_q_og))
S_converted = S_converted.masked_fill(
rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
seqlen_k_og = key_padding_mask.shape[-1] if key_padding_mask is not None else seqlen_k
if key_padding_mask is not None:
key_padding_mask = F.pad(
key_padding_mask, (0, seqlen_k_rounded - seqlen_k_og))
S_converted = S_converted.masked_fill(
rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0)
S_converted = F.pad(S_converted, (0, 0, 0, seqlen_q_og - seqlen_q_rounded))
S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k_rounded))
return S_converted[:, :, :seqlen_q, :seqlen_k]
def normalize_flash_attn_S(
attn_unnorm,
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
is_dropout=False,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
):
"""
Arguments:
q: (batch_size, seqlen_q, nheads, head_dim)
k, v: (batch_size, seqlen_k, nheads, head_dim)
key_padding_mask: (batch_size, seqlen_q)
Output:
softmax_lse: (batch_size, nheads, seqlen_q)
softmax_max: (batch_size, nheads, seqlen_q)
"""
if causal:
window_size = (window_size[0], 0)
q, k, v = q.float(), k.float(), v.float()
_, seqlen_q, _, head_dim = q.shape
seqlen_k = k.shape[1]
scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(head_dim), k)
if key_padding_mask is not None:
scores.masked_fill_(rearrange(~key_padding_mask,
"b s -> b 1 1 s"), float("-inf"))
if window_size[0] >= 0 or window_size[1] >= 0:
local_mask = construct_local_mask(
seqlen_q,
seqlen_k,
window_size,
query_padding_mask,
key_padding_mask,
q.device,
)
scores.masked_fill_(local_mask, float("-inf"))
_, block_size_n = _get_block_size(
scores.device, head_dim, is_dropout, causal)
scores_block = scores.split(block_size_n, dim=-1)
lse_block = torch.stack([torch.logsumexp(s, dim=-1)
for s in scores_block], dim=-1)
lse = torch.logsumexp(lse_block, dim=-1)
# lse could be -inf (i.e. all values in scores are -inf), and we want to set those to inf
# so that when we do torch.exp(m - lse), we get 0.0 instead of NaN.
lse[lse == float("-inf")] = float("inf")
scores_max_block = torch.stack(
[torch.amax(s, dim=-1) for s in scores_block], dim=-1)
cummax_block = torch.cummax(
scores_max_block.flip(-1), dim=-1).values.flip(-1).unbind(dim=-1)
attn_unnorm_block = attn_unnorm.split(block_size_n, dim=-1)
attn_norm = torch.cat(
[
a * rearrange(torch.exp(m - lse), "b h s -> b h s 1")
for a, m in zip(attn_unnorm_block, cummax_block)
],
dim=-1,
)
if query_padding_mask is not None:
attn_norm.masked_fill_(
rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
return attn_norm.to(dtype=attn_unnorm.dtype)
def get_dropout_fraction(
dropout_mask,
query_padding_mask=None,
key_padding_mask=None,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
):
"""
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k), bool. True means keep, False means drop.
query_padding_mask: (batch_size, seqlen_q)
key_padding_mask: (batch_size, seqlen_k)
"""
if causal:
window_size = (window_size[0], 0)
batch_size, nheads, seqlen_q, seqlen_k = dropout_mask.shape
dropped = ~dropout_mask
valid = torch.ones_like(dropout_mask)
if query_padding_mask is not None:
dropped.masked_fill_(
rearrange(~query_padding_mask, "b s -> b 1 s 1"), False)
valid.masked_fill_(
rearrange(~query_padding_mask, "b s -> b 1 s 1"), False)
if key_padding_mask is not None:
dropped.masked_fill_(
rearrange(~key_padding_mask, "b s -> b 1 1 s"), False)
valid.masked_fill_(
rearrange(~key_padding_mask, "b s -> b 1 1 s"), False)
if window_size[0] >= 0 or window_size[1] >= 0:
local_mask = construct_local_mask(
seqlen_q,
seqlen_k,
window_size,
query_padding_mask,
key_padding_mask,
dropout_mask.device,
)
dropped.masked_fill_(local_mask, False)
valid.masked_fill_(local_mask, False)
dropped_total = dropped.sum()
return dropped.sum() / valid.sum()
@pytest.mark.parametrize(
"dtype", [torch.float16]
)
@pytest.mark.parametrize(
"b_sq",
[
(32, 512),
(16, 1024),
(8, 2048),
(4, 4096),
(2, 8192),
(1, 16384)
]
)
@pytest.mark.parametrize(
"nh_hd",
[
(32, 64),
(16, 128),
(40, 128) # non power of 2 nh
]
)
@pytest.mark.parametrize(
"tp_world_size", [1, 2, 4]
)
def test_flash_attn_func(b_sq, nh_hd, tp_world_size, dtype):
b, sq = b_sq
nh, hd = nh_hd
nh_tp = nh // tp_world_size
q, k, v = [torch.randn(b, sq, nh_tp, hd, device="cuda",
dtype=dtype, requires_grad=True) for _ in range(3)]
dout = torch.rand_like(q)
for tp_index in range(tp_world_size):
alibi, alibi_slopes = generate_alibi(
max_seq_len=sq,
num_attention_heads=nh,
tp_world_size=tp_world_size,
tp_index=tp_index,
key_padding_mask=None,
device="cuda"
)
triton_out = flash_attn_func_triton(
q, k, v, alibi, True, hd**(-0.5))
triton_out.backward(dout)
triton_dq, q.grad = q.grad.clone(), None
triton_dk, k.grad = k.grad.clone(), None
triton_dv, v.grad = v.grad.clone(), None
flash_out = flash_attn_func(q, k, v, causal=True, alibi_slopes=repeat(alibi_slopes, "nh -> b nh", b=b))
flash_out.backward(dout)
flash_dq, q.grad = q.grad.clone(), None
flash_dk, k.grad = k.grad.clone(), None
flash_dv, v.grad = v.grad.clone(), None
assert torch.allclose(flash_out, triton_out, atol=1e-2, rtol=0.)
assert torch.allclose(flash_dq, triton_dq, atol=1e-2, rtol=0.)
assert torch.allclose(flash_dk, triton_dk, atol=1e-2, rtol=0.)
assert torch.allclose(flash_dv, triton_dv, atol=1e-2, rtol=0.)
@pytest.mark.parametrize(
"dtype", [torch.float16]
)
@pytest.mark.parametrize(
"right_padding", [True, False]
)
@pytest.mark.parametrize(
"b_sq",
[
(32, 512),
(16, 1024),
(8, 2048),
(4, 4096),
(2, 8192),
(1, 16384)
]
)
@pytest.mark.parametrize(
"nh_hd",
[
(32, 64),
(16, 128),
(40, 128) # non power of 2 nh
]
)
@pytest.mark.parametrize(
"tp_world_size", [1, 2, 4]
)
def test_flash_attn_varlen_func(b_sq, nh_hd, tp_world_size, right_padding, dtype):
b, sqk = b_sq
nh, hd = nh_hd
nh_tp = nh // tp_world_size
# flash_attn_func_triton(), flash-attention v2 (above v2.1) causal logic are different
# so only (seqlen_q == 1, causal=False to triton ver.) shows correct results
# https://github.com/huggingface/text-generation-inference/blob/v1.1.1/server/text_generation_server/models/custom_modeling/mpt_modeling.py#L53-L63
q = torch.randn(b, 1, nh_tp, hd, device="cuda", dtype=dtype, requires_grad=True)
k, v = [torch.randn(b, sqk, nh_tp, hd, device="cuda",
dtype=dtype, requires_grad=True) for _ in range(2)]
dout = torch.rand_like(q)
padding_mask = generate_random_padding_mask(sqk, b, "cuda", "random", right_padding)
(
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
k,
v,
output_pad_fn,
dq_pad_fn,
dk_pad_fn,
) = generate_qkv(q, k, v, None, padding_mask, kvpacked=False)
for tp_index in range(tp_world_size):
alibi, alibi_slopes = generate_alibi(
max_seq_len=sqk,
num_attention_heads=nh,
tp_world_size=tp_world_size,
tp_index=tp_index,
key_padding_mask=padding_mask,
device="cuda"
)
triton_out = flash_attn_func_triton(
q, k, v, alibi, False, hd**(-0.5))
triton_out.backward(dout)
triton_dq, q.grad = q.grad.clone(), None
triton_dk, k.grad = k.grad.clone(), None
triton_dv, v.grad = v.grad.clone(), None
flash_out_unpad = flash_attn_varlen_func(
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
causal=True,
alibi_slopes=repeat(alibi_slopes, "nh -> b nh", b=b)
)
flash_out = output_pad_fn(flash_out_unpad)
flash_out.backward(dout)
flash_dq_unpad, q_unpad.grad = q_unpad.grad.clone(), None
flash_dk_unpad, k_unpad.grad = k_unpad.grad.clone(), None
flash_dv_unpad, v_unpad.grad = v_unpad.grad.clone(), None
flash_dq = dq_pad_fn(flash_dq_unpad)
flash_dk = dk_pad_fn(flash_dk_unpad)
flash_dv = dk_pad_fn(flash_dv_unpad)
assert torch.allclose(flash_out, triton_out, atol=1e-2, rtol=0.)
assert torch.allclose(flash_dq, triton_dq, atol=1e-2, rtol=0.)
assert torch.allclose(flash_dk, triton_dk, atol=1e-2, rtol=0.)
assert torch.allclose(flash_dv, triton_dv, atol=1e-2, rtol=0.)
@pytest.mark.parametrize("alibi", [True])
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("num_splits", [1, 0])
# @pytest.mark.parametrize("num_splits", [0])
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
# @pytest.mark.parametrize("mha_type", ["mha"])
@pytest.mark.parametrize("new_kv", [False, True])
# @pytest.mark.parametrize("new_kv", [True])
# @pytest.mark.parametrize("local", [False, True])
@pytest.mark.parametrize("local", [False])
# @pytest.mark.parametrize("causal", [False, True])
@pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False])
# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True])
@pytest.mark.parametrize("rotary_interleaved", [False, True])
# @pytest.mark.parametrize("rotary_interleaved", [False])
@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
# @pytest.mark.parametrize("rotary_fraction", [0.0])
@pytest.mark.parametrize("has_batch_idx", [False, True])
# @pytest.mark.parametrize("has_batch_idx", [True])
@pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [128])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(1, 128),
(1, 339),
(3, 1024),
(64, 800),
(64, 256),
(3, 799),
(64, 2048),
(16, 20000),
(1, 128 * 1024),
(16, 128 * 1024),
(128, 128),
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def test_flash_attn_kvcache(
seqlen_q,
seqlen_k,
d,
has_batch_idx,
rotary_fraction,
rotary_interleaved,
seqlen_new_eq_seqlen_q,
causal,
local,
new_kv,
mha_type,
num_splits,
dtype,
alibi,
):
if seqlen_q > seqlen_k and new_kv:
pytest.skip()
if not new_kv and rotary_fraction > 0.0:
pytest.skip()
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 2
batch_size_cache = batch_size if not has_batch_idx else batch_size * 2
nheads = 8
# rotary_dim must be a multiple of 16, and must be <= d
rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 4)
assert nheads % nheads_k == 0
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.randn(batch_size, seqlen_q, nheads,
d, device=device, dtype=dtype)
seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(
1, seqlen_q + 1, (1,)).item()
if new_kv:
k = torch.randn(batch_size, seqlen_new, nheads_k,
d, device=device, dtype=dtype)
v = torch.randn(batch_size, seqlen_new, nheads_k,
d, device=device, dtype=dtype)
else:
k, v = None, None
k_cache = torch.randn(batch_size_cache, seqlen_k,
nheads_k, d, device=device, dtype=dtype)
v_cache = torch.randn(batch_size_cache, seqlen_k,
nheads_k, d, device=device, dtype=dtype)
cache_seqlens = torch.randint(
0,
# If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
(seqlen_k - (seqlen_q if (causal or local)
and rotary_dim > 1 else seqlen_new) + 1)
if new_kv
else (seqlen_k + 1),
(batch_size,),
dtype=torch.int32,
device=device,
)
if has_batch_idx:
cache_batch_idx = torch.randperm(
batch_size_cache, dtype=torch.int32, device=device)[:batch_size]
else:
cache_batch_idx = None
# cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
if rotary_dim > 0:
angle = torch.rand(seqlen_k, rotary_dim // 2,
device=device) * 2 * math.pi
cos = torch.cos(angle).to(dtype=dtype)
sin = torch.sin(angle).to(dtype=dtype)
if causal or local:
q_ro = apply_rotary_emb(
q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved
)
else:
q_ro = rearrange(
apply_rotary_emb(
rearrange(q, "b s h d -> b 1 (s h) d"),
cos,
sin,
seqlen_offsets=cache_seqlens,
interleaved=rotary_interleaved,
),
"b 1 (s h) d -> b s h d",
s=seqlen_q,
)
# q_ro = q
k_ro = apply_rotary_emb(
k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved
)
else:
cos, sin = None, None
q_ro, k_ro = q, k
# k_cache[:, 64:] = -1
k_cache_ref = (
k_cache if not has_batch_idx else k_cache[cache_batch_idx]).clone()
v_cache_ref = (
v_cache if not has_batch_idx else v_cache[cache_batch_idx]).clone()
arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s")
cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
if new_kv:
update_mask = torch.logical_and(
cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new
)
k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...")
v_cache_ref[update_mask] = rearrange(v, "b s ... -> (b s) ...")
k_cache_rep = repeat(
k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
v_cache_rep = repeat(
v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k)
if alibi:
seqlen_alibi = k_cache_rep.shape[1]
alibi_tensor, alibi_slopes = generate_alibi(
max_seq_len=seqlen_alibi,
num_attention_heads=nheads,
tp_world_size=1,
tp_index=0,
key_padding_mask=None,
device="cuda"
)
# alibi_tensor = alibi_tensor.expand(batch_size, -1, seqlen_q, -1)
alibi_slopes = repeat(alibi_slopes, "nh -> b nh", b=batch_size)
if alibi_tensor.abs().max().item() >= torch.finfo(dtype).max:
pytest.skip()
else:
alibi_tensor, alibi_slopes = None, None
out = flash_attn_with_kvcache(
q,
k_cache,
v_cache,
k,
v,
cos,
sin,
cache_seqlens,
cache_batch_idx,
causal=causal,
window_size=window_size,
rotary_interleaved=rotary_interleaved,
num_splits=num_splits,
alibi_slopes=alibi_slopes
)
# out = flash_attn_with_kvcache(
# q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size
# )
# out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size)
# qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref)
# m = qk.amax(-1, keepdim=True)
# s_tmp = torch.exp((qk - m) / math.sqrt(d))
# o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref)
# lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)
# probs = torch.softmax(qk, dim=-1)
key_padding_mask = arange < cache_seqlens_expanded + \
(seqlen_new if new_kv else 0)
out_ref, _ = attention_ref(
q_ro,
k_cache_rep,
v_cache_rep,
None,
key_padding_mask,
0.0,
None,
causal=causal,
window_size=window_size,
bias=alibi_tensor
)
out_pt, _ = attention_ref(
q_ro,
k_cache_rep,
v_cache_rep,
None,
key_padding_mask,
0.0,
None,
causal=causal,
window_size=window_size,
upcast=False,
reorder_ops=True,
bias=alibi_tensor
)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
if new_kv:
k_cache_select = k_cache if not has_batch_idx else k_cache[cache_batch_idx]
v_cache_select = v_cache if not has_batch_idx else v_cache[cache_batch_idx]
assert torch.allclose(k_cache_select, k_cache_ref,
rtol=1e-3, atol=1e-3)
assert torch.equal(v_cache_select, v_cache_ref)
assert (out - out_ref).abs().max().item() <= 3 * \
(out_pt - out_ref).abs().max().item() + 1e-5