Revert "[Core] Remove unnecessary copies in flash attn backend" (#5478)
This commit is contained in:
parent
a8fda4f661
commit
6b0511a57b
@ -317,7 +317,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
# normal attention
|
||||
# When block_tables are not filled, it means q and k are the
|
||||
# prompt, and they have the same length.
|
||||
flash_attn_varlen_func(
|
||||
out = flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
@ -329,13 +329,14 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
causal=True,
|
||||
window_size=self.sliding_window,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
out=output[:num_prefill_tokens],
|
||||
)
|
||||
assert output[:num_prefill_tokens].shape == out.shape
|
||||
output[:num_prefill_tokens] = out
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
assert prefill_meta.seq_lens is not None
|
||||
max_seq_len = max(prefill_meta.seq_lens)
|
||||
flash_attn_varlen_func(
|
||||
output[:num_prefill_tokens] = flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
@ -347,12 +348,11 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
block_table=prefill_meta.block_tables,
|
||||
out=output[:num_prefill_tokens],
|
||||
)
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
# Decoding run.
|
||||
flash_attn_with_kvcache(
|
||||
output[num_prefill_tokens:] = flash_attn_with_kvcache(
|
||||
decode_query.unsqueeze(1),
|
||||
key_cache,
|
||||
value_cache,
|
||||
@ -361,8 +361,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
out=output[num_prefill_tokens:].unsqueeze(1),
|
||||
)
|
||||
).squeeze(1)
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.view(num_tokens, hidden_size)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user