[Gen] Fix decode function not using top_p during iterative decoding
This commit is contained in:
parent
847abe653c
commit
7a3bd55f1a
@ -173,7 +173,7 @@ def decode(
|
||||
teacher_outputs is None
|
||||
or teacher_output_len <= inference_params.sequence_len_offset + 1
|
||||
):
|
||||
next_token = sample(logits, top_k=top_k, temperature=temperature)
|
||||
next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
|
||||
else:
|
||||
next_token = teacher_outputs[:, inference_params.sequence_len_offset + 1]
|
||||
sequences.append(next_token)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user