[Gen] Fix decode function not using top_p during iterative decoding

This commit is contained in:
Tri Dao 2023-08-26 15:14:22 -07:00
parent 847abe653c
commit 7a3bd55f1a

View File

@ -173,7 +173,7 @@ def decode(
teacher_outputs is None
or teacher_output_len <= inference_params.sequence_len_offset + 1
):
next_token = sample(logits, top_k=top_k, temperature=temperature)
next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
else:
next_token = teacher_outputs[:, inference_params.sequence_len_offset + 1]
sequences.append(next_token)