diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index c40bbc1..351472c 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -36,6 +36,7 @@ void set_params_fprop(Flash_fwd_params ¶ms, at::Tensor out, void *cu_seqlens_q_d, void *cu_seqlens_k_d, + void *seqused_k, void *p_d, void *softmax_lse_d, float p_dropout, @@ -72,6 +73,7 @@ void set_params_fprop(Flash_fwd_params ¶ms, params.cu_seqlens_q = static_cast(cu_seqlens_q_d); params.cu_seqlens_k = static_cast(cu_seqlens_k_d); + params.seqused_k = static_cast(seqused_k); // P = softmax(QK^T) params.p_ptr = p_d; @@ -156,6 +158,7 @@ void set_params_dgrad(Flash_bwd_params ¶ms, cu_seqlens_q_d, cu_seqlens_k_d, nullptr, + nullptr, softmax_lse_d, p_dropout, softmax_scale, @@ -363,6 +366,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size q_padded, k_padded, v_padded, out, /*cu_seqlens_q_d=*/nullptr, /*cu_seqlens_k_d=*/nullptr, + /*seqused_k=*/nullptr, return_softmax ? p.data_ptr() : nullptr, softmax_lse.data_ptr(), p_dropout, @@ -436,6 +440,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q c10::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 + c10::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. const int max_seqlen_q, const int max_seqlen_k, const float p_dropout, @@ -494,6 +499,13 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q CHECK_SHAPE(v, total_k, num_heads_k, head_size_og); CHECK_SHAPE(cu_seqlens_q, batch_size + 1); CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + if (seqused_k.has_value()){ + auto seqused_k_ = seqused_k.value(); + TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32"); + TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device"); + TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous"); + CHECK_SHAPE(seqused_k_, batch_size); + } at::Tensor q_padded, k_padded, v_padded; if (head_size_og % 8 != 0) { @@ -554,6 +566,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q q_padded, k_padded, v_padded, out, cu_seqlens_q.data_ptr(), cu_seqlens_k.data_ptr(), + seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr, return_softmax ? p.data_ptr() : nullptr, softmax_lse.data_ptr(), p_dropout, @@ -1167,6 +1180,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he q_padded, kcache_padded, vcache_padded, out, /*cu_seqlens_q_d=*/nullptr, /*cu_seqlens_k_d=*/nullptr, + /*seqused_k=*/nullptr, /*p_ptr=*/nullptr, softmax_lse.data_ptr(), /*p_dropout=*/0.f, diff --git a/csrc/flash_attn/src/block_info.h b/csrc/flash_attn/src/block_info.h index 18793e9..65435e5 100644 --- a/csrc/flash_attn/src/block_info.h +++ b/csrc/flash_attn/src/block_info.h @@ -19,7 +19,7 @@ struct BlockInfo { // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. , seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) - , actual_seqlen_k(seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) + , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) { } diff --git a/csrc/flash_attn/src/flash.h b/csrc/flash_attn/src/flash.h index fe0fe3f..9983805 100644 --- a/csrc/flash_attn/src/flash.h +++ b/csrc/flash_attn/src/flash.h @@ -77,6 +77,9 @@ struct Flash_fwd_params : public Qkv_params { int * __restrict__ cu_seqlens_q; int * __restrict__ cu_seqlens_k; + // If provided, the actual length of each k sequence. + int * __restrict__ seqused_k; + int *__restrict__ blockmask; // The K_new and V_new matrices. diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index e0444cd..ae3a1c7 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -83,6 +83,7 @@ def _flash_attn_varlen_forward( None, cu_seqlens_q, cu_seqlens_k, + None, max_seqlen_q, max_seqlen_k, dropout_p,