From f68d41ec7768db6c5578ca8718d081de36fa1246 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 17 Jan 2023 19:59:06 -0800 Subject: [PATCH] [Gen] Add OPT to generation test --- flash_attn/utils/generation.py | 12 ++- flash_attn/utils/pretrained.py | 28 +++++- tests/models/test_gpt_generation.py | 134 +++++++++++++++++++++++++++- tests/models/test_opt.py | 1 + 4 files changed, 164 insertions(+), 11 deletions(-) diff --git a/flash_attn/utils/generation.py b/flash_attn/utils/generation.py index 297496c..f264f4c 100644 --- a/flash_attn/utils/generation.py +++ b/flash_attn/utils/generation.py @@ -71,7 +71,8 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0): def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, - vocab_size=None, tensor_parallel=1, fused_ft_kernel=False, cg=False, timing=False): + eos_token_id=None, vocab_size=None, tensor_parallel=1, fused_ft_kernel=False, + cg=False, timing=False): """Decoding, either greedy or with top-k or top-p sampling. If top-k = 0, don't limit the number of candidates (pure sampling). Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first, @@ -104,14 +105,15 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, scores = [] with torch.inference_mode(): logits = model(input_ids, inference_params=inference_params).logits[:, -1] + if timing: + torch.cuda.synchronize() + start = time.time() if vocab_size is not None: logits = logits[..., :vocab_size] scores.append(logits) next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature) sequences = [next_token] inference_params.sequence_len_offset = seqlen_og - if timing: - start = time.time() while True: position_ids = torch.full((batch_size, 1), inference_params.sequence_len_offset, dtype=torch.long, device=input_ids.device) @@ -127,11 +129,13 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, next_token = sample(logits, top_k=top_k, temperature=temperature) sequences.append(next_token) inference_params.sequence_len_offset += 1 + if eos_token_id is not None and (next_token == eos_token_id).all(): + break if inference_params.sequence_len_offset >= max_length - 1: break if timing: torch.cuda.synchronize() - print(f'Decoding time: {time.time() - start}') + print(f'Decoding time: {(time.time() - start) * 1000:.0f}ms') output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput return output_cls( sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1), diff --git a/flash_attn/utils/pretrained.py b/flash_attn/utils/pretrained.py index 6732892..4b170a3 100644 --- a/flash_attn/utils/pretrained.py +++ b/flash_attn/utils/pretrained.py @@ -1,11 +1,33 @@ import torch -from transformers.utils import WEIGHTS_NAME -from transformers.utils.hub import cached_file +from transformers.utils import WEIGHTS_NAME, WEIGHTS_INDEX_NAME +from transformers.utils import is_remote_url +from transformers.modeling_utils import load_state_dict +from transformers.utils.hub import cached_file, get_checkpoint_shard_files def state_dict_from_pretrained(model_name, device=None, dtype=None): - state_dict = torch.load(cached_file(model_name, WEIGHTS_NAME), map_location=device) + is_sharded = False + resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, + _raise_exceptions_for_missing_entries=False) + if resolved_archive_file is None: + resolved_archive_file = cached_file(model_name, WEIGHTS_INDEX_NAME, + _raise_exceptions_for_missing_entries=False) + if resolved_archive_file is not None: + is_sharded = True + if resolved_archive_file is None: + raise EnvironmentError(f"Model name {model_name} was not found.") + if is_sharded: + # resolved_archive_file becomes a list of files that point to the different + # checkpoint shards in this case. + resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( + model_name, resolved_archive_file + ) + state_dict = {} + for sharded_file in resolved_archive_file: + state_dict.update(torch.load(sharded_file, map_location=device)) + else: + state_dict = torch.load(cached_file(model_name, WEIGHTS_NAME), map_location=device) if dtype is not None: state_dict = {k: v.to(dtype) for k, v in state_dict.items()} return state_dict diff --git a/tests/models/test_gpt_generation.py b/tests/models/test_gpt_generation.py index 2d1fdc7..a347387 100644 --- a/tests/models/test_gpt_generation.py +++ b/tests/models/test_gpt_generation.py @@ -1,18 +1,22 @@ import os import re +import time import torch import pytest from einops import rearrange -from transformers import GPT2Config, GPT2Tokenizer +from transformers import GPT2Config, GPT2Tokenizer, OPTConfig, AutoTokenizer from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF +from transformers.models.opt.modeling_opt import OPTForCausalLM from flash_attn.models.gpt import GPTLMHeadModel from flash_attn.models.gpt import remap_state_dict_gpt2 +from flash_attn.models.opt import remap_state_dict_opt, opt_config_to_gpt2_config from flash_attn.utils.pretrained import state_dict_from_pretrained from flash_attn.utils.distributed import all_gather_raw +from flash_attn.utils.generation import update_graph_cache @pytest.mark.parametrize('fused_ft_kernel', [False, True]) @@ -22,7 +26,7 @@ from flash_attn.utils.distributed import all_gather_raw @pytest.mark.parametrize('rotary', [False, True]) # @pytest.mark.parametrize('rotary', [False]) @pytest.mark.parametrize('model_name', ["gpt2"]) -def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel): +def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel): """Check that our implementation of GPT2 generation matches the HF implementation: the scores in fp16 should be around the same as the HF scores in fp16, when compared to the HF scores in fp32. @@ -49,13 +53,14 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel): if not rotary: model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device) - model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device, dtype=dtype) + model_hf = GPT2LMHeadModelHF.from_pretrained(model_name, + torch_dtype=dtype).to(device=device) model_ref.eval() model_hf.eval() torch.manual_seed(0) tokenizer = GPT2Tokenizer.from_pretrained("gpt2") - input_ids = tokenizer("Hello, my dog is cute and ", + input_ids = tokenizer("Hello, my dog is cute and", return_tensors="pt").input_ids.to(device=device) max_length = 30 # input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda') @@ -106,3 +111,124 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel): assert torch.all(out.sequences == out_hf.sequences) assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() + + +@pytest.mark.parametrize('model_name', ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b", "facebook/opt-2.7b", "facebook/opt-6.7b"]) +# @pytest.mark.parametrize('model_name', ["facebook/opt-6.7b"]) +def test_greedy_decode_opt(model_name): + """Check that our implementation of OPT generation matches the HF implementation: + the scores in fp16 should be around the same as the HF scores in fp16, when compared to + the HF scores in fp32. + """ + print(f'\nMODEL: {model_name}') + verbose = False + dtype = torch.float16 + device = 'cuda' + rtol, atol = 3e-3, 3e-1 + fused_ft_kernel = True + config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name)) + # Only prenorm supports residual_in_fp32 + config.residual_in_fp32 = getattr(config, 'prenorm', True) + config.use_flash_attn = True + config.fused_bias_fc = True + config.fused_mlp = True + config.fused_dropout_add_ln = True + + model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype) + model.eval() + + torch.manual_seed(0) + # OPT tokenizer requires use_fast=False + # https://huggingface.co/docs/transformers/model_doc/opt + tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) + eos_token_id = tokenizer.eos_token_id + + input_ids = tokenizer("Hello, my dog is cute and", + return_tensors="pt").input_ids.to(device=device) + max_length = 30 + # input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda') + # max_length = input_ids.shape[1] + 40 + + # Slow generation for reference + sequences = [] + scores = [] + cur_input_ids = input_ids + with torch.inference_mode(): + scores.append(model(cur_input_ids).logits[:, -1]) + sequences.append(scores[-1].argmax(dim=-1)) + for _ in range(input_ids.shape[1] + 1, max_length): + cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], 'b -> b 1')], dim=-1) + scores.append(model(cur_input_ids).logits[:, -1]) + sequences.append(scores[-1].argmax(dim=-1)) + if eos_token_id is not None and (sequences[-1] == eos_token_id).all(): + break + sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1) + scores = tuple(scores) + + print('Without CUDA graph') + torch.cuda.synchronize() + start = time.time() + out = model.generate(input_ids=input_ids, max_length=max_length, + eos_token_id=eos_token_id, fused_ft_kernel=fused_ft_kernel, + return_dict_in_generate=True, output_scores=True, timing=True) + torch.cuda.synchronize() + print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms') + if verbose: + print(out.sequences) + print(tokenizer.batch_decode(out.sequences.tolist())) + if fused_ft_kernel: + # Capture graph outside the timing loop + batch_size, seqlen_og = input_ids.shape + model._decoding_cache = update_graph_cache( + model, None, batch_size, seqlen_og, max_length + ) + print('With CUDA graph') + torch.cuda.synchronize() + start = time.time() + out_cg = model.generate(input_ids=input_ids, max_length=max_length, + fused_ft_kernel=fused_ft_kernel, cg=True, + return_dict_in_generate=True, output_scores=True, timing=True) + torch.cuda.synchronize() + print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms') + if verbose: + print(out_cg.sequences) + print(tokenizer.batch_decode(out.sequences.tolist())) + + del model + + model_hf = OPTForCausalLM.from_pretrained(model_name, torch_dtype=dtype).to(device=device) + model_hf.eval() + print("HF fp16") + torch.cuda.synchronize() + start = time.time() + out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length, + return_dict_in_generate=True, output_scores=True) + torch.cuda.synchronize() + print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms') + del model_hf + + model_ref = OPTForCausalLM.from_pretrained(model_name).to(device=device) + model_ref.eval() + print("HF fp32") + torch.cuda.synchronize() + start = time.time() + out_ref = model_ref.generate(input_ids=input_ids, max_length=max_length, + return_dict_in_generate=True, output_scores=True) + torch.cuda.synchronize() + print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms') + del model_ref + print(tokenizer.batch_decode(out_ref.sequences.tolist())) + + if verbose: + print(f'Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}') + print(f'Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}') + print(f'HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}') + print(f'HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}') + + assert torch.all(out.sequences == sequences) + assert torch.allclose(torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), + rtol=rtol, atol=atol) + assert torch.all(out.sequences == out_ref.sequences) + assert torch.all(out.sequences == out_hf.sequences) + + assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() diff --git a/tests/models/test_opt.py b/tests/models/test_opt.py index 5b5529b..04ebfe5 100644 --- a/tests/models/test_opt.py +++ b/tests/models/test_opt.py @@ -35,6 +35,7 @@ def test_opt_optimized(model_name): config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name)) config.use_flash_attn = True config.fused_bias_fc = True + config.fused_mlp = True config.fused_dropout_add_ln = True # Only prenorm supports residual_in_fp32 config.residual_in_fp32 = getattr(config, 'prenorm', True)