[GPT] Test generation when passing in multiple tokens
This commit is contained in:
parent
c000c3a2c0
commit
371e20658c
@ -4,6 +4,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from flash_attn.models.gpt import GPTLMHeadModel, remap_state_dict_hf_gpt2
|
from flash_attn.models.gpt import GPTLMHeadModel, remap_state_dict_hf_gpt2
|
||||||
|
from flash_attn.utils.generation import InferenceParams
|
||||||
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
||||||
from transformers import GPT2Config, GPT2Tokenizer
|
from transformers import GPT2Config, GPT2Tokenizer
|
||||||
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF
|
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF
|
||||||
@ -335,3 +336,45 @@ def test_gpt2_generation_cg(model_name, rotary, seqlen, maxlen):
|
|||||||
logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs)
|
logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs)
|
||||||
logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True)
|
logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True)
|
||||||
assert torch.equal(logits, logits_cg)
|
assert torch.equal(logits, logits_cg)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("optimized", [False, True])
|
||||||
|
# @pytest.mark.parametrize("optimized", [False])
|
||||||
|
@pytest.mark.parametrize("model_name", ["gpt2"])
|
||||||
|
def test_gpt2_multiple_token_generation(model_name, optimized):
|
||||||
|
"""Generation when we pass in multiple tokens at a time, not just one."""
|
||||||
|
dtype = torch.float16
|
||||||
|
device = "cuda"
|
||||||
|
rtol, atol = 3e-3, 3e-1
|
||||||
|
config = GPT2Config.from_pretrained(model_name)
|
||||||
|
config.residual_in_fp32 = True
|
||||||
|
if optimized:
|
||||||
|
config.use_flash_attn = True
|
||||||
|
config.fused_bias_fc = True
|
||||||
|
config.fused_mlp = True
|
||||||
|
config.fused_dropout_add_ln = True
|
||||||
|
# fused_ft_kernel currently doesn't work with multiple tokens at a time
|
||||||
|
|
||||||
|
# if not rotary, we load the weight from HF but ignore the position embeddings.
|
||||||
|
# The model would be nonsense but it doesn't matter for the test.
|
||||||
|
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
input_ids = torch.randint(0, config.vocab_size, (1, 20), dtype=torch.long, device=device)
|
||||||
|
# Reference logits
|
||||||
|
logits_ref = model(input_ids).logits
|
||||||
|
|
||||||
|
# Run 10 tokens, then pass in another 4, then another 6, to see if we get the same logits
|
||||||
|
inference_params = InferenceParams(max_sequence_len=20, max_batch_size=1)
|
||||||
|
logits_10 = model(input_ids[:, :10], inference_params=inference_params).logits
|
||||||
|
inference_params.sequence_len_offset += 10
|
||||||
|
position_ids = torch.arange(10, 14, dtype=torch.long, device=device)
|
||||||
|
logits_1014 = model(input_ids[:, 10:14], position_ids=position_ids, inference_params=inference_params).logits
|
||||||
|
inference_params.sequence_len_offset += 4
|
||||||
|
position_ids = torch.arange(14, 20, dtype=torch.long, device=device)
|
||||||
|
logits_1420 = model(input_ids[:, 14:20], position_ids=position_ids, inference_params=inference_params).logits
|
||||||
|
logits = torch.cat([logits_10, logits_1014, logits_1420], dim=1)
|
||||||
|
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
|
||||||
|
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
|
||||||
|
assert torch.allclose(logits, logits_ref, rtol=rtol, atol=atol)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user