From 819b18e7ba7f179ba90e44b2a846ddbdd1b0763d Mon Sep 17 00:00:00 2001 From: ljss <31004720+beginlner@users.noreply.github.com> Date: Tue, 21 Nov 2023 09:46:32 +0800 Subject: [PATCH] Rewrite torch.repeat_interleave to remove cpu synchronization (#1599) --- vllm/model_executor/layers/attention.py | 30 ++++++++++++++++--------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index a0a1d362..e51bb311 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -95,10 +95,15 @@ class PagedAttention(nn.Module): """ if self.num_kv_heads != self.num_heads: # Project the key and value tensors to the desired number of heads. - key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1) - value = torch.repeat_interleave(value, - self.num_queries_per_kv, - dim=1) + query = query.view(query.shape[0], self.num_kv_heads, + self.num_queries_per_kv, query.shape[-1]) + key = key[:, :, + None, :].expand(key.shape[0], self.num_kv_heads, + self.num_queries_per_kv, key.shape[-1]) + value = value[:, :, + None, :].expand(value.shape[0], self.num_kv_heads, + self.num_queries_per_kv, + value.shape[-1]) # TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize. out = xops.memory_efficient_attention_forward( @@ -110,7 +115,7 @@ class PagedAttention(nn.Module): scale=self.scale, ) # TODO(woosuk): Unnecessary copy. Optimize. - output.copy_(out.squeeze(0)) + output.copy_(out.view_as(output)) return output def get_alibi_slopes(self) -> Optional[torch.Tensor]: @@ -427,10 +432,15 @@ class PagedAttentionWithALiBi(PagedAttention): """ if self.num_kv_heads != self.num_heads: # Project the key and value tensors to the desired number of heads. - key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1) - value = torch.repeat_interleave(value, - self.num_queries_per_kv, - dim=1) + query = query.view(query.shape[0], self.num_kv_heads, + self.num_queries_per_kv, query.shape[-1]) + key = key[:, :, + None, :].expand(key.shape[0], self.num_kv_heads, + self.num_queries_per_kv, key.shape[-1]) + value = value[:, :, + None, :].expand(value.shape[0], self.num_kv_heads, + self.num_queries_per_kv, + value.shape[-1]) batch_size = input_metadata.num_prompts seq_len = input_metadata.max_prompt_len @@ -443,7 +453,7 @@ class PagedAttentionWithALiBi(PagedAttention): scale=self.scale, ) # TODO(woosuk): Unnecessary copy. Optimize. - output.copy_(out.view(-1, self.num_heads, self.head_size)) + output.copy_(out.view_as(output)) return output def get_alibi_slopes(self) -> Optional[torch.Tensor]: