[Gen] Add back num_last_tokens in gpt.py
This commit is contained in:
parent
5953c4f58c
commit
7b33743a72
@ -634,6 +634,8 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
|
||||
input_ids, position_ids=position_ids, inference_params=inference_params
|
||||
)
|
||||
assert hidden_states.ndim == 3, "sequence_parallel is not supported in generation mode"
|
||||
if num_last_tokens > 0:
|
||||
hidden_states = hidden_states[:, -num_last_tokens:]
|
||||
if self.project_out is not None:
|
||||
hidden_states = self.project_out(hidden_states)
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user