Allow varlen_fwd to take optional seqused_k (#647)
Co-authored-by: bottler <bottler@users.noreply.github.com>
This commit is contained in:
parent
23b77c8148
commit
ce3e7280f8
@ -36,6 +36,7 @@ void set_params_fprop(Flash_fwd_params ¶ms,
|
||||
at::Tensor out,
|
||||
void *cu_seqlens_q_d,
|
||||
void *cu_seqlens_k_d,
|
||||
void *seqused_k,
|
||||
void *p_d,
|
||||
void *softmax_lse_d,
|
||||
float p_dropout,
|
||||
@ -72,6 +73,7 @@ void set_params_fprop(Flash_fwd_params ¶ms,
|
||||
|
||||
params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
|
||||
params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
|
||||
params.seqused_k = static_cast<int *>(seqused_k);
|
||||
|
||||
// P = softmax(QK^T)
|
||||
params.p_ptr = p_d;
|
||||
@ -156,6 +158,7 @@ void set_params_dgrad(Flash_bwd_params ¶ms,
|
||||
cu_seqlens_q_d,
|
||||
cu_seqlens_k_d,
|
||||
nullptr,
|
||||
nullptr,
|
||||
softmax_lse_d,
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
@ -363,6 +366,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
||||
q_padded, k_padded, v_padded, out,
|
||||
/*cu_seqlens_q_d=*/nullptr,
|
||||
/*cu_seqlens_k_d=*/nullptr,
|
||||
/*seqused_k=*/nullptr,
|
||||
return_softmax ? p.data_ptr() : nullptr,
|
||||
softmax_lse.data_ptr(),
|
||||
p_dropout,
|
||||
@ -436,6 +440,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &cu_seqlens_q, // b+1
|
||||
const at::Tensor &cu_seqlens_k, // b+1
|
||||
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
|
||||
const int max_seqlen_q,
|
||||
const int max_seqlen_k,
|
||||
const float p_dropout,
|
||||
@ -494,6 +499,13 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
|
||||
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
|
||||
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
|
||||
if (seqused_k.has_value()){
|
||||
auto seqused_k_ = seqused_k.value();
|
||||
TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32");
|
||||
TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device");
|
||||
TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous");
|
||||
CHECK_SHAPE(seqused_k_, batch_size);
|
||||
}
|
||||
|
||||
at::Tensor q_padded, k_padded, v_padded;
|
||||
if (head_size_og % 8 != 0) {
|
||||
@ -554,6 +566,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
q_padded, k_padded, v_padded, out,
|
||||
cu_seqlens_q.data_ptr(),
|
||||
cu_seqlens_k.data_ptr(),
|
||||
seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
|
||||
return_softmax ? p.data_ptr() : nullptr,
|
||||
softmax_lse.data_ptr(),
|
||||
p_dropout,
|
||||
@ -1167,6 +1180,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
|
||||
q_padded, kcache_padded, vcache_padded, out,
|
||||
/*cu_seqlens_q_d=*/nullptr,
|
||||
/*cu_seqlens_k_d=*/nullptr,
|
||||
/*seqused_k=*/nullptr,
|
||||
/*p_ptr=*/nullptr,
|
||||
softmax_lse.data_ptr(),
|
||||
/*p_dropout=*/0.f,
|
||||
|
||||
@ -19,7 +19,7 @@ struct BlockInfo {
|
||||
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
|
||||
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
|
||||
, seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
|
||||
, actual_seqlen_k(seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
|
||||
, actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
@ -77,6 +77,9 @@ struct Flash_fwd_params : public Qkv_params {
|
||||
int * __restrict__ cu_seqlens_q;
|
||||
int * __restrict__ cu_seqlens_k;
|
||||
|
||||
// If provided, the actual length of each k sequence.
|
||||
int * __restrict__ seqused_k;
|
||||
|
||||
int *__restrict__ blockmask;
|
||||
|
||||
// The K_new and V_new matrices.
|
||||
|
||||
@ -83,6 +83,7 @@ def _flash_attn_varlen_forward(
|
||||
None,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
None,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
dropout_p,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user