2022-12-28 12:58:50 +08:00
|
|
|
# Copyright (c) 2022, Tri Dao.
|
|
|
|
|
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
|
2023-01-04 14:10:31 +08:00
|
|
|
from typing import Optional
|
|
|
|
|
|
2022-12-28 12:58:50 +08:00
|
|
|
from dataclasses import dataclass, field
|
|
|
|
|
import torch
|
2023-01-04 14:10:31 +08:00
|
|
|
from torch import Tensor
|
2022-12-28 12:58:50 +08:00
|
|
|
|
|
|
|
|
from einops import rearrange
|
|
|
|
|
|
|
|
|
|
from transformers.generation import GreedySearchDecoderOnlyOutput
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class InferenceParams:
|
|
|
|
|
"""Inference parameters that are passed to the main model in order
|
|
|
|
|
to efficienly calculate and store the context during inference."""
|
|
|
|
|
max_sequence_len: int
|
|
|
|
|
max_batch_size: int
|
|
|
|
|
sequence_len_offset: int = 0
|
|
|
|
|
batch_size_offset: int = 0
|
|
|
|
|
key_value_memory_dict: dict = field(default_factory=dict)
|
2023-01-04 14:10:31 +08:00
|
|
|
fused_ft_kernel: bool = False
|
|
|
|
|
lengths_per_sample: Optional[Tensor] = None
|
2022-12-28 12:58:50 +08:00
|
|
|
|
|
|
|
|
|
2023-01-04 14:10:31 +08:00
|
|
|
def greedy_decode(input_ids, model, max_length, fused_ft_kernel=True):
|
2022-12-28 12:58:50 +08:00
|
|
|
"""Greedy decoding. This is a very simple implementation.
|
|
|
|
|
We assume that all sequences in the same batch have the same length.
|
|
|
|
|
Arguments:
|
|
|
|
|
input_ids: (batch, seq_len)
|
|
|
|
|
max_length: int
|
|
|
|
|
Returns: GreedySearchDecoderOnlyOutput, with the following fields:
|
|
|
|
|
sequences: (batch, max_length)
|
|
|
|
|
scores: tuples of (batch, vocab_size)
|
|
|
|
|
"""
|
|
|
|
|
batch_size, seqlen_og = input_ids.shape
|
2023-01-04 14:10:31 +08:00
|
|
|
inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size,
|
|
|
|
|
fused_ft_kernel=fused_ft_kernel)
|
2022-12-28 12:58:50 +08:00
|
|
|
scores = []
|
|
|
|
|
with torch.inference_mode():
|
|
|
|
|
logits = model(input_ids, inference_params=inference_params).logits[:, -1]
|
|
|
|
|
scores.append(logits)
|
|
|
|
|
next_token = logits.argmax(dim=-1)
|
|
|
|
|
sequences = [next_token]
|
|
|
|
|
inference_params.sequence_len_offset = seqlen_og
|
|
|
|
|
while True:
|
|
|
|
|
position_ids = torch.full((batch_size, 1), inference_params.sequence_len_offset,
|
2022-12-28 13:18:45 +08:00
|
|
|
dtype=torch.long, device=input_ids.device)
|
2022-12-28 12:58:50 +08:00
|
|
|
logits = model(rearrange(next_token, 'b -> b 1'), position_ids=position_ids,
|
|
|
|
|
inference_params=inference_params).logits[:, -1]
|
|
|
|
|
scores.append(logits)
|
|
|
|
|
next_token = logits.argmax(dim=-1)
|
|
|
|
|
sequences.append(next_token)
|
|
|
|
|
inference_params.sequence_len_offset += 1
|
|
|
|
|
if inference_params.sequence_len_offset >= max_length - 1:
|
|
|
|
|
break
|
|
|
|
|
return GreedySearchDecoderOnlyOutput(
|
|
|
|
|
sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1),
|
|
|
|
|
scores=tuple(scores)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GenerationMixin:
|
|
|
|
|
|
2023-01-04 14:10:31 +08:00
|
|
|
def generate(self, input_ids, max_length, return_dict_in_generate=False, output_scores=False,
|
|
|
|
|
**kwargs):
|
|
|
|
|
output = greedy_decode(input_ids, self, max_length, **kwargs)
|
2022-12-28 12:58:50 +08:00
|
|
|
if not output_scores:
|
|
|
|
|
output.scores = None
|
|
|
|
|
return output if return_dict_in_generate else output.sequences
|