Add the return_softmax_lse parameter to the flash_attn_with_kvcache function to allow returning the logsumexp of the attention scores. (#989)
This commit is contained in:
parent
6df7e0a02e
commit
4e8d60069f
@ -1109,6 +1109,7 @@ def flash_attn_with_kvcache(
|
||||
rotary_interleaved=True,
|
||||
alibi_slopes=None,
|
||||
num_splits=0,
|
||||
return_softmax_lse=False,
|
||||
):
|
||||
"""
|
||||
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
|
||||
@ -1187,9 +1188,13 @@ def flash_attn_with_kvcache(
|
||||
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
|
||||
to automatically determine the number of splits.
|
||||
Don't change this unless you know what you are doing.
|
||||
return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.
|
||||
|
||||
Return:
|
||||
out: (batch_size, seqlen, nheads, headdim).
|
||||
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
|
||||
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
||||
normalization factor).
|
||||
"""
|
||||
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
|
||||
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
|
||||
@ -1224,4 +1229,4 @@ def flash_attn_with_kvcache(
|
||||
rotary_interleaved,
|
||||
num_splits,
|
||||
)
|
||||
return out
|
||||
return (out, softmax_lse) if return_softmax_lse else out
|
||||
Loading…
Reference in New Issue
Block a user