[Gen] Test generation with rotary embedding
This commit is contained in:
parent
8d9674ed08
commit
11be742aa3
@ -146,15 +146,17 @@ class GPTPreTrainedModel(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, model_name, config, *inputs, **kwargs):
|
def from_pretrained(cls, model_name, config, *args, strict=True, device=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
||||||
Download and cache the pre-trained model file if needed.
|
Download and cache the pre-trained model file if needed.
|
||||||
"""
|
"""
|
||||||
# Instantiate model.
|
# Instantiate model.
|
||||||
model = cls(config, *inputs, **kwargs)
|
model = cls(config, *args, device=device, **kwargs)
|
||||||
load_return = model.load_state_dict(
|
load_return = model.load_state_dict(
|
||||||
remap_state_dict_gpt2(state_dict_from_pretrained(model_name), config))
|
remap_state_dict_gpt2(state_dict_from_pretrained(model_name, device=device), config),
|
||||||
|
strict=strict
|
||||||
|
)
|
||||||
logger.info(load_return)
|
logger.info(load_return)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|||||||
@ -341,7 +341,6 @@ class MHA(nn.Module):
|
|||||||
self.dwconv_qkv = nn.Conv1d(3 * embed_dim, 3 * embed_dim, kernel_size=3, padding=2,
|
self.dwconv_qkv = nn.Conv1d(3 * embed_dim, 3 * embed_dim, kernel_size=3, padding=2,
|
||||||
groups=3 * embed_dim)
|
groups=3 * embed_dim)
|
||||||
else:
|
else:
|
||||||
inner_attn_cls = inner_cross_attn_cls
|
|
||||||
self.Wq = linear_cls(embed_dim, embed_dim, bias=bias, **factory_kwargs)
|
self.Wq = linear_cls(embed_dim, embed_dim, bias=bias, **factory_kwargs)
|
||||||
if not self.return_residual:
|
if not self.return_residual:
|
||||||
self.Wkv = linear_cls(embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs)
|
self.Wkv = linear_cls(embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs)
|
||||||
@ -482,9 +481,9 @@ class MHA(nn.Module):
|
|||||||
'b d s -> b s d').contiguous()
|
'b d s -> b s d').contiguous()
|
||||||
if inference_params is None:
|
if inference_params is None:
|
||||||
if not self.checkpointing:
|
if not self.checkpointing:
|
||||||
context = self.inner_attn(q, kv, **kwargs)
|
context = self.inner_cross_attn(q, kv, **kwargs)
|
||||||
else:
|
else:
|
||||||
context = torch.utils.checkpoint.checkpoint(self.inner_attn, q, kv, **kwargs)
|
context = torch.utils.checkpoint.checkpoint(self.inner_cross_attn, q, kv, **kwargs)
|
||||||
else:
|
else:
|
||||||
kv = self._update_kv_cache(kv)
|
kv = self._update_kv_cache(kv)
|
||||||
context = self.inner_cross_attn(q, kv, causal=False)
|
context = self.inner_cross_attn(q, kv, causal=False)
|
||||||
|
|||||||
@ -4,5 +4,5 @@ from transformers.utils import WEIGHTS_NAME
|
|||||||
from transformers.utils.hub import cached_file
|
from transformers.utils.hub import cached_file
|
||||||
|
|
||||||
|
|
||||||
def state_dict_from_pretrained(model_name):
|
def state_dict_from_pretrained(model_name, device=None):
|
||||||
return torch.load(cached_file(model_name, WEIGHTS_NAME))
|
return torch.load(cached_file(model_name, WEIGHTS_NAME), map_location=device)
|
||||||
|
|||||||
@ -14,32 +14,40 @@ from flash_attn.utils.pretrained import state_dict_from_pretrained
|
|||||||
from flash_attn.utils.generation import greedy_decode
|
from flash_attn.utils.generation import greedy_decode
|
||||||
|
|
||||||
|
|
||||||
# TODO: test with rotary embedding
|
|
||||||
@pytest.mark.parametrize('fused_ft_kernel', [False, True])
|
@pytest.mark.parametrize('fused_ft_kernel', [False, True])
|
||||||
@pytest.mark.parametrize('optimized', [False, True])
|
@pytest.mark.parametrize('optimized', [False, True])
|
||||||
|
# @pytest.mark.parametrize('fused_ft_kernel', [False])
|
||||||
# @pytest.mark.parametrize('optimized', [True])
|
# @pytest.mark.parametrize('optimized', [True])
|
||||||
|
# @pytest.mark.parametrize('optimized', [True])
|
||||||
|
@pytest.mark.parametrize('rotary', [False, True])
|
||||||
@pytest.mark.parametrize('model_name', ["gpt2"])
|
@pytest.mark.parametrize('model_name', ["gpt2"])
|
||||||
def test_greedy_decode(model_name, optimized, fused_ft_kernel):
|
def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
|
||||||
"""Check that our implementation of GPT2 generation matches the HF implementation:
|
"""Check that our implementation of GPT2 generation matches the HF implementation:
|
||||||
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
|
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
|
||||||
the HF scores in fp32.
|
the HF scores in fp32.
|
||||||
"""
|
"""
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
|
device = 'cuda'
|
||||||
rtol, atol = 3e-3, 3e-1
|
rtol, atol = 3e-3, 3e-1
|
||||||
config = GPT2Config.from_pretrained(model_name)
|
config = GPT2Config.from_pretrained(model_name)
|
||||||
|
if rotary:
|
||||||
|
config.n_positions = 0
|
||||||
|
config.rotary_emb_dim = 64
|
||||||
if optimized:
|
if optimized:
|
||||||
config.use_flash_attn = True
|
config.use_flash_attn = True
|
||||||
config.fused_bias_fc = True
|
config.fused_bias_fc = True
|
||||||
config.fused_dense_gelu_dense = True
|
config.fused_dense_gelu_dense = True
|
||||||
config.fused_dropout_add_ln = True
|
config.fused_dropout_add_ln = True
|
||||||
|
|
||||||
model = GPTLMHeadModel.from_pretrained(model_name, config)
|
# if not rotary, we load the weight from HF but ignore the position embeddings.
|
||||||
model = model.cuda().to(dtype=dtype)
|
# The model would be nonsense but it doesn't matter for the test.
|
||||||
|
model = GPTLMHeadModel.from_pretrained(model_name, config, strict=not rotary, device=device)
|
||||||
|
model = model.to(dtype=dtype)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
if not rotary:
|
||||||
model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).cuda()
|
model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).cuda()
|
||||||
model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).cuda().to(dtype=dtype)
|
model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).cuda().to(dtype=dtype)
|
||||||
|
|
||||||
model.eval()
|
|
||||||
model_ref.eval()
|
model_ref.eval()
|
||||||
model_hf.eval()
|
model_hf.eval()
|
||||||
|
|
||||||
@ -47,6 +55,8 @@ def test_greedy_decode(model_name, optimized, fused_ft_kernel):
|
|||||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||||
input_ids = tokenizer("Hello, my dog is cute and ", return_tensors="pt").input_ids.cuda()
|
input_ids = tokenizer("Hello, my dog is cute and ", return_tensors="pt").input_ids.cuda()
|
||||||
max_length = 30
|
max_length = 30
|
||||||
|
# input_ids = torch.randint(0, 100, (1, 512), dtype=torch.long, device='cuda')
|
||||||
|
# max_length = 512 + 50
|
||||||
|
|
||||||
# Slow generation for reference
|
# Slow generation for reference
|
||||||
sequences = []
|
sequences = []
|
||||||
@ -66,6 +76,7 @@ def test_greedy_decode(model_name, optimized, fused_ft_kernel):
|
|||||||
fused_ft_kernel=fused_ft_kernel,
|
fused_ft_kernel=fused_ft_kernel,
|
||||||
return_dict_in_generate=True, output_scores=True)
|
return_dict_in_generate=True, output_scores=True)
|
||||||
|
|
||||||
|
if not rotary:
|
||||||
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length,
|
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length,
|
||||||
return_dict_in_generate=True, output_scores=True)
|
return_dict_in_generate=True, output_scores=True)
|
||||||
out_ref = model_ref.generate(input_ids=input_ids, max_length=max_length,
|
out_ref = model_ref.generate(input_ids=input_ids, max_length=max_length,
|
||||||
@ -79,6 +90,7 @@ def test_greedy_decode(model_name, optimized, fused_ft_kernel):
|
|||||||
assert torch.all(out.sequences == sequences)
|
assert torch.all(out.sequences == sequences)
|
||||||
assert torch.allclose(torch.stack(out.scores, dim=1), torch.stack(scores, dim=1),
|
assert torch.allclose(torch.stack(out.scores, dim=1), torch.stack(scores, dim=1),
|
||||||
rtol=rtol, atol=atol)
|
rtol=rtol, atol=atol)
|
||||||
|
if not rotary:
|
||||||
assert torch.all(out.sequences == out_ref.sequences)
|
assert torch.all(out.sequences == out_ref.sequences)
|
||||||
assert torch.all(out.sequences == out_hf.sequences)
|
assert torch.all(out.sequences == out_hf.sequences)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user