Fix get_max_num_running_seqs for waiting and swapped seq groups (#1068)
This commit is contained in:
parent
95592fa00a
commit
f029ef94d7
@ -250,8 +250,8 @@ class SequenceGroup:
|
||||
# generation stage, we will have `best_of` sequences running.
|
||||
return self.sampling_params.best_of
|
||||
# At sampling stages, return the number of actual sequences
|
||||
# running.
|
||||
return self.num_seqs(status=SequenceStatus.RUNNING)
|
||||
# that are not finished yet.
|
||||
return self.num_unfinished_seqs()
|
||||
|
||||
def get_seqs(
|
||||
self,
|
||||
@ -264,12 +264,23 @@ class SequenceGroup:
|
||||
seq for seq in self.seqs_dict.values() if seq.status == status
|
||||
]
|
||||
|
||||
def get_unfinished_seqs(self) -> List[Sequence]:
|
||||
return [
|
||||
seq for seq in self.seqs_dict.values() if not seq.is_finished()
|
||||
]
|
||||
|
||||
def get_finished_seqs(self) -> List[Sequence]:
|
||||
return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
|
||||
|
||||
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
|
||||
return len(self.get_seqs(status))
|
||||
|
||||
def num_unfinished_seqs(self) -> int:
|
||||
return len(self.get_unfinished_seqs())
|
||||
|
||||
def num_finished_seqs(self) -> int:
|
||||
return len(self.get_finished_seqs())
|
||||
|
||||
def find(self, seq_id: int) -> Sequence:
|
||||
if seq_id not in self.seqs_dict:
|
||||
raise ValueError(f"Sequence {seq_id} not found.")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user