Fix get_max_num_running_seqs for waiting and swapped seq groups (#1068)
This commit is contained in:
parent
95592fa00a
commit
f029ef94d7
@ -250,8 +250,8 @@ class SequenceGroup:
|
|||||||
# generation stage, we will have `best_of` sequences running.
|
# generation stage, we will have `best_of` sequences running.
|
||||||
return self.sampling_params.best_of
|
return self.sampling_params.best_of
|
||||||
# At sampling stages, return the number of actual sequences
|
# At sampling stages, return the number of actual sequences
|
||||||
# running.
|
# that are not finished yet.
|
||||||
return self.num_seqs(status=SequenceStatus.RUNNING)
|
return self.num_unfinished_seqs()
|
||||||
|
|
||||||
def get_seqs(
|
def get_seqs(
|
||||||
self,
|
self,
|
||||||
@ -264,12 +264,23 @@ class SequenceGroup:
|
|||||||
seq for seq in self.seqs_dict.values() if seq.status == status
|
seq for seq in self.seqs_dict.values() if seq.status == status
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def get_unfinished_seqs(self) -> List[Sequence]:
|
||||||
|
return [
|
||||||
|
seq for seq in self.seqs_dict.values() if not seq.is_finished()
|
||||||
|
]
|
||||||
|
|
||||||
def get_finished_seqs(self) -> List[Sequence]:
|
def get_finished_seqs(self) -> List[Sequence]:
|
||||||
return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
|
return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
|
||||||
|
|
||||||
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
|
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
|
||||||
return len(self.get_seqs(status))
|
return len(self.get_seqs(status))
|
||||||
|
|
||||||
|
def num_unfinished_seqs(self) -> int:
|
||||||
|
return len(self.get_unfinished_seqs())
|
||||||
|
|
||||||
|
def num_finished_seqs(self) -> int:
|
||||||
|
return len(self.get_finished_seqs())
|
||||||
|
|
||||||
def find(self, seq_id: int) -> Sequence:
|
def find(self, seq_id: int) -> Sequence:
|
||||||
if seq_id not in self.seqs_dict:
|
if seq_id not in self.seqs_dict:
|
||||||
raise ValueError(f"Sequence {seq_id} not found.")
|
raise ValueError(f"Sequence {seq_id} not found.")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user