diff --git a/cacheflow/models/input_metadata.py b/cacheflow/models/input_metadata.py index 8f50fbfb..5963e67e 100644 --- a/cacheflow/models/input_metadata.py +++ b/cacheflow/models/input_metadata.py @@ -7,12 +7,14 @@ class InputMetadata: def __init__( self, + seq_ids: List[int], prompt_lens: List[int], slot_mapping: torch.Tensor, context_lens: torch.Tensor, max_context_len: int, block_tables: torch.Tensor, ) -> None: + self.seq_ids = seq_ids self.prompt_lens = prompt_lens self.slot_mapping = slot_mapping self.context_lens = context_lens @@ -23,3 +25,4 @@ class InputMetadata: self.num_generation_tokens = context_lens.shape[0] self.max_num_blocks_per_seq = block_tables.shape[1] assert self.num_generation_tokens == block_tables.shape[0] + assert self.num_prompts + self.num_generation_tokens == len(seq_ids) diff --git a/cacheflow/worker/worker.py b/cacheflow/worker/worker.py index 1d67c3ef..52a84128 100644 --- a/cacheflow/worker/worker.py +++ b/cacheflow/worker/worker.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Union import torch @@ -120,6 +120,7 @@ class Worker: padded_block_tables, dtype=int, device=self.device) input_metadata = InputMetadata( + seq_ids=prompt_seq_ids + generation_seq_ids, prompt_lens=prompt_lens, slot_mapping=slot_mapping_tensor, context_lens=context_lens_tensor, @@ -128,7 +129,6 @@ class Worker: ) return tokens_tensor, positions_tensor, input_metadata - @torch.inference_mode() def execute_stage( self, @@ -139,7 +139,7 @@ class Worker: blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, int], - ) -> torch.Tensor: + ) -> Union[torch.Tensor, Dict[int, Tuple[int, int]]]: # Issue cache operations. command_issued = False if blocks_to_swap_in: