diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 91cc370..bf8cdcb 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -1037,13 +1037,14 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size std::vector mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size - const at::Tensor &kcache, // batch_size x seqlen_k x num_heads_k x head_size - const at::Tensor &vcache, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size + const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size c10::optional &k_, // batch_size x seqlen_knew x num_heads_k x head_size c10::optional &v_, // batch_size x seqlen_knew x num_heads_k x head_size c10::optional &seqlens_k_, // batch_size c10::optional &rotary_cos_, // seqlen_ro x (rotary_dim / 2) c10::optional &rotary_sin_, // seqlen_ro x (rotary_dim / 2) + c10::optional &cache_batch_idx_, // indices to index into the KV cache c10::optional &out_, // batch_size x seqlen_q x num_heads x head_size const float softmax_scale, bool is_causal, @@ -1084,6 +1085,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he const int head_size_og = sizes[3]; const int seqlen_k = kcache.size(1); const int num_heads_k = kcache.size(2); + const int batch_size_c = kcache.size(0); TORCH_CHECK(batch_size > 0, "batch size must be postive"); TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); @@ -1102,8 +1104,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he } CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); - CHECK_SHAPE(kcache, batch_size, seqlen_k, num_heads_k, head_size_og); - CHECK_SHAPE(vcache, batch_size, seqlen_k, num_heads_k, head_size_og); + CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og); + CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og); at::Tensor q_padded, kcache_padded, vcache_padded; if (head_size_og % 8 != 0) { @@ -1229,6 +1231,13 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he params.rotary_dim = 0; } + if (cache_batch_idx_.has_value()) { + auto cache_batch_idx = cache_batch_idx_.value(); + CHECK_DEVICE(cache_batch_idx); + CHECK_CONTIGUOUS(cache_batch_idx); + TORCH_CHECK(cache_batch_idx.scalar_type() == torch::kInt32, "cache_batch_idx must have dtype int32"); + params.cache_batch_idx = reinterpret_cast(cache_batch_idx.data_ptr()); + } // This needs to match with run_mha_fwd_splitkv_dispatch const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); const int num_n_blocks = (seqlen_k + block_n - 1) / block_n; @@ -1248,8 +1257,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he } auto stream = at::cuda::getCurrentCUDAStream().stream(); - // Only split kernel supports appending to KV cache - run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value()); + // Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx + run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value() || cache_batch_idx_.has_value()); if (head_size_og % 8 != 0) { out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); diff --git a/csrc/flash_attn/src/flash.h b/csrc/flash_attn/src/flash.h index 81f33d9..fe0fe3f 100644 --- a/csrc/flash_attn/src/flash.h +++ b/csrc/flash_attn/src/flash.h @@ -95,6 +95,9 @@ struct Flash_fwd_params : public Qkv_params { void * __restrict__ rotary_cos_ptr; void * __restrict__ rotary_sin_ptr; + // The indices to index into the KV cache. + int *__restrict__ cache_batch_idx; + // The dropout probability (probability of keeping an activation). float p_dropout; // uint32_t p_dropout_in_uint; diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 312b4dd..323068e 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -668,9 +668,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; // We move K and V to the last block. - const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb]; + const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; - const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index eba913a..e0444cd 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -928,6 +928,7 @@ def flash_attn_with_kvcache( rotary_cos=None, rotary_sin=None, cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, + cache_batch_idx: Optional[torch.Tensor] = None, softmax_scale=None, causal=False, window_size=(-1, -1), # -1 means infinite context window @@ -978,8 +979,8 @@ def flash_attn_with_kvcache( Arguments: q: (batch_size, seqlen, nheads, headdim) - k_cache: (batch_size, seqlen_cache, nheads_k, headdim) - v_cache: (batch_size, seqlen_cache, nheads_k, headdim) + k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) + v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate k with k_cache, starting at the indices specified by cache_seqlens. v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k. @@ -988,6 +989,10 @@ def flash_attn_with_kvcache( rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos. cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the KV cache. + cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache. + If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1]. + If the indices are not distinct, and k and v are provided, the values updated in the cache + might come from any of the duplicate indices. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). @@ -1014,6 +1019,8 @@ def flash_attn_with_kvcache( cache_seqlens = torch.full( (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device ) + cache_seqlens = maybe_contiguous(cache_seqlens) + cache_batch_idx = maybe_contiguous(cache_batch_idx) out, softmax_lse = flash_attn_cuda.fwd_kvcache( q, k_cache, @@ -1023,6 +1030,7 @@ def flash_attn_with_kvcache( cache_seqlens, rotary_cos, rotary_sin, + cache_batch_idx, None, softmax_scale, causal, diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 4fe3309..90e5899 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -1668,7 +1668,7 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dt @pytest.mark.parametrize("new_kv", [False, True]) # @pytest.mark.parametrize("new_kv", [True]) @pytest.mark.parametrize("local", [False, True]) -# @pytest.mark.parametrize("local", [True]) +# @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) @@ -1677,6 +1677,8 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dt # @pytest.mark.parametrize("rotary_interleaved", [False]) @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) # @pytest.mark.parametrize("rotary_fraction", [0.0]) +@pytest.mark.parametrize("has_batch_idx", [False, True]) +# @pytest.mark.parametrize("has_batch_idx", [True]) @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) @@ -1703,6 +1705,7 @@ def test_flash_attn_kvcache( seqlen_q, seqlen_k, d, + has_batch_idx, rotary_fraction, rotary_interleaved, seqlen_new_eq_seqlen_q, @@ -1721,6 +1724,7 @@ def test_flash_attn_kvcache( # set seed torch.random.manual_seed(0) batch_size = 2 + batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 nheads = 6 # rotary_dim must be a multiple of 16, and must be <= d rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 @@ -1734,8 +1738,8 @@ def test_flash_attn_kvcache( v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) else: k, v = None, None - k_cache = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype) - v_cache = torch.randn(batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype) + k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) + v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) cache_seqlens = torch.randint( 0, # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough @@ -1746,6 +1750,10 @@ def test_flash_attn_kvcache( dtype=torch.int32, device=device, ) + if has_batch_idx: + cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[:batch_size] + else: + cache_batch_idx = None # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) if rotary_dim > 0: angle = torch.rand(seqlen_k, rotary_dim // 2, device=device) * 2 * math.pi @@ -1775,8 +1783,8 @@ def test_flash_attn_kvcache( cos, sin = None, None q_ro, k_ro = q, k # k_cache[:, 64:] = -1 - k_cache_ref = k_cache.clone() - v_cache_ref = v_cache.clone() + k_cache_ref = (k_cache if not has_batch_idx else k_cache[cache_batch_idx]).clone() + v_cache_ref = (v_cache if not has_batch_idx else v_cache[cache_batch_idx]).clone() arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") if new_kv: @@ -1796,6 +1804,7 @@ def test_flash_attn_kvcache( cos, sin, cache_seqlens, + cache_batch_idx, causal=causal, window_size=window_size, rotary_interleaved=rotary_interleaved, @@ -1844,8 +1853,10 @@ def test_flash_attn_kvcache( # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. if new_kv: - assert torch.allclose(k_cache, k_cache_ref, rtol=1e-3, atol=1e-3) - assert torch.equal(v_cache, v_cache_ref) + k_cache_select = k_cache if not has_batch_idx else k_cache[cache_batch_idx] + v_cache_select = v_cache if not has_batch_idx else v_cache[cache_batch_idx] + assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) + assert torch.equal(v_cache_select, v_cache_ref) assert (out - out_ref).abs().max().item() <= 3 * (out_pt - out_ref).abs().max().item() + 1e-5