# Copyright (c) 2022, Tri Dao. # Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31 from typing import Optional from dataclasses import dataclass, field import torch from torch import Tensor from einops import rearrange from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput @dataclass class InferenceParams: """Inference parameters that are passed to the main model in order to efficienly calculate and store the context during inference.""" max_sequence_len: int max_batch_size: int sequence_len_offset: int = 0 batch_size_offset: int = 0 key_value_memory_dict: dict = field(default_factory=dict) fused_ft_kernel: bool = False lengths_per_sample: Optional[Tensor] = None # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170 def modify_logits_for_top_p_filtering(logits, top_p): """Set the logits for none top-p values to -inf.""" if top_p <= 0.0: return # First sort and calculate cumulative sum of probabilities. sorted_logits, sorted_indices = torch.sort(logits, descending=False) cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) sorted_indices_to_remove = cumulative_probs <= (1 - top_p) # scatter sorted tensors to original indexing indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) logits = logits.masked_fill(indices_to_remove, float('-inf')) def sample(logits, top_k=1, top_p=0.0, temperature=1.0): """Sample from top-k logits. Arguments: logits: Tensor of shape (batch_size, vocab_size) """ if top_k == 1: # Short-circuit for greedy decoding return logits.argmax(dim=-1) else: if top_p > 0.0: assert top_p <= 1.0, 'top-p should be in (0, 1].' if top_k > 0: top_k = min(top_k, logits.size(-1)) # Safety check logits_top, indices = torch.topk(logits, top_k, dim=-1) logits_top /= temperature modify_logits_for_top_p_filtering(logits_top, top_p) return indices[ torch.arange(indices.shape[0], device=indices.device), torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1) ] else: logits_top = logits / temperature modify_logits_for_top_p_filtering(logits_top, top_p) return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1) def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, fused_ft_kernel=True): """Decoding, either greedy or with top-k or top-p sampling. If top-k = 0, don't limit the number of candidates (pure sampling). Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first, then top-p. We assume that all sequences in the same batch have the same length. Arguments: input_ids: (batch, seq_len) max_length: int Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields: sequences: (batch, max_length) scores: tuples of (batch, vocab_size) """ batch_size, seqlen_og = input_ids.shape inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size, fused_ft_kernel=fused_ft_kernel) scores = [] with torch.inference_mode(): logits = model(input_ids, inference_params=inference_params).logits[:, -1] scores.append(logits) next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature) sequences = [next_token] inference_params.sequence_len_offset = seqlen_og while True: position_ids = torch.full((batch_size, 1), inference_params.sequence_len_offset, dtype=torch.long, device=input_ids.device) logits = model(rearrange(next_token, 'b -> b 1'), position_ids=position_ids, inference_params=inference_params).logits[:, -1] scores.append(logits) next_token = sample(logits, top_k=top_k, temperature=temperature) sequences.append(next_token) inference_params.sequence_len_offset += 1 if inference_params.sequence_len_offset >= max_length - 1: break output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput return output_cls( sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1), scores=tuple(scores) ) class GenerationMixin: def generate(self, input_ids, max_length, top_k=1, top_p=0.0, temperature=1.0, return_dict_in_generate=False, output_scores=False, **kwargs): output = decode(input_ids, self, max_length, top_k=top_k, top_p=top_p, temperature=temperature, **kwargs) if not output_scores: output.scores = None return output if return_dict_in_generate else output.sequences