[BugFix] Define __eq__ in SequenceGroupOutputs (#1389)
This commit is contained in:
parent
a132435204
commit
f8a1e39fae
@ -401,6 +401,12 @@ class SequenceGroupOutputs:
|
||||
return (f"SequenceGroupOutputs(samples={self.samples}, "
|
||||
f"prompt_logprobs={self.prompt_logprobs})")
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, SequenceGroupOutputs):
|
||||
raise NotImplementedError()
|
||||
return (self.samples == other.samples
|
||||
and self.prompt_logprobs == other.prompt_logprobs)
|
||||
|
||||
|
||||
# For each sequence group, we generate a list of SequenceOutputs object,
|
||||
# each of which contains one possible candidate for the next token.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user