From 4b1ac23f53d0e714a4a48d2c8058438405c0fd07 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 23 Feb 2023 00:10:07 +0000 Subject: [PATCH] Fix slot mapping --- cacheflow/models/input_metadata.py | 2 +- cacheflow/worker/worker.py | 20 +++++++++++++++----- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/cacheflow/models/input_metadata.py b/cacheflow/models/input_metadata.py index 253b4389..8f50fbfb 100644 --- a/cacheflow/models/input_metadata.py +++ b/cacheflow/models/input_metadata.py @@ -14,7 +14,7 @@ class InputMetadata: block_tables: torch.Tensor, ) -> None: self.prompt_lens = prompt_lens - self.prompt_block_table = slot_mapping + self.slot_mapping = slot_mapping self.context_lens = context_lens self.max_context_len = max_context_len self.block_tables = block_tables diff --git a/cacheflow/worker/worker.py b/cacheflow/worker/worker.py index de272eec..1d67c3ef 100644 --- a/cacheflow/worker/worker.py +++ b/cacheflow/worker/worker.py @@ -83,12 +83,20 @@ class Worker: generation_seq_ids = sorted(generation_tokens.keys()) for seq_id in generation_seq_ids: input_tokens.append(generation_tokens[seq_id]) - input_positions.append(context_lens[seq_id] - 1) - generation_block_tables.append(block_tables[seq_id]) + position_id = context_lens[seq_id] - 1 + input_positions.append(position_id) + + block_table = block_tables[seq_id] + generation_block_tables.append(block_table) max_context_len = max(max_context_len, context_lens[seq_id]) max_num_blocks_per_seq = max( - max_num_blocks_per_seq, len(block_tables[seq_id])) + max_num_blocks_per_seq, len(block_table)) + + block_number = block_table[position_id // self.block_size] + block_offset = position_id % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.append(slot) # Optimization: Pad the input length to be a multiple of 8. # This is required for utilizing the Tensor Cores in NVIDIA GPUs. @@ -105,9 +113,11 @@ class Worker: context_lens_tensor = torch.tensor( [context_lens[seq_id] for seq_id in generation_seq_ids], dtype=torch.int, device=self.device) + padded_block_tables = [ + _pad_to_max(block_table, max_num_blocks_per_seq) + for block_table in generation_block_tables] block_tables_tensor = torch.tensor( - [_pad_to_max(block_table) for block_table in generation_block_tables], - dtype=int, device=self.device) + padded_block_tables, dtype=int, device=self.device) input_metadata = InputMetadata( prompt_lens=prompt_lens,