diff --git a/flash_attn/utils/generation.py b/flash_attn/utils/generation.py index e50fd27..3b4cae8 100644 --- a/flash_attn/utils/generation.py +++ b/flash_attn/utils/generation.py @@ -8,7 +8,7 @@ from torch import Tensor from einops import rearrange -from transformers.generation import GreedySearchDecoderOnlyOutput +from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput @dataclass @@ -24,13 +24,58 @@ class InferenceParams: lengths_per_sample: Optional[Tensor] = None -def greedy_decode(input_ids, model, max_length, fused_ft_kernel=True): - """Greedy decoding. This is a very simple implementation. +# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py +# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170 +def modify_logits_for_top_p_filtering(logits, top_p): + """Set the logits for none top-p values to -inf.""" + if top_p <= 0.0: + return + # First sort and calculate cumulative sum of probabilities. + sorted_logits, sorted_indices = torch.sort(logits, descending=False) + cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) + # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs <= (1 - top_p) + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + logits = logits.masked_fill(indices_to_remove, float('-inf')) + + +def sample(logits, top_k=1, top_p=0.0, temperature=1.0): + """Sample from top-k logits. + Arguments: + logits: Tensor of shape (batch_size, vocab_size) + """ + if top_k == 1: # Short-circuit for greedy decoding + return logits.argmax(dim=-1) + else: + if top_p > 0.0: + assert top_p <= 1.0, 'top-p should be in (0, 1].' + if top_k > 0: + top_k = min(top_k, logits.size(-1)) # Safety check + logits_top, indices = torch.topk(logits, top_k, dim=-1) + logits_top /= temperature + modify_logits_for_top_p_filtering(logits_top, top_p) + return indices[ + torch.arange(indices.shape[0], device=indices.device), + torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1) + ] + else: + logits_top = logits / temperature + modify_logits_for_top_p_filtering(logits_top, top_p) + return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1) + + +def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, fused_ft_kernel=True): + """Decoding, either greedy or with top-k or top-p sampling. + If top-k = 0, don't limit the number of candidates (pure sampling). + Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first, + then top-p. We assume that all sequences in the same batch have the same length. + Arguments: input_ids: (batch, seq_len) max_length: int - Returns: GreedySearchDecoderOnlyOutput, with the following fields: + Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields: sequences: (batch, max_length) scores: tuples of (batch, vocab_size) """ @@ -41,7 +86,7 @@ def greedy_decode(input_ids, model, max_length, fused_ft_kernel=True): with torch.inference_mode(): logits = model(input_ids, inference_params=inference_params).logits[:, -1] scores.append(logits) - next_token = logits.argmax(dim=-1) + next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature) sequences = [next_token] inference_params.sequence_len_offset = seqlen_og while True: @@ -50,12 +95,13 @@ def greedy_decode(input_ids, model, max_length, fused_ft_kernel=True): logits = model(rearrange(next_token, 'b -> b 1'), position_ids=position_ids, inference_params=inference_params).logits[:, -1] scores.append(logits) - next_token = logits.argmax(dim=-1) + next_token = sample(logits, top_k=top_k, temperature=temperature) sequences.append(next_token) inference_params.sequence_len_offset += 1 if inference_params.sequence_len_offset >= max_length - 1: break - return GreedySearchDecoderOnlyOutput( + output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput + return output_cls( sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1), scores=tuple(scores) ) @@ -63,9 +109,10 @@ def greedy_decode(input_ids, model, max_length, fused_ft_kernel=True): class GenerationMixin: - def generate(self, input_ids, max_length, return_dict_in_generate=False, output_scores=False, - **kwargs): - output = greedy_decode(input_ids, self, max_length, **kwargs) + def generate(self, input_ids, max_length, top_k=1, top_p=0.0, temperature=1.0, + return_dict_in_generate=False, output_scores=False, **kwargs): + output = decode(input_ids, self, max_length, top_k=top_k, top_p=top_p, + temperature=temperature, **kwargs) if not output_scores: output.scores = None return output if return_dict_in_generate else output.sequences diff --git a/tests/models/test_gpt_generation.py b/tests/models/test_gpt_generation.py index 4ddebf5..793d247 100644 --- a/tests/models/test_gpt_generation.py +++ b/tests/models/test_gpt_generation.py @@ -11,14 +11,12 @@ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHead from flash_attn.models.gpt import GPTLMHeadModel from flash_attn.models.gpt import remap_state_dict_gpt2 from flash_attn.utils.pretrained import state_dict_from_pretrained -from flash_attn.utils.generation import greedy_decode @pytest.mark.parametrize('fused_ft_kernel', [False, True]) @pytest.mark.parametrize('optimized', [False, True]) # @pytest.mark.parametrize('fused_ft_kernel', [False]) # @pytest.mark.parametrize('optimized', [True]) -# @pytest.mark.parametrize('optimized', [True]) @pytest.mark.parametrize('rotary', [False, True]) @pytest.mark.parametrize('model_name', ["gpt2"]) def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):