diff --git a/cacheflow/sequence.py b/cacheflow/sequence.py index 9f140f0b..364f3b2e 100644 --- a/cacheflow/sequence.py +++ b/cacheflow/sequence.py @@ -86,6 +86,9 @@ class SequenceGroup: return seq raise ValueError(f'Sequence {seq_id} not found.') + def is_finished(self) -> bool: + return all(seq.status == SequenceStatus.FINISHED for seq in self.seqs) + def __repr__(self) -> str: return (f'SequenceGroup(group_id={self.group_id}, ' f'num_seqs={len(self.seqs)})')