Rewrite torch.repeat_interleave to remove cpu synchronization (#1599)

This commit is contained in:
ljss 2023-11-21 09:46:32 +08:00 committed by GitHub
parent 19849db573
commit 819b18e7ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -95,10 +95,15 @@ class PagedAttention(nn.Module):
""" """
if self.num_kv_heads != self.num_heads: if self.num_kv_heads != self.num_heads:
# Project the key and value tensors to the desired number of heads. # Project the key and value tensors to the desired number of heads.
key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1) query = query.view(query.shape[0], self.num_kv_heads,
value = torch.repeat_interleave(value, self.num_queries_per_kv, query.shape[-1])
self.num_queries_per_kv, key = key[:, :,
dim=1) None, :].expand(key.shape[0], self.num_kv_heads,
self.num_queries_per_kv, key.shape[-1])
value = value[:, :,
None, :].expand(value.shape[0], self.num_kv_heads,
self.num_queries_per_kv,
value.shape[-1])
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize. # TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
out = xops.memory_efficient_attention_forward( out = xops.memory_efficient_attention_forward(
@ -110,7 +115,7 @@ class PagedAttention(nn.Module):
scale=self.scale, scale=self.scale,
) )
# TODO(woosuk): Unnecessary copy. Optimize. # TODO(woosuk): Unnecessary copy. Optimize.
output.copy_(out.squeeze(0)) output.copy_(out.view_as(output))
return output return output
def get_alibi_slopes(self) -> Optional[torch.Tensor]: def get_alibi_slopes(self) -> Optional[torch.Tensor]:
@ -427,10 +432,15 @@ class PagedAttentionWithALiBi(PagedAttention):
""" """
if self.num_kv_heads != self.num_heads: if self.num_kv_heads != self.num_heads:
# Project the key and value tensors to the desired number of heads. # Project the key and value tensors to the desired number of heads.
key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1) query = query.view(query.shape[0], self.num_kv_heads,
value = torch.repeat_interleave(value, self.num_queries_per_kv, query.shape[-1])
self.num_queries_per_kv, key = key[:, :,
dim=1) None, :].expand(key.shape[0], self.num_kv_heads,
self.num_queries_per_kv, key.shape[-1])
value = value[:, :,
None, :].expand(value.shape[0], self.num_kv_heads,
self.num_queries_per_kv,
value.shape[-1])
batch_size = input_metadata.num_prompts batch_size = input_metadata.num_prompts
seq_len = input_metadata.max_prompt_len seq_len = input_metadata.max_prompt_len
@ -443,7 +453,7 @@ class PagedAttentionWithALiBi(PagedAttention):
scale=self.scale, scale=self.scale,
) )
# TODO(woosuk): Unnecessary copy. Optimize. # TODO(woosuk): Unnecessary copy. Optimize.
output.copy_(out.view(-1, self.num_heads, self.head_size)) output.copy_(out.view_as(output))
return output return output
def get_alibi_slopes(self) -> Optional[torch.Tensor]: def get_alibi_slopes(self) -> Optional[torch.Tensor]: