[Fix] Fix best_of behavior when n=1 (#3298)
This commit is contained in:
parent
9e8744a545
commit
4b59f00e91
@ -87,12 +87,12 @@ class RequestOutput:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
|
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
|
||||||
# Get the top-n sequences.
|
|
||||||
n = seq_group.sampling_params.n
|
|
||||||
seqs = seq_group.get_seqs()
|
seqs = seq_group.get_seqs()
|
||||||
if n == 1:
|
if len(seqs) == 1:
|
||||||
top_n_seqs = seqs
|
top_n_seqs = seqs
|
||||||
else:
|
else:
|
||||||
|
# Get the top-n sequences.
|
||||||
|
n = seq_group.sampling_params.n
|
||||||
if seq_group.sampling_params.use_beam_search:
|
if seq_group.sampling_params.use_beam_search:
|
||||||
sorting_key = lambda seq: seq.get_beam_search_score(
|
sorting_key = lambda seq: seq.get_beam_search_score(
|
||||||
seq_group.sampling_params.length_penalty)
|
seq_group.sampling_params.length_penalty)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user