[Gen] Implement top-k and top-p sampling

This commit is contained in:
Tri Dao 2023-01-07 17:00:02 -08:00
parent 11be742aa3
commit e02fd588aa
2 changed files with 57 additions and 12 deletions

View File

@ -8,7 +8,7 @@ from torch import Tensor
from einops import rearrange
from transformers.generation import GreedySearchDecoderOnlyOutput
from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput
@dataclass
@ -24,13 +24,58 @@ class InferenceParams:
lengths_per_sample: Optional[Tensor] = None
def greedy_decode(input_ids, model, max_length, fused_ft_kernel=True):
"""Greedy decoding. This is a very simple implementation.
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
def modify_logits_for_top_p_filtering(logits, top_p):
"""Set the logits for none top-p values to -inf."""
if top_p <= 0.0:
return
# First sort and calculate cumulative sum of probabilities.
sorted_logits, sorted_indices = torch.sort(logits, descending=False)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits = logits.masked_fill(indices_to_remove, float('-inf'))
def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
"""Sample from top-k logits.
Arguments:
logits: Tensor of shape (batch_size, vocab_size)
"""
if top_k == 1: # Short-circuit for greedy decoding
return logits.argmax(dim=-1)
else:
if top_p > 0.0:
assert top_p <= 1.0, 'top-p should be in (0, 1].'
if top_k > 0:
top_k = min(top_k, logits.size(-1)) # Safety check
logits_top, indices = torch.topk(logits, top_k, dim=-1)
logits_top /= temperature
modify_logits_for_top_p_filtering(logits_top, top_p)
return indices[
torch.arange(indices.shape[0], device=indices.device),
torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
]
else:
logits_top = logits / temperature
modify_logits_for_top_p_filtering(logits_top, top_p)
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, fused_ft_kernel=True):
"""Decoding, either greedy or with top-k or top-p sampling.
If top-k = 0, don't limit the number of candidates (pure sampling).
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
then top-p.
We assume that all sequences in the same batch have the same length.
Arguments:
input_ids: (batch, seq_len)
max_length: int
Returns: GreedySearchDecoderOnlyOutput, with the following fields:
Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
sequences: (batch, max_length)
scores: tuples of (batch, vocab_size)
"""
@ -41,7 +86,7 @@ def greedy_decode(input_ids, model, max_length, fused_ft_kernel=True):
with torch.inference_mode():
logits = model(input_ids, inference_params=inference_params).logits[:, -1]
scores.append(logits)
next_token = logits.argmax(dim=-1)
next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
sequences = [next_token]
inference_params.sequence_len_offset = seqlen_og
while True:
@ -50,12 +95,13 @@ def greedy_decode(input_ids, model, max_length, fused_ft_kernel=True):
logits = model(rearrange(next_token, 'b -> b 1'), position_ids=position_ids,
inference_params=inference_params).logits[:, -1]
scores.append(logits)
next_token = logits.argmax(dim=-1)
next_token = sample(logits, top_k=top_k, temperature=temperature)
sequences.append(next_token)
inference_params.sequence_len_offset += 1
if inference_params.sequence_len_offset >= max_length - 1:
break
return GreedySearchDecoderOnlyOutput(
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
return output_cls(
sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1),
scores=tuple(scores)
)
@ -63,9 +109,10 @@ def greedy_decode(input_ids, model, max_length, fused_ft_kernel=True):
class GenerationMixin:
def generate(self, input_ids, max_length, return_dict_in_generate=False, output_scores=False,
**kwargs):
output = greedy_decode(input_ids, self, max_length, **kwargs)
def generate(self, input_ids, max_length, top_k=1, top_p=0.0, temperature=1.0,
return_dict_in_generate=False, output_scores=False, **kwargs):
output = decode(input_ids, self, max_length, top_k=top_k, top_p=top_p,
temperature=temperature, **kwargs)
if not output_scores:
output.scores = None
return output if return_dict_in_generate else output.sequences

View File

@ -11,14 +11,12 @@ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHead
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gpt import remap_state_dict_gpt2
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.generation import greedy_decode
@pytest.mark.parametrize('fused_ft_kernel', [False, True])
@pytest.mark.parametrize('optimized', [False, True])
# @pytest.mark.parametrize('fused_ft_kernel', [False])
# @pytest.mark.parametrize('optimized', [True])
# @pytest.mark.parametrize('optimized', [True])
@pytest.mark.parametrize('rotary', [False, True])
@pytest.mark.parametrize('model_name', ["gpt2"])
def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):