From f8aea6ead01028599c8f47a56691dad3828935c0 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 26 Aug 2023 20:47:53 -0700 Subject: [PATCH] [GPT] Generalize last_token_only arg to num_last_tokens --- flash_attn/models/gpt.py | 9 ++++----- flash_attn/utils/generation.py | 36 ++++++++++++++++++++++------------ tests/models/test_gpt.py | 2 -- 3 files changed, 28 insertions(+), 19 deletions(-) diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index b2e0076..5a3eb85 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -621,18 +621,17 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): batch_size, max_seqlen, dtype=dtype, **kwargs ) - def forward(self, input_ids, position_ids=None, inference_params=None, last_token_only=False): + def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0): """ inference_params: for generation. Adapted from Megatron-LM (and Apex) https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470 - last_token_only: whether to return the logit for the last token only, - of shape (batch_size, vocab_size) + num_last_tokens: if > 0, only return the logits for the last n tokens """ hidden_states = self.transformer( input_ids, position_ids=position_ids, inference_params=inference_params ) - if last_token_only: - hidden_states = hidden_states[:, -1] + if num_last_tokens > 0: + hidden_states = hidden_states[:, -num_last_tokens:] if self.project_out is not None: hidden_states = self.project_out(hidden_states) lm_logits = self.lm_head(hidden_states) diff --git a/flash_attn/utils/generation.py b/flash_attn/utils/generation.py index f91870b..e6d4794 100644 --- a/flash_attn/utils/generation.py +++ b/flash_attn/utils/generation.py @@ -27,11 +27,19 @@ class InferenceParams: lengths_per_sample: Optional[Tensor] = None +# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py +# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231 +def modify_logits_for_top_k_filtering(logits, top_k): + """Set the logits for none top-k values to -inf.""" + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits.masked_fill_(indices_to_remove, float("-Inf")) + + # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170 def modify_logits_for_top_p_filtering(logits, top_p): """Set the logits for none top-p values to -inf.""" - if top_p <= 0.0: + if top_p <= 0.0 or top_p >= 1.0: return # First sort and calculate cumulative sum of probabilities. sorted_logits, sorted_indices = torch.sort(logits, descending=False) @@ -58,14 +66,16 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0): if top_k > 0: top_k = min(top_k, logits.size(-1)) # Safety check logits_top, indices = torch.topk(logits, top_k, dim=-1) - logits_top /= temperature + if temperature != 1.0: + logits_top /= temperature modify_logits_for_top_p_filtering(logits_top, top_p) return indices[ torch.arange(indices.shape[0], device=indices.device), torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1), ] else: - logits_top = logits / temperature + # Clone so that when we modify for top_p we don't change the original logits + logits_top = logits / temperature if temperature != 1.0 else logits.clone() modify_logits_for_top_p_filtering(logits_top, top_p) return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze( dim=-1 @@ -131,8 +141,8 @@ def decode( input_ids, position_ids=position_ids, inference_params=inference_params, - last_token_only=True, - ).logits + num_last_tokens=1, + ).logits.squeeze(dim=1) else: return model._decoding_cache.run( input_ids, position_ids, inference_params.sequence_len_offset @@ -149,7 +159,9 @@ def decode( torch.distributed.barrier() torch.cuda.synchronize() start = time.time() - logits = model(input_ids, inference_params=inference_params, last_token_only=True).logits + logits = model( + input_ids, inference_params=inference_params, num_last_tokens=1 + ).logits.squeeze(dim=1) logits = logits_postprocess_fn(logits) scores.append(logits if not cg else logits.clone()) if teacher_outputs is None or teacher_output_len <= seqlen_og: @@ -165,9 +177,9 @@ def decode( dtype=torch.long, device=input_ids.device, ) - logits = logits_postprocess_fn(logits_forward_fn( - rearrange(next_token, "b -> b 1"), position_ids, inference_params - )) + logits = logits_postprocess_fn( + logits_forward_fn(rearrange(next_token, "b -> b 1"), position_ids, inference_params) + ) scores.append(logits) if ( teacher_outputs is None @@ -357,7 +369,7 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None, input_ids, position_ids=position_ids, inference_params=inference_params, - last_token_only=True, + num_last_tokens=1, ).logits s.synchronize() # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0, @@ -374,8 +386,8 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None, input_ids, position_ids=position_ids, inference_params=inference_params, - last_token_only=True, - ).logits + num_last_tokens=1, + ).logits.squeeze(dim=1) def run(new_input_ids, new_position_ids, seqlen): inference_params.lengths_per_sample[:] = seqlen diff --git a/tests/models/test_gpt.py b/tests/models/test_gpt.py index b48352d..4ac519e 100644 --- a/tests/models/test_gpt.py +++ b/tests/models/test_gpt.py @@ -355,8 +355,6 @@ def test_gpt2_multiple_token_generation(model_name, optimized): config.fused_dropout_add_ln = True # fused_ft_kernel currently doesn't work with multiple tokens at a time - # if not rotary, we load the weight from HF but ignore the position embeddings. - # The model would be nonsense but it doesn't matter for the test. model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype) model.eval()