From 932844f1cd781aa926439ee2394edfb2f9e696f7 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 23 Feb 2023 23:02:25 +0000 Subject: [PATCH] Fix attention --- cacheflow/models/attention.py | 16 ++++++++++------ cacheflow/models/input_metadata.py | 11 +++++++++++ 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/cacheflow/models/attention.py b/cacheflow/models/attention.py index f6d48a3a..6e6b8e98 100644 --- a/cacheflow/models/attention.py +++ b/cacheflow/models/attention.py @@ -53,20 +53,19 @@ class OPTCacheFlowAttention(nn.Module): context_len = int(input_metadata.context_lens[i]) keys = [] + values = [] for j in range(context_len): - block_number = block_table[j // block_size] + block_number = int(block_table[j // block_size]) block_offset = j % block_size + k = key_cache[block_number, :, :, block_offset, :] k = k.reshape(num_heads, head_size) keys.append(k) - keys = torch.stack(keys, dim=0) - values = [] - for j in range(context_len): - block_number = block_table[j // block_size] - block_offset = j % block_size v = value_cache[block_number, :, block_offset, :] values.append(v) + + keys = torch.stack(keys, dim=0) values = torch.stack(values, dim=0) q = q.unsqueeze(0) @@ -87,6 +86,11 @@ class OPTCacheFlowAttention(nn.Module): input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: + # Prune out invalid tokens. + query = query[:input_metadata.num_valid_tokens] + key = key[:input_metadata.num_valid_tokens] + value = value[:input_metadata.num_valid_tokens] + # Reshape the input tensors. num_heads = value_cache.shape[1] head_size = value_cache.shape[3] diff --git a/cacheflow/models/input_metadata.py b/cacheflow/models/input_metadata.py index b311e203..86cc2e8f 100644 --- a/cacheflow/models/input_metadata.py +++ b/cacheflow/models/input_metadata.py @@ -11,6 +11,7 @@ class InputMetadata: prompt_lens: List[int], slot_mapping: torch.Tensor, context_lens: torch.Tensor, + # FIXME: Rename max_context_len: int, block_tables: torch.Tensor, ) -> None: @@ -23,9 +24,19 @@ class InputMetadata: self.num_prompts = len(prompt_lens) self.num_generation_tokens = context_lens.shape[0] + self.num_valid_tokens = len(slot_mapping) if block_tables.numel() > 0: self.max_num_blocks_per_seq = block_tables.shape[1] else: self.max_num_blocks_per_seq = 0 assert self.num_generation_tokens == block_tables.shape[0] assert self.num_prompts + self.num_generation_tokens == len(seq_ids) + + def __repr__(self) -> str: + return (f'InputMetadata(' + f'seq_ids={self.seq_ids}, ' + f'num_prompts={self.num_prompts}, ' + f'num_generation_tokens={self.num_generation_tokens}, ' + f'num_valid_tokens={self.num_valid_tokens}, ' + f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, ' + f'max_context_len={self.max_context_len})')