From 898dd4bbf237b24ed8fd2a3d13ee33bd156bfb23 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 13 Jul 2024 00:08:27 -0700 Subject: [PATCH] Pass seqused_k to _flash_attn_varlen_forward --- flash_attn/flash_attn_interface.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 737e7a2..8e7076d 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -77,12 +77,13 @@ def _flash_attn_varlen_forward( dropout_p, softmax_scale, causal, - window_size, - softcap, - alibi_slopes, - return_softmax, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + return_softmax=False, block_table=None, leftpad_k=None, + seqused_k=None, ): maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x q, k, v = [maybe_contiguous(x) for x in (q, k, v)] @@ -93,7 +94,7 @@ def _flash_attn_varlen_forward( None, cu_seqlens_q, cu_seqlens_k, - None, + seqused_k, leftpad_k, block_table, alibi_slopes,