2022-12-28 12:58:50 +08:00
|
|
|
# Copyright (c) 2022, Tri Dao.
|
|
|
|
|
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
|
2023-01-04 14:10:31 +08:00
|
|
|
from typing import Optional
|
|
|
|
|
|
2022-12-28 12:58:50 +08:00
|
|
|
from dataclasses import dataclass, field
|
|
|
|
|
import torch
|
2023-01-04 14:10:31 +08:00
|
|
|
from torch import Tensor
|
2022-12-28 12:58:50 +08:00
|
|
|
|
|
|
|
|
from einops import rearrange
|
|
|
|
|
|
2023-01-08 09:00:02 +08:00
|
|
|
from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput
|
2022-12-28 12:58:50 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class InferenceParams:
|
|
|
|
|
"""Inference parameters that are passed to the main model in order
|
|
|
|
|
to efficienly calculate and store the context during inference."""
|
|
|
|
|
max_sequence_len: int
|
|
|
|
|
max_batch_size: int
|
|
|
|
|
sequence_len_offset: int = 0
|
|
|
|
|
batch_size_offset: int = 0
|
|
|
|
|
key_value_memory_dict: dict = field(default_factory=dict)
|
2023-01-04 14:10:31 +08:00
|
|
|
fused_ft_kernel: bool = False
|
|
|
|
|
lengths_per_sample: Optional[Tensor] = None
|
2022-12-28 12:58:50 +08:00
|
|
|
|
|
|
|
|
|
2023-01-08 09:00:02 +08:00
|
|
|
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
|
|
|
|
|
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
|
|
|
|
|
def modify_logits_for_top_p_filtering(logits, top_p):
|
|
|
|
|
"""Set the logits for none top-p values to -inf."""
|
|
|
|
|
if top_p <= 0.0:
|
|
|
|
|
return
|
|
|
|
|
# First sort and calculate cumulative sum of probabilities.
|
|
|
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=False)
|
|
|
|
|
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
|
|
|
|
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
|
|
|
|
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
|
|
|
|
|
# scatter sorted tensors to original indexing
|
|
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
|
|
|
|
logits = logits.masked_fill(indices_to_remove, float('-inf'))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
|
|
|
|
|
"""Sample from top-k logits.
|
|
|
|
|
Arguments:
|
|
|
|
|
logits: Tensor of shape (batch_size, vocab_size)
|
|
|
|
|
"""
|
|
|
|
|
if top_k == 1: # Short-circuit for greedy decoding
|
|
|
|
|
return logits.argmax(dim=-1)
|
|
|
|
|
else:
|
|
|
|
|
if top_p > 0.0:
|
|
|
|
|
assert top_p <= 1.0, 'top-p should be in (0, 1].'
|
|
|
|
|
if top_k > 0:
|
|
|
|
|
top_k = min(top_k, logits.size(-1)) # Safety check
|
|
|
|
|
logits_top, indices = torch.topk(logits, top_k, dim=-1)
|
|
|
|
|
logits_top /= temperature
|
|
|
|
|
modify_logits_for_top_p_filtering(logits_top, top_p)
|
|
|
|
|
return indices[
|
|
|
|
|
torch.arange(indices.shape[0], device=indices.device),
|
|
|
|
|
torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
|
|
|
|
|
]
|
|
|
|
|
else:
|
|
|
|
|
logits_top = logits / temperature
|
|
|
|
|
modify_logits_for_top_p_filtering(logits_top, top_p)
|
|
|
|
|
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, fused_ft_kernel=True):
|
|
|
|
|
"""Decoding, either greedy or with top-k or top-p sampling.
|
|
|
|
|
If top-k = 0, don't limit the number of candidates (pure sampling).
|
|
|
|
|
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
|
|
|
|
|
then top-p.
|
2022-12-28 12:58:50 +08:00
|
|
|
We assume that all sequences in the same batch have the same length.
|
2023-01-08 09:00:02 +08:00
|
|
|
|
2022-12-28 12:58:50 +08:00
|
|
|
Arguments:
|
|
|
|
|
input_ids: (batch, seq_len)
|
|
|
|
|
max_length: int
|
2023-01-08 09:00:02 +08:00
|
|
|
Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
|
2022-12-28 12:58:50 +08:00
|
|
|
sequences: (batch, max_length)
|
|
|
|
|
scores: tuples of (batch, vocab_size)
|
|
|
|
|
"""
|
|
|
|
|
batch_size, seqlen_og = input_ids.shape
|
2023-01-04 14:10:31 +08:00
|
|
|
inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size,
|
|
|
|
|
fused_ft_kernel=fused_ft_kernel)
|
2022-12-28 12:58:50 +08:00
|
|
|
scores = []
|
|
|
|
|
with torch.inference_mode():
|
|
|
|
|
logits = model(input_ids, inference_params=inference_params).logits[:, -1]
|
|
|
|
|
scores.append(logits)
|
2023-01-08 09:00:02 +08:00
|
|
|
next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
|
2022-12-28 12:58:50 +08:00
|
|
|
sequences = [next_token]
|
|
|
|
|
inference_params.sequence_len_offset = seqlen_og
|
|
|
|
|
while True:
|
|
|
|
|
position_ids = torch.full((batch_size, 1), inference_params.sequence_len_offset,
|
2022-12-28 13:18:45 +08:00
|
|
|
dtype=torch.long, device=input_ids.device)
|
2022-12-28 12:58:50 +08:00
|
|
|
logits = model(rearrange(next_token, 'b -> b 1'), position_ids=position_ids,
|
|
|
|
|
inference_params=inference_params).logits[:, -1]
|
|
|
|
|
scores.append(logits)
|
2023-01-08 09:00:02 +08:00
|
|
|
next_token = sample(logits, top_k=top_k, temperature=temperature)
|
2022-12-28 12:58:50 +08:00
|
|
|
sequences.append(next_token)
|
|
|
|
|
inference_params.sequence_len_offset += 1
|
|
|
|
|
if inference_params.sequence_len_offset >= max_length - 1:
|
|
|
|
|
break
|
2023-01-08 09:00:02 +08:00
|
|
|
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
|
|
|
|
|
return output_cls(
|
2022-12-28 12:58:50 +08:00
|
|
|
sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1),
|
|
|
|
|
scores=tuple(scores)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GenerationMixin:
|
|
|
|
|
|
2023-01-08 09:00:02 +08:00
|
|
|
def generate(self, input_ids, max_length, top_k=1, top_p=0.0, temperature=1.0,
|
|
|
|
|
return_dict_in_generate=False, output_scores=False, **kwargs):
|
|
|
|
|
output = decode(input_ids, self, max_length, top_k=top_k, top_p=top_p,
|
|
|
|
|
temperature=temperature, **kwargs)
|
2022-12-28 12:58:50 +08:00
|
|
|
if not output_scores:
|
|
|
|
|
output.scores = None
|
|
|
|
|
return output if return_dict_in_generate else output.sequences
|