From 7c2191542afe110c87f61f227f8df4e95d0ea0af Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 15 Jan 2023 11:34:27 -0800 Subject: [PATCH] [Gen] Make generation work with Tensor Parallel --- csrc/ft_attention/ft_attention.cpp | 6 + flash_attn/models/gpt.py | 53 ++++---- flash_attn/modules/mha.py | 141 +++++++++++++--------- flash_attn/utils/generation.py | 180 ++++++++++++++++++---------- flash_attn/utils/pretrained.py | 7 +- tests/models/test_gpt_generation.py | 131 +++++++++++++++++++- 6 files changed, 373 insertions(+), 145 deletions(-) diff --git a/csrc/ft_attention/ft_attention.cpp b/csrc/ft_attention/ft_attention.cpp index 99e5a55..41a8485 100644 --- a/csrc/ft_attention/ft_attention.cpp +++ b/csrc/ft_attention/ft_attention.cpp @@ -1,5 +1,7 @@ #include #include "ATen/cuda/CUDAContext.h" +#include + #include "decoder_masked_multihead_attention.h" @@ -138,6 +140,10 @@ torch::Tensor single_query_attention(const torch::Tensor q, TORCH_CHECK(length_per_sample.dtype() == torch::kInt32); } + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + torch::Tensor out = torch::empty_like(q); DISPATCH_FLOAT_AND_HALF_AND_BF16(q.scalar_type(), out.scalar_type(), "single_query_attention", [&] { diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index 99f508d..3f5f115 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -20,7 +20,7 @@ from flash_attn.modules.mha import MHA, ParallelMHA from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense, ParallelFusedDenseGeluDense from flash_attn.modules.block import Block from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings -from flash_attn.utils.distributed import sync_shared_params +from flash_attn.utils.distributed import sync_shared_params, all_gather_raw from flash_attn.utils.pretrained import state_dict_from_pretrained from flash_attn.utils.generation import GenerationMixin @@ -146,17 +146,23 @@ class GPTPreTrainedModel(nn.Module): self.config = config @classmethod - def from_pretrained(cls, model_name, config, *args, strict=True, device=None, **kwargs): + def from_pretrained(cls, model_name, config, *args, strict=True, device=None, dtype=None, + world_size=1, rank=0, **kwargs): """ Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict. Download and cache the pre-trained model file if needed. """ # Instantiate model. - model = cls(config, *args, device=device, **kwargs) - load_return = model.load_state_dict( - remap_state_dict_gpt2(state_dict_from_pretrained(model_name, device=device), config), - strict=strict + model = cls(config, *args, device=device, dtype=dtype, **kwargs) + state_dict = remap_state_dict_gpt2( + # If we're going to shard the model, then don't load fp32 weights to GPU. + state_dict_from_pretrained(model_name, device=device if world_size == 1 else None, + dtype=dtype), config ) + if world_size > 1: + state_dict = shard_state_dict_tp(state_dict, config, world_size, rank) + state_dict = {k: v.to(device=device) for k, v in state_dict.items()} + load_return = model.load_state_dict(state_dict, strict=strict) logger.info(load_return) return model @@ -190,17 +196,16 @@ class GPTModel(GPTPreTrainedModel): self.process_group = process_group self.sequence_parallel = getattr(config, 'sequence_parallel', True) assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'sqrelu'] - self.pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) - if config.vocab_size % self.pad_vocab_size_multiple != 0: - config.vocab_size += (self.pad_vocab_size_multiple - - (config.vocab_size % self.pad_vocab_size_multiple)) + pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) + vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) + * pad_vocab_size_multiple) if process_group is None: - self.embeddings = GPT2Embeddings(config.hidden_size, config.vocab_size, + self.embeddings = GPT2Embeddings(config.hidden_size, vocab_size, config.max_position_embeddings, **factory_kwargs) else: self.embeddings = ParallelGPT2Embeddings( - config.hidden_size, config.vocab_size, config.max_position_embeddings, + config.hidden_size, vocab_size, config.max_position_embeddings, process_group=process_group, sequence_parallel=self.sequence_parallel, **factory_kwargs ) @@ -248,8 +253,9 @@ class GPTModel(GPTPreTrainedModel): hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs) # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable if not self.fused_dropout_add_ln: - residual = self.emb_drop(hidden_states).float() + residual = self.emb_drop(hidden_states) hidden_states = self.ln_0(residual.to(dtype=self.ln_0.weight.dtype)) + residual = residual.float() else: hidden_states, residual = dropout_add_layer_norm( hidden_states, None, self.ln_0.weight, self.ln_0.bias, @@ -272,13 +278,16 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): super().__init__(config) self.process_group = process_group self.transformer = GPTModel(config, process_group=process_group, **factory_kwargs) + pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) + vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) + * pad_vocab_size_multiple) if process_group is None: - self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False, **factory_kwargs) + self.lm_head = nn.Linear(config.n_embd, vocab_size, bias=False, **factory_kwargs) else: if ColumnParallelLinear is None: raise ImportError('fused_dense_lib is not installed') self.lm_head = ColumnParallelLinear( - config.n_embd, config.vocab_size, process_group, bias=False, + config.n_embd, vocab_size, process_group, bias=False, sequence_parallel=getattr(config, 'sequence_parallel', True), **factory_kwargs ) # Initialize weights and apply final processing @@ -299,6 +308,10 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): hidden_states = self.transformer(input_ids, position_ids=position_ids, inference_params=inference_params) lm_logits = self.lm_head(hidden_states) + # During inference, we want the full logit for sampling + if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None: + lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group) + lm_logits = rearrange(lm_logits, '(n b) s d -> b s (n d)', b=hidden_states.shape[0]) CausalLMOutput = namedtuple('CausalLMOutput', ['logits']) return CausalLMOutput(logits=lm_logits) @@ -310,8 +323,10 @@ def remap_state_dict_gpt2(state_dict, config): state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items()) word_embeddings = state_dict.pop('wte.weight') # It's possible that vocab_size is padded to be a multiple of 8, for example. + pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) + vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple) state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad( - word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0]) + word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) ) state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight'] @@ -365,10 +380,8 @@ def shard_state_dict_tp(state_dict, config, world_size, rank): """Convert the state_dict of a standard GPT model to the state_dict of a GPT model with tensor parallel. """ - vocab_size = config.vocab_size - if config.vocab_size % config.pad_vocab_size_multiple != 0: - vocab_size += (config.pad_vocab_size_multiple - - (config.vocab_size % config.pad_vocab_size_multiple)) + pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) + vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple) assert vocab_size % world_size == 0 assert config.hidden_size % world_size == 0 inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index e6477cc..4998d93 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -289,6 +289,60 @@ class LinearResidual(nn.Linear): return super().forward(input), input +def _update_kv_cache(kv, inference_params, layer_idx): + """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim) + """ + # Pre-allocate memory for key-values for inference. + num_heads, head_dim = kv.shape[-2:] + if layer_idx not in inference_params.key_value_memory_dict: + kv_cache = torch.empty( + inference_params.max_batch_size, inference_params.max_sequence_len, 2, + num_heads, head_dim, dtype=kv.dtype, device=kv.device + ) + inference_params.key_value_memory_dict[layer_idx] = kv_cache + else: + if not inference_params.fused_ft_kernel: + kv_cache = inference_params.key_value_memory_dict[layer_idx] + else: + # For FT, k_cache has shape (b, h, headdim / packsize, s, packsize) + # where packsize = 4 if fp32, 8 if fp16 or bf16. + # v_cache has shape (b, h, s, headdim) + k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx] + kv_cache = None + # Adjust key and value for inference + batch_start = inference_params.batch_size_offset + batch_end = batch_start + kv.shape[0] + sequence_start = inference_params.sequence_len_offset + sequence_end = sequence_start + kv.shape[1] + assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0]) + assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2]) + # Copy key and values. + if not inference_params.fused_ft_kernel: + assert kv_cache is not None + kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv + kv = kv_cache[batch_start:batch_end, :sequence_end, ...] + return kv + else: + assert inference_params.sequence_len_offset == 0 + # FT kernel requires different layouts for the k_cache and v_cache. + assert kv.dtype in [torch.float16, torch.bfloat16, torch.float32] + packsize = 4 if kv.dtype == torch.float32 else 8 + if kv_cache is not None: + kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv + k_cache = rearrange(kv_cache[:, :, 0], 'b s h (d packsize) -> b h d s packsize', + packsize=packsize).contiguous() + v_cache = rearrange(kv_cache[:, :, 1], 'b s h d -> b h s d').contiguous() + inference_params.key_value_memory_dict[layer_idx] = (k_cache, v_cache) + else: + k_cache[batch_start:batch_end, :, :, :sequence_end, :] = rearrange( + kv[:, :, 0], 'b s h (d packsize) -> b h d s packsize', packsize=packsize + ) + v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange( + kv[:, :, 1], 'b s h d -> b h s d' + ) + return kv + + class MHA(nn.Module): """Multi-head self-attention and cross-attention """ @@ -363,54 +417,7 @@ class MHA(nn.Module): """ assert not self.dwconv, 'Generation does not support dwconv yet' assert self.layer_idx is not None, 'Generation requires layer_idx in the constructor' - # Pre-allocate memory for key-values for inference. - if self.layer_idx not in inference_params.key_value_memory_dict: - kv_cache = torch.empty( - inference_params.max_batch_size, inference_params.max_sequence_len, 2, - self.num_heads, self.head_dim, dtype=kv.dtype, device=kv.device - ) - inference_params.key_value_memory_dict[self.layer_idx] = kv_cache - else: - if not inference_params.fused_ft_kernel: - kv_cache = inference_params.key_value_memory_dict[self.layer_idx] - else: - # For FT, k_cache has shape (b, h, headdim / packsize, s, packsize) - # where packsize = 4 if fp32, 8 if fp16 or bf16. - # v_cache has shape (b, h, s, headdim) - k_cache, v_cache = inference_params.key_value_memory_dict[self.layer_idx] - kv_cache = None - # Adjust key and value for inference - batch_start = inference_params.batch_size_offset - batch_end = batch_start + kv.shape[0] - sequence_start = inference_params.sequence_len_offset - sequence_end = sequence_start + kv.shape[1] - assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0]) - assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2]) - # Copy key and values. - if not inference_params.fused_ft_kernel: - assert kv_cache is not None - kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv - kv = kv_cache[batch_start:batch_end, :sequence_end, ...] - return kv - else: - assert inference_params.sequence_len_offset == 0 - # FT kernel requires different layouts for the k_cache and v_cache. - assert kv.dtype in [torch.float16, torch.bfloat16, torch.float32] - packsize = 4 if kv.dtype == torch.float32 else 8 - if kv_cache is not None: - kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv - k_cache = rearrange(kv_cache[:, :, 0], 'b s h (d packsize) -> b h d s packsize', - packsize=packsize).contiguous() - v_cache = rearrange(kv_cache[:, :, 1], 'b s h d -> b h s d').contiguous() - inference_params.key_value_memory_dict[self.layer_idx] = (k_cache, v_cache) - else: - k_cache[batch_start:batch_end, :, :, :sequence_end, :] = rearrange( - kv[:, :, 0], 'b s h (d packsize) -> b h d s packsize', packsize=packsize - ) - v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange( - kv[:, :, 1], 'b s h d -> b h s d' - ) - return kv + return _update_kv_cache(kv, inference_params, self.layer_idx) def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None, inference_params=None, **kwargs): @@ -473,6 +480,7 @@ class MHA(nn.Module): causal = None if inference_params.sequence_len_offset == 0 else False context = self.inner_cross_attn(q, kv, causal=causal) else: + assert inference_params.fused_ft_kernel assert ft_attention is not None context = ft_attention.single_query_attention( *rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1), @@ -541,13 +549,16 @@ class ParallelMHA(nn.Module): self.Wqkv = ColumnParallelLinear(embed_dim, 3 * embed_dim, process_group, bias=bias, sequence_parallel=sequence_parallel, **factory_kwargs) inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention + inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) + self.inner_cross_attn = inner_cross_attn_cls(causal=causal, softmax_scale=softmax_scale, + attention_dropout=dropout) # output projection always have the bias (for now) self.out_proj = RowParallelLinear(embed_dim, embed_dim, process_group, sequence_parallel=sequence_parallel, **factory_kwargs) - def forward(self, x, seqlen=None, **kwargs): + def forward(self, x, seqlen=None, inference_params=None, **kwargs): """ Arguments: x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None. @@ -561,12 +572,34 @@ class ParallelMHA(nn.Module): else: qkv = rearrange(qkv, '(b s) (three h d) -> b s three h d', s=seqlen, three=3, d=self.head_dim) - if self.rotary_emb_dim > 0: - qkv = self.rotary_emb(qkv) - if not self.checkpointing: - context = self.inner_attn(qkv, **kwargs) + if inference_params is None: + if self.rotary_emb_dim > 0: + qkv = self.rotary_emb(qkv) + if not self.checkpointing: + context = self.inner_attn(qkv, **kwargs) + else: + context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs) else: - context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs) + if (not inference_params.fused_ft_kernel) or inference_params.sequence_len_offset == 0: + if self.rotary_emb_dim > 0: + qkv = self.rotary_emb(qkv, seqlen_offset=inference_params.sequence_len_offset) + q = qkv[:, :, 0] + assert self.layer_idx is not None, 'Generation requires layer_idx in the constructor' + kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx) + # If we're processing the prompt, causal=None (use self.causal). + # If we're decoding, then causal=False. + causal = None if inference_params.sequence_len_offset == 0 else False + context = self.inner_cross_attn(q, kv, causal=causal) + else: + assert inference_params.fused_ft_kernel + assert ft_attention is not None + context = ft_attention.single_query_attention( + *rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1), + *inference_params.key_value_memory_dict[self.layer_idx], + inference_params.lengths_per_sample, inference_params.sequence_len_offset, + self.rotary_emb_dim + ) + context = rearrange(context, 'b h d -> b 1 h d') if seqlen is None: context = rearrange(context, 'b s h d -> b s (h d)') else: diff --git a/flash_attn/utils/generation.py b/flash_attn/utils/generation.py index ea01ea1..297496c 100644 --- a/flash_attn/utils/generation.py +++ b/flash_attn/utils/generation.py @@ -1,6 +1,7 @@ -# Copyright (c) 2022, Tri Dao. +# Copyright (c) 2023, Tri Dao. # Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31 -from typing import Optional +from typing import Optional, Union, Sequence, Callable +import gc import time from dataclasses import dataclass, field @@ -70,7 +71,7 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0): def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, - fused_ft_kernel=False, cg=False, timing=False): + vocab_size=None, tensor_parallel=1, fused_ft_kernel=False, cg=False, timing=False): """Decoding, either greedy or with top-k or top-p sampling. If top-k = 0, don't limit the number of candidates (pure sampling). Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first, @@ -85,18 +86,30 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, scores: tuples of (batch, vocab_size) """ batch_size, seqlen_og = input_ids.shape - inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size, - fused_ft_kernel=fused_ft_kernel) + if cg: + assert fused_ft_kernel + if not hasattr(model, '_decoding_cache'): + model._decoding_cache = None + model._decoding_cache = update_graph_cache( + model, model._decoding_cache, batch_size, seqlen_og, max_length, + tensor_parallel=tensor_parallel + ) + inference_params = model._decoding_cache.inference_params + inference_params.max_sequence_len = max_length + inference_params.max_batch_size = batch_size + inference_params.sequence_len_offset = 0 + else: + inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size, + fused_ft_kernel=fused_ft_kernel) scores = [] with torch.inference_mode(): logits = model(input_ids, inference_params=inference_params).logits[:, -1] + if vocab_size is not None: + logits = logits[..., :vocab_size] scores.append(logits) next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature) sequences = [next_token] inference_params.sequence_len_offset = seqlen_og - if cg: - assert fused_ft_kernel - run, cg_cache = capture_cg(model, inference_params, batch_size, seqlen_og, max_length) if timing: start = time.time() while True: @@ -106,8 +119,10 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, logits = model(rearrange(next_token, 'b -> b 1'), position_ids=position_ids, inference_params=inference_params).logits[:, -1] else: - logits = run(rearrange(next_token, 'b -> b 1'), position_ids, - inference_params.sequence_len_offset) + logits = model._decoding_cache.run(rearrange(next_token, 'b -> b 1'), position_ids, + inference_params.sequence_len_offset) + if vocab_size is not None: + logits = logits[..., :vocab_size] scores.append(logits) next_token = sample(logits, top_k=top_k, temperature=temperature) sequences.append(next_token) @@ -115,6 +130,7 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, if inference_params.sequence_len_offset >= max_length - 1: break if timing: + torch.cuda.synchronize() print(f'Decoding time: {time.time() - start}') output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput return output_cls( @@ -134,8 +150,18 @@ class GenerationMixin: return output if return_dict_in_generate else output.sequences -CgKey = namedtuple('CgKey', ['batch_size', 'seqlen_type', 'max_length']) -CgVal = namedtuple('CgVal', ['graph', 'input_ids', 'position_ids', 'lengths', 'logits']) +def allocate_kv_cache(max_batch_size, max_seqlen, nheads, headdim, layers: Union[int, Sequence], + device, dtype=torch.float16): + assert dtype in [torch.float16, torch.bfloat16, torch.float32] + packsize = 4 if dtype == torch.float32 else 8 + assert headdim % packsize == 0 + k_cache_shape = (max_batch_size, nheads, headdim // packsize, max_seqlen, packsize) + v_cache_shape = (max_batch_size, nheads, max_seqlen, headdim) + if isinstance(layers, int): + layers = range(layers) + return {i: (torch.empty(k_cache_shape, device=device, dtype=dtype), + torch.empty(v_cache_shape, device=device, dtype=dtype)) + for i in layers} def seqlen_to_seqlen_type(seqlen: int) -> int: @@ -152,63 +178,91 @@ def seqlen_type_to_seqlen(seqlen_type: int) -> int: return 1 if seqlen_type == 0 else (32 if seqlen_type == 1 else 2048) -def capture_cg(model, inference_params, batch_size, seqlen_og, max_length, copy_output=False): - """Build a cache of cuda graphs for decoding. - Arguments: - model: a GPTLMHeadModel - batch_size: int - seqlen_og: int. Length of the prompt. - max_length: int - TODO: how do we deal with the k_cache and v_cache memory? I think the CUDA graph also - has to own the k_cache and v_cache? - Here we assume that the model already has inference_params from the prompt processing. - """ - assert max_length > seqlen_og - cg_cache: dict[CgKey, CgVal] = {} +@dataclass +class DecodingCGCache: + max_batch_size: int = 0 + max_seqlen: int = 0 + device = None + dtype = None + callables: dict = field(default_factory=dict) + mempool = None + inference_params: Optional[InferenceParams] = None + run: Optional[Callable] = None + + +@torch.inference_mode() +def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_parallel=1, + dtype=None): + if cache is None: + cache = DecodingCGCache() + param_example = next(iter(model.parameters())) + device = param_example.device + if dtype is None: + dtype = param_example.dtype + if ((device, dtype) != (cache.device, cache.dtype) or batch_size > cache.max_batch_size + or max_seqlen > cache.max_seqlen): # Invalidate the cache + cache.callables = {} + cache.mempool = None + cache.inference_params = None + gc.collect() + cache.device, cache.dtype = device, dtype + cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen + headdim = getattr(model.config, 'head_dim', + model.config.hidden_size // model.config.num_attention_heads) + kv_cache = allocate_kv_cache( + batch_size, max_seqlen, model.config.num_attention_heads // tensor_parallel, headdim, + model.config.num_hidden_layers, device, dtype + ) + lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device) + cache.inference_params = InferenceParams( + max_sequence_len=max_seqlen, max_batch_size=batch_size, + sequence_len_offset=seqlen_og, key_value_memory_dict=kv_cache, fused_ft_kernel=True, + lengths_per_sample=lengths_per_sample + ) + cache.mempool = torch.cuda.graphs.graph_pool_handle() + for s_type in range(seqlen_to_seqlen_type(seqlen_og), seqlen_to_seqlen_type(max_seqlen) + 1): + if s_type not in cache.callables: + seqlen = min(max(seqlen_og, seqlen_type_to_seqlen(s_type)), max_seqlen) + cache.callables[s_type] = capture_graph( + model, cache.inference_params, batch_size, seqlen_og, seqlen, mempool=cache.mempool + ) + + def dispatch(input_ids, position_ids, seqlen): + return cache.callables[seqlen_to_seqlen_type(seqlen)](input_ids, position_ids, seqlen) + + cache.run = dispatch + cache.inference_params.sequence_length_offset = 0 # Reset so it's not confusing + return cache + + +def capture_graph(model, inference_params, batch_size, seqlen_og, max_seqlen, mempool=None): + assert max_seqlen >= seqlen_og device = next(iter(model.parameters())).device - sequence_length_offset_og = inference_params.sequence_len_offset input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device) position_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device) - inference_params.lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, - device=device) + inference_params.lengths_per_sample[:] = seqlen_og - memory_pool = None - for s_type in range(seqlen_to_seqlen_type(seqlen_og), seqlen_to_seqlen_type(max_length) + 1): - seqlen = max(seqlen_og, seqlen_type_to_seqlen(s_type)) - input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device) - position_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device) - inference_params.lengths_per_sample[:] = seqlen - inference_params.sequence_len_offset = seqlen - g = torch.cuda.CUDAGraph() - # Warmup before capture - s = torch.cuda.Stream() - s.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(s): - for _ in range(2): - logits = model(input_ids, position_ids=position_ids, - inference_params=inference_params).logits[:, -1] - torch.cuda.current_stream().wait_stream(s) - # Captures the graph - # To allow capture, automatically sets a side stream as the current stream in the context - with torch.cuda.graph(g, pool=memory_pool): + # Warmup before capture + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(2): logits = model(input_ids, position_ids=position_ids, - inference_params=inference_params).logits[:, -1] - if memory_pool is None: - memory_pool = g.pool() - cg_cache[CgKey(batch_size, s_type, max_length)] = CgVal( - g, input_ids, position_ids, inference_params.lengths_per_sample, logits - ) + inference_params=inference_params).logits[:, -1] + s.synchronize() + torch.cuda.current_stream().wait_stream(s) + # Captures the graph + # To allow capture, automatically sets a side stream as the current stream in the context + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, pool=mempool): + logits = model(input_ids, position_ids=position_ids, + inference_params=inference_params).logits[:, -1] def run(new_input_ids, new_position_ids, seqlen): - cg_val = cg_cache[CgKey(batch_size, seqlen_to_seqlen_type(seqlen), max_length)] - inference_params.lengths_per_sample = cg_val.lengths inference_params.lengths_per_sample[:] = seqlen - cg_val.input_ids.copy_(new_input_ids) - cg_val.position_ids.copy_(new_position_ids) - cg_val.graph.replay() - output = cg_val.logits - return output.clone() if copy_output else output + input_ids.copy_(new_input_ids) + position_ids.copy_(new_position_ids) + graph.replay() + return logits - inference_params.sequence_len_offset = sequence_length_offset_og - - return run, cg_cache + return run diff --git a/flash_attn/utils/pretrained.py b/flash_attn/utils/pretrained.py index c91391a..6732892 100644 --- a/flash_attn/utils/pretrained.py +++ b/flash_attn/utils/pretrained.py @@ -4,5 +4,8 @@ from transformers.utils import WEIGHTS_NAME from transformers.utils.hub import cached_file -def state_dict_from_pretrained(model_name, device=None): - return torch.load(cached_file(model_name, WEIGHTS_NAME), map_location=device) +def state_dict_from_pretrained(model_name, device=None, dtype=None): + state_dict = torch.load(cached_file(model_name, WEIGHTS_NAME), map_location=device) + if dtype is not None: + state_dict = {k: v.to(dtype) for k, v in state_dict.items()} + return state_dict diff --git a/tests/models/test_gpt_generation.py b/tests/models/test_gpt_generation.py index df1ab6a..c4e026e 100644 --- a/tests/models/test_gpt_generation.py +++ b/tests/models/test_gpt_generation.py @@ -1,3 +1,4 @@ +import os import re import torch @@ -11,12 +12,13 @@ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHead from flash_attn.models.gpt import GPTLMHeadModel from flash_attn.models.gpt import remap_state_dict_gpt2 from flash_attn.utils.pretrained import state_dict_from_pretrained +from flash_attn.utils.distributed import all_gather_raw @pytest.mark.parametrize('fused_ft_kernel', [False, True]) # @pytest.mark.parametrize('fused_ft_kernel', [True]) @pytest.mark.parametrize('optimized', [False, True]) -# @pytest.mark.parametrize('optimized', [False]) +# @pytest.mark.parametrize('optimized', [True]) @pytest.mark.parametrize('rotary', [False, True]) # @pytest.mark.parametrize('rotary', [False]) @pytest.mark.parametrize('model_name', ["gpt2"]) @@ -40,19 +42,20 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel): # if not rotary, we load the weight from HF but ignore the position embeddings. # The model would be nonsense but it doesn't matter for the test. - model = GPTLMHeadModel.from_pretrained(model_name, config, strict=not rotary, device=device) - model = model.to(dtype=dtype) + model = GPTLMHeadModel.from_pretrained(model_name, config, strict=not rotary, device=device, + dtype=dtype) model.eval() if not rotary: - model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).cuda() - model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).cuda().to(dtype=dtype) + model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device) + model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device, dtype=dtype) model_ref.eval() model_hf.eval() torch.manual_seed(0) tokenizer = GPT2Tokenizer.from_pretrained("gpt2") - input_ids = tokenizer("Hello, my dog is cute and ", return_tensors="pt").input_ids.cuda() + input_ids = tokenizer("Hello, my dog is cute and ", + return_tensors="pt").input_ids.to(device=device) max_length = 30 # input_ids = torch.randint(0, 100, (1, 10), dtype=torch.long, device='cuda') # max_length = input_ids.shape[1] + 40 @@ -100,3 +103,119 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel): assert torch.all(out.sequences == out_hf.sequences) assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() + + +# Run test with: +# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/models/test_gpt_generation.py -k "parallel" + +# @pytest.mark.parametrize('world_size', [1, 2, 4, 8]) +@pytest.mark.parametrize('world_size', [2]) +# @pytest.mark.parametrize('fused_ft_kernel', [False, True]) +@pytest.mark.parametrize('fused_ft_kernel', [True]) +# @pytest.mark.parametrize('rotary', [False, True]) +@pytest.mark.parametrize('rotary', [False]) +@pytest.mark.parametrize('model_name', ["gpt2"]) +def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size): + """Check that our implementation of GPT2 generation matches the HF implementation: + the scores in fp16 should be around the same as the HF scores in fp16, when compared to + the HF scores in fp32. + """ + dtype = torch.float16 + rtol, atol = 3e-3, 3e-1 + config = GPT2Config.from_pretrained(model_name) + if rotary: + config.n_positions = 0 + config.rotary_emb_dim = 64 + config.use_flash_attn = True + config.fused_bias_fc = True + config.fused_dense_gelu_dense = True + config.fused_dropout_add_ln = True + config.pad_vocab_size_multiple = 8 * world_size + config.sequence_parallel = False # Need to set this to False for generation + + os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0" + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group(backend='nccl', init_method='env://') + device = f'cuda:{torch.distributed.get_rank()}' + assert world_size <= torch.distributed.get_world_size() + # Need this, otherwise when we capture the graph the process for GPU 1 would run on both + # GPU0 and GPU1 and things would hang + torch.cuda.set_device(device) + + from apex.transformer import parallel_state + parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) + rank = parallel_state.get_tensor_model_parallel_rank() + process_group = parallel_state.get_tensor_model_parallel_group() + + # if not rotary, we load the weight from HF but ignore the position embeddings. + # The model would be nonsense but it doesn't matter for the test. + model = GPTLMHeadModel.from_pretrained(model_name, config, strict=not rotary, device=device, + dtype=dtype, process_group=process_group, + world_size=world_size, rank=rank) + model.eval() + + if not rotary: + model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device) + model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device, dtype=dtype) + model_ref.eval() + model_hf.eval() + + torch.manual_seed(0) + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + input_ids = tokenizer("Hello, my dog is cute and ", + return_tensors="pt").input_ids.to(device=device) + max_length = 30 + # input_ids = torch.randint(0, 100, (1, 10), dtype=torch.long, device='cuda') + # max_length = input_ids.shape[1] + 40 + + # Slow generation for reference + sequences = [] + scores = [] + cur_input_ids = input_ids + with torch.inference_mode(): + logits, _ = all_gather_raw(model(cur_input_ids).logits[:, -1], process_group) + logits = rearrange(logits, '(n b) d -> b (n d)', + b=input_ids.shape[0])[..., :config.vocab_size] + scores.append(logits) + sequences.append(scores[-1].argmax(dim=-1)) + for _ in range(input_ids.shape[1] + 1, max_length): + cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], 'b -> b 1')], dim=-1) + logits, _ = all_gather_raw(model(cur_input_ids).logits[:, -1], process_group) + logits = rearrange(logits, '(n b) d -> b (n d)', + b=input_ids.shape[0])[..., :config.vocab_size] + scores.append(logits) + sequences.append(scores[-1].argmax(dim=-1)) + sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1) + scores = tuple(scores) + print(sequences) + + out = model.generate(input_ids=input_ids, max_length=max_length, tensor_parallel=world_size, + vocab_size=config.vocab_size, fused_ft_kernel=fused_ft_kernel, + return_dict_in_generate=True, output_scores=True, timing=True) + print(out.sequences) + if fused_ft_kernel: + out_cg = model.generate( + input_ids=input_ids, max_length=max_length, tensor_parallel=world_size, + vocab_size=config.vocab_size, fused_ft_kernel=fused_ft_kernel, cg=True, + return_dict_in_generate=True, output_scores=True, timing=True) + print(out_cg.sequences) + + if not rotary: + out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length, + return_dict_in_generate=True, output_scores=True) + out_ref = model_ref.generate(input_ids=input_ids, max_length=max_length, + return_dict_in_generate=True, output_scores=True) + + print(f'Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}') + print(f'Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}') + print(f'HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}') + print(f'HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}') + + assert torch.all(out.sequences == sequences) + assert torch.allclose(torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), + rtol=rtol, atol=atol) + if not rotary: + assert torch.all(out.sequences == out_ref.sequences) + assert torch.all(out.sequences == out_hf.sequences) + + assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()