2023-01-16 03:34:27 +08:00
|
|
|
# Copyright (c) 2023, Tri Dao.
|
2022-12-28 12:58:50 +08:00
|
|
|
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
|
2023-01-16 03:34:27 +08:00
|
|
|
from typing import Optional, Union, Sequence, Callable
|
|
|
|
|
import gc
|
2023-01-08 11:05:09 +08:00
|
|
|
import time
|
2023-01-04 14:10:31 +08:00
|
|
|
|
2022-12-28 12:58:50 +08:00
|
|
|
from dataclasses import dataclass, field
|
2023-01-08 11:05:09 +08:00
|
|
|
from collections import namedtuple
|
|
|
|
|
|
2022-12-28 12:58:50 +08:00
|
|
|
import torch
|
2023-01-04 14:10:31 +08:00
|
|
|
from torch import Tensor
|
2023-01-08 11:05:09 +08:00
|
|
|
from torch.profiler import profile, record_function, ProfilerActivity
|
2022-12-28 12:58:50 +08:00
|
|
|
|
|
|
|
|
from einops import rearrange
|
|
|
|
|
|
2023-01-08 09:00:02 +08:00
|
|
|
from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput
|
2022-12-28 12:58:50 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class InferenceParams:
|
|
|
|
|
"""Inference parameters that are passed to the main model in order
|
|
|
|
|
to efficienly calculate and store the context during inference."""
|
|
|
|
|
max_sequence_len: int
|
|
|
|
|
max_batch_size: int
|
|
|
|
|
sequence_len_offset: int = 0
|
|
|
|
|
batch_size_offset: int = 0
|
|
|
|
|
key_value_memory_dict: dict = field(default_factory=dict)
|
2023-01-04 14:10:31 +08:00
|
|
|
fused_ft_kernel: bool = False
|
|
|
|
|
lengths_per_sample: Optional[Tensor] = None
|
2022-12-28 12:58:50 +08:00
|
|
|
|
|
|
|
|
|
2023-01-08 09:00:02 +08:00
|
|
|
# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
|
|
|
|
|
# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
|
|
|
|
|
def modify_logits_for_top_p_filtering(logits, top_p):
|
|
|
|
|
"""Set the logits for none top-p values to -inf."""
|
|
|
|
|
if top_p <= 0.0:
|
|
|
|
|
return
|
|
|
|
|
# First sort and calculate cumulative sum of probabilities.
|
|
|
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=False)
|
|
|
|
|
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
|
|
|
|
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
|
|
|
|
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
|
|
|
|
|
# scatter sorted tensors to original indexing
|
|
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
|
|
|
|
logits = logits.masked_fill(indices_to_remove, float('-inf'))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
|
|
|
|
|
"""Sample from top-k logits.
|
|
|
|
|
Arguments:
|
|
|
|
|
logits: Tensor of shape (batch_size, vocab_size)
|
|
|
|
|
"""
|
|
|
|
|
if top_k == 1: # Short-circuit for greedy decoding
|
|
|
|
|
return logits.argmax(dim=-1)
|
|
|
|
|
else:
|
|
|
|
|
if top_p > 0.0:
|
|
|
|
|
assert top_p <= 1.0, 'top-p should be in (0, 1].'
|
|
|
|
|
if top_k > 0:
|
|
|
|
|
top_k = min(top_k, logits.size(-1)) # Safety check
|
|
|
|
|
logits_top, indices = torch.topk(logits, top_k, dim=-1)
|
|
|
|
|
logits_top /= temperature
|
|
|
|
|
modify_logits_for_top_p_filtering(logits_top, top_p)
|
|
|
|
|
return indices[
|
|
|
|
|
torch.arange(indices.shape[0], device=indices.device),
|
|
|
|
|
torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
|
|
|
|
|
]
|
|
|
|
|
else:
|
|
|
|
|
logits_top = logits / temperature
|
|
|
|
|
modify_logits_for_top_p_filtering(logits_top, top_p)
|
|
|
|
|
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
|
|
|
|
|
|
|
|
|
|
|
2023-01-08 11:05:09 +08:00
|
|
|
def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
|
2023-01-18 11:59:06 +08:00
|
|
|
eos_token_id=None, vocab_size=None, tensor_parallel=1, fused_ft_kernel=False,
|
|
|
|
|
cg=False, timing=False):
|
2023-01-08 09:00:02 +08:00
|
|
|
"""Decoding, either greedy or with top-k or top-p sampling.
|
|
|
|
|
If top-k = 0, don't limit the number of candidates (pure sampling).
|
|
|
|
|
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
|
|
|
|
|
then top-p.
|
2022-12-28 12:58:50 +08:00
|
|
|
We assume that all sequences in the same batch have the same length.
|
2023-01-08 09:00:02 +08:00
|
|
|
|
2022-12-28 12:58:50 +08:00
|
|
|
Arguments:
|
|
|
|
|
input_ids: (batch, seq_len)
|
|
|
|
|
max_length: int
|
2023-01-08 09:00:02 +08:00
|
|
|
Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
|
2022-12-28 12:58:50 +08:00
|
|
|
sequences: (batch, max_length)
|
|
|
|
|
scores: tuples of (batch, vocab_size)
|
|
|
|
|
"""
|
|
|
|
|
batch_size, seqlen_og = input_ids.shape
|
2023-01-16 03:34:27 +08:00
|
|
|
if cg:
|
|
|
|
|
assert fused_ft_kernel
|
|
|
|
|
if not hasattr(model, '_decoding_cache'):
|
|
|
|
|
model._decoding_cache = None
|
|
|
|
|
model._decoding_cache = update_graph_cache(
|
|
|
|
|
model, model._decoding_cache, batch_size, seqlen_og, max_length,
|
|
|
|
|
tensor_parallel=tensor_parallel
|
|
|
|
|
)
|
|
|
|
|
inference_params = model._decoding_cache.inference_params
|
|
|
|
|
inference_params.max_sequence_len = max_length
|
|
|
|
|
inference_params.max_batch_size = batch_size
|
|
|
|
|
inference_params.sequence_len_offset = 0
|
|
|
|
|
else:
|
|
|
|
|
inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size,
|
|
|
|
|
fused_ft_kernel=fused_ft_kernel)
|
2022-12-28 12:58:50 +08:00
|
|
|
scores = []
|
|
|
|
|
with torch.inference_mode():
|
|
|
|
|
logits = model(input_ids, inference_params=inference_params).logits[:, -1]
|
2023-01-18 11:59:06 +08:00
|
|
|
if timing:
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
start = time.time()
|
2023-01-16 03:34:27 +08:00
|
|
|
if vocab_size is not None:
|
|
|
|
|
logits = logits[..., :vocab_size]
|
2022-12-28 12:58:50 +08:00
|
|
|
scores.append(logits)
|
2023-01-08 09:00:02 +08:00
|
|
|
next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
|
2022-12-28 12:58:50 +08:00
|
|
|
sequences = [next_token]
|
|
|
|
|
inference_params.sequence_len_offset = seqlen_og
|
|
|
|
|
while True:
|
|
|
|
|
position_ids = torch.full((batch_size, 1), inference_params.sequence_len_offset,
|
2023-01-08 11:05:09 +08:00
|
|
|
dtype=torch.long, device=input_ids.device)
|
|
|
|
|
if not cg:
|
|
|
|
|
logits = model(rearrange(next_token, 'b -> b 1'), position_ids=position_ids,
|
|
|
|
|
inference_params=inference_params).logits[:, -1]
|
|
|
|
|
else:
|
2023-01-16 03:34:27 +08:00
|
|
|
logits = model._decoding_cache.run(rearrange(next_token, 'b -> b 1'), position_ids,
|
|
|
|
|
inference_params.sequence_len_offset)
|
|
|
|
|
if vocab_size is not None:
|
|
|
|
|
logits = logits[..., :vocab_size]
|
2022-12-28 12:58:50 +08:00
|
|
|
scores.append(logits)
|
2023-01-08 09:00:02 +08:00
|
|
|
next_token = sample(logits, top_k=top_k, temperature=temperature)
|
2022-12-28 12:58:50 +08:00
|
|
|
sequences.append(next_token)
|
|
|
|
|
inference_params.sequence_len_offset += 1
|
2023-01-18 11:59:06 +08:00
|
|
|
if eos_token_id is not None and (next_token == eos_token_id).all():
|
|
|
|
|
break
|
2022-12-28 12:58:50 +08:00
|
|
|
if inference_params.sequence_len_offset >= max_length - 1:
|
|
|
|
|
break
|
2023-01-08 11:05:09 +08:00
|
|
|
if timing:
|
2023-01-16 03:34:27 +08:00
|
|
|
torch.cuda.synchronize()
|
2023-01-18 11:59:06 +08:00
|
|
|
print(f'Decoding time: {(time.time() - start) * 1000:.0f}ms')
|
2023-01-08 09:00:02 +08:00
|
|
|
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
|
|
|
|
|
return output_cls(
|
2022-12-28 12:58:50 +08:00
|
|
|
sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1),
|
|
|
|
|
scores=tuple(scores)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GenerationMixin:
|
|
|
|
|
|
2023-01-08 09:00:02 +08:00
|
|
|
def generate(self, input_ids, max_length, top_k=1, top_p=0.0, temperature=1.0,
|
|
|
|
|
return_dict_in_generate=False, output_scores=False, **kwargs):
|
|
|
|
|
output = decode(input_ids, self, max_length, top_k=top_k, top_p=top_p,
|
|
|
|
|
temperature=temperature, **kwargs)
|
2022-12-28 12:58:50 +08:00
|
|
|
if not output_scores:
|
|
|
|
|
output.scores = None
|
|
|
|
|
return output if return_dict_in_generate else output.sequences
|
2023-01-08 11:05:09 +08:00
|
|
|
|
|
|
|
|
|
2023-01-16 03:34:27 +08:00
|
|
|
def allocate_kv_cache(max_batch_size, max_seqlen, nheads, headdim, layers: Union[int, Sequence],
|
|
|
|
|
device, dtype=torch.float16):
|
|
|
|
|
assert dtype in [torch.float16, torch.bfloat16, torch.float32]
|
|
|
|
|
packsize = 4 if dtype == torch.float32 else 8
|
|
|
|
|
assert headdim % packsize == 0
|
|
|
|
|
k_cache_shape = (max_batch_size, nheads, headdim // packsize, max_seqlen, packsize)
|
|
|
|
|
v_cache_shape = (max_batch_size, nheads, max_seqlen, headdim)
|
|
|
|
|
if isinstance(layers, int):
|
|
|
|
|
layers = range(layers)
|
|
|
|
|
return {i: (torch.empty(k_cache_shape, device=device, dtype=dtype),
|
|
|
|
|
torch.empty(v_cache_shape, device=device, dtype=dtype))
|
|
|
|
|
for i in layers}
|
2023-01-08 11:05:09 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def seqlen_to_seqlen_type(seqlen: int) -> int:
|
|
|
|
|
"""Convert sequence length to a seqlen_type.
|
|
|
|
|
This is used to determine which cuda graph to use.
|
|
|
|
|
Arguments:
|
|
|
|
|
seqlen: int
|
|
|
|
|
"""
|
|
|
|
|
return 0 if seqlen < 32 else (1 if seqlen < 2048 else 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def seqlen_type_to_seqlen(seqlen_type: int) -> int:
|
|
|
|
|
assert seqlen_type in [0, 1, 2]
|
|
|
|
|
return 1 if seqlen_type == 0 else (32 if seqlen_type == 1 else 2048)
|
|
|
|
|
|
|
|
|
|
|
2023-01-16 03:34:27 +08:00
|
|
|
@dataclass
|
|
|
|
|
class DecodingCGCache:
|
|
|
|
|
max_batch_size: int = 0
|
|
|
|
|
max_seqlen: int = 0
|
|
|
|
|
device = None
|
|
|
|
|
dtype = None
|
|
|
|
|
callables: dict = field(default_factory=dict)
|
|
|
|
|
mempool = None
|
|
|
|
|
inference_params: Optional[InferenceParams] = None
|
|
|
|
|
run: Optional[Callable] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
|
|
def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_parallel=1,
|
|
|
|
|
dtype=None):
|
|
|
|
|
if cache is None:
|
|
|
|
|
cache = DecodingCGCache()
|
|
|
|
|
param_example = next(iter(model.parameters()))
|
|
|
|
|
device = param_example.device
|
|
|
|
|
if dtype is None:
|
|
|
|
|
dtype = param_example.dtype
|
|
|
|
|
if ((device, dtype) != (cache.device, cache.dtype) or batch_size > cache.max_batch_size
|
|
|
|
|
or max_seqlen > cache.max_seqlen): # Invalidate the cache
|
|
|
|
|
cache.callables = {}
|
|
|
|
|
cache.mempool = None
|
|
|
|
|
cache.inference_params = None
|
|
|
|
|
gc.collect()
|
|
|
|
|
cache.device, cache.dtype = device, dtype
|
|
|
|
|
cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
|
|
|
|
|
headdim = getattr(model.config, 'head_dim',
|
|
|
|
|
model.config.hidden_size // model.config.num_attention_heads)
|
|
|
|
|
kv_cache = allocate_kv_cache(
|
|
|
|
|
batch_size, max_seqlen, model.config.num_attention_heads // tensor_parallel, headdim,
|
|
|
|
|
model.config.num_hidden_layers, device, dtype
|
|
|
|
|
)
|
|
|
|
|
lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
|
|
|
|
|
cache.inference_params = InferenceParams(
|
|
|
|
|
max_sequence_len=max_seqlen, max_batch_size=batch_size,
|
|
|
|
|
sequence_len_offset=seqlen_og, key_value_memory_dict=kv_cache, fused_ft_kernel=True,
|
|
|
|
|
lengths_per_sample=lengths_per_sample
|
|
|
|
|
)
|
|
|
|
|
cache.mempool = torch.cuda.graphs.graph_pool_handle()
|
|
|
|
|
for s_type in range(seqlen_to_seqlen_type(seqlen_og), seqlen_to_seqlen_type(max_seqlen) + 1):
|
|
|
|
|
if s_type not in cache.callables:
|
|
|
|
|
seqlen = min(max(seqlen_og, seqlen_type_to_seqlen(s_type)), max_seqlen)
|
|
|
|
|
cache.callables[s_type] = capture_graph(
|
|
|
|
|
model, cache.inference_params, batch_size, seqlen_og, seqlen, mempool=cache.mempool
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def dispatch(input_ids, position_ids, seqlen):
|
|
|
|
|
return cache.callables[seqlen_to_seqlen_type(seqlen)](input_ids, position_ids, seqlen)
|
|
|
|
|
|
|
|
|
|
cache.run = dispatch
|
|
|
|
|
cache.inference_params.sequence_length_offset = 0 # Reset so it's not confusing
|
|
|
|
|
return cache
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def capture_graph(model, inference_params, batch_size, seqlen_og, max_seqlen, mempool=None):
|
|
|
|
|
assert max_seqlen >= seqlen_og
|
2023-01-08 11:05:09 +08:00
|
|
|
device = next(iter(model.parameters())).device
|
|
|
|
|
input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
|
|
|
|
|
position_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
|
2023-01-16 03:34:27 +08:00
|
|
|
inference_params.lengths_per_sample[:] = seqlen_og
|
|
|
|
|
|
|
|
|
|
# Warmup before capture
|
|
|
|
|
s = torch.cuda.Stream()
|
|
|
|
|
s.wait_stream(torch.cuda.current_stream())
|
|
|
|
|
with torch.cuda.stream(s):
|
|
|
|
|
for _ in range(2):
|
2023-01-08 11:05:09 +08:00
|
|
|
logits = model(input_ids, position_ids=position_ids,
|
2023-01-16 03:34:27 +08:00
|
|
|
inference_params=inference_params).logits[:, -1]
|
|
|
|
|
s.synchronize()
|
|
|
|
|
torch.cuda.current_stream().wait_stream(s)
|
|
|
|
|
# Captures the graph
|
|
|
|
|
# To allow capture, automatically sets a side stream as the current stream in the context
|
|
|
|
|
graph = torch.cuda.CUDAGraph()
|
|
|
|
|
with torch.cuda.graph(graph, pool=mempool):
|
|
|
|
|
logits = model(input_ids, position_ids=position_ids,
|
|
|
|
|
inference_params=inference_params).logits[:, -1]
|
2023-01-08 11:05:09 +08:00
|
|
|
|
|
|
|
|
def run(new_input_ids, new_position_ids, seqlen):
|
|
|
|
|
inference_params.lengths_per_sample[:] = seqlen
|
2023-01-16 03:34:27 +08:00
|
|
|
input_ids.copy_(new_input_ids)
|
|
|
|
|
position_ids.copy_(new_position_ids)
|
|
|
|
|
graph.replay()
|
|
|
|
|
return logits
|
2023-01-08 11:05:09 +08:00
|
|
|
|
2023-01-16 03:34:27 +08:00
|
|
|
return run
|