[Gen] Minor fix to modify logits for top_p
This commit is contained in:
parent
1d817a8ffc
commit
8a326bbc9e
@ -32,7 +32,7 @@ class InferenceParams:
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
|
||||
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
|
||||
def modify_logits_for_top_k_filtering(logits, top_k):
|
||||
"""Set the logits for none top-k values to -inf."""
|
||||
"""Set the logits for none top-k values to -inf. Done in-place."""
|
||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||||
logits.masked_fill_(indices_to_remove, float("-Inf"))
|
||||
|
||||
@ -40,7 +40,7 @@ def modify_logits_for_top_k_filtering(logits, top_k):
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
|
||||
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
|
||||
def modify_logits_for_top_p_filtering(logits, top_p):
|
||||
"""Set the logits for none top-p values to -inf."""
|
||||
"""Set the logits for none top-p values to -inf. Done in-place."""
|
||||
if top_p <= 0.0 or top_p >= 1.0:
|
||||
return
|
||||
# First sort and calculate cumulative sum of probabilities.
|
||||
@ -52,7 +52,7 @@ def modify_logits_for_top_p_filtering(logits, top_p):
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
1, sorted_indices, sorted_indices_to_remove
|
||||
)
|
||||
logits = logits.masked_fill(indices_to_remove, float("-inf"))
|
||||
logits.masked_fill_(indices_to_remove, float("-inf"))
|
||||
|
||||
|
||||
def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user