Enable paged attention in varlen forward (#831)

* Enable paged attention in varlen forward

* Format + fix padding
This commit is contained in:
Grigory Sizov 2024-03-15 08:48:19 +01:00 committed by GitHub
parent 26c9e82743
commit 2a15840f09
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 106 additions and 37 deletions

View File

@ -494,12 +494,13 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
std::vector<at::Tensor> std::vector<at::Tensor>
mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1 const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used. c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
int max_seqlen_q, int max_seqlen_q,
const int max_seqlen_k, const int max_seqlen_k,
@ -535,6 +536,15 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_q);
CHECK_DEVICE(cu_seqlens_k); CHECK_DEVICE(cu_seqlens_k);
at::Tensor block_table;
const bool paged_KV = block_table_.has_value();
if (paged_KV) {
block_table = block_table_.value();
CHECK_DEVICE(block_table);
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
}
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
@ -546,8 +556,12 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
const int batch_size = cu_seqlens_q.numel() - 1; const int batch_size = cu_seqlens_q.numel() - 1;
int num_heads = sizes[1]; int num_heads = sizes[1];
const int head_size_og = sizes[2]; const int head_size_og = sizes[2];
const int total_k = k.size(0); const int num_heads_k = paged_KV ? k.size(2) : k.size(1);
const int num_heads_k = k.size(1);
const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
const int num_blocks = !paged_KV ? 0 : k.size(0);
const int page_block_size = !paged_KV ? 1 : k.size(1);
TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256");
if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case
if (is_causal) { window_size_right = 0; } if (is_causal) { window_size_right = 0; }
@ -575,8 +589,16 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
if (window_size_right >= max_seqlen_k) { window_size_right = -1; } if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
CHECK_SHAPE(q, total_q, num_heads, head_size_og); CHECK_SHAPE(q, total_q, num_heads, head_size_og);
CHECK_SHAPE(k, total_k, num_heads_k, head_size_og); if (!paged_KV) {
CHECK_SHAPE(v, total_k, num_heads_k, head_size_og); const int total_k = k.size(0);
CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
} else {
CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size_og);
CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size_og);
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
}
CHECK_SHAPE(cu_seqlens_q, batch_size + 1); CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
CHECK_SHAPE(cu_seqlens_k, batch_size + 1); CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
if (seqused_k.has_value()){ if (seqused_k.has_value()){
@ -654,6 +676,14 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
window_size_left, window_size_left,
window_size_right, window_size_right,
seqlenq_ngroups_swapped); seqlenq_ngroups_swapped);
if (paged_KV) {
params.block_table = block_table.data_ptr<int>();
params.block_table_batch_stride = block_table.stride(0);
params.k_batch_stride = k_padded.stride(0);
params.v_batch_stride = v_padded.stride(0);
}
params.page_block_size = page_block_size;
if (seqlenq_ngroups_swapped) { if (seqlenq_ngroups_swapped) {
// Only apply split-k for decoding // Only apply split-k for decoding
set_params_splitkv(params, batch_size, num_heads, set_params_splitkv(params, batch_size, num_heads,
@ -682,7 +712,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
if (max_seqlen_k > 0) { if (max_seqlen_k > 0) {
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
run_mha_fwd(params, stream); run_mha_fwd(params, stream, paged_KV);
} else { } else {
// If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
out.zero_(); out.zero_();

View File

@ -79,6 +79,7 @@ def _flash_attn_varlen_forward(
window_size, window_size,
alibi_slopes, alibi_slopes,
return_softmax, return_softmax,
block_table,
): ):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)] q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
@ -90,6 +91,7 @@ def _flash_attn_varlen_forward(
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_k, cu_seqlens_k,
None, None,
block_table,
alibi_slopes, alibi_slopes,
max_seqlen_q, max_seqlen_q,
max_seqlen_k, max_seqlen_k,
@ -299,6 +301,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
window_size=window_size, window_size=window_size,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0, return_softmax=return_softmax and dropout_p > 0,
block_table=None,
) )
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state) ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
ctx.dropout_p = dropout_p ctx.dropout_p = dropout_p
@ -440,6 +443,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
window_size=window_size, window_size=window_size,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0, return_softmax=return_softmax and dropout_p > 0,
block_table=None,
) )
ctx.save_for_backward( ctx.save_for_backward(
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
@ -570,6 +574,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
alibi_slopes, alibi_slopes,
deterministic, deterministic,
return_softmax, return_softmax,
block_table,
): ):
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
@ -587,6 +592,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
window_size=window_size, window_size=window_size,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0, return_softmax=return_softmax and dropout_p > 0,
block_table=block_table,
) )
ctx.save_for_backward( ctx.save_for_backward(
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
@ -630,7 +636,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]] dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None
def flash_attn_qkvpacked_func( def flash_attn_qkvpacked_func(
@ -1001,6 +1007,7 @@ def flash_attn_varlen_func(
alibi_slopes=None, alibi_slopes=None,
deterministic=False, deterministic=False,
return_attn_probs=False, return_attn_probs=False,
block_table=None,
): ):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
@ -1071,6 +1078,7 @@ def flash_attn_varlen_func(
alibi_slopes, alibi_slopes,
deterministic, deterministic,
return_attn_probs, return_attn_probs,
block_table,
) )

View File

@ -1542,8 +1542,12 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
(1023, 1024), (1023, 1024),
], ],
) )
# TODO: add smaller page sizes when https://github.com/Dao-AILab/flash-attention/pull/824 is merged
@pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512])
# @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)]) # @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)])
def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): def test_flash_attn_varlen_causal(
seqlen_q, seqlen_k, swap_sq_sk, d, local, paged_kv_block_size, dtype
):
if ( if (
max(seqlen_q, seqlen_k) >= 2048 max(seqlen_q, seqlen_k) >= 2048
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
@ -1559,8 +1563,19 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
nheads = 9 nheads = 9
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) if paged_kv_block_size is None:
k = torch.randn(
batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True
)
v = torch.randn(
batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True
)
block_table = None
else:
k, v, block_table, k_cache_paged, v_cache_paged, num_blocks = _generate_block_kvcache(
seqlen_k, paged_kv_block_size, batch_size, nheads, d, device, dtype
)
query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random") key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
( (
@ -1580,8 +1595,8 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
out_unpad = flash_attn_varlen_func( out_unpad = flash_attn_varlen_func(
q_unpad, q_unpad,
k_unpad, k_unpad if paged_kv_block_size is None else k_cache_paged,
v_unpad, v_unpad if paged_kv_block_size is None else v_cache_paged,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_k, cu_seqlens_k,
max_seqlen_q, max_seqlen_q,
@ -1589,6 +1604,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
0.0, 0.0,
causal=causal, causal=causal,
window_size=window_size, window_size=window_size,
block_table=block_table,
) )
out = output_pad_fn(out_unpad) out = output_pad_fn(out_unpad)
out_ref, attn_ref = attention_ref( out_ref, attn_ref = attention_ref(
@ -1625,7 +1641,8 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
g = torch.randn_like(out) g = torch.randn_like(out)
do_o = (g.float() * out.float()).sum(-1) do_o = (g.float() * out.float()).sum(-1)
if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90): test_backward = (d <= MAX_HEADDIM_SM8x or d > 224 or is_sm80 or is_sm90) and block_table is None
if test_backward:
( (
dq_unpad, dq_unpad,
dk_unpad, dk_unpad,
@ -1661,7 +1678,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
# of a Pytorch implementation. # of a Pytorch implementation.
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5
if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90): if test_backward:
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5 assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5 assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5
@ -1888,29 +1905,16 @@ def test_flash_attn_kvcache(
v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)
block_table = None block_table = None
else: else:
num_blocks = math.ceil(seqlen_k / paged_kv_block_size) * batch_size * 3 (
k_cache_paged = torch.randn( k_cache,
num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype v_cache,
block_table,
k_cache_paged,
v_cache_paged,
num_blocks,
) = _generate_block_kvcache(
seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype
) )
v_cache_paged = torch.randn(
num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype
)
block_table = rearrange(
torch.randperm(num_blocks, dtype=torch.int32, device=device),
"(b nblocks) -> b nblocks",
b=batch_size,
)
k_cache = rearrange(
# pytorch 1.12 doesn't have indexing with int32
k_cache_paged[block_table.to(dtype=torch.long).flatten()],
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
b=batch_size,
)[:, :seqlen_k]
v_cache = rearrange(
v_cache_paged[block_table.to(dtype=torch.long).flatten()],
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
b=batch_size,
)[:, :seqlen_k]
cache_seqlens = torch.randint( cache_seqlens = torch.randint(
0 if new_kv else 1, 0 if new_kv else 1,
# If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
@ -2073,6 +2077,33 @@ def test_flash_attn_kvcache(
assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5
def _generate_block_kvcache(seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype):
num_blocks = math.ceil(seqlen_k / paged_kv_block_size) * batch_size * 3
k_cache_paged = torch.randn(
num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype
)
v_cache_paged = torch.randn(
num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype
)
block_table = rearrange(
torch.randperm(num_blocks, dtype=torch.int32, device=device),
"(b nblocks) -> b nblocks",
b=batch_size,
)
k_cache = rearrange(
# pytorch 1.12 doesn't have indexing with int32
k_cache_paged[block_table.to(dtype=torch.long).flatten()],
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
b=batch_size,
)[:, :seqlen_k]
v_cache = rearrange(
v_cache_paged[block_table.to(dtype=torch.long).flatten()],
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
b=batch_size,
)[:, :seqlen_k]
return k_cache, v_cache, block_table, k_cache_paged, v_cache_paged, num_blocks
# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
@pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("causal", [False, True])