Add seq_ids to input metadata
This commit is contained in:
parent
4f6f4967f6
commit
343cea3dbc
@ -7,12 +7,14 @@ class InputMetadata:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
seq_ids: List[int],
|
||||||
prompt_lens: List[int],
|
prompt_lens: List[int],
|
||||||
slot_mapping: torch.Tensor,
|
slot_mapping: torch.Tensor,
|
||||||
context_lens: torch.Tensor,
|
context_lens: torch.Tensor,
|
||||||
max_context_len: int,
|
max_context_len: int,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
self.seq_ids = seq_ids
|
||||||
self.prompt_lens = prompt_lens
|
self.prompt_lens = prompt_lens
|
||||||
self.slot_mapping = slot_mapping
|
self.slot_mapping = slot_mapping
|
||||||
self.context_lens = context_lens
|
self.context_lens = context_lens
|
||||||
@ -23,3 +25,4 @@ class InputMetadata:
|
|||||||
self.num_generation_tokens = context_lens.shape[0]
|
self.num_generation_tokens = context_lens.shape[0]
|
||||||
self.max_num_blocks_per_seq = block_tables.shape[1]
|
self.max_num_blocks_per_seq = block_tables.shape[1]
|
||||||
assert self.num_generation_tokens == block_tables.shape[0]
|
assert self.num_generation_tokens == block_tables.shape[0]
|
||||||
|
assert self.num_prompts + self.num_generation_tokens == len(seq_ids)
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -120,6 +120,7 @@ class Worker:
|
|||||||
padded_block_tables, dtype=int, device=self.device)
|
padded_block_tables, dtype=int, device=self.device)
|
||||||
|
|
||||||
input_metadata = InputMetadata(
|
input_metadata = InputMetadata(
|
||||||
|
seq_ids=prompt_seq_ids + generation_seq_ids,
|
||||||
prompt_lens=prompt_lens,
|
prompt_lens=prompt_lens,
|
||||||
slot_mapping=slot_mapping_tensor,
|
slot_mapping=slot_mapping_tensor,
|
||||||
context_lens=context_lens_tensor,
|
context_lens=context_lens_tensor,
|
||||||
@ -128,7 +129,6 @@ class Worker:
|
|||||||
)
|
)
|
||||||
return tokens_tensor, positions_tensor, input_metadata
|
return tokens_tensor, positions_tensor, input_metadata
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_stage(
|
def execute_stage(
|
||||||
self,
|
self,
|
||||||
@ -139,7 +139,7 @@ class Worker:
|
|||||||
blocks_to_swap_in: Dict[int, int],
|
blocks_to_swap_in: Dict[int, int],
|
||||||
blocks_to_swap_out: Dict[int, int],
|
blocks_to_swap_out: Dict[int, int],
|
||||||
blocks_to_copy: Dict[int, int],
|
blocks_to_copy: Dict[int, int],
|
||||||
) -> torch.Tensor:
|
) -> Union[torch.Tensor, Dict[int, Tuple[int, int]]]:
|
||||||
# Issue cache operations.
|
# Issue cache operations.
|
||||||
command_issued = False
|
command_issued = False
|
||||||
if blocks_to_swap_in:
|
if blocks_to_swap_in:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user