diff --git a/vllm/sequence.py b/vllm/sequence.py index 5847626b..ecfaee6e 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -401,6 +401,12 @@ class SequenceGroupOutputs: return (f"SequenceGroupOutputs(samples={self.samples}, " f"prompt_logprobs={self.prompt_logprobs})") + def __eq__(self, other: object) -> bool: + if not isinstance(other, SequenceGroupOutputs): + raise NotImplementedError() + return (self.samples == other.samples + and self.prompt_logprobs == other.prompt_logprobs) + # For each sequence group, we generate a list of SequenceOutputs object, # each of which contains one possible candidate for the next token.