diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 4d41310..07e97ce 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -279,7 +279,7 @@ def flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seq (they might not have the right scaling). deterministic: bool. Whether or not to ensure deterministic execution. Return: - out: (total, nheads, headdim). + out: (total_q, nheads, headdim). softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). @@ -315,7 +315,7 @@ def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, (they might not have the right scaling). deterministic: bool. Whether or not to ensure deterministic execution. Return: - out: (total, nheads, headdim). + out: (total_q, nheads, headdim). softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor).