[Gen] Add OPT to generation test

This commit is contained in:
Tri Dao 2023-01-17 19:59:06 -08:00
parent 88173a1aaf
commit f68d41ec77
4 changed files with 164 additions and 11 deletions

View File

@ -71,7 +71,8 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
vocab_size=None, tensor_parallel=1, fused_ft_kernel=False, cg=False, timing=False): eos_token_id=None, vocab_size=None, tensor_parallel=1, fused_ft_kernel=False,
cg=False, timing=False):
"""Decoding, either greedy or with top-k or top-p sampling. """Decoding, either greedy or with top-k or top-p sampling.
If top-k = 0, don't limit the number of candidates (pure sampling). If top-k = 0, don't limit the number of candidates (pure sampling).
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first, Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
@ -104,14 +105,15 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
scores = [] scores = []
with torch.inference_mode(): with torch.inference_mode():
logits = model(input_ids, inference_params=inference_params).logits[:, -1] logits = model(input_ids, inference_params=inference_params).logits[:, -1]
if timing:
torch.cuda.synchronize()
start = time.time()
if vocab_size is not None: if vocab_size is not None:
logits = logits[..., :vocab_size] logits = logits[..., :vocab_size]
scores.append(logits) scores.append(logits)
next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature) next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
sequences = [next_token] sequences = [next_token]
inference_params.sequence_len_offset = seqlen_og inference_params.sequence_len_offset = seqlen_og
if timing:
start = time.time()
while True: while True:
position_ids = torch.full((batch_size, 1), inference_params.sequence_len_offset, position_ids = torch.full((batch_size, 1), inference_params.sequence_len_offset,
dtype=torch.long, device=input_ids.device) dtype=torch.long, device=input_ids.device)
@ -127,11 +129,13 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
next_token = sample(logits, top_k=top_k, temperature=temperature) next_token = sample(logits, top_k=top_k, temperature=temperature)
sequences.append(next_token) sequences.append(next_token)
inference_params.sequence_len_offset += 1 inference_params.sequence_len_offset += 1
if eos_token_id is not None and (next_token == eos_token_id).all():
break
if inference_params.sequence_len_offset >= max_length - 1: if inference_params.sequence_len_offset >= max_length - 1:
break break
if timing: if timing:
torch.cuda.synchronize() torch.cuda.synchronize()
print(f'Decoding time: {time.time() - start}') print(f'Decoding time: {(time.time() - start) * 1000:.0f}ms')
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
return output_cls( return output_cls(
sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1), sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1),

View File

@ -1,10 +1,32 @@
import torch import torch
from transformers.utils import WEIGHTS_NAME from transformers.utils import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
from transformers.utils.hub import cached_file from transformers.utils import is_remote_url
from transformers.modeling_utils import load_state_dict
from transformers.utils.hub import cached_file, get_checkpoint_shard_files
def state_dict_from_pretrained(model_name, device=None, dtype=None): def state_dict_from_pretrained(model_name, device=None, dtype=None):
is_sharded = False
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME,
_raise_exceptions_for_missing_entries=False)
if resolved_archive_file is None:
resolved_archive_file = cached_file(model_name, WEIGHTS_INDEX_NAME,
_raise_exceptions_for_missing_entries=False)
if resolved_archive_file is not None:
is_sharded = True
if resolved_archive_file is None:
raise EnvironmentError(f"Model name {model_name} was not found.")
if is_sharded:
# resolved_archive_file becomes a list of files that point to the different
# checkpoint shards in this case.
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
model_name, resolved_archive_file
)
state_dict = {}
for sharded_file in resolved_archive_file:
state_dict.update(torch.load(sharded_file, map_location=device))
else:
state_dict = torch.load(cached_file(model_name, WEIGHTS_NAME), map_location=device) state_dict = torch.load(cached_file(model_name, WEIGHTS_NAME), map_location=device)
if dtype is not None: if dtype is not None:
state_dict = {k: v.to(dtype) for k, v in state_dict.items()} state_dict = {k: v.to(dtype) for k, v in state_dict.items()}

View File

@ -1,18 +1,22 @@
import os import os
import re import re
import time
import torch import torch
import pytest import pytest
from einops import rearrange from einops import rearrange
from transformers import GPT2Config, GPT2Tokenizer from transformers import GPT2Config, GPT2Tokenizer, OPTConfig, AutoTokenizer
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF
from transformers.models.opt.modeling_opt import OPTForCausalLM
from flash_attn.models.gpt import GPTLMHeadModel from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gpt import remap_state_dict_gpt2 from flash_attn.models.gpt import remap_state_dict_gpt2
from flash_attn.models.opt import remap_state_dict_opt, opt_config_to_gpt2_config
from flash_attn.utils.pretrained import state_dict_from_pretrained from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.distributed import all_gather_raw from flash_attn.utils.distributed import all_gather_raw
from flash_attn.utils.generation import update_graph_cache
@pytest.mark.parametrize('fused_ft_kernel', [False, True]) @pytest.mark.parametrize('fused_ft_kernel', [False, True])
@ -22,7 +26,7 @@ from flash_attn.utils.distributed import all_gather_raw
@pytest.mark.parametrize('rotary', [False, True]) @pytest.mark.parametrize('rotary', [False, True])
# @pytest.mark.parametrize('rotary', [False]) # @pytest.mark.parametrize('rotary', [False])
@pytest.mark.parametrize('model_name', ["gpt2"]) @pytest.mark.parametrize('model_name', ["gpt2"])
def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel): def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel):
"""Check that our implementation of GPT2 generation matches the HF implementation: """Check that our implementation of GPT2 generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32. the HF scores in fp32.
@ -49,7 +53,8 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
if not rotary: if not rotary:
model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device) model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device)
model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device, dtype=dtype) model_hf = GPT2LMHeadModelHF.from_pretrained(model_name,
torch_dtype=dtype).to(device=device)
model_ref.eval() model_ref.eval()
model_hf.eval() model_hf.eval()
@ -106,3 +111,124 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
assert torch.all(out.sequences == out_hf.sequences) assert torch.all(out.sequences == out_hf.sequences)
assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()
@pytest.mark.parametrize('model_name', ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b", "facebook/opt-2.7b", "facebook/opt-6.7b"])
# @pytest.mark.parametrize('model_name', ["facebook/opt-6.7b"])
def test_greedy_decode_opt(model_name):
"""Check that our implementation of OPT generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
"""
print(f'\nMODEL: {model_name}')
verbose = False
dtype = torch.float16
device = 'cuda'
rtol, atol = 3e-3, 3e-1
fused_ft_kernel = True
config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name))
# Only prenorm supports residual_in_fp32
config.residual_in_fp32 = getattr(config, 'prenorm', True)
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = True
config.fused_dropout_add_ln = True
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
model.eval()
torch.manual_seed(0)
# OPT tokenizer requires use_fast=False
# https://huggingface.co/docs/transformers/model_doc/opt
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
eos_token_id = tokenizer.eos_token_id
input_ids = tokenizer("Hello, my dog is cute and",
return_tensors="pt").input_ids.to(device=device)
max_length = 30
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40
# Slow generation for reference
sequences = []
scores = []
cur_input_ids = input_ids
with torch.inference_mode():
scores.append(model(cur_input_ids).logits[:, -1])
sequences.append(scores[-1].argmax(dim=-1))
for _ in range(input_ids.shape[1] + 1, max_length):
cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], 'b -> b 1')], dim=-1)
scores.append(model(cur_input_ids).logits[:, -1])
sequences.append(scores[-1].argmax(dim=-1))
if eos_token_id is not None and (sequences[-1] == eos_token_id).all():
break
sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1)
scores = tuple(scores)
print('Without CUDA graph')
torch.cuda.synchronize()
start = time.time()
out = model.generate(input_ids=input_ids, max_length=max_length,
eos_token_id=eos_token_id, fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True, output_scores=True, timing=True)
torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
if verbose:
print(out.sequences)
print(tokenizer.batch_decode(out.sequences.tolist()))
if fused_ft_kernel:
# Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(
model, None, batch_size, seqlen_og, max_length
)
print('With CUDA graph')
torch.cuda.synchronize()
start = time.time()
out_cg = model.generate(input_ids=input_ids, max_length=max_length,
fused_ft_kernel=fused_ft_kernel, cg=True,
return_dict_in_generate=True, output_scores=True, timing=True)
torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
if verbose:
print(out_cg.sequences)
print(tokenizer.batch_decode(out.sequences.tolist()))
del model
model_hf = OPTForCausalLM.from_pretrained(model_name, torch_dtype=dtype).to(device=device)
model_hf.eval()
print("HF fp16")
torch.cuda.synchronize()
start = time.time()
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length,
return_dict_in_generate=True, output_scores=True)
torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
del model_hf
model_ref = OPTForCausalLM.from_pretrained(model_name).to(device=device)
model_ref.eval()
print("HF fp32")
torch.cuda.synchronize()
start = time.time()
out_ref = model_ref.generate(input_ids=input_ids, max_length=max_length,
return_dict_in_generate=True, output_scores=True)
torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
del model_ref
print(tokenizer.batch_decode(out_ref.sequences.tolist()))
if verbose:
print(f'Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
print(f'Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
print(f'HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
print(f'HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
assert torch.all(out.sequences == sequences)
assert torch.allclose(torch.stack(out.scores, dim=1), torch.stack(scores, dim=1),
rtol=rtol, atol=atol)
assert torch.all(out.sequences == out_ref.sequences)
assert torch.all(out.sequences == out_hf.sequences)
assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()

View File

@ -35,6 +35,7 @@ def test_opt_optimized(model_name):
config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name)) config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name))
config.use_flash_attn = True config.use_flash_attn = True
config.fused_bias_fc = True config.fused_bias_fc = True
config.fused_mlp = True
config.fused_dropout_add_ln = True config.fused_dropout_add_ln = True
# Only prenorm supports residual_in_fp32 # Only prenorm supports residual_in_fp32
config.residual_in_fp32 = getattr(config, 'prenorm', True) config.residual_in_fp32 = getattr(config, 'prenorm', True)