[Core] Remove unnecessary copies in flash attn backend (#5138)
This commit is contained in:
parent
7a64d24aad
commit
0ab278ca31
@ -6,4 +6,4 @@ ray >= 2.9
|
|||||||
nvidia-ml-py # for pynvml package
|
nvidia-ml-py # for pynvml package
|
||||||
torch == 2.3.0
|
torch == 2.3.0
|
||||||
xformers == 0.0.26.post1 # Requires PyTorch 2.3.0
|
xformers == 0.0.26.post1 # Requires PyTorch 2.3.0
|
||||||
vllm-flash-attn == 2.5.8.post2 # Requires PyTorch 2.3.0
|
vllm-flash-attn == 2.5.9 # Requires PyTorch 2.3.0
|
||||||
|
|||||||
@ -317,7 +317,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
# normal attention
|
# normal attention
|
||||||
# When block_tables are not filled, it means q and k are the
|
# When block_tables are not filled, it means q and k are the
|
||||||
# prompt, and they have the same length.
|
# prompt, and they have the same length.
|
||||||
out = flash_attn_varlen_func(
|
flash_attn_varlen_func(
|
||||||
q=query,
|
q=query,
|
||||||
k=key,
|
k=key,
|
||||||
v=value,
|
v=value,
|
||||||
@ -329,14 +329,13 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
causal=True,
|
causal=True,
|
||||||
window_size=self.sliding_window,
|
window_size=self.sliding_window,
|
||||||
alibi_slopes=self.alibi_slopes,
|
alibi_slopes=self.alibi_slopes,
|
||||||
|
out=output[:num_prefill_tokens],
|
||||||
)
|
)
|
||||||
assert output[:num_prefill_tokens].shape == out.shape
|
|
||||||
output[:num_prefill_tokens] = out
|
|
||||||
else:
|
else:
|
||||||
# prefix-enabled attention
|
# prefix-enabled attention
|
||||||
assert prefill_meta.seq_lens is not None
|
assert prefill_meta.seq_lens is not None
|
||||||
max_seq_len = max(prefill_meta.seq_lens)
|
max_seq_len = max(prefill_meta.seq_lens)
|
||||||
output[:num_prefill_tokens] = flash_attn_varlen_func(
|
flash_attn_varlen_func(
|
||||||
q=query,
|
q=query,
|
||||||
k=key_cache,
|
k=key_cache,
|
||||||
v=value_cache,
|
v=value_cache,
|
||||||
@ -348,11 +347,12 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
causal=True,
|
causal=True,
|
||||||
alibi_slopes=self.alibi_slopes,
|
alibi_slopes=self.alibi_slopes,
|
||||||
block_table=prefill_meta.block_tables,
|
block_table=prefill_meta.block_tables,
|
||||||
|
out=output[:num_prefill_tokens],
|
||||||
)
|
)
|
||||||
|
|
||||||
if decode_meta := attn_metadata.decode_metadata:
|
if decode_meta := attn_metadata.decode_metadata:
|
||||||
# Decoding run.
|
# Decoding run.
|
||||||
output[num_prefill_tokens:] = flash_attn_with_kvcache(
|
flash_attn_with_kvcache(
|
||||||
decode_query.unsqueeze(1),
|
decode_query.unsqueeze(1),
|
||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
@ -361,7 +361,8 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
softmax_scale=self.scale,
|
softmax_scale=self.scale,
|
||||||
causal=True,
|
causal=True,
|
||||||
alibi_slopes=self.alibi_slopes,
|
alibi_slopes=self.alibi_slopes,
|
||||||
).squeeze(1)
|
out=output[num_prefill_tokens:].unsqueeze(1),
|
||||||
|
)
|
||||||
|
|
||||||
# Reshape the output tensor.
|
# Reshape the output tensor.
|
||||||
return output.view(num_tokens, hidden_size)
|
return output.view(num_tokens, hidden_size)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user