[Bugfix] Fix type annotations in CPU model runner (#4256)

This commit is contained in:
Woosuk Kwon 2024-04-22 00:54:16 -07:00 committed by GitHub
parent 296cdf8ac7
commit e73ed0f1c6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -73,7 +73,8 @@ class CPUModelRunner:
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int]]:
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
Optional[torch.Tensor]]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = []
input_positions: List[int] = []
@ -347,8 +348,8 @@ class CPUModelRunner:
def prepare_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata,
SamplingMetadata]:
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
Optional[torch.Tensor]]:
multi_modal_input = None
if self.is_driver_worker:
# NOTE: We assume that all sequences in the group are all prompts or