diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index 3436745..66d5542 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -634,6 +634,8 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): input_ids, position_ids=position_ids, inference_params=inference_params ) assert hidden_states.ndim == 3, "sequence_parallel is not supported in generation mode" + if num_last_tokens > 0: + hidden_states = hidden_states[:, -num_last_tokens:] if self.project_out is not None: hidden_states = self.project_out(hidden_states) lm_logits = self.lm_head(hidden_states)