[GPT] Generalize last_token_only arg to num_last_tokens

This commit is contained in:
Tri Dao 2023-08-26 20:47:53 -07:00
parent 7a3bd55f1a
commit f8aea6ead0
3 changed files with 28 additions and 19 deletions

View File

@ -621,18 +621,17 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
batch_size, max_seqlen, dtype=dtype, **kwargs
)
def forward(self, input_ids, position_ids=None, inference_params=None, last_token_only=False):
def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0):
"""
inference_params: for generation. Adapted from Megatron-LM (and Apex)
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
last_token_only: whether to return the logit for the last token only,
of shape (batch_size, vocab_size)
num_last_tokens: if > 0, only return the logits for the last n tokens
"""
hidden_states = self.transformer(
input_ids, position_ids=position_ids, inference_params=inference_params
)
if last_token_only:
hidden_states = hidden_states[:, -1]
if num_last_tokens > 0:
hidden_states = hidden_states[:, -num_last_tokens:]
if self.project_out is not None:
hidden_states = self.project_out(hidden_states)
lm_logits = self.lm_head(hidden_states)

View File

@ -27,11 +27,19 @@ class InferenceParams:
lengths_per_sample: Optional[Tensor] = None
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
def modify_logits_for_top_k_filtering(logits, top_k):
"""Set the logits for none top-k values to -inf."""
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits.masked_fill_(indices_to_remove, float("-Inf"))
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
def modify_logits_for_top_p_filtering(logits, top_p):
"""Set the logits for none top-p values to -inf."""
if top_p <= 0.0:
if top_p <= 0.0 or top_p >= 1.0:
return
# First sort and calculate cumulative sum of probabilities.
sorted_logits, sorted_indices = torch.sort(logits, descending=False)
@ -58,14 +66,16 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
if top_k > 0:
top_k = min(top_k, logits.size(-1)) # Safety check
logits_top, indices = torch.topk(logits, top_k, dim=-1)
logits_top /= temperature
if temperature != 1.0:
logits_top /= temperature
modify_logits_for_top_p_filtering(logits_top, top_p)
return indices[
torch.arange(indices.shape[0], device=indices.device),
torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),
]
else:
logits_top = logits / temperature
# Clone so that when we modify for top_p we don't change the original logits
logits_top = logits / temperature if temperature != 1.0 else logits.clone()
modify_logits_for_top_p_filtering(logits_top, top_p)
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(
dim=-1
@ -131,8 +141,8 @@ def decode(
input_ids,
position_ids=position_ids,
inference_params=inference_params,
last_token_only=True,
).logits
num_last_tokens=1,
).logits.squeeze(dim=1)
else:
return model._decoding_cache.run(
input_ids, position_ids, inference_params.sequence_len_offset
@ -149,7 +159,9 @@ def decode(
torch.distributed.barrier()
torch.cuda.synchronize()
start = time.time()
logits = model(input_ids, inference_params=inference_params, last_token_only=True).logits
logits = model(
input_ids, inference_params=inference_params, num_last_tokens=1
).logits.squeeze(dim=1)
logits = logits_postprocess_fn(logits)
scores.append(logits if not cg else logits.clone())
if teacher_outputs is None or teacher_output_len <= seqlen_og:
@ -165,9 +177,9 @@ def decode(
dtype=torch.long,
device=input_ids.device,
)
logits = logits_postprocess_fn(logits_forward_fn(
rearrange(next_token, "b -> b 1"), position_ids, inference_params
))
logits = logits_postprocess_fn(
logits_forward_fn(rearrange(next_token, "b -> b 1"), position_ids, inference_params)
)
scores.append(logits)
if (
teacher_outputs is None
@ -357,7 +369,7 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
input_ids,
position_ids=position_ids,
inference_params=inference_params,
last_token_only=True,
num_last_tokens=1,
).logits
s.synchronize()
# This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
@ -374,8 +386,8 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
input_ids,
position_ids=position_ids,
inference_params=inference_params,
last_token_only=True,
).logits
num_last_tokens=1,
).logits.squeeze(dim=1)
def run(new_input_ids, new_position_ids, seqlen):
inference_params.lengths_per_sample[:] = seqlen

View File

@ -355,8 +355,6 @@ def test_gpt2_multiple_token_generation(model_name, optimized):
config.fused_dropout_add_ln = True
# fused_ft_kernel currently doesn't work with multiple tokens at a time
# if not rotary, we load the weight from HF but ignore the position embeddings.
# The model would be nonsense but it doesn't matter for the test.
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
model.eval()