From 371e20658cbd27aa69566e46ae38524809e42290 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 26 Aug 2023 13:56:41 -0700 Subject: [PATCH] [GPT] Test generation when passing in multiple tokens --- tests/models/test_gpt.py | 43 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/tests/models/test_gpt.py b/tests/models/test_gpt.py index 7f9beca..b48352d 100644 --- a/tests/models/test_gpt.py +++ b/tests/models/test_gpt.py @@ -4,6 +4,7 @@ import pytest import torch from einops import rearrange from flash_attn.models.gpt import GPTLMHeadModel, remap_state_dict_hf_gpt2 +from flash_attn.utils.generation import InferenceParams from flash_attn.utils.pretrained import state_dict_from_pretrained from transformers import GPT2Config, GPT2Tokenizer from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF @@ -335,3 +336,45 @@ def test_gpt2_generation_cg(model_name, rotary, seqlen, maxlen): logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs) logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True) assert torch.equal(logits, logits_cg) + + +@pytest.mark.parametrize("optimized", [False, True]) +# @pytest.mark.parametrize("optimized", [False]) +@pytest.mark.parametrize("model_name", ["gpt2"]) +def test_gpt2_multiple_token_generation(model_name, optimized): + """Generation when we pass in multiple tokens at a time, not just one.""" + dtype = torch.float16 + device = "cuda" + rtol, atol = 3e-3, 3e-1 + config = GPT2Config.from_pretrained(model_name) + config.residual_in_fp32 = True + if optimized: + config.use_flash_attn = True + config.fused_bias_fc = True + config.fused_mlp = True + config.fused_dropout_add_ln = True + # fused_ft_kernel currently doesn't work with multiple tokens at a time + + # if not rotary, we load the weight from HF but ignore the position embeddings. + # The model would be nonsense but it doesn't matter for the test. + model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype) + model.eval() + + torch.manual_seed(0) + input_ids = torch.randint(0, config.vocab_size, (1, 20), dtype=torch.long, device=device) + # Reference logits + logits_ref = model(input_ids).logits + + # Run 10 tokens, then pass in another 4, then another 6, to see if we get the same logits + inference_params = InferenceParams(max_sequence_len=20, max_batch_size=1) + logits_10 = model(input_ids[:, :10], inference_params=inference_params).logits + inference_params.sequence_len_offset += 10 + position_ids = torch.arange(10, 14, dtype=torch.long, device=device) + logits_1014 = model(input_ids[:, 10:14], position_ids=position_ids, inference_params=inference_params).logits + inference_params.sequence_len_offset += 4 + position_ids = torch.arange(14, 20, dtype=torch.long, device=device) + logits_1420 = model(input_ids[:, 14:20], position_ids=position_ids, inference_params=inference_params).logits + logits = torch.cat([logits_10, logits_1014, logits_1420], dim=1) + print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}") + print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}") + assert torch.allclose(logits, logits_ref, rtol=rtol, atol=atol)