[Gen] Add back num_last_tokens in gpt.py
This commit is contained in:
parent
5953c4f58c
commit
7b33743a72
@ -634,6 +634,8 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
|
|||||||
input_ids, position_ids=position_ids, inference_params=inference_params
|
input_ids, position_ids=position_ids, inference_params=inference_params
|
||||||
)
|
)
|
||||||
assert hidden_states.ndim == 3, "sequence_parallel is not supported in generation mode"
|
assert hidden_states.ndim == 3, "sequence_parallel is not supported in generation mode"
|
||||||
|
if num_last_tokens > 0:
|
||||||
|
hidden_states = hidden_states[:, -num_last_tokens:]
|
||||||
if self.project_out is not None:
|
if self.project_out is not None:
|
||||||
hidden_states = self.project_out(hidden_states)
|
hidden_states = self.project_out(hidden_states)
|
||||||
lm_logits = self.lm_head(hidden_states)
|
lm_logits = self.lm_head(hidden_states)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user