[GPT] Add option to only return the logit for the last token
This commit is contained in:
parent
311d6606bf
commit
3da42d24b1
@ -426,20 +426,24 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
|
||||
if self.process_group is not None:
|
||||
sync_shared_params(self, self.process_group)
|
||||
|
||||
def forward(self, input_ids, position_ids=None, inference_params=None):
|
||||
def forward(self, input_ids, position_ids=None, inference_params=None, last_token_only=False):
|
||||
"""
|
||||
inference_params: for generation. Adapted from Megatron-LM (and Apex)
|
||||
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
|
||||
last_token_only: whether to return the logit for the last token only,
|
||||
of shape (batch_size, vocab_size)
|
||||
"""
|
||||
hidden_states = self.transformer(input_ids, position_ids=position_ids,
|
||||
inference_params=inference_params)
|
||||
if last_token_only:
|
||||
hidden_states = hidden_states[:, -1]
|
||||
if self.project_out is not None:
|
||||
hidden_states = self.project_out(hidden_states)
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
# During inference, we want the full logit for sampling
|
||||
if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None:
|
||||
lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group)
|
||||
lm_logits = rearrange(lm_logits, '(n b) s d -> b s (n d)', b=hidden_states.shape[0])
|
||||
lm_logits = rearrange(lm_logits, '(n b) ... d -> b ... (n d)', b=hidden_states.shape[0])
|
||||
CausalLMOutput = namedtuple('CausalLMOutput', ['logits'])
|
||||
return CausalLMOutput(logits=lm_logits)
|
||||
|
||||
|
||||
@ -112,7 +112,7 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
|
||||
torch.distributed.barrier()
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
logits = model(input_ids, inference_params=inference_params).logits[:, -1]
|
||||
logits = model(input_ids, inference_params=inference_params, last_token_only=True).logits
|
||||
if vocab_size is not None:
|
||||
logits = logits[..., :vocab_size]
|
||||
scores.append(logits if not cg else logits.clone())
|
||||
@ -127,7 +127,7 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
|
||||
dtype=torch.long, device=input_ids.device)
|
||||
if not cg:
|
||||
logits = model(rearrange(next_token, 'b -> b 1'), position_ids=position_ids,
|
||||
inference_params=inference_params).logits[:, -1]
|
||||
inference_params=inference_params, last_token_only=True).logits
|
||||
else:
|
||||
logits = model._decoding_cache.run(rearrange(next_token, 'b -> b 1'), position_ids,
|
||||
inference_params.sequence_len_offset)
|
||||
@ -269,8 +269,8 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
|
||||
s.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(s):
|
||||
for _ in range(n_warmups):
|
||||
logits = model(input_ids, position_ids=position_ids,
|
||||
inference_params=inference_params).logits[:, -1]
|
||||
logits = model(input_ids, position_ids=position_ids, inference_params=inference_params,
|
||||
last_token_only=True).logits
|
||||
s.synchronize()
|
||||
# This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
|
||||
# which requires that graph launch and non-captured launch to not overlap (I think,
|
||||
@ -282,8 +282,8 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
|
||||
# To allow capture, automatically sets a side stream as the current stream in the context
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, pool=mempool):
|
||||
logits = model(input_ids, position_ids=position_ids,
|
||||
inference_params=inference_params).logits[:, -1]
|
||||
logits = model(input_ids, position_ids=position_ids, inference_params=inference_params,
|
||||
last_token_only=True).logits
|
||||
|
||||
def run(new_input_ids, new_position_ids, seqlen):
|
||||
inference_params.lengths_per_sample[:] = seqlen
|
||||
|
||||
Loading…
Reference in New Issue
Block a user