diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index a0a1d362..e51bb311 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -95,10 +95,15 @@ class PagedAttention(nn.Module): """ if self.num_kv_heads != self.num_heads: # Project the key and value tensors to the desired number of heads. - key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1) - value = torch.repeat_interleave(value, - self.num_queries_per_kv, - dim=1) + query = query.view(query.shape[0], self.num_kv_heads, + self.num_queries_per_kv, query.shape[-1]) + key = key[:, :, + None, :].expand(key.shape[0], self.num_kv_heads, + self.num_queries_per_kv, key.shape[-1]) + value = value[:, :, + None, :].expand(value.shape[0], self.num_kv_heads, + self.num_queries_per_kv, + value.shape[-1]) # TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize. out = xops.memory_efficient_attention_forward( @@ -110,7 +115,7 @@ class PagedAttention(nn.Module): scale=self.scale, ) # TODO(woosuk): Unnecessary copy. Optimize. - output.copy_(out.squeeze(0)) + output.copy_(out.view_as(output)) return output def get_alibi_slopes(self) -> Optional[torch.Tensor]: @@ -427,10 +432,15 @@ class PagedAttentionWithALiBi(PagedAttention): """ if self.num_kv_heads != self.num_heads: # Project the key and value tensors to the desired number of heads. - key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1) - value = torch.repeat_interleave(value, - self.num_queries_per_kv, - dim=1) + query = query.view(query.shape[0], self.num_kv_heads, + self.num_queries_per_kv, query.shape[-1]) + key = key[:, :, + None, :].expand(key.shape[0], self.num_kv_heads, + self.num_queries_per_kv, key.shape[-1]) + value = value[:, :, + None, :].expand(value.shape[0], self.num_kv_heads, + self.num_queries_per_kv, + value.shape[-1]) batch_size = input_metadata.num_prompts seq_len = input_metadata.max_prompt_len @@ -443,7 +453,7 @@ class PagedAttentionWithALiBi(PagedAttention): scale=self.scale, ) # TODO(woosuk): Unnecessary copy. Optimize. - output.copy_(out.view(-1, self.num_heads, self.head_size)) + output.copy_(out.view_as(output)) return output def get_alibi_slopes(self) -> Optional[torch.Tensor]: