[BugFix] Define __eq__ in SequenceGroupOutputs (#1389)

This commit is contained in:
Woosuk Kwon 2023-10-17 01:09:44 -07:00 committed by GitHub
parent a132435204
commit f8a1e39fae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -401,6 +401,12 @@ class SequenceGroupOutputs:
return (f"SequenceGroupOutputs(samples={self.samples}, " return (f"SequenceGroupOutputs(samples={self.samples}, "
f"prompt_logprobs={self.prompt_logprobs})") f"prompt_logprobs={self.prompt_logprobs})")
def __eq__(self, other: object) -> bool:
if not isinstance(other, SequenceGroupOutputs):
raise NotImplementedError()
return (self.samples == other.samples
and self.prompt_logprobs == other.prompt_logprobs)
# For each sequence group, we generate a list of SequenceOutputs object, # For each sequence group, we generate a list of SequenceOutputs object,
# each of which contains one possible candidate for the next token. # each of which contains one possible candidate for the next token.