diff --git a/flash_attn/utils/generation.py b/flash_attn/utils/generation.py index ed774d8..f91870b 100644 --- a/flash_attn/utils/generation.py +++ b/flash_attn/utils/generation.py @@ -173,7 +173,7 @@ def decode( teacher_outputs is None or teacher_output_len <= inference_params.sequence_len_offset + 1 ): - next_token = sample(logits, top_k=top_k, temperature=temperature) + next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature) else: next_token = teacher_outputs[:, inference_params.sequence_len_offset + 1] sequences.append(next_token)