[Gen] Make generation work with Tensor Parallel

This commit is contained in:
Tri Dao 2023-01-15 11:34:27 -08:00
parent d509832426
commit 7c2191542a
6 changed files with 373 additions and 145 deletions

View File

@ -1,5 +1,7 @@
#include <torch/extension.h>
#include "ATen/cuda/CUDAContext.h"
#include <c10/cuda/CUDAGuard.h>
#include "decoder_masked_multihead_attention.h"
@ -138,6 +140,10 @@ torch::Tensor single_query_attention(const torch::Tensor q,
TORCH_CHECK(length_per_sample.dtype() == torch::kInt32);
}
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
torch::Tensor out = torch::empty_like(q);
DISPATCH_FLOAT_AND_HALF_AND_BF16(q.scalar_type(), out.scalar_type(), "single_query_attention", [&] {

View File

@ -20,7 +20,7 @@ from flash_attn.modules.mha import MHA, ParallelMHA
from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense, ParallelFusedDenseGeluDense
from flash_attn.modules.block import Block
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
from flash_attn.utils.distributed import sync_shared_params
from flash_attn.utils.distributed import sync_shared_params, all_gather_raw
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.generation import GenerationMixin
@ -146,17 +146,23 @@ class GPTPreTrainedModel(nn.Module):
self.config = config
@classmethod
def from_pretrained(cls, model_name, config, *args, strict=True, device=None, **kwargs):
def from_pretrained(cls, model_name, config, *args, strict=True, device=None, dtype=None,
world_size=1, rank=0, **kwargs):
"""
Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
"""
# Instantiate model.
model = cls(config, *args, device=device, **kwargs)
load_return = model.load_state_dict(
remap_state_dict_gpt2(state_dict_from_pretrained(model_name, device=device), config),
strict=strict
model = cls(config, *args, device=device, dtype=dtype, **kwargs)
state_dict = remap_state_dict_gpt2(
# If we're going to shard the model, then don't load fp32 weights to GPU.
state_dict_from_pretrained(model_name, device=device if world_size == 1 else None,
dtype=dtype), config
)
if world_size > 1:
state_dict = shard_state_dict_tp(state_dict, config, world_size, rank)
state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
load_return = model.load_state_dict(state_dict, strict=strict)
logger.info(load_return)
return model
@ -190,17 +196,16 @@ class GPTModel(GPTPreTrainedModel):
self.process_group = process_group
self.sequence_parallel = getattr(config, 'sequence_parallel', True)
assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'sqrelu']
self.pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
if config.vocab_size % self.pad_vocab_size_multiple != 0:
config.vocab_size += (self.pad_vocab_size_multiple
- (config.vocab_size % self.pad_vocab_size_multiple))
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple)
* pad_vocab_size_multiple)
if process_group is None:
self.embeddings = GPT2Embeddings(config.hidden_size, config.vocab_size,
self.embeddings = GPT2Embeddings(config.hidden_size, vocab_size,
config.max_position_embeddings, **factory_kwargs)
else:
self.embeddings = ParallelGPT2Embeddings(
config.hidden_size, config.vocab_size, config.max_position_embeddings,
config.hidden_size, vocab_size, config.max_position_embeddings,
process_group=process_group, sequence_parallel=self.sequence_parallel,
**factory_kwargs
)
@ -248,8 +253,9 @@ class GPTModel(GPTPreTrainedModel):
hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs)
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
if not self.fused_dropout_add_ln:
residual = self.emb_drop(hidden_states).float()
residual = self.emb_drop(hidden_states)
hidden_states = self.ln_0(residual.to(dtype=self.ln_0.weight.dtype))
residual = residual.float()
else:
hidden_states, residual = dropout_add_layer_norm(
hidden_states, None, self.ln_0.weight, self.ln_0.bias,
@ -272,13 +278,16 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
super().__init__(config)
self.process_group = process_group
self.transformer = GPTModel(config, process_group=process_group, **factory_kwargs)
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple)
* pad_vocab_size_multiple)
if process_group is None:
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False, **factory_kwargs)
self.lm_head = nn.Linear(config.n_embd, vocab_size, bias=False, **factory_kwargs)
else:
if ColumnParallelLinear is None:
raise ImportError('fused_dense_lib is not installed')
self.lm_head = ColumnParallelLinear(
config.n_embd, config.vocab_size, process_group, bias=False,
config.n_embd, vocab_size, process_group, bias=False,
sequence_parallel=getattr(config, 'sequence_parallel', True), **factory_kwargs
)
# Initialize weights and apply final processing
@ -299,6 +308,10 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
hidden_states = self.transformer(input_ids, position_ids=position_ids,
inference_params=inference_params)
lm_logits = self.lm_head(hidden_states)
# During inference, we want the full logit for sampling
if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None:
lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group)
lm_logits = rearrange(lm_logits, '(n b) s d -> b s (n d)', b=hidden_states.shape[0])
CausalLMOutput = namedtuple('CausalLMOutput', ['logits'])
return CausalLMOutput(logits=lm_logits)
@ -310,8 +323,10 @@ def remap_state_dict_gpt2(state_dict, config):
state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop('wte.weight')
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
@ -365,10 +380,8 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
"""Convert the state_dict of a standard GPT model to the state_dict of a GPT model
with tensor parallel.
"""
vocab_size = config.vocab_size
if config.vocab_size % config.pad_vocab_size_multiple != 0:
vocab_size += (config.pad_vocab_size_multiple
- (config.vocab_size % config.pad_vocab_size_multiple))
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
assert vocab_size % world_size == 0
assert config.hidden_size % world_size == 0
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size

View File

@ -289,6 +289,60 @@ class LinearResidual(nn.Linear):
return super().forward(input), input
def _update_kv_cache(kv, inference_params, layer_idx):
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
"""
# Pre-allocate memory for key-values for inference.
num_heads, head_dim = kv.shape[-2:]
if layer_idx not in inference_params.key_value_memory_dict:
kv_cache = torch.empty(
inference_params.max_batch_size, inference_params.max_sequence_len, 2,
num_heads, head_dim, dtype=kv.dtype, device=kv.device
)
inference_params.key_value_memory_dict[layer_idx] = kv_cache
else:
if not inference_params.fused_ft_kernel:
kv_cache = inference_params.key_value_memory_dict[layer_idx]
else:
# For FT, k_cache has shape (b, h, headdim / packsize, s, packsize)
# where packsize = 4 if fp32, 8 if fp16 or bf16.
# v_cache has shape (b, h, s, headdim)
k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx]
kv_cache = None
# Adjust key and value for inference
batch_start = inference_params.batch_size_offset
batch_end = batch_start + kv.shape[0]
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + kv.shape[1]
assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0])
assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2])
# Copy key and values.
if not inference_params.fused_ft_kernel:
assert kv_cache is not None
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
return kv
else:
assert inference_params.sequence_len_offset == 0
# FT kernel requires different layouts for the k_cache and v_cache.
assert kv.dtype in [torch.float16, torch.bfloat16, torch.float32]
packsize = 4 if kv.dtype == torch.float32 else 8
if kv_cache is not None:
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
k_cache = rearrange(kv_cache[:, :, 0], 'b s h (d packsize) -> b h d s packsize',
packsize=packsize).contiguous()
v_cache = rearrange(kv_cache[:, :, 1], 'b s h d -> b h s d').contiguous()
inference_params.key_value_memory_dict[layer_idx] = (k_cache, v_cache)
else:
k_cache[batch_start:batch_end, :, :, :sequence_end, :] = rearrange(
kv[:, :, 0], 'b s h (d packsize) -> b h d s packsize', packsize=packsize
)
v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange(
kv[:, :, 1], 'b s h d -> b h s d'
)
return kv
class MHA(nn.Module):
"""Multi-head self-attention and cross-attention
"""
@ -363,54 +417,7 @@ class MHA(nn.Module):
"""
assert not self.dwconv, 'Generation does not support dwconv yet'
assert self.layer_idx is not None, 'Generation requires layer_idx in the constructor'
# Pre-allocate memory for key-values for inference.
if self.layer_idx not in inference_params.key_value_memory_dict:
kv_cache = torch.empty(
inference_params.max_batch_size, inference_params.max_sequence_len, 2,
self.num_heads, self.head_dim, dtype=kv.dtype, device=kv.device
)
inference_params.key_value_memory_dict[self.layer_idx] = kv_cache
else:
if not inference_params.fused_ft_kernel:
kv_cache = inference_params.key_value_memory_dict[self.layer_idx]
else:
# For FT, k_cache has shape (b, h, headdim / packsize, s, packsize)
# where packsize = 4 if fp32, 8 if fp16 or bf16.
# v_cache has shape (b, h, s, headdim)
k_cache, v_cache = inference_params.key_value_memory_dict[self.layer_idx]
kv_cache = None
# Adjust key and value for inference
batch_start = inference_params.batch_size_offset
batch_end = batch_start + kv.shape[0]
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + kv.shape[1]
assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0])
assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2])
# Copy key and values.
if not inference_params.fused_ft_kernel:
assert kv_cache is not None
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
return kv
else:
assert inference_params.sequence_len_offset == 0
# FT kernel requires different layouts for the k_cache and v_cache.
assert kv.dtype in [torch.float16, torch.bfloat16, torch.float32]
packsize = 4 if kv.dtype == torch.float32 else 8
if kv_cache is not None:
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
k_cache = rearrange(kv_cache[:, :, 0], 'b s h (d packsize) -> b h d s packsize',
packsize=packsize).contiguous()
v_cache = rearrange(kv_cache[:, :, 1], 'b s h d -> b h s d').contiguous()
inference_params.key_value_memory_dict[self.layer_idx] = (k_cache, v_cache)
else:
k_cache[batch_start:batch_end, :, :, :sequence_end, :] = rearrange(
kv[:, :, 0], 'b s h (d packsize) -> b h d s packsize', packsize=packsize
)
v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange(
kv[:, :, 1], 'b s h d -> b h s d'
)
return kv
return _update_kv_cache(kv, inference_params, self.layer_idx)
def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None,
inference_params=None, **kwargs):
@ -473,6 +480,7 @@ class MHA(nn.Module):
causal = None if inference_params.sequence_len_offset == 0 else False
context = self.inner_cross_attn(q, kv, causal=causal)
else:
assert inference_params.fused_ft_kernel
assert ft_attention is not None
context = ft_attention.single_query_attention(
*rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1),
@ -541,13 +549,16 @@ class ParallelMHA(nn.Module):
self.Wqkv = ColumnParallelLinear(embed_dim, 3 * embed_dim, process_group, bias=bias,
sequence_parallel=sequence_parallel, **factory_kwargs)
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale,
attention_dropout=dropout)
self.inner_cross_attn = inner_cross_attn_cls(causal=causal, softmax_scale=softmax_scale,
attention_dropout=dropout)
# output projection always have the bias (for now)
self.out_proj = RowParallelLinear(embed_dim, embed_dim, process_group,
sequence_parallel=sequence_parallel, **factory_kwargs)
def forward(self, x, seqlen=None, **kwargs):
def forward(self, x, seqlen=None, inference_params=None, **kwargs):
"""
Arguments:
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
@ -561,12 +572,34 @@ class ParallelMHA(nn.Module):
else:
qkv = rearrange(qkv, '(b s) (three h d) -> b s three h d', s=seqlen, three=3,
d=self.head_dim)
if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(qkv)
if not self.checkpointing:
context = self.inner_attn(qkv, **kwargs)
if inference_params is None:
if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(qkv)
if not self.checkpointing:
context = self.inner_attn(qkv, **kwargs)
else:
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
else:
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
if (not inference_params.fused_ft_kernel) or inference_params.sequence_len_offset == 0:
if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(qkv, seqlen_offset=inference_params.sequence_len_offset)
q = qkv[:, :, 0]
assert self.layer_idx is not None, 'Generation requires layer_idx in the constructor'
kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx)
# If we're processing the prompt, causal=None (use self.causal).
# If we're decoding, then causal=False.
causal = None if inference_params.sequence_len_offset == 0 else False
context = self.inner_cross_attn(q, kv, causal=causal)
else:
assert inference_params.fused_ft_kernel
assert ft_attention is not None
context = ft_attention.single_query_attention(
*rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1),
*inference_params.key_value_memory_dict[self.layer_idx],
inference_params.lengths_per_sample, inference_params.sequence_len_offset,
self.rotary_emb_dim
)
context = rearrange(context, 'b h d -> b 1 h d')
if seqlen is None:
context = rearrange(context, 'b s h d -> b s (h d)')
else:

View File

@ -1,6 +1,7 @@
# Copyright (c) 2022, Tri Dao.
# Copyright (c) 2023, Tri Dao.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31
from typing import Optional
from typing import Optional, Union, Sequence, Callable
import gc
import time
from dataclasses import dataclass, field
@ -70,7 +71,7 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
fused_ft_kernel=False, cg=False, timing=False):
vocab_size=None, tensor_parallel=1, fused_ft_kernel=False, cg=False, timing=False):
"""Decoding, either greedy or with top-k or top-p sampling.
If top-k = 0, don't limit the number of candidates (pure sampling).
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
@ -85,18 +86,30 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
scores: tuples of (batch, vocab_size)
"""
batch_size, seqlen_og = input_ids.shape
inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size,
fused_ft_kernel=fused_ft_kernel)
if cg:
assert fused_ft_kernel
if not hasattr(model, '_decoding_cache'):
model._decoding_cache = None
model._decoding_cache = update_graph_cache(
model, model._decoding_cache, batch_size, seqlen_og, max_length,
tensor_parallel=tensor_parallel
)
inference_params = model._decoding_cache.inference_params
inference_params.max_sequence_len = max_length
inference_params.max_batch_size = batch_size
inference_params.sequence_len_offset = 0
else:
inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size,
fused_ft_kernel=fused_ft_kernel)
scores = []
with torch.inference_mode():
logits = model(input_ids, inference_params=inference_params).logits[:, -1]
if vocab_size is not None:
logits = logits[..., :vocab_size]
scores.append(logits)
next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
sequences = [next_token]
inference_params.sequence_len_offset = seqlen_og
if cg:
assert fused_ft_kernel
run, cg_cache = capture_cg(model, inference_params, batch_size, seqlen_og, max_length)
if timing:
start = time.time()
while True:
@ -106,8 +119,10 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
logits = model(rearrange(next_token, 'b -> b 1'), position_ids=position_ids,
inference_params=inference_params).logits[:, -1]
else:
logits = run(rearrange(next_token, 'b -> b 1'), position_ids,
inference_params.sequence_len_offset)
logits = model._decoding_cache.run(rearrange(next_token, 'b -> b 1'), position_ids,
inference_params.sequence_len_offset)
if vocab_size is not None:
logits = logits[..., :vocab_size]
scores.append(logits)
next_token = sample(logits, top_k=top_k, temperature=temperature)
sequences.append(next_token)
@ -115,6 +130,7 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
if inference_params.sequence_len_offset >= max_length - 1:
break
if timing:
torch.cuda.synchronize()
print(f'Decoding time: {time.time() - start}')
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
return output_cls(
@ -134,8 +150,18 @@ class GenerationMixin:
return output if return_dict_in_generate else output.sequences
CgKey = namedtuple('CgKey', ['batch_size', 'seqlen_type', 'max_length'])
CgVal = namedtuple('CgVal', ['graph', 'input_ids', 'position_ids', 'lengths', 'logits'])
def allocate_kv_cache(max_batch_size, max_seqlen, nheads, headdim, layers: Union[int, Sequence],
device, dtype=torch.float16):
assert dtype in [torch.float16, torch.bfloat16, torch.float32]
packsize = 4 if dtype == torch.float32 else 8
assert headdim % packsize == 0
k_cache_shape = (max_batch_size, nheads, headdim // packsize, max_seqlen, packsize)
v_cache_shape = (max_batch_size, nheads, max_seqlen, headdim)
if isinstance(layers, int):
layers = range(layers)
return {i: (torch.empty(k_cache_shape, device=device, dtype=dtype),
torch.empty(v_cache_shape, device=device, dtype=dtype))
for i in layers}
def seqlen_to_seqlen_type(seqlen: int) -> int:
@ -152,63 +178,91 @@ def seqlen_type_to_seqlen(seqlen_type: int) -> int:
return 1 if seqlen_type == 0 else (32 if seqlen_type == 1 else 2048)
def capture_cg(model, inference_params, batch_size, seqlen_og, max_length, copy_output=False):
"""Build a cache of cuda graphs for decoding.
Arguments:
model: a GPTLMHeadModel
batch_size: int
seqlen_og: int. Length of the prompt.
max_length: int
TODO: how do we deal with the k_cache and v_cache memory? I think the CUDA graph also
has to own the k_cache and v_cache?
Here we assume that the model already has inference_params from the prompt processing.
"""
assert max_length > seqlen_og
cg_cache: dict[CgKey, CgVal] = {}
@dataclass
class DecodingCGCache:
max_batch_size: int = 0
max_seqlen: int = 0
device = None
dtype = None
callables: dict = field(default_factory=dict)
mempool = None
inference_params: Optional[InferenceParams] = None
run: Optional[Callable] = None
@torch.inference_mode()
def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_parallel=1,
dtype=None):
if cache is None:
cache = DecodingCGCache()
param_example = next(iter(model.parameters()))
device = param_example.device
if dtype is None:
dtype = param_example.dtype
if ((device, dtype) != (cache.device, cache.dtype) or batch_size > cache.max_batch_size
or max_seqlen > cache.max_seqlen): # Invalidate the cache
cache.callables = {}
cache.mempool = None
cache.inference_params = None
gc.collect()
cache.device, cache.dtype = device, dtype
cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
headdim = getattr(model.config, 'head_dim',
model.config.hidden_size // model.config.num_attention_heads)
kv_cache = allocate_kv_cache(
batch_size, max_seqlen, model.config.num_attention_heads // tensor_parallel, headdim,
model.config.num_hidden_layers, device, dtype
)
lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
cache.inference_params = InferenceParams(
max_sequence_len=max_seqlen, max_batch_size=batch_size,
sequence_len_offset=seqlen_og, key_value_memory_dict=kv_cache, fused_ft_kernel=True,
lengths_per_sample=lengths_per_sample
)
cache.mempool = torch.cuda.graphs.graph_pool_handle()
for s_type in range(seqlen_to_seqlen_type(seqlen_og), seqlen_to_seqlen_type(max_seqlen) + 1):
if s_type not in cache.callables:
seqlen = min(max(seqlen_og, seqlen_type_to_seqlen(s_type)), max_seqlen)
cache.callables[s_type] = capture_graph(
model, cache.inference_params, batch_size, seqlen_og, seqlen, mempool=cache.mempool
)
def dispatch(input_ids, position_ids, seqlen):
return cache.callables[seqlen_to_seqlen_type(seqlen)](input_ids, position_ids, seqlen)
cache.run = dispatch
cache.inference_params.sequence_length_offset = 0 # Reset so it's not confusing
return cache
def capture_graph(model, inference_params, batch_size, seqlen_og, max_seqlen, mempool=None):
assert max_seqlen >= seqlen_og
device = next(iter(model.parameters())).device
sequence_length_offset_og = inference_params.sequence_len_offset
input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
position_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
inference_params.lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32,
device=device)
inference_params.lengths_per_sample[:] = seqlen_og
memory_pool = None
for s_type in range(seqlen_to_seqlen_type(seqlen_og), seqlen_to_seqlen_type(max_length) + 1):
seqlen = max(seqlen_og, seqlen_type_to_seqlen(s_type))
input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
position_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device)
inference_params.lengths_per_sample[:] = seqlen
inference_params.sequence_len_offset = seqlen
g = torch.cuda.CUDAGraph()
# Warmup before capture
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(2):
logits = model(input_ids, position_ids=position_ids,
inference_params=inference_params).logits[:, -1]
torch.cuda.current_stream().wait_stream(s)
# Captures the graph
# To allow capture, automatically sets a side stream as the current stream in the context
with torch.cuda.graph(g, pool=memory_pool):
# Warmup before capture
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(2):
logits = model(input_ids, position_ids=position_ids,
inference_params=inference_params).logits[:, -1]
if memory_pool is None:
memory_pool = g.pool()
cg_cache[CgKey(batch_size, s_type, max_length)] = CgVal(
g, input_ids, position_ids, inference_params.lengths_per_sample, logits
)
inference_params=inference_params).logits[:, -1]
s.synchronize()
torch.cuda.current_stream().wait_stream(s)
# Captures the graph
# To allow capture, automatically sets a side stream as the current stream in the context
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, pool=mempool):
logits = model(input_ids, position_ids=position_ids,
inference_params=inference_params).logits[:, -1]
def run(new_input_ids, new_position_ids, seqlen):
cg_val = cg_cache[CgKey(batch_size, seqlen_to_seqlen_type(seqlen), max_length)]
inference_params.lengths_per_sample = cg_val.lengths
inference_params.lengths_per_sample[:] = seqlen
cg_val.input_ids.copy_(new_input_ids)
cg_val.position_ids.copy_(new_position_ids)
cg_val.graph.replay()
output = cg_val.logits
return output.clone() if copy_output else output
input_ids.copy_(new_input_ids)
position_ids.copy_(new_position_ids)
graph.replay()
return logits
inference_params.sequence_len_offset = sequence_length_offset_og
return run, cg_cache
return run

View File

@ -4,5 +4,8 @@ from transformers.utils import WEIGHTS_NAME
from transformers.utils.hub import cached_file
def state_dict_from_pretrained(model_name, device=None):
return torch.load(cached_file(model_name, WEIGHTS_NAME), map_location=device)
def state_dict_from_pretrained(model_name, device=None, dtype=None):
state_dict = torch.load(cached_file(model_name, WEIGHTS_NAME), map_location=device)
if dtype is not None:
state_dict = {k: v.to(dtype) for k, v in state_dict.items()}
return state_dict

View File

@ -1,3 +1,4 @@
import os
import re
import torch
@ -11,12 +12,13 @@ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHead
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gpt import remap_state_dict_gpt2
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.distributed import all_gather_raw
@pytest.mark.parametrize('fused_ft_kernel', [False, True])
# @pytest.mark.parametrize('fused_ft_kernel', [True])
@pytest.mark.parametrize('optimized', [False, True])
# @pytest.mark.parametrize('optimized', [False])
# @pytest.mark.parametrize('optimized', [True])
@pytest.mark.parametrize('rotary', [False, True])
# @pytest.mark.parametrize('rotary', [False])
@pytest.mark.parametrize('model_name', ["gpt2"])
@ -40,19 +42,20 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
# if not rotary, we load the weight from HF but ignore the position embeddings.
# The model would be nonsense but it doesn't matter for the test.
model = GPTLMHeadModel.from_pretrained(model_name, config, strict=not rotary, device=device)
model = model.to(dtype=dtype)
model = GPTLMHeadModel.from_pretrained(model_name, config, strict=not rotary, device=device,
dtype=dtype)
model.eval()
if not rotary:
model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).cuda()
model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).cuda().to(dtype=dtype)
model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device)
model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device, dtype=dtype)
model_ref.eval()
model_hf.eval()
torch.manual_seed(0)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
input_ids = tokenizer("Hello, my dog is cute and ", return_tensors="pt").input_ids.cuda()
input_ids = tokenizer("Hello, my dog is cute and ",
return_tensors="pt").input_ids.to(device=device)
max_length = 30
# input_ids = torch.randint(0, 100, (1, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40
@ -100,3 +103,119 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
assert torch.all(out.sequences == out_hf.sequences)
assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/models/test_gpt_generation.py -k "parallel"
# @pytest.mark.parametrize('world_size', [1, 2, 4, 8])
@pytest.mark.parametrize('world_size', [2])
# @pytest.mark.parametrize('fused_ft_kernel', [False, True])
@pytest.mark.parametrize('fused_ft_kernel', [True])
# @pytest.mark.parametrize('rotary', [False, True])
@pytest.mark.parametrize('rotary', [False])
@pytest.mark.parametrize('model_name', ["gpt2"])
def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
"""Check that our implementation of GPT2 generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
"""
dtype = torch.float16
rtol, atol = 3e-3, 3e-1
config = GPT2Config.from_pretrained(model_name)
if rotary:
config.n_positions = 0
config.rotary_emb_dim = 64
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_dense_gelu_dense = True
config.fused_dropout_add_ln = True
config.pad_vocab_size_multiple = 8 * world_size
config.sequence_parallel = False # Need to set this to False for generation
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://')
device = f'cuda:{torch.distributed.get_rank()}'
assert world_size <= torch.distributed.get_world_size()
# Need this, otherwise when we capture the graph the process for GPU 1 would run on both
# GPU0 and GPU1 and things would hang
torch.cuda.set_device(device)
from apex.transformer import parallel_state
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
process_group = parallel_state.get_tensor_model_parallel_group()
# if not rotary, we load the weight from HF but ignore the position embeddings.
# The model would be nonsense but it doesn't matter for the test.
model = GPTLMHeadModel.from_pretrained(model_name, config, strict=not rotary, device=device,
dtype=dtype, process_group=process_group,
world_size=world_size, rank=rank)
model.eval()
if not rotary:
model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device)
model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device, dtype=dtype)
model_ref.eval()
model_hf.eval()
torch.manual_seed(0)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
input_ids = tokenizer("Hello, my dog is cute and ",
return_tensors="pt").input_ids.to(device=device)
max_length = 30
# input_ids = torch.randint(0, 100, (1, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40
# Slow generation for reference
sequences = []
scores = []
cur_input_ids = input_ids
with torch.inference_mode():
logits, _ = all_gather_raw(model(cur_input_ids).logits[:, -1], process_group)
logits = rearrange(logits, '(n b) d -> b (n d)',
b=input_ids.shape[0])[..., :config.vocab_size]
scores.append(logits)
sequences.append(scores[-1].argmax(dim=-1))
for _ in range(input_ids.shape[1] + 1, max_length):
cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], 'b -> b 1')], dim=-1)
logits, _ = all_gather_raw(model(cur_input_ids).logits[:, -1], process_group)
logits = rearrange(logits, '(n b) d -> b (n d)',
b=input_ids.shape[0])[..., :config.vocab_size]
scores.append(logits)
sequences.append(scores[-1].argmax(dim=-1))
sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1)
scores = tuple(scores)
print(sequences)
out = model.generate(input_ids=input_ids, max_length=max_length, tensor_parallel=world_size,
vocab_size=config.vocab_size, fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True, output_scores=True, timing=True)
print(out.sequences)
if fused_ft_kernel:
out_cg = model.generate(
input_ids=input_ids, max_length=max_length, tensor_parallel=world_size,
vocab_size=config.vocab_size, fused_ft_kernel=fused_ft_kernel, cg=True,
return_dict_in_generate=True, output_scores=True, timing=True)
print(out_cg.sequences)
if not rotary:
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length,
return_dict_in_generate=True, output_scores=True)
out_ref = model_ref.generate(input_ids=input_ids, max_length=max_length,
return_dict_in_generate=True, output_scores=True)
print(f'Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
print(f'Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
print(f'HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
print(f'HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
assert torch.all(out.sequences == sequences)
assert torch.allclose(torch.stack(out.scores, dim=1), torch.stack(scores, dim=1),
rtol=rtol, atol=atol)
if not rotary:
assert torch.all(out.sequences == out_ref.sequences)
assert torch.all(out.sequences == out_hf.sequences)
assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()