diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 9061372..60c8cd2 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -1109,6 +1109,7 @@ def flash_attn_with_kvcache( rotary_interleaved=True, alibi_slopes=None, num_splits=0, + return_softmax_lse=False, ): """ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from @@ -1187,9 +1188,13 @@ def flash_attn_with_kvcache( If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic to automatically determine the number of splits. Don't change this unless you know what you are doing. + return_softmax_lse: bool. Whether to return the logsumexp of the attention scores. Return: out: (batch_size, seqlen, nheads, headdim). + softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). """ assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" @@ -1224,4 +1229,4 @@ def flash_attn_with_kvcache( rotary_interleaved, num_splits, ) - return out + return (out, softmax_lse) if return_softmax_lse else out \ No newline at end of file