From c000c3a2c080d1bf2809bcc2a68881f7e8e2a212 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 26 Aug 2023 13:00:40 -0700 Subject: [PATCH] [GPT] Move more tests to test_gpt.py --- tests/models/test_gpt.py | 79 +++++++++++++++++++++++ tests/models/test_gpt_generation_cg.py | 89 -------------------------- 2 files changed, 79 insertions(+), 89 deletions(-) delete mode 100644 tests/models/test_gpt_generation_cg.py diff --git a/tests/models/test_gpt.py b/tests/models/test_gpt.py index cbfce7d..7f9beca 100644 --- a/tests/models/test_gpt.py +++ b/tests/models/test_gpt.py @@ -256,3 +256,82 @@ def test_gpt2_generation(model_name, rotary, optimized, fused_ft_kernel): ).abs().max().item() < 3 * ( torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1) ).abs().max().item() + + +def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs): + out = model.generate( + input_ids=input_ids, + max_length=max_length, + fused_ft_kernel=True, + teacher_outputs=teacher_outputs, + return_dict_in_generate=True, + output_scores=True, + timing=True, + **kwargs, + ) + return torch.stack(out.scores, dim=1) + + +@pytest.mark.parametrize("seqlen,maxlen", [(10, 20), (30, 150), (3000, 3400), (14000, 15000)]) +# @pytest.mark.parametrize('seqlen,maxlen', [(10, 20)]) +@pytest.mark.parametrize("rotary", [None, "interleaved", "block"]) +# @pytest.mark.parametrize('rotary', [None]) +@pytest.mark.parametrize("model_name", ["gpt2"]) +def test_gpt2_generation_cg(model_name, rotary, seqlen, maxlen): + """Check that decoding with CUDA graph is the same as decoding without CUDA graph.""" + dtype = torch.float16 + device = "cuda" + rtol, atol = 3e-3, 3e-1 + config = GPT2Config.from_pretrained(model_name) + config.n_positions = 16 * 1024 + assert seqlen <= maxlen <= config.n_positions + if rotary is not None: + config.n_positions = 0 + config.rotary_emb_dim = 32 + config.rotary_emb_interleaved = rotary == "interleaved" + config.residual_in_fp32 = True + config.use_flash_attn = True + config.fused_bias_fc = True + config.fused_mlp = True + config.fused_dropout_add_ln = True + + model = GPTLMHeadModel(config, device=device, dtype=dtype) + model.eval() + + torch.manual_seed(0) + batch_size = 1 + input_ids = torch.randint( + 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device + ) + teacher_outputs = torch.randint( + 0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device + ) + + logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs) + logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True) + assert torch.equal(logits, logits_cg) + + # Try increasing batch size and seqlen, then decrease them to see if it's still correct + batch_size = 3 + maxlen += 30 + input_ids = torch.randint( + 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device + ) + teacher_outputs = torch.randint( + 0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device + ) + logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs) + logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True) + assert torch.equal(logits, logits_cg) + + batch_size = 2 + maxlen -= 35 + input_ids = torch.randint( + 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device + ) + teacher_outputs = torch.randint( + 0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device + ) + logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs) + logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True) + assert torch.equal(logits, logits_cg) diff --git a/tests/models/test_gpt_generation_cg.py b/tests/models/test_gpt_generation_cg.py deleted file mode 100644 index 54c7421..0000000 --- a/tests/models/test_gpt_generation_cg.py +++ /dev/null @@ -1,89 +0,0 @@ -import os -import re -import time - -import pytest -import torch -from einops import rearrange -from flash_attn.models.gpt import GPTLMHeadModel -from flash_attn.utils.generation import update_graph_cache -from transformers import GPT2Config - - -def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs): - out = model.generate( - input_ids=input_ids, - max_length=max_length, - fused_ft_kernel=True, - teacher_outputs=teacher_outputs, - return_dict_in_generate=True, - output_scores=True, - timing=True, - **kwargs, - ) - return torch.stack(out.scores, dim=1) - - -@pytest.mark.parametrize("seqlen,maxlen", [(10, 20), (30, 150), (3000, 3400), (14000, 15000)]) -# @pytest.mark.parametrize('seqlen,maxlen', [(10, 20)]) -@pytest.mark.parametrize("rotary", [None, "interleaved", "block"]) -# @pytest.mark.parametrize('rotary', [None]) -@pytest.mark.parametrize("model_name", ["gpt2"]) -def test_greedy_decode_gpt2_cg(model_name, rotary, seqlen, maxlen): - """Check that decoding with CUDA graph is the same as decoding without CUDA graph.""" - dtype = torch.float16 - device = "cuda" - rtol, atol = 3e-3, 3e-1 - config = GPT2Config.from_pretrained(model_name) - config.n_positions = 16 * 1024 - assert seqlen <= maxlen <= config.n_positions - if rotary is not None: - config.n_positions = 0 - config.rotary_emb_dim = 32 - config.rotary_emb_interleaved = rotary == "interleaved" - config.residual_in_fp32 = True - config.use_flash_attn = True - config.fused_bias_fc = True - config.fused_mlp = True - config.fused_dropout_add_ln = True - - model = GPTLMHeadModel(config, device=device, dtype=dtype) - model.eval() - - torch.manual_seed(0) - batch_size = 1 - input_ids = torch.randint( - 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device - ) - teacher_outputs = torch.randint( - 0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device - ) - - logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs) - logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True) - assert torch.equal(logits, logits_cg) - - # Try increasing batch size and seqlen, then decrease them to see if it's still correct - batch_size = 3 - maxlen += 30 - input_ids = torch.randint( - 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device - ) - teacher_outputs = torch.randint( - 0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device - ) - logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs) - logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True) - assert torch.equal(logits, logits_cg) - - batch_size = 2 - maxlen -= 35 - input_ids = torch.randint( - 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device - ) - teacher_outputs = torch.randint( - 0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device - ) - logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs) - logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True) - assert torch.equal(logits, logits_cg)