Refactor and annotate types for attention
This commit is contained in:
parent
7f22f90e8c
commit
762fd1c3fa
@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -30,24 +30,34 @@ class OPTCacheFlowAttention(nn.Module):
|
||||
|
||||
def multi_query_kv_attention(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
prompt_lens: List[int],
|
||||
) -> None:
|
||||
# FIXME(woosuk): Replace this with a custom op call.
|
||||
attention_mask = torch.triu(
|
||||
torch.ones(query.shape[0], key.shape[0]), diagonal=1) * -1e5
|
||||
attention_mask = attention_mask.to(dtype=query.dtype, device=query.device)
|
||||
out = self._masked_attention(query, key, value, attention_mask)
|
||||
output.copy_(out, non_blocking=True)
|
||||
# FIXME(woosuk): Replace the following with a custom op.
|
||||
start_idx = 0
|
||||
for prompt_len in prompt_lens:
|
||||
out = output[start_idx:start_idx + prompt_len]
|
||||
q = query[start_idx:start_idx + prompt_len]
|
||||
k = key[start_idx:start_idx + prompt_len]
|
||||
v = value[start_idx:start_idx + prompt_len]
|
||||
|
||||
attention_mask = torch.triu(
|
||||
torch.ones(q.shape[0], k.shape[0]), diagonal=1) * -1e5
|
||||
attention_mask = attention_mask.to(dtype=q.dtype, device=q.device)
|
||||
attention_out = self._masked_attention(q, k, v, attention_mask)
|
||||
out.copy_(attention_out, non_blocking=True)
|
||||
|
||||
start_idx += prompt_len
|
||||
|
||||
def single_query_cached_kv_attention(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
output: torch.Tensor, # [num_generation_tokens, num_heads, head_size]
|
||||
query: torch.Tensor, # [num_generation_tokens, num_heads, head_size]
|
||||
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
value_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size]
|
||||
input_metadata: InputMetadata,
|
||||
) -> None:
|
||||
num_heads = value_cache.shape[1]
|
||||
@ -82,15 +92,18 @@ class OPTCacheFlowAttention(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
query: torch.Tensor, # [num_tokens, num_heads * head_size]
|
||||
key: torch.Tensor, # [num_tokens, num_heads * head_size]
|
||||
value: torch.Tensor, # [num_tokens, num_heads * head_size]
|
||||
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
value_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size]
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
# Prune out invalid tokens.
|
||||
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
|
||||
# Pre-allocate the output tensor.
|
||||
output = torch.empty_like(query)
|
||||
|
||||
# Prune out paddings if any.
|
||||
query = query[:input_metadata.num_valid_tokens]
|
||||
key = key[:input_metadata.num_valid_tokens]
|
||||
value = value[:input_metadata.num_valid_tokens]
|
||||
@ -101,18 +114,11 @@ class OPTCacheFlowAttention(nn.Module):
|
||||
query = query.view(-1, num_heads, head_size)
|
||||
key = key.view(-1, num_heads, head_size)
|
||||
value = value.view(-1, num_heads, head_size)
|
||||
output = output.view(-1, num_heads, head_size)
|
||||
|
||||
# Compute the attention op for prompts.
|
||||
output = torch.empty_like(query)
|
||||
start_idx = 0
|
||||
for i in range(input_metadata.num_prompts):
|
||||
prompt_len = input_metadata.prompt_lens[i]
|
||||
out = output[start_idx:start_idx + prompt_len]
|
||||
q = query[start_idx:start_idx + prompt_len]
|
||||
k = key[start_idx:start_idx + prompt_len]
|
||||
v = value[start_idx:start_idx + prompt_len]
|
||||
self.multi_query_kv_attention(out, q, k, v)
|
||||
start_idx += prompt_len
|
||||
self.multi_query_kv_attention(
|
||||
output, query, key, value, input_metadata.prompt_lens)
|
||||
|
||||
# Wait until the cache op is done.
|
||||
if cache_event is not None:
|
||||
@ -124,6 +130,7 @@ class OPTCacheFlowAttention(nn.Module):
|
||||
|
||||
if input_metadata.num_generation_tokens > 0:
|
||||
# Compute the attention op for generation tokens.
|
||||
start_idx = sum(input_metadata.prompt_lens)
|
||||
self.single_query_cached_kv_attention(
|
||||
output[start_idx:],
|
||||
query[start_idx:],
|
||||
@ -132,4 +139,5 @@ class OPTCacheFlowAttention(nn.Module):
|
||||
input_metadata)
|
||||
|
||||
# Reshape the output tensor.
|
||||
# NOTE(woosuk): The output tensor may include paddings.
|
||||
return output.view(-1, num_heads * head_size)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user