Rewrite torch.repeat_interleave to remove cpu synchronization (#1599)
This commit is contained in:
parent
19849db573
commit
819b18e7ba
@ -95,10 +95,15 @@ class PagedAttention(nn.Module):
|
||||
"""
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
# Project the key and value tensors to the desired number of heads.
|
||||
key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
|
||||
value = torch.repeat_interleave(value,
|
||||
self.num_queries_per_kv,
|
||||
dim=1)
|
||||
query = query.view(query.shape[0], self.num_kv_heads,
|
||||
self.num_queries_per_kv, query.shape[-1])
|
||||
key = key[:, :,
|
||||
None, :].expand(key.shape[0], self.num_kv_heads,
|
||||
self.num_queries_per_kv, key.shape[-1])
|
||||
value = value[:, :,
|
||||
None, :].expand(value.shape[0], self.num_kv_heads,
|
||||
self.num_queries_per_kv,
|
||||
value.shape[-1])
|
||||
|
||||
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
@ -110,7 +115,7 @@ class PagedAttention(nn.Module):
|
||||
scale=self.scale,
|
||||
)
|
||||
# TODO(woosuk): Unnecessary copy. Optimize.
|
||||
output.copy_(out.squeeze(0))
|
||||
output.copy_(out.view_as(output))
|
||||
return output
|
||||
|
||||
def get_alibi_slopes(self) -> Optional[torch.Tensor]:
|
||||
@ -427,10 +432,15 @@ class PagedAttentionWithALiBi(PagedAttention):
|
||||
"""
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
# Project the key and value tensors to the desired number of heads.
|
||||
key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
|
||||
value = torch.repeat_interleave(value,
|
||||
self.num_queries_per_kv,
|
||||
dim=1)
|
||||
query = query.view(query.shape[0], self.num_kv_heads,
|
||||
self.num_queries_per_kv, query.shape[-1])
|
||||
key = key[:, :,
|
||||
None, :].expand(key.shape[0], self.num_kv_heads,
|
||||
self.num_queries_per_kv, key.shape[-1])
|
||||
value = value[:, :,
|
||||
None, :].expand(value.shape[0], self.num_kv_heads,
|
||||
self.num_queries_per_kv,
|
||||
value.shape[-1])
|
||||
batch_size = input_metadata.num_prompts
|
||||
seq_len = input_metadata.max_prompt_len
|
||||
|
||||
@ -443,7 +453,7 @@ class PagedAttentionWithALiBi(PagedAttention):
|
||||
scale=self.scale,
|
||||
)
|
||||
# TODO(woosuk): Unnecessary copy. Optimize.
|
||||
output.copy_(out.view(-1, self.num_heads, self.head_size))
|
||||
output.copy_(out.view_as(output))
|
||||
return output
|
||||
|
||||
def get_alibi_slopes(self) -> Optional[torch.Tensor]:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user