[Bugfix] Fix type annotations in CPU model runner (#4256)
This commit is contained in:
parent
296cdf8ac7
commit
e73ed0f1c6
@ -73,7 +73,8 @@ class CPUModelRunner:
|
||||
def _prepare_prompt(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int]]:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
|
||||
Optional[torch.Tensor]]:
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[int] = []
|
||||
input_positions: List[int] = []
|
||||
@ -347,8 +348,8 @@ class CPUModelRunner:
|
||||
def prepare_input_tensors(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata,
|
||||
SamplingMetadata]:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
|
||||
Optional[torch.Tensor]]:
|
||||
multi_modal_input = None
|
||||
if self.is_driver_worker:
|
||||
# NOTE: We assume that all sequences in the group are all prompts or
|
||||
|
||||
Loading…
Reference in New Issue
Block a user