From 311d6606bf82cece56b600f7f500b4c41d88e27d Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Thu, 20 Apr 2023 16:26:12 -0700 Subject: [PATCH] [Gen] Fix FT kernel smem size, CG when batch size changed --- .../decoder_masked_multihead_attention.cu | 15 ++-- flash_attn/modules/mha.py | 20 +++-- flash_attn/utils/generation.py | 7 +- tests/models/test_gpt_generation_cg.py | 80 +++++++++++++++++++ 4 files changed, 106 insertions(+), 16 deletions(-) create mode 100644 tests/models/test_gpt_generation_cg.py diff --git a/csrc/ft_attention/decoder_masked_multihead_attention.cu b/csrc/ft_attention/decoder_masked_multihead_attention.cu index 406cbb5..5b5966a 100644 --- a/csrc/ft_attention/decoder_masked_multihead_attention.cu +++ b/csrc/ft_attention/decoder_masked_multihead_attention.cu @@ -29,14 +29,13 @@ #define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, stream) \ size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ - dim3 grid(params.num_heads, params.batch_size); \ - mmha::masked_multihead_attention_kernel<<>>(params) + auto kernel = mmha::masked_multihead_attention_kernel; \ + if (smem_sz >= 48 * 1024) { \ + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \ + } \ + dim3 grid(params.num_heads, params.batch_size); \ + kernel<<>>(params) //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index 528bb3c..f4d7d64 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -490,10 +490,15 @@ class MHA(nn.Module): else: assert inference_params.fused_ft_kernel assert ft_attention is not None + batch_start = inference_params.batch_size_offset + batch_end = batch_start + qkv.shape[0] + k_cache, v_cache = inference_params.key_value_memory_dict[self.layer_idx] + lengths_per_sample = (inference_params.lengths_per_sample[batch_start:batch_end] + if inference_params.lengths_per_sample is not None else None) context = ft_attention.single_query_attention( *rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1), - *inference_params.key_value_memory_dict[self.layer_idx], - inference_params.lengths_per_sample, inference_params.sequence_len_offset, + k_cache[batch_start:batch_end], v_cache[batch_start:batch_end], + lengths_per_sample, inference_params.sequence_len_offset, self.rotary_emb_dim, # neox_rotary_style (not self.rotary_emb.interleaved) if self.rotary_emb_dim > 0 else True @@ -605,11 +610,16 @@ class ParallelMHA(nn.Module): else: assert inference_params.fused_ft_kernel assert ft_attention is not None + batch_start = inference_params.batch_size_offset + batch_end = batch_start + qkv.shape[0] + k_cache, v_cache = inference_params.key_value_memory_dict[self.layer_idx] + lengths_per_sample = (inference_params.lengths_per_sample[batch_start:batch_end] + if inference_params.lengths_per_sample is not None else None) context = ft_attention.single_query_attention( *rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1), - *inference_params.key_value_memory_dict[self.layer_idx], - inference_params.lengths_per_sample, inference_params.sequence_len_offset, - self.rotary_emb_dim, + k_cache[batch_start:batch_end], v_cache[batch_start:batch_end], + lengths_per_sample, inference_params.sequence_len_offset, + self.rotary_emb_dim, inference_params.sequence_len_offset, # neox_rotary_style (not self.rotary_emb.interleaved) if self.rotary_emb_dim > 0 else True ) diff --git a/flash_attn/utils/generation.py b/flash_attn/utils/generation.py index 043eaec..ce4367b 100644 --- a/flash_attn/utils/generation.py +++ b/flash_attn/utils/generation.py @@ -238,15 +238,16 @@ def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_p ) cache.mempool = torch.cuda.graphs.graph_pool_handle() for s_type in range(seqlen_to_seqlen_type(seqlen_og), seqlen_to_seqlen_type(max_seqlen) + 1): - if s_type not in cache.callables: + if (batch_size, s_type) not in cache.callables: max_seqlen_ = min(max(seqlen_og, seqlen_type_to_max_seqlen(s_type)), max_seqlen) - cache.callables[s_type] = capture_graph( + cache.callables[batch_size, s_type] = capture_graph( model, cache.inference_params, batch_size, max_seqlen_, mempool=cache.mempool, n_warmups=n_warmups ) def dispatch(input_ids, position_ids, seqlen): - return cache.callables[seqlen_to_seqlen_type(seqlen)](input_ids, position_ids, seqlen) + batch_size = input_ids.shape[0] + return cache.callables[batch_size, seqlen_to_seqlen_type(seqlen)](input_ids, position_ids, seqlen) cache.run = dispatch cache.inference_params.sequence_len_offset = 0 # Reset so it's not confusing diff --git a/tests/models/test_gpt_generation_cg.py b/tests/models/test_gpt_generation_cg.py new file mode 100644 index 0000000..e0b0b82 --- /dev/null +++ b/tests/models/test_gpt_generation_cg.py @@ -0,0 +1,80 @@ +import os +import re +import time + +import torch +import pytest + +from einops import rearrange + +from transformers import GPT2Config + +from flash_attn.models.gpt import GPTLMHeadModel +from flash_attn.utils.generation import update_graph_cache + + +def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs): + out = model.generate(input_ids=input_ids, max_length=max_length, fused_ft_kernel=True, + teacher_outputs=teacher_outputs, return_dict_in_generate=True, + output_scores=True, timing=True, **kwargs) + return torch.stack(out.scores, dim=1) + + +@pytest.mark.parametrize('seqlen,maxlen', [(10, 20), (30, 150), (3000, 3400), (14000, 15000)]) +# @pytest.mark.parametrize('seqlen,maxlen', [(10, 20)]) +@pytest.mark.parametrize('rotary', [None, "interleaved", "block"]) +# @pytest.mark.parametrize('rotary', [None]) +@pytest.mark.parametrize('model_name', ["gpt2"]) +def test_greedy_decode_gpt2_cg(model_name, rotary, seqlen, maxlen): + """Check that decoding with CUDA graph is the same as decoding without CUDA graph. + """ + dtype = torch.float16 + device = 'cuda' + rtol, atol = 3e-3, 3e-1 + config = GPT2Config.from_pretrained(model_name) + config.n_positions = 16 * 1024 + assert seqlen <= maxlen <= config.n_positions + if rotary is not None: + config.n_positions = 0 + config.rotary_emb_dim = 32 + config.rotary_emb_interleaved = rotary == "interleaved" + config.residual_in_fp32 = True + config.use_flash_attn = True + config.fused_bias_fc = True + config.fused_mlp = True + config.fused_dropout_add_ln = True + + model = GPTLMHeadModel(config, device=device, dtype=dtype) + model.eval() + + torch.manual_seed(0) + batch_size = 1 + input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, + device=device) + teacher_outputs = torch.randint(0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, + device=device) + + logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs) + logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True) + assert torch.equal(logits, logits_cg) + + # Try increasing batch size and seqlen, then decrease them to see if it's still correct + batch_size = 3 + maxlen += 30 + input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, + device=device) + teacher_outputs = torch.randint(0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, + device=device) + logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs) + logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True) + assert torch.equal(logits, logits_cg) + + batch_size = 2 + maxlen -= 35 + input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, + device=device) + teacher_outputs = torch.randint(0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, + device=device) + logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs) + logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True) + assert torch.equal(logits, logits_cg)