diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index 96f8c82..bc227f0 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -495,7 +495,8 @@ class MHA(nn.Module): *inference_params.key_value_memory_dict[self.layer_idx], inference_params.lengths_per_sample, inference_params.sequence_len_offset, self.rotary_emb_dim, - not self.rotary_emb.interleaved # neox_rotary_style + # neox_rotary_style + (not self.rotary_emb.interleaved) if self.rotary_emb_dim > 0 else True ) context = rearrange(context, 'b h d -> b 1 h d') else: @@ -609,7 +610,8 @@ class ParallelMHA(nn.Module): *inference_params.key_value_memory_dict[self.layer_idx], inference_params.lengths_per_sample, inference_params.sequence_len_offset, self.rotary_emb_dim, - not self.rotary_emb.interleaved # neox_rotary_style + # neox_rotary_style + (not self.rotary_emb.interleaved) if self.rotary_emb_dim > 0 else True ) context = rearrange(context, 'b h d -> b 1 h d') if seqlen is None: diff --git a/flash_attn/utils/generation.py b/flash_attn/utils/generation.py index b180e8e..043eaec 100644 --- a/flash_attn/utils/generation.py +++ b/flash_attn/utils/generation.py @@ -190,9 +190,9 @@ def seqlen_to_seqlen_type(seqlen: int) -> int: return 0 if seqlen < 32 else (1 if seqlen < 2048 else 2) -def seqlen_type_to_seqlen(seqlen_type: int) -> int: +def seqlen_type_to_max_seqlen(seqlen_type: int) -> int: assert seqlen_type in [0, 1, 2] - return 1 if seqlen_type == 0 else (32 if seqlen_type == 1 else 2048) + return 32 if seqlen_type == 0 else (2048 if seqlen_type == 1 else 2**32) @dataclass @@ -239,9 +239,9 @@ def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_p cache.mempool = torch.cuda.graphs.graph_pool_handle() for s_type in range(seqlen_to_seqlen_type(seqlen_og), seqlen_to_seqlen_type(max_seqlen) + 1): if s_type not in cache.callables: - seqlen = min(max(seqlen_og, seqlen_type_to_seqlen(s_type)), max_seqlen) + max_seqlen_ = min(max(seqlen_og, seqlen_type_to_max_seqlen(s_type)), max_seqlen) cache.callables[s_type] = capture_graph( - model, cache.inference_params, batch_size, seqlen_og, seqlen, mempool=cache.mempool, + model, cache.inference_params, batch_size, max_seqlen_, mempool=cache.mempool, n_warmups=n_warmups ) @@ -249,17 +249,19 @@ def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_p return cache.callables[seqlen_to_seqlen_type(seqlen)](input_ids, position_ids, seqlen) cache.run = dispatch - cache.inference_params.sequence_length_offset = 0 # Reset so it's not confusing + cache.inference_params.sequence_len_offset = 0 # Reset so it's not confusing return cache -def capture_graph(model, inference_params, batch_size, seqlen_og, max_seqlen, mempool=None, - n_warmups=2): - assert max_seqlen >= seqlen_og +def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None, n_warmups=2): device = next(iter(model.parameters())).device input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device) position_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device) - inference_params.lengths_per_sample[:] = seqlen_og + sequence_len_offset_og = inference_params.sequence_len_offset + # TD [2023-04-14]: important for correctness of the FT's attention kernel, as seqlen_cpu is + # used to determine the size of smem. Hence seqlen_cpu must be >= lengths_per_sample. + inference_params.sequence_len_offset = max_seqlen - 1 + inference_params.lengths_per_sample[:] = max_seqlen - 1 # Warmup before capture s = torch.cuda.Stream() @@ -289,4 +291,5 @@ def capture_graph(model, inference_params, batch_size, seqlen_og, max_seqlen, me graph.replay() return logits + inference_params.sequence_len_offset = sequence_len_offset_og return run diff --git a/tests/models/test_gptj.py b/tests/models/test_gptj.py index 31b9439..735eb88 100644 --- a/tests/models/test_gptj.py +++ b/tests/models/test_gptj.py @@ -1,4 +1,4 @@ -import re +import time import torch import pytest @@ -9,6 +9,7 @@ from transformers.models.gptj.modeling_gptj import GPTJForCausalLM from flash_attn.models.gpt import GPTLMHeadModel from flash_attn.models.gptj import remap_state_dict_hf_gptj, gptj_config_to_gpt2_config from flash_attn.utils.pretrained import state_dict_from_pretrained +from flash_attn.utils.generation import update_graph_cache @pytest.mark.parametrize('model_name', ["EleutherAI/gpt-j-6B"]) @@ -79,3 +80,92 @@ def test_gptj_optimized(model_name): print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}') print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}') assert (logits - logits_ref).abs().max().item() < 3 * (logits_hf - logits_ref).abs().max().item() + + +@pytest.mark.parametrize('model_name', ["EleutherAI/gpt-j-6B"]) +def test_gptj_generation(model_name): + """Check that our implementation of GPT-J (with all optimizations enabled) matches the + HF implementation: the output of our forward pass in fp16 should be around the same as the HF + forward pass in fp16, when compared to the HF forward pass in fp32. + """ + dtype = torch.float16 + device = 'cuda' + config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name)) + config.use_flash_attn = False # FlashAttention doesn't support hdim 256 yet + config.fused_bias_fc = True + config.fused_mlp = True + config.fused_dropout_add_ln = True + # Only prenorm supports residual_in_fp32 + config.residual_in_fp32 = True + + tokenizer = AutoTokenizer.from_pretrained(model_name) + eos_token_id = tokenizer.eos_token_id + + torch.manual_seed(0) + batch_size = 1 + seqlen = 100 + max_length = 150 + input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, + device=device) + + model_hf = GPTJForCausalLM.from_pretrained(model_name, torch_dtype=dtype, + device_map={"": device}) + model_hf.eval() + print("HF fp16") + torch.cuda.synchronize() + start = time.time() + out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length, + return_dict_in_generate=True, output_scores=True) + torch.cuda.synchronize() + print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms') + del model_hf + + model_ref = GPTJForCausalLM.from_pretrained(model_name, device_map={"": device}) + model_ref.eval() + with torch.no_grad(): + logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1):-1] + del model_ref + + model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype) + model.eval() + + print('Without CUDA graph') + torch.cuda.synchronize() + start = time.time() + out = model.generate(input_ids=input_ids, max_length=max_length, + eos_token_id=eos_token_id, fused_ft_kernel=True, + # eos_token_id=eos_token_id, fused_ft_kernel=False, + return_dict_in_generate=True, output_scores=True, timing=True, + teacher_outputs=out_hf.sequences) + torch.cuda.synchronize() + print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms') + + # Capture graph outside the timing loop + batch_size, seqlen_og = input_ids.shape + model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length) + print('With CUDA graph') + torch.cuda.synchronize() + start = time.time() + out_cg = model.generate(input_ids=input_ids, max_length=max_length, + fused_ft_kernel=True, cg=True, + return_dict_in_generate=True, output_scores=True, timing=True, + teacher_outputs=out_hf.sequences) + torch.cuda.synchronize() + print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms') + + with torch.no_grad(): + logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1):-1] + logits_hf = torch.stack(out_hf.scores, dim=1) + logits = torch.stack(out.scores, dim=1) + logits_cg = torch.stack(out_cg.scores, dim=1) + + del model + + hf_error = (logits_hf - logits_ref).abs().max().item() + assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error + + print(f'HF fp16 logits max diff: {hf_error}') + print(f'Logits max diff: {(logits - logits_ref).abs().max().item() }') + assert (logits - logits_ref).abs().max().item() < 2 * hf_error + print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }') + assert torch.equal(logits_cg, logits)