From 8a326bbc9e4f6eb7deca5737693b84c7c570569c Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 29 Aug 2023 14:28:44 -0700 Subject: [PATCH] [Gen] Minor fix to modify logits for top_p --- flash_attn/utils/generation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flash_attn/utils/generation.py b/flash_attn/utils/generation.py index de874c9..b9a5ec3 100644 --- a/flash_attn/utils/generation.py +++ b/flash_attn/utils/generation.py @@ -32,7 +32,7 @@ class InferenceParams: # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231 def modify_logits_for_top_k_filtering(logits, top_k): - """Set the logits for none top-k values to -inf.""" + """Set the logits for none top-k values to -inf. Done in-place.""" indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits.masked_fill_(indices_to_remove, float("-Inf")) @@ -40,7 +40,7 @@ def modify_logits_for_top_k_filtering(logits, top_k): # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170 def modify_logits_for_top_p_filtering(logits, top_p): - """Set the logits for none top-p values to -inf.""" + """Set the logits for none top-p values to -inf. Done in-place.""" if top_p <= 0.0 or top_p >= 1.0: return # First sort and calculate cumulative sum of probabilities. @@ -52,7 +52,7 @@ def modify_logits_for_top_p_filtering(logits, top_p): indices_to_remove = sorted_indices_to_remove.scatter( 1, sorted_indices, sorted_indices_to_remove ) - logits = logits.masked_fill(indices_to_remove, float("-inf")) + logits.masked_fill_(indices_to_remove, float("-inf")) def sample(logits, top_k=1, top_p=0.0, temperature=1.0):