[GPT] Add option to only return the logit for the last token

This commit is contained in:
Tri Dao 2023-04-20 17:21:08 -07:00
parent 311d6606bf
commit 3da42d24b1
2 changed files with 12 additions and 8 deletions

View File

@ -426,20 +426,24 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
if self.process_group is not None:
sync_shared_params(self, self.process_group)
def forward(self, input_ids, position_ids=None, inference_params=None):
def forward(self, input_ids, position_ids=None, inference_params=None, last_token_only=False):
"""
inference_params: for generation. Adapted from Megatron-LM (and Apex)
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
last_token_only: whether to return the logit for the last token only,
of shape (batch_size, vocab_size)
"""
hidden_states = self.transformer(input_ids, position_ids=position_ids,
inference_params=inference_params)
if last_token_only:
hidden_states = hidden_states[:, -1]
if self.project_out is not None:
hidden_states = self.project_out(hidden_states)
lm_logits = self.lm_head(hidden_states)
# During inference, we want the full logit for sampling
if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None:
lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group)
lm_logits = rearrange(lm_logits, '(n b) s d -> b s (n d)', b=hidden_states.shape[0])
lm_logits = rearrange(lm_logits, '(n b) ... d -> b ... (n d)', b=hidden_states.shape[0])
CausalLMOutput = namedtuple('CausalLMOutput', ['logits'])
return CausalLMOutput(logits=lm_logits)

View File

@ -112,7 +112,7 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
torch.distributed.barrier()
torch.cuda.synchronize()
start = time.time()
logits = model(input_ids, inference_params=inference_params).logits[:, -1]
logits = model(input_ids, inference_params=inference_params, last_token_only=True).logits
if vocab_size is not None:
logits = logits[..., :vocab_size]
scores.append(logits if not cg else logits.clone())
@ -127,7 +127,7 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
dtype=torch.long, device=input_ids.device)
if not cg:
logits = model(rearrange(next_token, 'b -> b 1'), position_ids=position_ids,
inference_params=inference_params).logits[:, -1]
inference_params=inference_params, last_token_only=True).logits
else:
logits = model._decoding_cache.run(rearrange(next_token, 'b -> b 1'), position_ids,
inference_params.sequence_len_offset)
@ -269,8 +269,8 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(n_warmups):
logits = model(input_ids, position_ids=position_ids,
inference_params=inference_params).logits[:, -1]
logits = model(input_ids, position_ids=position_ids, inference_params=inference_params,
last_token_only=True).logits
s.synchronize()
# This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
# which requires that graph launch and non-captured launch to not overlap (I think,
@ -282,8 +282,8 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
# To allow capture, automatically sets a side stream as the current stream in the context
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, pool=mempool):
logits = model(input_ids, position_ids=position_ids,
inference_params=inference_params).logits[:, -1]
logits = model(input_ids, position_ids=position_ids, inference_params=inference_params,
last_token_only=True).logits
def run(new_input_ids, new_position_ids, seqlen):
inference_params.lengths_per_sample[:] = seqlen