[Gen] Test generation with rotary embedding

This commit is contained in:
Tri Dao 2023-01-07 14:33:54 -08:00
parent 8d9674ed08
commit 11be742aa3
4 changed files with 42 additions and 29 deletions

View File

@ -146,15 +146,17 @@ class GPTPreTrainedModel(nn.Module):
self.config = config self.config = config
@classmethod @classmethod
def from_pretrained(cls, model_name, config, *inputs, **kwargs): def from_pretrained(cls, model_name, config, *args, strict=True, device=None, **kwargs):
""" """
Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict. Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed. Download and cache the pre-trained model file if needed.
""" """
# Instantiate model. # Instantiate model.
model = cls(config, *inputs, **kwargs) model = cls(config, *args, device=device, **kwargs)
load_return = model.load_state_dict( load_return = model.load_state_dict(
remap_state_dict_gpt2(state_dict_from_pretrained(model_name), config)) remap_state_dict_gpt2(state_dict_from_pretrained(model_name, device=device), config),
strict=strict
)
logger.info(load_return) logger.info(load_return)
return model return model

View File

@ -341,7 +341,6 @@ class MHA(nn.Module):
self.dwconv_qkv = nn.Conv1d(3 * embed_dim, 3 * embed_dim, kernel_size=3, padding=2, self.dwconv_qkv = nn.Conv1d(3 * embed_dim, 3 * embed_dim, kernel_size=3, padding=2,
groups=3 * embed_dim) groups=3 * embed_dim)
else: else:
inner_attn_cls = inner_cross_attn_cls
self.Wq = linear_cls(embed_dim, embed_dim, bias=bias, **factory_kwargs) self.Wq = linear_cls(embed_dim, embed_dim, bias=bias, **factory_kwargs)
if not self.return_residual: if not self.return_residual:
self.Wkv = linear_cls(embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs) self.Wkv = linear_cls(embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs)
@ -482,9 +481,9 @@ class MHA(nn.Module):
'b d s -> b s d').contiguous() 'b d s -> b s d').contiguous()
if inference_params is None: if inference_params is None:
if not self.checkpointing: if not self.checkpointing:
context = self.inner_attn(q, kv, **kwargs) context = self.inner_cross_attn(q, kv, **kwargs)
else: else:
context = torch.utils.checkpoint.checkpoint(self.inner_attn, q, kv, **kwargs) context = torch.utils.checkpoint.checkpoint(self.inner_cross_attn, q, kv, **kwargs)
else: else:
kv = self._update_kv_cache(kv) kv = self._update_kv_cache(kv)
context = self.inner_cross_attn(q, kv, causal=False) context = self.inner_cross_attn(q, kv, causal=False)

View File

@ -4,5 +4,5 @@ from transformers.utils import WEIGHTS_NAME
from transformers.utils.hub import cached_file from transformers.utils.hub import cached_file
def state_dict_from_pretrained(model_name): def state_dict_from_pretrained(model_name, device=None):
return torch.load(cached_file(model_name, WEIGHTS_NAME)) return torch.load(cached_file(model_name, WEIGHTS_NAME), map_location=device)

View File

@ -14,39 +14,49 @@ from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.generation import greedy_decode from flash_attn.utils.generation import greedy_decode
# TODO: test with rotary embedding
@pytest.mark.parametrize('fused_ft_kernel', [False, True]) @pytest.mark.parametrize('fused_ft_kernel', [False, True])
@pytest.mark.parametrize('optimized', [False, True]) @pytest.mark.parametrize('optimized', [False, True])
# @pytest.mark.parametrize('fused_ft_kernel', [False])
# @pytest.mark.parametrize('optimized', [True]) # @pytest.mark.parametrize('optimized', [True])
# @pytest.mark.parametrize('optimized', [True])
@pytest.mark.parametrize('rotary', [False, True])
@pytest.mark.parametrize('model_name', ["gpt2"]) @pytest.mark.parametrize('model_name', ["gpt2"])
def test_greedy_decode(model_name, optimized, fused_ft_kernel): def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
"""Check that our implementation of GPT2 generation matches the HF implementation: """Check that our implementation of GPT2 generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32. the HF scores in fp32.
""" """
dtype = torch.float16 dtype = torch.float16
device = 'cuda'
rtol, atol = 3e-3, 3e-1 rtol, atol = 3e-3, 3e-1
config = GPT2Config.from_pretrained(model_name) config = GPT2Config.from_pretrained(model_name)
if rotary:
config.n_positions = 0
config.rotary_emb_dim = 64
if optimized: if optimized:
config.use_flash_attn = True config.use_flash_attn = True
config.fused_bias_fc = True config.fused_bias_fc = True
config.fused_dense_gelu_dense = True config.fused_dense_gelu_dense = True
config.fused_dropout_add_ln = True config.fused_dropout_add_ln = True
model = GPTLMHeadModel.from_pretrained(model_name, config) # if not rotary, we load the weight from HF but ignore the position embeddings.
model = model.cuda().to(dtype=dtype) # The model would be nonsense but it doesn't matter for the test.
model = GPTLMHeadModel.from_pretrained(model_name, config, strict=not rotary, device=device)
model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).cuda() model = model.to(dtype=dtype)
model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).cuda().to(dtype=dtype)
model.eval() model.eval()
model_ref.eval()
model_hf.eval() if not rotary:
model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).cuda()
model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).cuda().to(dtype=dtype)
model_ref.eval()
model_hf.eval()
torch.manual_seed(0) torch.manual_seed(0)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
input_ids = tokenizer("Hello, my dog is cute and ", return_tensors="pt").input_ids.cuda() input_ids = tokenizer("Hello, my dog is cute and ", return_tensors="pt").input_ids.cuda()
max_length = 30 max_length = 30
# input_ids = torch.randint(0, 100, (1, 512), dtype=torch.long, device='cuda')
# max_length = 512 + 50
# Slow generation for reference # Slow generation for reference
sequences = [] sequences = []
@ -66,20 +76,22 @@ def test_greedy_decode(model_name, optimized, fused_ft_kernel):
fused_ft_kernel=fused_ft_kernel, fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True, output_scores=True) return_dict_in_generate=True, output_scores=True)
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length, if not rotary:
return_dict_in_generate=True, output_scores=True) out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length,
out_ref = model_ref.generate(input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True)
return_dict_in_generate=True, output_scores=True) out_ref = model_ref.generate(input_ids=input_ids, max_length=max_length,
return_dict_in_generate=True, output_scores=True)
print(f'Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}') print(f'Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
print(f'Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}') print(f'Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
print(f'HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}') print(f'HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
print(f'HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}') print(f'HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
assert torch.all(out.sequences == sequences) assert torch.all(out.sequences == sequences)
assert torch.allclose(torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), assert torch.allclose(torch.stack(out.scores, dim=1), torch.stack(scores, dim=1),
rtol=rtol, atol=atol) rtol=rtol, atol=atol)
assert torch.all(out.sequences == out_ref.sequences) if not rotary:
assert torch.all(out.sequences == out_hf.sequences) assert torch.all(out.sequences == out_ref.sequences)
assert torch.all(out.sequences == out_hf.sequences)
assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()