From 7b33743a728a69dbb92d7e52e2ecae3d399e0dc1 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 3 Sep 2023 20:44:40 -0700 Subject: [PATCH] [Gen] Add back num_last_tokens in gpt.py --- flash_attn/models/gpt.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index 3436745..66d5542 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -634,6 +634,8 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): input_ids, position_ids=position_ids, inference_params=inference_params ) assert hidden_states.ndim == 3, "sequence_parallel is not supported in generation mode" + if num_last_tokens > 0: + hidden_states = hidden_states[:, -num_last_tokens:] if self.project_out is not None: hidden_states = self.project_out(hidden_states) lm_logits = self.lm_head(hidden_states)