From f029ef94d72528aaccd7d48f973031e9dff23447 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Mon, 18 Sep 2023 11:49:40 -0700 Subject: [PATCH] Fix get_max_num_running_seqs for waiting and swapped seq groups (#1068) --- vllm/sequence.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index eac3af28..4f0adac7 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -250,8 +250,8 @@ class SequenceGroup: # generation stage, we will have `best_of` sequences running. return self.sampling_params.best_of # At sampling stages, return the number of actual sequences - # running. - return self.num_seqs(status=SequenceStatus.RUNNING) + # that are not finished yet. + return self.num_unfinished_seqs() def get_seqs( self, @@ -264,12 +264,23 @@ class SequenceGroup: seq for seq in self.seqs_dict.values() if seq.status == status ] + def get_unfinished_seqs(self) -> List[Sequence]: + return [ + seq for seq in self.seqs_dict.values() if not seq.is_finished() + ] + def get_finished_seqs(self) -> List[Sequence]: return [seq for seq in self.seqs_dict.values() if seq.is_finished()] def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: return len(self.get_seqs(status)) + def num_unfinished_seqs(self) -> int: + return len(self.get_unfinished_seqs()) + + def num_finished_seqs(self) -> int: + return len(self.get_finished_seqs()) + def find(self, seq_id: int) -> Sequence: if seq_id not in self.seqs_dict: raise ValueError(f"Sequence {seq_id} not found.")