Revert "[Core] Remove unnecessary copies in flash attn backend" (#5478)

This commit is contained in:
Antoni Baum 2024-06-13 11:22:50 -07:00 committed by GitHub
parent a8fda4f661
commit 6b0511a57b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -317,7 +317,7 @@ class FlashAttentionImpl(AttentionImpl):
# normal attention # normal attention
# When block_tables are not filled, it means q and k are the # When block_tables are not filled, it means q and k are the
# prompt, and they have the same length. # prompt, and they have the same length.
flash_attn_varlen_func( out = flash_attn_varlen_func(
q=query, q=query,
k=key, k=key,
v=value, v=value,
@ -329,13 +329,14 @@ class FlashAttentionImpl(AttentionImpl):
causal=True, causal=True,
window_size=self.sliding_window, window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
out=output[:num_prefill_tokens],
) )
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
else: else:
# prefix-enabled attention # prefix-enabled attention
assert prefill_meta.seq_lens is not None assert prefill_meta.seq_lens is not None
max_seq_len = max(prefill_meta.seq_lens) max_seq_len = max(prefill_meta.seq_lens)
flash_attn_varlen_func( output[:num_prefill_tokens] = flash_attn_varlen_func(
q=query, q=query,
k=key_cache, k=key_cache,
v=value_cache, v=value_cache,
@ -347,12 +348,11 @@ class FlashAttentionImpl(AttentionImpl):
causal=True, causal=True,
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
block_table=prefill_meta.block_tables, block_table=prefill_meta.block_tables,
out=output[:num_prefill_tokens],
) )
if decode_meta := attn_metadata.decode_metadata: if decode_meta := attn_metadata.decode_metadata:
# Decoding run. # Decoding run.
flash_attn_with_kvcache( output[num_prefill_tokens:] = flash_attn_with_kvcache(
decode_query.unsqueeze(1), decode_query.unsqueeze(1),
key_cache, key_cache,
value_cache, value_cache,
@ -361,8 +361,7 @@ class FlashAttentionImpl(AttentionImpl):
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
out=output[num_prefill_tokens:].unsqueeze(1), ).squeeze(1)
)
# Reshape the output tensor. # Reshape the output tensor.
return output.view(num_tokens, hidden_size) return output.view(num_tokens, hidden_size)