[Bugfix] Fix type annotations in CPU model runner (#4256)
This commit is contained in:
parent
296cdf8ac7
commit
e73ed0f1c6
@ -73,7 +73,8 @@ class CPUModelRunner:
|
|||||||
def _prepare_prompt(
|
def _prepare_prompt(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int]]:
|
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
|
||||||
|
Optional[torch.Tensor]]:
|
||||||
assert len(seq_group_metadata_list) > 0
|
assert len(seq_group_metadata_list) > 0
|
||||||
input_tokens: List[int] = []
|
input_tokens: List[int] = []
|
||||||
input_positions: List[int] = []
|
input_positions: List[int] = []
|
||||||
@ -347,8 +348,8 @@ class CPUModelRunner:
|
|||||||
def prepare_input_tensors(
|
def prepare_input_tensors(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata,
|
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
|
||||||
SamplingMetadata]:
|
Optional[torch.Tensor]]:
|
||||||
multi_modal_input = None
|
multi_modal_input = None
|
||||||
if self.is_driver_worker:
|
if self.is_driver_worker:
|
||||||
# NOTE: We assume that all sequences in the group are all prompts or
|
# NOTE: We assume that all sequences in the group are all prompts or
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user