diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 5109f173..35361798 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -6,4 +6,4 @@ ray >= 2.9 nvidia-ml-py # for pynvml package torch == 2.3.0 xformers == 0.0.26.post1 # Requires PyTorch 2.3.0 -vllm-flash-attn == 2.5.8.post2 # Requires PyTorch 2.3.0 +vllm-flash-attn == 2.5.9 # Requires PyTorch 2.3.0 diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 0b9d6283..070c074e 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -317,7 +317,7 @@ class FlashAttentionImpl(AttentionImpl): # normal attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. - out = flash_attn_varlen_func( + flash_attn_varlen_func( q=query, k=key, v=value, @@ -329,14 +329,13 @@ class FlashAttentionImpl(AttentionImpl): causal=True, window_size=self.sliding_window, alibi_slopes=self.alibi_slopes, + out=output[:num_prefill_tokens], ) - assert output[:num_prefill_tokens].shape == out.shape - output[:num_prefill_tokens] = out else: # prefix-enabled attention assert prefill_meta.seq_lens is not None max_seq_len = max(prefill_meta.seq_lens) - output[:num_prefill_tokens] = flash_attn_varlen_func( + flash_attn_varlen_func( q=query, k=key_cache, v=value_cache, @@ -348,11 +347,12 @@ class FlashAttentionImpl(AttentionImpl): causal=True, alibi_slopes=self.alibi_slopes, block_table=prefill_meta.block_tables, + out=output[:num_prefill_tokens], ) if decode_meta := attn_metadata.decode_metadata: # Decoding run. - output[num_prefill_tokens:] = flash_attn_with_kvcache( + flash_attn_with_kvcache( decode_query.unsqueeze(1), key_cache, value_cache, @@ -361,7 +361,8 @@ class FlashAttentionImpl(AttentionImpl): softmax_scale=self.scale, causal=True, alibi_slopes=self.alibi_slopes, - ).squeeze(1) + out=output[num_prefill_tokens:].unsqueeze(1), + ) # Reshape the output tensor. return output.view(num_tokens, hidden_size)