Pass seqused_k to _flash_attn_varlen_forward
This commit is contained in:
parent
7ef24848cf
commit
898dd4bbf2
@ -77,12 +77,13 @@ def _flash_attn_varlen_forward(
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
causal,
|
||||
window_size,
|
||||
softcap,
|
||||
alibi_slopes,
|
||||
return_softmax,
|
||||
window_size=(-1, -1),
|
||||
softcap=0.0,
|
||||
alibi_slopes=None,
|
||||
return_softmax=False,
|
||||
block_table=None,
|
||||
leftpad_k=None,
|
||||
seqused_k=None,
|
||||
):
|
||||
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
|
||||
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
||||
@ -93,7 +94,7 @@ def _flash_attn_varlen_forward(
|
||||
None,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
None,
|
||||
seqused_k,
|
||||
leftpad_k,
|
||||
block_table,
|
||||
alibi_slopes,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user