From 7a3bd55f1aafdcfc8167358c180261db71f7c0d8 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 26 Aug 2023 15:14:22 -0700 Subject: [PATCH] [Gen] Fix decode function not using top_p during iterative decoding --- flash_attn/utils/generation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash_attn/utils/generation.py b/flash_attn/utils/generation.py index ed774d8..f91870b 100644 --- a/flash_attn/utils/generation.py +++ b/flash_attn/utils/generation.py @@ -173,7 +173,7 @@ def decode( teacher_outputs is None or teacher_output_len <= inference_params.sequence_len_offset + 1 ): - next_token = sample(logits, top_k=top_k, temperature=temperature) + next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature) else: next_token = teacher_outputs[:, inference_params.sequence_len_offset + 1] sequences.append(next_token)