From 4e8d60069f65573aca0626d1425513a7597655c2 Mon Sep 17 00:00:00 2001 From: Jianwei Dong <1913953267@qq.com> Date: Mon, 8 Jul 2024 23:29:40 +0800 Subject: [PATCH] Add the return_softmax_lse parameter to the flash_attn_with_kvcache function to allow returning the logsumexp of the attention scores. (#989) --- flash_attn/flash_attn_interface.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 9061372..60c8cd2 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -1109,6 +1109,7 @@ def flash_attn_with_kvcache( rotary_interleaved=True, alibi_slopes=None, num_splits=0, + return_softmax_lse=False, ): """ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from @@ -1187,9 +1188,13 @@ def flash_attn_with_kvcache( If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic to automatically determine the number of splits. Don't change this unless you know what you are doing. + return_softmax_lse: bool. Whether to return the logsumexp of the attention scores. Return: out: (batch_size, seqlen, nheads, headdim). + softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). """ assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" @@ -1224,4 +1229,4 @@ def flash_attn_with_kvcache( rotary_interleaved, num_splits, ) - return out + return (out, softmax_lse) if return_softmax_lse else out \ No newline at end of file