From 762fd1c3faf17001b943cb68e5d08e7d7fe59119 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 24 Feb 2023 08:58:46 +0000 Subject: [PATCH] Refactor and annotate types for attention --- cacheflow/models/attention.py | 72 +++++++++++++++++++---------------- 1 file changed, 40 insertions(+), 32 deletions(-) diff --git a/cacheflow/models/attention.py b/cacheflow/models/attention.py index 068c7d49..71218f7e 100644 --- a/cacheflow/models/attention.py +++ b/cacheflow/models/attention.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import List, Optional import torch import torch.nn as nn @@ -30,24 +30,34 @@ class OPTCacheFlowAttention(nn.Module): def multi_query_kv_attention( self, - output: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, + output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] + query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] + key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] + value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] + prompt_lens: List[int], ) -> None: - # FIXME(woosuk): Replace this with a custom op call. - attention_mask = torch.triu( - torch.ones(query.shape[0], key.shape[0]), diagonal=1) * -1e5 - attention_mask = attention_mask.to(dtype=query.dtype, device=query.device) - out = self._masked_attention(query, key, value, attention_mask) - output.copy_(out, non_blocking=True) + # FIXME(woosuk): Replace the following with a custom op. + start_idx = 0 + for prompt_len in prompt_lens: + out = output[start_idx:start_idx + prompt_len] + q = query[start_idx:start_idx + prompt_len] + k = key[start_idx:start_idx + prompt_len] + v = value[start_idx:start_idx + prompt_len] + + attention_mask = torch.triu( + torch.ones(q.shape[0], k.shape[0]), diagonal=1) * -1e5 + attention_mask = attention_mask.to(dtype=q.dtype, device=q.device) + attention_out = self._masked_attention(q, k, v, attention_mask) + out.copy_(attention_out, non_blocking=True) + + start_idx += prompt_len def single_query_cached_kv_attention( self, - output: torch.Tensor, - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, + output: torch.Tensor, # [num_generation_tokens, num_heads, head_size] + query: torch.Tensor, # [num_generation_tokens, num_heads, head_size] + key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x] + value_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size] input_metadata: InputMetadata, ) -> None: num_heads = value_cache.shape[1] @@ -82,15 +92,18 @@ class OPTCacheFlowAttention(nn.Module): def forward( self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, + query: torch.Tensor, # [num_tokens, num_heads * head_size] + key: torch.Tensor, # [num_tokens, num_heads * head_size] + value: torch.Tensor, # [num_tokens, num_heads * head_size] + key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x] + value_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size] input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], - ) -> torch.Tensor: - # Prune out invalid tokens. + ) -> torch.Tensor: # [num_tokens, num_heads * head_size] + # Pre-allocate the output tensor. + output = torch.empty_like(query) + + # Prune out paddings if any. query = query[:input_metadata.num_valid_tokens] key = key[:input_metadata.num_valid_tokens] value = value[:input_metadata.num_valid_tokens] @@ -101,18 +114,11 @@ class OPTCacheFlowAttention(nn.Module): query = query.view(-1, num_heads, head_size) key = key.view(-1, num_heads, head_size) value = value.view(-1, num_heads, head_size) + output = output.view(-1, num_heads, head_size) # Compute the attention op for prompts. - output = torch.empty_like(query) - start_idx = 0 - for i in range(input_metadata.num_prompts): - prompt_len = input_metadata.prompt_lens[i] - out = output[start_idx:start_idx + prompt_len] - q = query[start_idx:start_idx + prompt_len] - k = key[start_idx:start_idx + prompt_len] - v = value[start_idx:start_idx + prompt_len] - self.multi_query_kv_attention(out, q, k, v) - start_idx += prompt_len + self.multi_query_kv_attention( + output, query, key, value, input_metadata.prompt_lens) # Wait until the cache op is done. if cache_event is not None: @@ -124,6 +130,7 @@ class OPTCacheFlowAttention(nn.Module): if input_metadata.num_generation_tokens > 0: # Compute the attention op for generation tokens. + start_idx = sum(input_metadata.prompt_lens) self.single_query_cached_kv_attention( output[start_idx:], query[start_idx:], @@ -132,4 +139,5 @@ class OPTCacheFlowAttention(nn.Module): input_metadata) # Reshape the output tensor. + # NOTE(woosuk): The output tensor may include paddings. return output.view(-1, num_heads * head_size)