Allow varlen_fwd to take optional seqused_k (#647)

Co-authored-by: bottler <bottler@users.noreply.github.com>
This commit is contained in:
Jeremy Reizenstein 2023-11-27 08:41:23 +00:00 committed by GitHub
parent 23b77c8148
commit ce3e7280f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 19 additions and 1 deletions

View File

@ -36,6 +36,7 @@ void set_params_fprop(Flash_fwd_params &params,
at::Tensor out,
void *cu_seqlens_q_d,
void *cu_seqlens_k_d,
void *seqused_k,
void *p_d,
void *softmax_lse_d,
float p_dropout,
@ -72,6 +73,7 @@ void set_params_fprop(Flash_fwd_params &params,
params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
params.seqused_k = static_cast<int *>(seqused_k);
// P = softmax(QK^T)
params.p_ptr = p_d;
@ -156,6 +158,7 @@ void set_params_dgrad(Flash_bwd_params &params,
cu_seqlens_q_d,
cu_seqlens_k_d,
nullptr,
nullptr,
softmax_lse_d,
p_dropout,
softmax_scale,
@ -363,6 +366,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
q_padded, k_padded, v_padded, out,
/*cu_seqlens_q_d=*/nullptr,
/*cu_seqlens_k_d=*/nullptr,
/*seqused_k=*/nullptr,
return_softmax ? p.data_ptr() : nullptr,
softmax_lse.data_ptr(),
p_dropout,
@ -436,6 +440,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
const int max_seqlen_q,
const int max_seqlen_k,
const float p_dropout,
@ -494,6 +499,13 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
if (seqused_k.has_value()){
auto seqused_k_ = seqused_k.value();
TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32");
TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device");
TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous");
CHECK_SHAPE(seqused_k_, batch_size);
}
at::Tensor q_padded, k_padded, v_padded;
if (head_size_og % 8 != 0) {
@ -554,6 +566,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
q_padded, k_padded, v_padded, out,
cu_seqlens_q.data_ptr(),
cu_seqlens_k.data_ptr(),
seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
return_softmax ? p.data_ptr() : nullptr,
softmax_lse.data_ptr(),
p_dropout,
@ -1167,6 +1180,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
q_padded, kcache_padded, vcache_padded, out,
/*cu_seqlens_q_d=*/nullptr,
/*cu_seqlens_k_d=*/nullptr,
/*seqused_k=*/nullptr,
/*p_ptr=*/nullptr,
softmax_lse.data_ptr(),
/*p_dropout=*/0.f,

View File

@ -19,7 +19,7 @@ struct BlockInfo {
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
, seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
, actual_seqlen_k(seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
, actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
{
}

View File

@ -77,6 +77,9 @@ struct Flash_fwd_params : public Qkv_params {
int * __restrict__ cu_seqlens_q;
int * __restrict__ cu_seqlens_k;
// If provided, the actual length of each k sequence.
int * __restrict__ seqused_k;
int *__restrict__ blockmask;
// The K_new and V_new matrices.

View File

@ -83,6 +83,7 @@ def _flash_attn_varlen_forward(
None,
cu_seqlens_q,
cu_seqlens_k,
None,
max_seqlen_q,
max_seqlen_k,
dropout_p,