diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index fdb6c71..96f0cff 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -494,12 +494,13 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size std::vector mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i - const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i - const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. c10::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 c10::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. + c10::optional &block_table_, // batch_size x max_num_blocks_per_seq c10::optional &alibi_slopes_, // num_heads or b x num_heads int max_seqlen_q, const int max_seqlen_k, @@ -535,6 +536,15 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k); + at::Tensor block_table; + const bool paged_KV = block_table_.has_value(); + if (paged_KV) { + block_table = block_table_.value(); + CHECK_DEVICE(block_table); + TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); + TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); + } + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); @@ -546,8 +556,12 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s const int batch_size = cu_seqlens_q.numel() - 1; int num_heads = sizes[1]; const int head_size_og = sizes[2]; - const int total_k = k.size(0); - const int num_heads_k = k.size(1); + const int num_heads_k = paged_KV ? k.size(2) : k.size(1); + + const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); + const int num_blocks = !paged_KV ? 0 : k.size(0); + const int page_block_size = !paged_KV ? 1 : k.size(1); + TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256"); if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case if (is_causal) { window_size_right = 0; } @@ -575,8 +589,16 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s if (window_size_right >= max_seqlen_k) { window_size_right = -1; } CHECK_SHAPE(q, total_q, num_heads, head_size_og); - CHECK_SHAPE(k, total_k, num_heads_k, head_size_og); - CHECK_SHAPE(v, total_k, num_heads_k, head_size_og); + if (!paged_KV) { + const int total_k = k.size(0); + CHECK_SHAPE(k, total_k, num_heads_k, head_size_og); + CHECK_SHAPE(v, total_k, num_heads_k, head_size_og); + } else { + CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size_og); + CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size_og); + CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); + } + CHECK_SHAPE(cu_seqlens_q, batch_size + 1); CHECK_SHAPE(cu_seqlens_k, batch_size + 1); if (seqused_k.has_value()){ @@ -654,6 +676,14 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s window_size_left, window_size_right, seqlenq_ngroups_swapped); + + if (paged_KV) { + params.block_table = block_table.data_ptr(); + params.block_table_batch_stride = block_table.stride(0); + params.k_batch_stride = k_padded.stride(0); + params.v_batch_stride = v_padded.stride(0); + } + params.page_block_size = page_block_size; if (seqlenq_ngroups_swapped) { // Only apply split-k for decoding set_params_splitkv(params, batch_size, num_heads, @@ -682,7 +712,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s if (max_seqlen_k > 0) { auto stream = at::cuda::getCurrentCUDAStream().stream(); - run_mha_fwd(params, stream); + run_mha_fwd(params, stream, paged_KV); } else { // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. out.zero_(); diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index a1ef865..a7f15be 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -79,6 +79,7 @@ def _flash_attn_varlen_forward( window_size, alibi_slopes, return_softmax, + block_table, ): maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x q, k, v = [maybe_contiguous(x) for x in (q, k, v)] @@ -90,6 +91,7 @@ def _flash_attn_varlen_forward( cu_seqlens_q, cu_seqlens_k, None, + block_table, alibi_slopes, max_seqlen_q, max_seqlen_k, @@ -299,6 +301,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): window_size=window_size, alibi_slopes=alibi_slopes, return_softmax=return_softmax and dropout_p > 0, + block_table=None, ) ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state) ctx.dropout_p = dropout_p @@ -440,6 +443,7 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): window_size=window_size, alibi_slopes=alibi_slopes, return_softmax=return_softmax and dropout_p > 0, + block_table=None, ) ctx.save_for_backward( q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state @@ -570,6 +574,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function): alibi_slopes, deterministic, return_softmax, + block_table, ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -587,6 +592,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function): window_size=window_size, alibi_slopes=alibi_slopes, return_softmax=return_softmax and dropout_p > 0, + block_table=block_table, ) ctx.save_for_backward( q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state @@ -630,7 +636,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function): dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None def flash_attn_qkvpacked_func( @@ -1001,6 +1007,7 @@ def flash_attn_varlen_func( alibi_slopes=None, deterministic=False, return_attn_probs=False, + block_table=None, ): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads @@ -1071,6 +1078,7 @@ def flash_attn_varlen_func( alibi_slopes, deterministic, return_attn_probs, + block_table, ) diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 892b8be..308e30b 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -1542,8 +1542,12 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): (1023, 1024), ], ) +# TODO: add smaller page sizes when https://github.com/Dao-AILab/flash-attention/pull/824 is merged +@pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512]) # @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)]) -def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): +def test_flash_attn_varlen_causal( + seqlen_q, seqlen_k, swap_sq_sk, d, local, paged_kv_block_size, dtype +): if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 @@ -1559,8 +1563,19 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp nheads = 9 window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) - k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) - v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + + if paged_kv_block_size is None: + k = torch.randn( + batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True + ) + v = torch.randn( + batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True + ) + block_table = None + else: + k, v, block_table, k_cache_paged, v_cache_paged, num_blocks = _generate_block_kvcache( + seqlen_k, paged_kv_block_size, batch_size, nheads, d, device, dtype + ) query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random") ( @@ -1580,8 +1595,8 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) out_unpad = flash_attn_varlen_func( q_unpad, - k_unpad, - v_unpad, + k_unpad if paged_kv_block_size is None else k_cache_paged, + v_unpad if paged_kv_block_size is None else v_cache_paged, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, @@ -1589,6 +1604,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp 0.0, causal=causal, window_size=window_size, + block_table=block_table, ) out = output_pad_fn(out_unpad) out_ref, attn_ref = attention_ref( @@ -1625,7 +1641,8 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp g = torch.randn_like(out) do_o = (g.float() * out.float()).sum(-1) - if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90): + test_backward = (d <= MAX_HEADDIM_SM8x or d > 224 or is_sm80 or is_sm90) and block_table is None + if test_backward: ( dq_unpad, dk_unpad, @@ -1661,7 +1678,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 - if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90): + if test_backward: assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5 assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5 assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 @@ -1888,29 +1905,16 @@ def test_flash_attn_kvcache( v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) block_table = None else: - num_blocks = math.ceil(seqlen_k / paged_kv_block_size) * batch_size * 3 - k_cache_paged = torch.randn( - num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype + ( + k_cache, + v_cache, + block_table, + k_cache_paged, + v_cache_paged, + num_blocks, + ) = _generate_block_kvcache( + seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype ) - v_cache_paged = torch.randn( - num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype - ) - block_table = rearrange( - torch.randperm(num_blocks, dtype=torch.int32, device=device), - "(b nblocks) -> b nblocks", - b=batch_size, - ) - k_cache = rearrange( - # pytorch 1.12 doesn't have indexing with int32 - k_cache_paged[block_table.to(dtype=torch.long).flatten()], - "(b nblocks) block_size ... -> b (nblocks block_size) ...", - b=batch_size, - )[:, :seqlen_k] - v_cache = rearrange( - v_cache_paged[block_table.to(dtype=torch.long).flatten()], - "(b nblocks) block_size ... -> b (nblocks block_size) ...", - b=batch_size, - )[:, :seqlen_k] cache_seqlens = torch.randint( 0 if new_kv else 1, # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough @@ -2073,6 +2077,33 @@ def test_flash_attn_kvcache( assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 +def _generate_block_kvcache(seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype): + num_blocks = math.ceil(seqlen_k / paged_kv_block_size) * batch_size * 3 + k_cache_paged = torch.randn( + num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype + ) + v_cache_paged = torch.randn( + num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype + ) + block_table = rearrange( + torch.randperm(num_blocks, dtype=torch.int32, device=device), + "(b nblocks) -> b nblocks", + b=batch_size, + ) + k_cache = rearrange( + # pytorch 1.12 doesn't have indexing with int32 + k_cache_paged[block_table.to(dtype=torch.long).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k] + v_cache = rearrange( + v_cache_paged[block_table.to(dtype=torch.long).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k] + return k_cache, v_cache, block_table, k_cache_paged, v_cache_paged, num_blocks + + # @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("causal", [False, True])