[Gen] Make generation work with Tensor Parallel
This commit is contained in:
parent
d509832426
commit
7c2191542a
@ -1,5 +1,7 @@
|
||||
#include <torch/extension.h>
|
||||
#include "ATen/cuda/CUDAContext.h"
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
|
||||
#include "decoder_masked_multihead_attention.h"
|
||||
|
||||
@ -138,6 +140,10 @@ torch::Tensor single_query_attention(const torch::Tensor q,
|
||||
TORCH_CHECK(length_per_sample.dtype() == torch::kInt32);
|
||||
}
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
|
||||
torch::Tensor out = torch::empty_like(q);
|
||||
|
||||
DISPATCH_FLOAT_AND_HALF_AND_BF16(q.scalar_type(), out.scalar_type(), "single_query_attention", [&] {
|
||||
|
||||
@ -20,7 +20,7 @@ from flash_attn.modules.mha import MHA, ParallelMHA
|
||||
from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense, ParallelFusedDenseGeluDense
|
||||
from flash_attn.modules.block import Block
|
||||
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
|
||||
from flash_attn.utils.distributed import sync_shared_params
|
||||
from flash_attn.utils.distributed import sync_shared_params, all_gather_raw
|
||||
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
||||
from flash_attn.utils.generation import GenerationMixin
|
||||
|
||||
@ -146,17 +146,23 @@ class GPTPreTrainedModel(nn.Module):
|
||||
self.config = config
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_name, config, *args, strict=True, device=None, **kwargs):
|
||||
def from_pretrained(cls, model_name, config, *args, strict=True, device=None, dtype=None,
|
||||
world_size=1, rank=0, **kwargs):
|
||||
"""
|
||||
Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
"""
|
||||
# Instantiate model.
|
||||
model = cls(config, *args, device=device, **kwargs)
|
||||
load_return = model.load_state_dict(
|
||||
remap_state_dict_gpt2(state_dict_from_pretrained(model_name, device=device), config),
|
||||
strict=strict
|
||||
model = cls(config, *args, device=device, dtype=dtype, **kwargs)
|
||||
state_dict = remap_state_dict_gpt2(
|
||||
# If we're going to shard the model, then don't load fp32 weights to GPU.
|
||||
state_dict_from_pretrained(model_name, device=device if world_size == 1 else None,
|
||||
dtype=dtype), config
|
||||
)
|
||||
if world_size > 1:
|
||||
state_dict = shard_state_dict_tp(state_dict, config, world_size, rank)
|
||||
state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
|
||||
load_return = model.load_state_dict(state_dict, strict=strict)
|
||||
logger.info(load_return)
|
||||
return model
|
||||
|
||||
@ -190,17 +196,16 @@ class GPTModel(GPTPreTrainedModel):
|
||||
self.process_group = process_group
|
||||
self.sequence_parallel = getattr(config, 'sequence_parallel', True)
|
||||
assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'sqrelu']
|
||||
self.pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
if config.vocab_size % self.pad_vocab_size_multiple != 0:
|
||||
config.vocab_size += (self.pad_vocab_size_multiple
|
||||
- (config.vocab_size % self.pad_vocab_size_multiple))
|
||||
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple)
|
||||
* pad_vocab_size_multiple)
|
||||
|
||||
if process_group is None:
|
||||
self.embeddings = GPT2Embeddings(config.hidden_size, config.vocab_size,
|
||||
self.embeddings = GPT2Embeddings(config.hidden_size, vocab_size,
|
||||
config.max_position_embeddings, **factory_kwargs)
|
||||
else:
|
||||
self.embeddings = ParallelGPT2Embeddings(
|
||||
config.hidden_size, config.vocab_size, config.max_position_embeddings,
|
||||
config.hidden_size, vocab_size, config.max_position_embeddings,
|
||||
process_group=process_group, sequence_parallel=self.sequence_parallel,
|
||||
**factory_kwargs
|
||||
)
|
||||
@ -248,8 +253,9 @@ class GPTModel(GPTPreTrainedModel):
|
||||
hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs)
|
||||
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
|
||||
if not self.fused_dropout_add_ln:
|
||||
residual = self.emb_drop(hidden_states).float()
|
||||
residual = self.emb_drop(hidden_states)
|
||||
hidden_states = self.ln_0(residual.to(dtype=self.ln_0.weight.dtype))
|
||||
residual = residual.float()
|
||||
else:
|
||||
hidden_states, residual = dropout_add_layer_norm(
|
||||
hidden_states, None, self.ln_0.weight, self.ln_0.bias,
|
||||
@ -272,13 +278,16 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
|
||||
super().__init__(config)
|
||||
self.process_group = process_group
|
||||
self.transformer = GPTModel(config, process_group=process_group, **factory_kwargs)
|
||||
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple)
|
||||
* pad_vocab_size_multiple)
|
||||
if process_group is None:
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False, **factory_kwargs)
|
||||
self.lm_head = nn.Linear(config.n_embd, vocab_size, bias=False, **factory_kwargs)
|
||||
else:
|
||||
if ColumnParallelLinear is None:
|
||||
raise ImportError('fused_dense_lib is not installed')
|
||||
self.lm_head = ColumnParallelLinear(
|
||||
config.n_embd, config.vocab_size, process_group, bias=False,
|
||||
config.n_embd, vocab_size, process_group, bias=False,
|
||||
sequence_parallel=getattr(config, 'sequence_parallel', True), **factory_kwargs
|
||||
)
|
||||
# Initialize weights and apply final processing
|
||||
@ -299,6 +308,10 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
|
||||
hidden_states = self.transformer(input_ids, position_ids=position_ids,
|
||||
inference_params=inference_params)
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
# During inference, we want the full logit for sampling
|
||||
if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None:
|
||||
lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group)
|
||||
lm_logits = rearrange(lm_logits, '(n b) s d -> b s (n d)', b=hidden_states.shape[0])
|
||||
CausalLMOutput = namedtuple('CausalLMOutput', ['logits'])
|
||||
return CausalLMOutput(logits=lm_logits)
|
||||
|
||||
@ -310,8 +323,10 @@ def remap_state_dict_gpt2(state_dict, config):
|
||||
state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
|
||||
word_embeddings = state_dict.pop('wte.weight')
|
||||
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
||||
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
|
||||
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
|
||||
word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
|
||||
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
||||
)
|
||||
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
|
||||
|
||||
@ -365,10 +380,8 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
|
||||
"""Convert the state_dict of a standard GPT model to the state_dict of a GPT model
|
||||
with tensor parallel.
|
||||
"""
|
||||
vocab_size = config.vocab_size
|
||||
if config.vocab_size % config.pad_vocab_size_multiple != 0:
|
||||
vocab_size += (config.pad_vocab_size_multiple
|
||||
- (config.vocab_size % config.pad_vocab_size_multiple))
|
||||
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
|
||||
assert vocab_size % world_size == 0
|
||||
assert config.hidden_size % world_size == 0
|
||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
|
||||
|
||||
@ -289,6 +289,60 @@ class LinearResidual(nn.Linear):
|
||||
return super().forward(input), input
|
||||
|
||||
|
||||
def _update_kv_cache(kv, inference_params, layer_idx):
|
||||
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
|
||||
"""
|
||||
# Pre-allocate memory for key-values for inference.
|
||||
num_heads, head_dim = kv.shape[-2:]
|
||||
if layer_idx not in inference_params.key_value_memory_dict:
|
||||
kv_cache = torch.empty(
|
||||
inference_params.max_batch_size, inference_params.max_sequence_len, 2,
|
||||
num_heads, head_dim, dtype=kv.dtype, device=kv.device
|
||||
)
|
||||
inference_params.key_value_memory_dict[layer_idx] = kv_cache
|
||||
else:
|
||||
if not inference_params.fused_ft_kernel:
|
||||
kv_cache = inference_params.key_value_memory_dict[layer_idx]
|
||||
else:
|
||||
# For FT, k_cache has shape (b, h, headdim / packsize, s, packsize)
|
||||
# where packsize = 4 if fp32, 8 if fp16 or bf16.
|
||||
# v_cache has shape (b, h, s, headdim)
|
||||
k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx]
|
||||
kv_cache = None
|
||||
# Adjust key and value for inference
|
||||
batch_start = inference_params.batch_size_offset
|
||||
batch_end = batch_start + kv.shape[0]
|
||||
sequence_start = inference_params.sequence_len_offset
|
||||
sequence_end = sequence_start + kv.shape[1]
|
||||
assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0])
|
||||
assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2])
|
||||
# Copy key and values.
|
||||
if not inference_params.fused_ft_kernel:
|
||||
assert kv_cache is not None
|
||||
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
||||
kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
|
||||
return kv
|
||||
else:
|
||||
assert inference_params.sequence_len_offset == 0
|
||||
# FT kernel requires different layouts for the k_cache and v_cache.
|
||||
assert kv.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
packsize = 4 if kv.dtype == torch.float32 else 8
|
||||
if kv_cache is not None:
|
||||
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
||||
k_cache = rearrange(kv_cache[:, :, 0], 'b s h (d packsize) -> b h d s packsize',
|
||||
packsize=packsize).contiguous()
|
||||
v_cache = rearrange(kv_cache[:, :, 1], 'b s h d -> b h s d').contiguous()
|
||||
inference_params.key_value_memory_dict[layer_idx] = (k_cache, v_cache)
|
||||
else:
|
||||
k_cache[batch_start:batch_end, :, :, :sequence_end, :] = rearrange(
|
||||
kv[:, :, 0], 'b s h (d packsize) -> b h d s packsize', packsize=packsize
|
||||
)
|
||||
v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange(
|
||||
kv[:, :, 1], 'b s h d -> b h s d'
|
||||
)
|
||||
return kv
|
||||
|
||||
|
||||
class MHA(nn.Module):
|
||||
"""Multi-head self-attention and cross-attention
|
||||
"""
|
||||
@ -363,54 +417,7 @@ class MHA(nn.Module):
|
||||
"""
|
||||
assert not self.dwconv, 'Generation does not support dwconv yet'
|
||||
assert self.layer_idx is not None, 'Generation requires layer_idx in the constructor'
|
||||
# Pre-allocate memory for key-values for inference.
|
||||
if self.layer_idx not in inference_params.key_value_memory_dict:
|
||||
kv_cache = torch.empty(
|
||||
inference_params.max_batch_size, inference_params.max_sequence_len, 2,
|
||||
self.num_heads, self.head_dim, dtype=kv.dtype, device=kv.device
|
||||
)
|
||||
inference_params.key_value_memory_dict[self.layer_idx] = kv_cache
|
||||
else:
|
||||
if not inference_params.fused_ft_kernel:
|
||||
kv_cache = inference_params.key_value_memory_dict[self.layer_idx]
|
||||
else:
|
||||
# For FT, k_cache has shape (b, h, headdim / packsize, s, packsize)
|
||||
# where packsize = 4 if fp32, 8 if fp16 or bf16.
|
||||
# v_cache has shape (b, h, s, headdim)
|
||||
k_cache, v_cache = inference_params.key_value_memory_dict[self.layer_idx]
|
||||
kv_cache = None
|
||||
# Adjust key and value for inference
|
||||
batch_start = inference_params.batch_size_offset
|
||||
batch_end = batch_start + kv.shape[0]
|
||||
sequence_start = inference_params.sequence_len_offset
|
||||
sequence_end = sequence_start + kv.shape[1]
|
||||
assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0])
|
||||
assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2])
|
||||
# Copy key and values.
|
||||
if not inference_params.fused_ft_kernel:
|
||||
assert kv_cache is not None
|
||||
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
||||
kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
|
||||
return kv
|
||||
else:
|
||||
assert inference_params.sequence_len_offset == 0
|
||||
# FT kernel requires different layouts for the k_cache and v_cache.
|
||||
assert kv.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
packsize = 4 if kv.dtype == torch.float32 else 8
|
||||
if kv_cache is not None:
|
||||
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
||||
k_cache = rearrange(kv_cache[:, :, 0], 'b s h (d packsize) -> b h d s packsize',
|
||||
packsize=packsize).contiguous()
|
||||
v_cache = rearrange(kv_cache[:, :, 1], 'b s h d -> b h s d').contiguous()
|
||||
inference_params.key_value_memory_dict[self.layer_idx] = (k_cache, v_cache)
|
||||
else:
|
||||
k_cache[batch_start:batch_end, :, :, :sequence_end, :] = rearrange(
|
||||
kv[:, :, 0], 'b s h (d packsize) -> b h d s packsize', packsize=packsize
|
||||
)
|
||||
v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange(
|
||||
kv[:, :, 1], 'b s h d -> b h s d'
|
||||
)
|
||||
return kv
|
||||
return _update_kv_cache(kv, inference_params, self.layer_idx)
|
||||
|
||||
def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None,
|
||||
inference_params=None, **kwargs):
|
||||
@ -473,6 +480,7 @@ class MHA(nn.Module):
|
||||
causal = None if inference_params.sequence_len_offset == 0 else False
|
||||
context = self.inner_cross_attn(q, kv, causal=causal)
|
||||
else:
|
||||
assert inference_params.fused_ft_kernel
|
||||
assert ft_attention is not None
|
||||
context = ft_attention.single_query_attention(
|
||||
*rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1),
|
||||
@ -541,13 +549,16 @@ class ParallelMHA(nn.Module):
|
||||
self.Wqkv = ColumnParallelLinear(embed_dim, 3 * embed_dim, process_group, bias=bias,
|
||||
sequence_parallel=sequence_parallel, **factory_kwargs)
|
||||
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
|
||||
inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
|
||||
self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale,
|
||||
attention_dropout=dropout)
|
||||
self.inner_cross_attn = inner_cross_attn_cls(causal=causal, softmax_scale=softmax_scale,
|
||||
attention_dropout=dropout)
|
||||
# output projection always have the bias (for now)
|
||||
self.out_proj = RowParallelLinear(embed_dim, embed_dim, process_group,
|
||||
sequence_parallel=sequence_parallel, **factory_kwargs)
|
||||
|
||||
def forward(self, x, seqlen=None, **kwargs):
|
||||
def forward(self, x, seqlen=None, inference_params=None, **kwargs):
|
||||
"""
|
||||
Arguments:
|
||||
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
|
||||
@ -561,12 +572,34 @@ class ParallelMHA(nn.Module):
|
||||
else:
|
||||
qkv = rearrange(qkv, '(b s) (three h d) -> b s three h d', s=seqlen, three=3,
|
||||
d=self.head_dim)
|
||||
if self.rotary_emb_dim > 0:
|
||||
qkv = self.rotary_emb(qkv)
|
||||
if not self.checkpointing:
|
||||
context = self.inner_attn(qkv, **kwargs)
|
||||
if inference_params is None:
|
||||
if self.rotary_emb_dim > 0:
|
||||
qkv = self.rotary_emb(qkv)
|
||||
if not self.checkpointing:
|
||||
context = self.inner_attn(qkv, **kwargs)
|
||||
else:
|
||||
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
|
||||
else:
|
||||
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
|
||||
if (not inference_params.fused_ft_kernel) or inference_params.sequence_len_offset == 0:
|
||||
if self.rotary_emb_dim > 0:
|
||||
qkv = self.rotary_emb(qkv, seqlen_offset=inference_params.sequence_len_offset)
|
||||
q = qkv[:, :, 0]
|
||||
assert self.layer_idx is not None, 'Generation requires layer_idx in the constructor'
|
||||
kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx)
|
||||
# If we're processing the prompt, causal=None (use self.causal).
|
||||
# If we're decoding, then causal=False.
|
||||
causal = None if inference_params.sequence_len_offset == 0 else False
|
||||
context = self.inner_cross_attn(q, kv, causal=causal)
|
||||
else:
|
||||
assert inference_params.fused_ft_kernel
|
||||
assert ft_attention is not None
|
||||
context = ft_attention.single_query_attention(
|
||||
*rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1),
|
||||
*inference_params.key_value_memory_dict[self.layer_idx],
|
||||
inference_params.lengths_per_sample, inference_params.sequence_len_offset,
|
||||
self.rotary_emb_dim
|
||||
)
|
||||
context = rearrange(context, 'b h d -> b 1 h d')
|
||||
if seqlen is None:
|
||||
context = rearrange(context, 'b s h d -> b s (h d)')
|
||||
else:
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# Copyright (c) 2022, Tri Dao.
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
|
||||
from typing import Optional
|
||||
from typing import Optional, Union, Sequence, Callable
|
||||
import gc
|
||||
import time
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
@ -70,7 +71,7 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
|
||||
|
||||
|
||||
def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
|
||||
fused_ft_kernel=False, cg=False, timing=False):
|
||||
vocab_size=None, tensor_parallel=1, fused_ft_kernel=False, cg=False, timing=False):
|
||||
"""Decoding, either greedy or with top-k or top-p sampling.
|
||||
If top-k = 0, don't limit the number of candidates (pure sampling).
|
||||
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
|
||||
@ -85,18 +86,30 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
|
||||
scores: tuples of (batch, vocab_size)
|
||||
"""
|
||||
batch_size, seqlen_og = input_ids.shape
|
||||
inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size,
|
||||
fused_ft_kernel=fused_ft_kernel)
|
||||
if cg:
|
||||
assert fused_ft_kernel
|
||||
if not hasattr(model, '_decoding_cache'):
|
||||
model._decoding_cache = None
|
||||
model._decoding_cache = update_graph_cache(
|
||||
model, model._decoding_cache, batch_size, seqlen_og, max_length,
|
||||
tensor_parallel=tensor_parallel
|
||||
)
|
||||
inference_params = model._decoding_cache.inference_params
|
||||
inference_params.max_sequence_len = max_length
|
||||
inference_params.max_batch_size = batch_size
|
||||
inference_params.sequence_len_offset = 0
|
||||
else:
|
||||
inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size,
|
||||
fused_ft_kernel=fused_ft_kernel)
|
||||
scores = []
|
||||
with torch.inference_mode():
|
||||
logits = model(input_ids, inference_params=inference_params).logits[:, -1]
|
||||
if vocab_size is not None:
|
||||
logits = logits[..., :vocab_size]
|
||||
scores.append(logits)
|
||||
next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
|
||||
sequences = [next_token]
|
||||
inference_params.sequence_len_offset = seqlen_og
|
||||
if cg:
|
||||
assert fused_ft_kernel
|
||||
run, cg_cache = capture_cg(model, inference_params, batch_size, seqlen_og, max_length)
|
||||
if timing:
|
||||
start = time.time()
|
||||
while True:
|
||||
@ -106,8 +119,10 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
|
||||
logits = model(rearrange(next_token, 'b -> b 1'), position_ids=position_ids,
|
||||
inference_params=inference_params).logits[:, -1]
|
||||
else:
|
||||
logits = run(rearrange(next_token, 'b -> b 1'), position_ids,
|
||||
inference_params.sequence_len_offset)
|
||||
logits = model._decoding_cache.run(rearrange(next_token, 'b -> b 1'), position_ids,
|
||||
inference_params.sequence_len_offset)
|
||||
if vocab_size is not None:
|
||||
logits = logits[..., :vocab_size]
|
||||
scores.append(logits)
|
||||
next_token = sample(logits, top_k=top_k, temperature=temperature)
|
||||
sequences.append(next_token)
|
||||
@ -115,6 +130,7 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
|
||||
if inference_params.sequence_len_offset >= max_length - 1:
|
||||
break
|
||||
if timing:
|
||||
torch.cuda.synchronize()
|
||||
print(f'Decoding time: {time.time() - start}')
|
||||
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
|
||||
return output_cls(
|
||||
@ -134,8 +150,18 @@ class GenerationMixin:
|
||||
return output if return_dict_in_generate else output.sequences
|
||||
|
||||
|
||||
CgKey = namedtuple('CgKey', ['batch_size', 'seqlen_type', 'max_length'])
|
||||
CgVal = namedtuple('CgVal', ['graph', 'input_ids', 'position_ids', 'lengths', 'logits'])
|
||||
def allocate_kv_cache(max_batch_size, max_seqlen, nheads, headdim, layers: Union[int, Sequence],
|
||||
device, dtype=torch.float16):
|
||||
assert dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
packsize = 4 if dtype == torch.float32 else 8
|
||||
assert headdim % packsize == 0
|
||||
k_cache_shape = (max_batch_size, nheads, headdim // packsize, max_seqlen, packsize)
|
||||
v_cache_shape = (max_batch_size, nheads, max_seqlen, headdim)
|
||||
if isinstance(layers, int):
|
||||
layers = range(layers)
|
||||
return {i: (torch.empty(k_cache_shape, device=device, dtype=dtype),
|
||||
torch.empty(v_cache_shape, device=device, dtype=dtype))
|
||||
for i in layers}
|
||||
|
||||
|
||||
def seqlen_to_seqlen_type(seqlen: int) -> int:
|
||||
@ -152,63 +178,91 @@ def seqlen_type_to_seqlen(seqlen_type: int) -> int:
|
||||
return 1 if seqlen_type == 0 else (32 if seqlen_type == 1 else 2048)
|
||||
|
||||
|
||||
def capture_cg(model, inference_params, batch_size, seqlen_og, max_length, copy_output=False):
|
||||
"""Build a cache of cuda graphs for decoding.
|
||||
Arguments:
|
||||
model: a GPTLMHeadModel
|
||||
batch_size: int
|
||||
seqlen_og: int. Length of the prompt.
|
||||
max_length: int
|
||||
TODO: how do we deal with the k_cache and v_cache memory? I think the CUDA graph also
|
||||
has to own the k_cache and v_cache?
|
||||
Here we assume that the model already has inference_params from the prompt processing.
|
||||
"""
|
||||
assert max_length > seqlen_og
|
||||
cg_cache: dict[CgKey, CgVal] = {}
|
||||
@dataclass
|
||||
class DecodingCGCache:
|
||||
max_batch_size: int = 0
|
||||
max_seqlen: int = 0
|
||||
device = None
|
||||
dtype = None
|
||||
callables: dict = field(default_factory=dict)
|
||||
mempool = None
|
||||
inference_params: Optional[InferenceParams] = None
|
||||
run: Optional[Callable] = None
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_parallel=1,
|
||||
dtype=None):
|
||||
if cache is None:
|
||||
cache = DecodingCGCache()
|
||||
param_example = next(iter(model.parameters()))
|
||||
device = param_example.device
|
||||
if dtype is None:
|
||||
dtype = param_example.dtype
|
||||
if ((device, dtype) != (cache.device, cache.dtype) or batch_size > cache.max_batch_size
|
||||
or max_seqlen > cache.max_seqlen): # Invalidate the cache
|
||||
cache.callables = {}
|
||||
cache.mempool = None
|
||||
cache.inference_params = None
|
||||
gc.collect()
|
||||
cache.device, cache.dtype = device, dtype
|
||||
cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
|
||||
headdim = getattr(model.config, 'head_dim',
|
||||
model.config.hidden_size // model.config.num_attention_heads)
|
||||
kv_cache = allocate_kv_cache(
|
||||
batch_size, max_seqlen, model.config.num_attention_heads // tensor_parallel, headdim,
|
||||
model.config.num_hidden_layers, device, dtype
|
||||
)
|
||||
lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
|
||||
cache.inference_params = InferenceParams(
|
||||
max_sequence_len=max_seqlen, max_batch_size=batch_size,
|
||||
sequence_len_offset=seqlen_og, key_value_memory_dict=kv_cache, fused_ft_kernel=True,
|
||||
lengths_per_sample=lengths_per_sample
|
||||
)
|
||||
cache.mempool = torch.cuda.graphs.graph_pool_handle()
|
||||
for s_type in range(seqlen_to_seqlen_type(seqlen_og), seqlen_to_seqlen_type(max_seqlen) + 1):
|
||||
if s_type not in cache.callables:
|
||||
seqlen = min(max(seqlen_og, seqlen_type_to_seqlen(s_type)), max_seqlen)
|
||||
cache.callables[s_type] = capture_graph(
|
||||
model, cache.inference_params, batch_size, seqlen_og, seqlen, mempool=cache.mempool
|
||||
)
|
||||
|
||||
def dispatch(input_ids, position_ids, seqlen):
|
||||
return cache.callables[seqlen_to_seqlen_type(seqlen)](input_ids, position_ids, seqlen)
|
||||
|
||||
cache.run = dispatch
|
||||
cache.inference_params.sequence_length_offset = 0 # Reset so it's not confusing
|
||||
return cache
|
||||
|
||||
|
||||
def capture_graph(model, inference_params, batch_size, seqlen_og, max_seqlen, mempool=None):
|
||||
assert max_seqlen >= seqlen_og
|
||||
device = next(iter(model.parameters())).device
|
||||
sequence_length_offset_og = inference_params.sequence_len_offset
|
||||
input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
|
||||
position_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
|
||||
inference_params.lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32,
|
||||
device=device)
|
||||
inference_params.lengths_per_sample[:] = seqlen_og
|
||||
|
||||
memory_pool = None
|
||||
for s_type in range(seqlen_to_seqlen_type(seqlen_og), seqlen_to_seqlen_type(max_length) + 1):
|
||||
seqlen = max(seqlen_og, seqlen_type_to_seqlen(s_type))
|
||||
input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
|
||||
position_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
|
||||
inference_params.lengths_per_sample[:] = seqlen
|
||||
inference_params.sequence_len_offset = seqlen
|
||||
g = torch.cuda.CUDAGraph()
|
||||
# Warmup before capture
|
||||
s = torch.cuda.Stream()
|
||||
s.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(s):
|
||||
for _ in range(2):
|
||||
logits = model(input_ids, position_ids=position_ids,
|
||||
inference_params=inference_params).logits[:, -1]
|
||||
torch.cuda.current_stream().wait_stream(s)
|
||||
# Captures the graph
|
||||
# To allow capture, automatically sets a side stream as the current stream in the context
|
||||
with torch.cuda.graph(g, pool=memory_pool):
|
||||
# Warmup before capture
|
||||
s = torch.cuda.Stream()
|
||||
s.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(s):
|
||||
for _ in range(2):
|
||||
logits = model(input_ids, position_ids=position_ids,
|
||||
inference_params=inference_params).logits[:, -1]
|
||||
if memory_pool is None:
|
||||
memory_pool = g.pool()
|
||||
cg_cache[CgKey(batch_size, s_type, max_length)] = CgVal(
|
||||
g, input_ids, position_ids, inference_params.lengths_per_sample, logits
|
||||
)
|
||||
inference_params=inference_params).logits[:, -1]
|
||||
s.synchronize()
|
||||
torch.cuda.current_stream().wait_stream(s)
|
||||
# Captures the graph
|
||||
# To allow capture, automatically sets a side stream as the current stream in the context
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, pool=mempool):
|
||||
logits = model(input_ids, position_ids=position_ids,
|
||||
inference_params=inference_params).logits[:, -1]
|
||||
|
||||
def run(new_input_ids, new_position_ids, seqlen):
|
||||
cg_val = cg_cache[CgKey(batch_size, seqlen_to_seqlen_type(seqlen), max_length)]
|
||||
inference_params.lengths_per_sample = cg_val.lengths
|
||||
inference_params.lengths_per_sample[:] = seqlen
|
||||
cg_val.input_ids.copy_(new_input_ids)
|
||||
cg_val.position_ids.copy_(new_position_ids)
|
||||
cg_val.graph.replay()
|
||||
output = cg_val.logits
|
||||
return output.clone() if copy_output else output
|
||||
input_ids.copy_(new_input_ids)
|
||||
position_ids.copy_(new_position_ids)
|
||||
graph.replay()
|
||||
return logits
|
||||
|
||||
inference_params.sequence_len_offset = sequence_length_offset_og
|
||||
|
||||
return run, cg_cache
|
||||
return run
|
||||
|
||||
@ -4,5 +4,8 @@ from transformers.utils import WEIGHTS_NAME
|
||||
from transformers.utils.hub import cached_file
|
||||
|
||||
|
||||
def state_dict_from_pretrained(model_name, device=None):
|
||||
return torch.load(cached_file(model_name, WEIGHTS_NAME), map_location=device)
|
||||
def state_dict_from_pretrained(model_name, device=None, dtype=None):
|
||||
state_dict = torch.load(cached_file(model_name, WEIGHTS_NAME), map_location=device)
|
||||
if dtype is not None:
|
||||
state_dict = {k: v.to(dtype) for k, v in state_dict.items()}
|
||||
return state_dict
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import os
|
||||
import re
|
||||
|
||||
import torch
|
||||
@ -11,12 +12,13 @@ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHead
|
||||
from flash_attn.models.gpt import GPTLMHeadModel
|
||||
from flash_attn.models.gpt import remap_state_dict_gpt2
|
||||
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
||||
from flash_attn.utils.distributed import all_gather_raw
|
||||
|
||||
|
||||
@pytest.mark.parametrize('fused_ft_kernel', [False, True])
|
||||
# @pytest.mark.parametrize('fused_ft_kernel', [True])
|
||||
@pytest.mark.parametrize('optimized', [False, True])
|
||||
# @pytest.mark.parametrize('optimized', [False])
|
||||
# @pytest.mark.parametrize('optimized', [True])
|
||||
@pytest.mark.parametrize('rotary', [False, True])
|
||||
# @pytest.mark.parametrize('rotary', [False])
|
||||
@pytest.mark.parametrize('model_name', ["gpt2"])
|
||||
@ -40,19 +42,20 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
|
||||
|
||||
# if not rotary, we load the weight from HF but ignore the position embeddings.
|
||||
# The model would be nonsense but it doesn't matter for the test.
|
||||
model = GPTLMHeadModel.from_pretrained(model_name, config, strict=not rotary, device=device)
|
||||
model = model.to(dtype=dtype)
|
||||
model = GPTLMHeadModel.from_pretrained(model_name, config, strict=not rotary, device=device,
|
||||
dtype=dtype)
|
||||
model.eval()
|
||||
|
||||
if not rotary:
|
||||
model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).cuda()
|
||||
model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).cuda().to(dtype=dtype)
|
||||
model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device)
|
||||
model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device, dtype=dtype)
|
||||
model_ref.eval()
|
||||
model_hf.eval()
|
||||
|
||||
torch.manual_seed(0)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
input_ids = tokenizer("Hello, my dog is cute and ", return_tensors="pt").input_ids.cuda()
|
||||
input_ids = tokenizer("Hello, my dog is cute and ",
|
||||
return_tensors="pt").input_ids.to(device=device)
|
||||
max_length = 30
|
||||
# input_ids = torch.randint(0, 100, (1, 10), dtype=torch.long, device='cuda')
|
||||
# max_length = input_ids.shape[1] + 40
|
||||
@ -100,3 +103,119 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
|
||||
assert torch.all(out.sequences == out_hf.sequences)
|
||||
|
||||
assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()
|
||||
|
||||
|
||||
# Run test with:
|
||||
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/models/test_gpt_generation.py -k "parallel"
|
||||
|
||||
# @pytest.mark.parametrize('world_size', [1, 2, 4, 8])
|
||||
@pytest.mark.parametrize('world_size', [2])
|
||||
# @pytest.mark.parametrize('fused_ft_kernel', [False, True])
|
||||
@pytest.mark.parametrize('fused_ft_kernel', [True])
|
||||
# @pytest.mark.parametrize('rotary', [False, True])
|
||||
@pytest.mark.parametrize('rotary', [False])
|
||||
@pytest.mark.parametrize('model_name', ["gpt2"])
|
||||
def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
|
||||
"""Check that our implementation of GPT2 generation matches the HF implementation:
|
||||
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
|
||||
the HF scores in fp32.
|
||||
"""
|
||||
dtype = torch.float16
|
||||
rtol, atol = 3e-3, 3e-1
|
||||
config = GPT2Config.from_pretrained(model_name)
|
||||
if rotary:
|
||||
config.n_positions = 0
|
||||
config.rotary_emb_dim = 64
|
||||
config.use_flash_attn = True
|
||||
config.fused_bias_fc = True
|
||||
config.fused_dense_gelu_dense = True
|
||||
config.fused_dropout_add_ln = True
|
||||
config.pad_vocab_size_multiple = 8 * world_size
|
||||
config.sequence_parallel = False # Need to set this to False for generation
|
||||
|
||||
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group(backend='nccl', init_method='env://')
|
||||
device = f'cuda:{torch.distributed.get_rank()}'
|
||||
assert world_size <= torch.distributed.get_world_size()
|
||||
# Need this, otherwise when we capture the graph the process for GPU 1 would run on both
|
||||
# GPU0 and GPU1 and things would hang
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
from apex.transformer import parallel_state
|
||||
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
|
||||
rank = parallel_state.get_tensor_model_parallel_rank()
|
||||
process_group = parallel_state.get_tensor_model_parallel_group()
|
||||
|
||||
# if not rotary, we load the weight from HF but ignore the position embeddings.
|
||||
# The model would be nonsense but it doesn't matter for the test.
|
||||
model = GPTLMHeadModel.from_pretrained(model_name, config, strict=not rotary, device=device,
|
||||
dtype=dtype, process_group=process_group,
|
||||
world_size=world_size, rank=rank)
|
||||
model.eval()
|
||||
|
||||
if not rotary:
|
||||
model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device)
|
||||
model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device, dtype=dtype)
|
||||
model_ref.eval()
|
||||
model_hf.eval()
|
||||
|
||||
torch.manual_seed(0)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
input_ids = tokenizer("Hello, my dog is cute and ",
|
||||
return_tensors="pt").input_ids.to(device=device)
|
||||
max_length = 30
|
||||
# input_ids = torch.randint(0, 100, (1, 10), dtype=torch.long, device='cuda')
|
||||
# max_length = input_ids.shape[1] + 40
|
||||
|
||||
# Slow generation for reference
|
||||
sequences = []
|
||||
scores = []
|
||||
cur_input_ids = input_ids
|
||||
with torch.inference_mode():
|
||||
logits, _ = all_gather_raw(model(cur_input_ids).logits[:, -1], process_group)
|
||||
logits = rearrange(logits, '(n b) d -> b (n d)',
|
||||
b=input_ids.shape[0])[..., :config.vocab_size]
|
||||
scores.append(logits)
|
||||
sequences.append(scores[-1].argmax(dim=-1))
|
||||
for _ in range(input_ids.shape[1] + 1, max_length):
|
||||
cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], 'b -> b 1')], dim=-1)
|
||||
logits, _ = all_gather_raw(model(cur_input_ids).logits[:, -1], process_group)
|
||||
logits = rearrange(logits, '(n b) d -> b (n d)',
|
||||
b=input_ids.shape[0])[..., :config.vocab_size]
|
||||
scores.append(logits)
|
||||
sequences.append(scores[-1].argmax(dim=-1))
|
||||
sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1)
|
||||
scores = tuple(scores)
|
||||
print(sequences)
|
||||
|
||||
out = model.generate(input_ids=input_ids, max_length=max_length, tensor_parallel=world_size,
|
||||
vocab_size=config.vocab_size, fused_ft_kernel=fused_ft_kernel,
|
||||
return_dict_in_generate=True, output_scores=True, timing=True)
|
||||
print(out.sequences)
|
||||
if fused_ft_kernel:
|
||||
out_cg = model.generate(
|
||||
input_ids=input_ids, max_length=max_length, tensor_parallel=world_size,
|
||||
vocab_size=config.vocab_size, fused_ft_kernel=fused_ft_kernel, cg=True,
|
||||
return_dict_in_generate=True, output_scores=True, timing=True)
|
||||
print(out_cg.sequences)
|
||||
|
||||
if not rotary:
|
||||
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length,
|
||||
return_dict_in_generate=True, output_scores=True)
|
||||
out_ref = model_ref.generate(input_ids=input_ids, max_length=max_length,
|
||||
return_dict_in_generate=True, output_scores=True)
|
||||
|
||||
print(f'Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
|
||||
print(f'Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
|
||||
print(f'HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
|
||||
print(f'HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
|
||||
|
||||
assert torch.all(out.sequences == sequences)
|
||||
assert torch.allclose(torch.stack(out.scores, dim=1), torch.stack(scores, dim=1),
|
||||
rtol=rtol, atol=atol)
|
||||
if not rotary:
|
||||
assert torch.all(out.sequences == out_ref.sequences)
|
||||
assert torch.all(out.sequences == out_hf.sequences)
|
||||
|
||||
assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user