[GPT] Generalize last_token_only arg to num_last_tokens
This commit is contained in:
parent
7a3bd55f1a
commit
f8aea6ead0
@ -621,18 +621,17 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
|
||||
batch_size, max_seqlen, dtype=dtype, **kwargs
|
||||
)
|
||||
|
||||
def forward(self, input_ids, position_ids=None, inference_params=None, last_token_only=False):
|
||||
def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0):
|
||||
"""
|
||||
inference_params: for generation. Adapted from Megatron-LM (and Apex)
|
||||
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
|
||||
last_token_only: whether to return the logit for the last token only,
|
||||
of shape (batch_size, vocab_size)
|
||||
num_last_tokens: if > 0, only return the logits for the last n tokens
|
||||
"""
|
||||
hidden_states = self.transformer(
|
||||
input_ids, position_ids=position_ids, inference_params=inference_params
|
||||
)
|
||||
if last_token_only:
|
||||
hidden_states = hidden_states[:, -1]
|
||||
if num_last_tokens > 0:
|
||||
hidden_states = hidden_states[:, -num_last_tokens:]
|
||||
if self.project_out is not None:
|
||||
hidden_states = self.project_out(hidden_states)
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
@ -27,11 +27,19 @@ class InferenceParams:
|
||||
lengths_per_sample: Optional[Tensor] = None
|
||||
|
||||
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
|
||||
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
|
||||
def modify_logits_for_top_k_filtering(logits, top_k):
|
||||
"""Set the logits for none top-k values to -inf."""
|
||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||||
logits.masked_fill_(indices_to_remove, float("-Inf"))
|
||||
|
||||
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
|
||||
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
|
||||
def modify_logits_for_top_p_filtering(logits, top_p):
|
||||
"""Set the logits for none top-p values to -inf."""
|
||||
if top_p <= 0.0:
|
||||
if top_p <= 0.0 or top_p >= 1.0:
|
||||
return
|
||||
# First sort and calculate cumulative sum of probabilities.
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=False)
|
||||
@ -58,14 +66,16 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
|
||||
if top_k > 0:
|
||||
top_k = min(top_k, logits.size(-1)) # Safety check
|
||||
logits_top, indices = torch.topk(logits, top_k, dim=-1)
|
||||
logits_top /= temperature
|
||||
if temperature != 1.0:
|
||||
logits_top /= temperature
|
||||
modify_logits_for_top_p_filtering(logits_top, top_p)
|
||||
return indices[
|
||||
torch.arange(indices.shape[0], device=indices.device),
|
||||
torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),
|
||||
]
|
||||
else:
|
||||
logits_top = logits / temperature
|
||||
# Clone so that when we modify for top_p we don't change the original logits
|
||||
logits_top = logits / temperature if temperature != 1.0 else logits.clone()
|
||||
modify_logits_for_top_p_filtering(logits_top, top_p)
|
||||
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(
|
||||
dim=-1
|
||||
@ -131,8 +141,8 @@ def decode(
|
||||
input_ids,
|
||||
position_ids=position_ids,
|
||||
inference_params=inference_params,
|
||||
last_token_only=True,
|
||||
).logits
|
||||
num_last_tokens=1,
|
||||
).logits.squeeze(dim=1)
|
||||
else:
|
||||
return model._decoding_cache.run(
|
||||
input_ids, position_ids, inference_params.sequence_len_offset
|
||||
@ -149,7 +159,9 @@ def decode(
|
||||
torch.distributed.barrier()
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
logits = model(input_ids, inference_params=inference_params, last_token_only=True).logits
|
||||
logits = model(
|
||||
input_ids, inference_params=inference_params, num_last_tokens=1
|
||||
).logits.squeeze(dim=1)
|
||||
logits = logits_postprocess_fn(logits)
|
||||
scores.append(logits if not cg else logits.clone())
|
||||
if teacher_outputs is None or teacher_output_len <= seqlen_og:
|
||||
@ -165,9 +177,9 @@ def decode(
|
||||
dtype=torch.long,
|
||||
device=input_ids.device,
|
||||
)
|
||||
logits = logits_postprocess_fn(logits_forward_fn(
|
||||
rearrange(next_token, "b -> b 1"), position_ids, inference_params
|
||||
))
|
||||
logits = logits_postprocess_fn(
|
||||
logits_forward_fn(rearrange(next_token, "b -> b 1"), position_ids, inference_params)
|
||||
)
|
||||
scores.append(logits)
|
||||
if (
|
||||
teacher_outputs is None
|
||||
@ -357,7 +369,7 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
|
||||
input_ids,
|
||||
position_ids=position_ids,
|
||||
inference_params=inference_params,
|
||||
last_token_only=True,
|
||||
num_last_tokens=1,
|
||||
).logits
|
||||
s.synchronize()
|
||||
# This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
|
||||
@ -374,8 +386,8 @@ def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None,
|
||||
input_ids,
|
||||
position_ids=position_ids,
|
||||
inference_params=inference_params,
|
||||
last_token_only=True,
|
||||
).logits
|
||||
num_last_tokens=1,
|
||||
).logits.squeeze(dim=1)
|
||||
|
||||
def run(new_input_ids, new_position_ids, seqlen):
|
||||
inference_params.lengths_per_sample[:] = seqlen
|
||||
|
||||
@ -355,8 +355,6 @@ def test_gpt2_multiple_token_generation(model_name, optimized):
|
||||
config.fused_dropout_add_ln = True
|
||||
# fused_ft_kernel currently doesn't work with multiple tokens at a time
|
||||
|
||||
# if not rotary, we load the weight from HF but ignore the position embeddings.
|
||||
# The model would be nonsense but it doesn't matter for the test.
|
||||
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
|
||||
model.eval()
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user