[Gen] Add back num_last_tokens in gpt.py

This commit is contained in:
Tri Dao 2023-09-03 20:44:40 -07:00
parent 5953c4f58c
commit 7b33743a72

View File

@ -634,6 +634,8 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
input_ids, position_ids=position_ids, inference_params=inference_params
)
assert hidden_states.ndim == 3, "sequence_parallel is not supported in generation mode"
if num_last_tokens > 0:
hidden_states = hidden_states[:, -num_last_tokens:]
if self.project_out is not None:
hidden_states = self.project_out(hidden_states)
lm_logits = self.lm_head(hidden_states)