Rewrite torch.repeat_interleave to remove cpu synchronization (#1599)

This commit is contained in:
ljss 2023-11-21 09:46:32 +08:00 committed by GitHub
parent 19849db573
commit 819b18e7ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -95,10 +95,15 @@ class PagedAttention(nn.Module):
"""
if self.num_kv_heads != self.num_heads:
# Project the key and value tensors to the desired number of heads.
key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
value = torch.repeat_interleave(value,
self.num_queries_per_kv,
dim=1)
query = query.view(query.shape[0], self.num_kv_heads,
self.num_queries_per_kv, query.shape[-1])
key = key[:, :,
None, :].expand(key.shape[0], self.num_kv_heads,
self.num_queries_per_kv, key.shape[-1])
value = value[:, :,
None, :].expand(value.shape[0], self.num_kv_heads,
self.num_queries_per_kv,
value.shape[-1])
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
out = xops.memory_efficient_attention_forward(
@ -110,7 +115,7 @@ class PagedAttention(nn.Module):
scale=self.scale,
)
# TODO(woosuk): Unnecessary copy. Optimize.
output.copy_(out.squeeze(0))
output.copy_(out.view_as(output))
return output
def get_alibi_slopes(self) -> Optional[torch.Tensor]:
@ -427,10 +432,15 @@ class PagedAttentionWithALiBi(PagedAttention):
"""
if self.num_kv_heads != self.num_heads:
# Project the key and value tensors to the desired number of heads.
key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
value = torch.repeat_interleave(value,
self.num_queries_per_kv,
dim=1)
query = query.view(query.shape[0], self.num_kv_heads,
self.num_queries_per_kv, query.shape[-1])
key = key[:, :,
None, :].expand(key.shape[0], self.num_kv_heads,
self.num_queries_per_kv, key.shape[-1])
value = value[:, :,
None, :].expand(value.shape[0], self.num_kv_heads,
self.num_queries_per_kv,
value.shape[-1])
batch_size = input_metadata.num_prompts
seq_len = input_metadata.max_prompt_len
@ -443,7 +453,7 @@ class PagedAttentionWithALiBi(PagedAttention):
scale=self.scale,
)
# TODO(woosuk): Unnecessary copy. Optimize.
output.copy_(out.view(-1, self.num_heads, self.head_size))
output.copy_(out.view_as(output))
return output
def get_alibi_slopes(self) -> Optional[torch.Tensor]: