Add the return_softmax_lse parameter to the flash_attn_with_kvcache function to allow returning the logsumexp of the attention scores. (#989)

This commit is contained in:
Jianwei Dong 2024-07-08 23:29:40 +08:00 committed by GitHub
parent 6df7e0a02e
commit 4e8d60069f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1109,6 +1109,7 @@ def flash_attn_with_kvcache(
rotary_interleaved=True,
alibi_slopes=None,
num_splits=0,
return_softmax_lse=False,
):
"""
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
@ -1187,9 +1188,13 @@ def flash_attn_with_kvcache(
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
to automatically determine the number of splits.
Don't change this unless you know what you are doing.
return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
Return:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
"""
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
@ -1224,4 +1229,4 @@ def flash_attn_with_kvcache(
rotary_interleaved,
num_splits,
)
return out
return (out, softmax_lse) if return_softmax_lse else out