Add seq_ids to input metadata

This commit is contained in:
Woosuk Kwon 2023-02-23 09:25:01 +00:00
parent 4f6f4967f6
commit 343cea3dbc
2 changed files with 6 additions and 3 deletions

View File

@ -7,12 +7,14 @@ class InputMetadata:
def __init__(
self,
seq_ids: List[int],
prompt_lens: List[int],
slot_mapping: torch.Tensor,
context_lens: torch.Tensor,
max_context_len: int,
block_tables: torch.Tensor,
) -> None:
self.seq_ids = seq_ids
self.prompt_lens = prompt_lens
self.slot_mapping = slot_mapping
self.context_lens = context_lens
@ -23,3 +25,4 @@ class InputMetadata:
self.num_generation_tokens = context_lens.shape[0]
self.max_num_blocks_per_seq = block_tables.shape[1]
assert self.num_generation_tokens == block_tables.shape[0]
assert self.num_prompts + self.num_generation_tokens == len(seq_ids)

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Tuple
from typing import Dict, List, Tuple, Union
import torch
@ -120,6 +120,7 @@ class Worker:
padded_block_tables, dtype=int, device=self.device)
input_metadata = InputMetadata(
seq_ids=prompt_seq_ids + generation_seq_ids,
prompt_lens=prompt_lens,
slot_mapping=slot_mapping_tensor,
context_lens=context_lens_tensor,
@ -128,7 +129,6 @@ class Worker:
)
return tokens_tensor, positions_tensor, input_metadata
@torch.inference_mode()
def execute_stage(
self,
@ -139,7 +139,7 @@ class Worker:
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, int],
) -> torch.Tensor:
) -> Union[torch.Tensor, Dict[int, Tuple[int, int]]]:
# Issue cache operations.
command_issued = False
if blocks_to_swap_in: