Enable paged attention in varlen forward (#831)
* Enable paged attention in varlen forward * Format + fix padding
This commit is contained in:
parent
26c9e82743
commit
2a15840f09
@ -494,12 +494,13 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
||||
|
||||
std::vector<at::Tensor>
|
||||
mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
|
||||
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
|
||||
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &cu_seqlens_q, // b+1
|
||||
const at::Tensor &cu_seqlens_k, // b+1
|
||||
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
|
||||
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
|
||||
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
|
||||
int max_seqlen_q,
|
||||
const int max_seqlen_k,
|
||||
@ -535,6 +536,15 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
|
||||
CHECK_DEVICE(cu_seqlens_q);
|
||||
CHECK_DEVICE(cu_seqlens_k);
|
||||
|
||||
at::Tensor block_table;
|
||||
const bool paged_KV = block_table_.has_value();
|
||||
if (paged_KV) {
|
||||
block_table = block_table_.value();
|
||||
CHECK_DEVICE(block_table);
|
||||
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
|
||||
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
|
||||
}
|
||||
|
||||
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
@ -546,8 +556,12 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
|
||||
const int batch_size = cu_seqlens_q.numel() - 1;
|
||||
int num_heads = sizes[1];
|
||||
const int head_size_og = sizes[2];
|
||||
const int total_k = k.size(0);
|
||||
const int num_heads_k = k.size(1);
|
||||
const int num_heads_k = paged_KV ? k.size(2) : k.size(1);
|
||||
|
||||
const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
|
||||
const int num_blocks = !paged_KV ? 0 : k.size(0);
|
||||
const int page_block_size = !paged_KV ? 1 : k.size(1);
|
||||
TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256");
|
||||
|
||||
if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case
|
||||
if (is_causal) { window_size_right = 0; }
|
||||
@ -575,8 +589,16 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
|
||||
if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
|
||||
|
||||
CHECK_SHAPE(q, total_q, num_heads, head_size_og);
|
||||
CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
|
||||
CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
|
||||
if (!paged_KV) {
|
||||
const int total_k = k.size(0);
|
||||
CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
|
||||
CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
|
||||
} else {
|
||||
CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size_og);
|
||||
CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size_og);
|
||||
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
|
||||
}
|
||||
|
||||
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
|
||||
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
|
||||
if (seqused_k.has_value()){
|
||||
@ -654,6 +676,14 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
seqlenq_ngroups_swapped);
|
||||
|
||||
if (paged_KV) {
|
||||
params.block_table = block_table.data_ptr<int>();
|
||||
params.block_table_batch_stride = block_table.stride(0);
|
||||
params.k_batch_stride = k_padded.stride(0);
|
||||
params.v_batch_stride = v_padded.stride(0);
|
||||
}
|
||||
params.page_block_size = page_block_size;
|
||||
if (seqlenq_ngroups_swapped) {
|
||||
// Only apply split-k for decoding
|
||||
set_params_splitkv(params, batch_size, num_heads,
|
||||
@ -682,7 +712,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
|
||||
|
||||
if (max_seqlen_k > 0) {
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
run_mha_fwd(params, stream);
|
||||
run_mha_fwd(params, stream, paged_KV);
|
||||
} else {
|
||||
// If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
|
||||
out.zero_();
|
||||
|
||||
@ -79,6 +79,7 @@ def _flash_attn_varlen_forward(
|
||||
window_size,
|
||||
alibi_slopes,
|
||||
return_softmax,
|
||||
block_table,
|
||||
):
|
||||
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
|
||||
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
||||
@ -90,6 +91,7 @@ def _flash_attn_varlen_forward(
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
None,
|
||||
block_table,
|
||||
alibi_slopes,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
@ -299,6 +301,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
|
||||
window_size=window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
return_softmax=return_softmax and dropout_p > 0,
|
||||
block_table=None,
|
||||
)
|
||||
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
|
||||
ctx.dropout_p = dropout_p
|
||||
@ -440,6 +443,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
|
||||
window_size=window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
return_softmax=return_softmax and dropout_p > 0,
|
||||
block_table=None,
|
||||
)
|
||||
ctx.save_for_backward(
|
||||
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
|
||||
@ -570,6 +574,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
|
||||
alibi_slopes,
|
||||
deterministic,
|
||||
return_softmax,
|
||||
block_table,
|
||||
):
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
@ -587,6 +592,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
|
||||
window_size=window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
return_softmax=return_softmax and dropout_p > 0,
|
||||
block_table=block_table,
|
||||
)
|
||||
ctx.save_for_backward(
|
||||
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
|
||||
@ -630,7 +636,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
|
||||
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
|
||||
dk = dk[..., : dout.shape[-1]]
|
||||
dv = dv[..., : dout.shape[-1]]
|
||||
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None
|
||||
return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
def flash_attn_qkvpacked_func(
|
||||
@ -1001,6 +1007,7 @@ def flash_attn_varlen_func(
|
||||
alibi_slopes=None,
|
||||
deterministic=False,
|
||||
return_attn_probs=False,
|
||||
block_table=None,
|
||||
):
|
||||
"""dropout_p should be set to 0.0 during evaluation
|
||||
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
|
||||
@ -1071,6 +1078,7 @@ def flash_attn_varlen_func(
|
||||
alibi_slopes,
|
||||
deterministic,
|
||||
return_attn_probs,
|
||||
block_table,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -1542,8 +1542,12 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
|
||||
(1023, 1024),
|
||||
],
|
||||
)
|
||||
# TODO: add smaller page sizes when https://github.com/Dao-AILab/flash-attention/pull/824 is merged
|
||||
@pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512])
|
||||
# @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)])
|
||||
def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
|
||||
def test_flash_attn_varlen_causal(
|
||||
seqlen_q, seqlen_k, swap_sq_sk, d, local, paged_kv_block_size, dtype
|
||||
):
|
||||
if (
|
||||
max(seqlen_q, seqlen_k) >= 2048
|
||||
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
|
||||
@ -1559,8 +1563,19 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
|
||||
nheads = 9
|
||||
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
|
||||
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
||||
k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
||||
v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
||||
|
||||
if paged_kv_block_size is None:
|
||||
k = torch.randn(
|
||||
batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True
|
||||
)
|
||||
v = torch.randn(
|
||||
batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True
|
||||
)
|
||||
block_table = None
|
||||
else:
|
||||
k, v, block_table, k_cache_paged, v_cache_paged, num_blocks = _generate_block_kvcache(
|
||||
seqlen_k, paged_kv_block_size, batch_size, nheads, d, device, dtype
|
||||
)
|
||||
query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
|
||||
key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
|
||||
(
|
||||
@ -1580,8 +1595,8 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
|
||||
) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
|
||||
out_unpad = flash_attn_varlen_func(
|
||||
q_unpad,
|
||||
k_unpad,
|
||||
v_unpad,
|
||||
k_unpad if paged_kv_block_size is None else k_cache_paged,
|
||||
v_unpad if paged_kv_block_size is None else v_cache_paged,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
@ -1589,6 +1604,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
|
||||
0.0,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
block_table=block_table,
|
||||
)
|
||||
out = output_pad_fn(out_unpad)
|
||||
out_ref, attn_ref = attention_ref(
|
||||
@ -1625,7 +1641,8 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
|
||||
|
||||
g = torch.randn_like(out)
|
||||
do_o = (g.float() * out.float()).sum(-1)
|
||||
if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
|
||||
test_backward = (d <= MAX_HEADDIM_SM8x or d > 224 or is_sm80 or is_sm90) and block_table is None
|
||||
if test_backward:
|
||||
(
|
||||
dq_unpad,
|
||||
dk_unpad,
|
||||
@ -1661,7 +1678,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
|
||||
# of a Pytorch implementation.
|
||||
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5
|
||||
|
||||
if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
|
||||
if test_backward:
|
||||
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5
|
||||
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5
|
||||
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5
|
||||
@ -1888,29 +1905,16 @@ def test_flash_attn_kvcache(
|
||||
v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype)
|
||||
block_table = None
|
||||
else:
|
||||
num_blocks = math.ceil(seqlen_k / paged_kv_block_size) * batch_size * 3
|
||||
k_cache_paged = torch.randn(
|
||||
num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype
|
||||
(
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_table,
|
||||
k_cache_paged,
|
||||
v_cache_paged,
|
||||
num_blocks,
|
||||
) = _generate_block_kvcache(
|
||||
seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype
|
||||
)
|
||||
v_cache_paged = torch.randn(
|
||||
num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype
|
||||
)
|
||||
block_table = rearrange(
|
||||
torch.randperm(num_blocks, dtype=torch.int32, device=device),
|
||||
"(b nblocks) -> b nblocks",
|
||||
b=batch_size,
|
||||
)
|
||||
k_cache = rearrange(
|
||||
# pytorch 1.12 doesn't have indexing with int32
|
||||
k_cache_paged[block_table.to(dtype=torch.long).flatten()],
|
||||
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
|
||||
b=batch_size,
|
||||
)[:, :seqlen_k]
|
||||
v_cache = rearrange(
|
||||
v_cache_paged[block_table.to(dtype=torch.long).flatten()],
|
||||
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
|
||||
b=batch_size,
|
||||
)[:, :seqlen_k]
|
||||
cache_seqlens = torch.randint(
|
||||
0 if new_kv else 1,
|
||||
# If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
|
||||
@ -2073,6 +2077,33 @@ def test_flash_attn_kvcache(
|
||||
assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5
|
||||
|
||||
|
||||
def _generate_block_kvcache(seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype):
|
||||
num_blocks = math.ceil(seqlen_k / paged_kv_block_size) * batch_size * 3
|
||||
k_cache_paged = torch.randn(
|
||||
num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype
|
||||
)
|
||||
v_cache_paged = torch.randn(
|
||||
num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype
|
||||
)
|
||||
block_table = rearrange(
|
||||
torch.randperm(num_blocks, dtype=torch.int32, device=device),
|
||||
"(b nblocks) -> b nblocks",
|
||||
b=batch_size,
|
||||
)
|
||||
k_cache = rearrange(
|
||||
# pytorch 1.12 doesn't have indexing with int32
|
||||
k_cache_paged[block_table.to(dtype=torch.long).flatten()],
|
||||
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
|
||||
b=batch_size,
|
||||
)[:, :seqlen_k]
|
||||
v_cache = rearrange(
|
||||
v_cache_paged[block_table.to(dtype=torch.long).flatten()],
|
||||
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
|
||||
b=batch_size,
|
||||
)[:, :seqlen_k]
|
||||
return k_cache, v_cache, block_table, k_cache_paged, v_cache_paged, num_blocks
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
||||
@pytest.mark.parametrize("dtype", [torch.float16])
|
||||
@pytest.mark.parametrize("causal", [False, True])
|
||||
|
||||
Loading…
Reference in New Issue
Block a user