Pass seqused_k to _flash_attn_varlen_forward

This commit is contained in:
Tri Dao 2024-07-13 00:08:27 -07:00
parent 7ef24848cf
commit 898dd4bbf2

View File

@ -77,12 +77,13 @@ def _flash_attn_varlen_forward(
dropout_p,
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
return_softmax,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
return_softmax=False,
block_table=None,
leftpad_k=None,
seqused_k=None,
):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
@ -93,7 +94,7 @@ def _flash_attn_varlen_forward(
None,
cu_seqlens_q,
cu_seqlens_k,
None,
seqused_k,
leftpad_k,
block_table,
alibi_slopes,