From 0705d2718dd39a39507dbdac85c538189a8436a1 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 20 Sep 2023 23:36:46 -0700 Subject: [PATCH] [Llama] Fix some tests, add tests for Llama 2 and CodeLlama --- flash_attn/models/llama.py | 31 +-- flash_attn/modules/mha.py | 2 +- tests/models/test_gpt_generation_parallel.py | 1 + tests/models/test_llama.py | 251 ++++++++----------- 4 files changed, 124 insertions(+), 161 deletions(-) diff --git a/flash_attn/models/llama.py b/flash_attn/models/llama.py index 7bea141..2841efd 100644 --- a/flash_attn/models/llama.py +++ b/flash_attn/models/llama.py @@ -13,6 +13,8 @@ import torch.nn.functional as F from sentencepiece import SentencePieceProcessor from transformers import GPT2Config, LlamaConfig +from einops import rearrange + def remap_state_dict_meta_llama( state_dict: dict[str, torch.Tensor], config: GPT2Config @@ -30,9 +32,7 @@ def remap_state_dict_meta_llama( # Word embedding def key_mapping_emb(key): return re.sub( - r"^transformer.tok_embeddings.", - "transformer.embeddings.word_embeddings.", - key, + r"^transformer.tok_embeddings.", "transformer.embeddings.word_embeddings.", key ) state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) @@ -113,7 +113,7 @@ def remap_state_dict_meta_llama( def remap_state_dict_hf_llama( - state_dict: dict[str, torch.Tensor], config: GPT2Config, multi_query: bool = False + state_dict: dict[str, torch.Tensor], config: GPT2Config ) -> dict[str, torch.Tensor]: """Convert the state_dict in Hugging Face format to standard GPT format. @@ -186,13 +186,11 @@ def remap_state_dict_hf_llama( state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) - def inv_permute(w, first_dim=None): + def inv_permute(w): # Inverse of permute implemented in: # https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114 - return ( - w.reshape(first_dim or config.n_head, 2, -1, config.n_embd) - .transpose(1, 2) - .reshape(-1, config.n_embd) + return rearrange( + w, "(h two d) n -> (h d two) n", d=config.n_embd // config.n_head // 2, two=2 ) # Attention @@ -202,8 +200,7 @@ def remap_state_dict_hf_llama( Wv = state_dict.pop(f"model.layers.{l}.self_attn.v_proj.weight") state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat( - (inv_permute(Wq), inv_permute(Wk, getattr(config, "n_head_kv")), Wv), - dim=0, + [inv_permute(Wq), inv_permute(Wk), Wv], dim=0 ) # We don't store these state_dict.pop(f"model.layers.{l}.self_attn.rotary_emb.inv_freq", None) @@ -220,7 +217,7 @@ def remap_state_dict_hf_llama( def inv_remap_state_dict_hf_llama( - state_dict: dict[str, torch.Tensor], config: GPT2Config, multi_query: bool = False + state_dict: dict[str, torch.Tensor], config: GPT2Config ) -> dict[str, torch.Tensor]: """Convert the state_dict in standard GPT format to Hugging Face format. @@ -293,11 +290,9 @@ def inv_remap_state_dict_hf_llama( state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) - def permute(w, first_dim=None): - return ( - w.view(first_dim or config.n_head, -1, 2, config.n_embd) - .transpose(1, 2) - .reshape(-1, config.n_embd) + def permute(w): + return rearrange( + w, "(h d two) n -> (h two d) n", d=config.n_embd // config.n_head // 2, two=2 ) n_head = config.n_head @@ -316,7 +311,7 @@ def inv_remap_state_dict_hf_llama( Wk = Wqkv[q_dim : q_dim + k_dim] Wv = Wqkv[q_dim + k_dim : q_dim + k_dim + v_dim] state_dict[f"model.layers.{l}.self_attn.q_proj.weight"] = permute(Wq) - state_dict[f"model.layers.{l}.self_attn.k_proj.weight"] = permute(Wk, n_head_kv) + state_dict[f"model.layers.{l}.self_attn.k_proj.weight"] = permute(Wk) state_dict[f"model.layers.{l}.self_attn.v_proj.weight"] = Wv state_dict.pop(f"transformer.layers.{l}.attention.inner_attention.rope.freqs", None) diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index 4894dac..976bd3d 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -725,7 +725,7 @@ class ParallelMHA(nn.Module): process_group, bias=qkv_proj_bias, sequence_parallel=sequence_parallel, - multiple_of=self.head_dim * (self.num_heads_per_rank + 2 * self.num_heads_kv_per_rank), + multiple_of=self.head_dim * (self.num_heads // self.num_heads_kv + 2), **factory_kwargs, ) inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention diff --git a/tests/models/test_gpt_generation_parallel.py b/tests/models/test_gpt_generation_parallel.py index b398bf9..bcf2bf5 100644 --- a/tests/models/test_gpt_generation_parallel.py +++ b/tests/models/test_gpt_generation_parallel.py @@ -160,6 +160,7 @@ def test_tensor_parallel(model_name, rotary, world_size): assert torch.allclose( torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), rtol=rtol, atol=atol ) + assert torch.equal(torch.stack(out.scores, dim=1), torch.stack(out_cg.scores, dim=1)) if not rotary: assert torch.all(out.sequences == out_ref.sequences) assert torch.all(out.sequences == out_hf.sequences) diff --git a/tests/models/test_llama.py b/tests/models/test_llama.py index e0b8c30..32e9cd2 100644 --- a/tests/models/test_llama.py +++ b/tests/models/test_llama.py @@ -1,6 +1,6 @@ # Copyright (c) 2023, Tri Dao. -# To run the huggingface implementation, we first need to convert the weights: +# To run the huggingface implementation of LLaMa (1), we first need to convert the weights: # https://github.com/huggingface/transformers/pull/21955 # python -m transformers.models.llama.convert_llama_weights_to_hf --input_dir $CHECKPOINT_DIR/llama --model_size 7B --output_dir $CHECKPOINT_DIR/llama/7B-hf # and repeat for 13B, 30B, 65B @@ -30,6 +30,7 @@ from flash_attn.utils.generation import update_graph_cache from flash_attn.utils.pretrained import state_dict_from_pretrained from transformers import LlamaConfig, LlamaTokenizer from transformers.models.llama.modeling_llama import LlamaForCausalLM +from transformers import AutoConfig def _pretrained_state_dict_from_checkpoint(checkpoint_path, model_name, config, checkpoint_format): @@ -60,9 +61,38 @@ def test_llama_state_dict(model_name): assert state_dict[k].shape == pretrained_state_dict[k].shape -@pytest.mark.parametrize("model_name", ["7B", "13B"]) -@pytest.mark.parametrize("checkpoint_format", ["meta", "hf"]) -def test_llama_optimized(model_name, checkpoint_format): +# TinyLlama-1.1B is to test MQA +@pytest.mark.parametrize( + "model_name", ["meta-llama/Llama-2-7b-hf", "PY007/TinyLlama-1.1B-step-50K-105b"] +) +def test_inv_remap_state_dict_hf_llama(model_name): + config = llama_config_to_gpt2_config( + AutoConfig.from_pretrained(model_name, trust_remote_code=True) + ) + state_dict = state_dict_from_pretrained(model_name) + # inv_remap_state_dict_hf_llama should be the inverse of remap_state_dict_hf_llama + state_dict = {key: val for key, val in state_dict.items() if "rotary_emb.inv_freq" not in key} + pretrained_state_dict = remap_state_dict_hf_llama(state_dict, config) + state_dict_recover = inv_remap_state_dict_hf_llama(pretrained_state_dict, config) + assert set(state_dict_recover.keys()) == set(state_dict.keys()) + for key in state_dict_recover.keys(): + torch.testing.assert_close(state_dict_recover[key], state_dict[key]) + + +# TinyLlama-1.1B is to test MQA +@pytest.mark.parametrize( + "model_name", + [ + "7B", # Llama 1 + "13B", # Llama 1 + "meta-llama/Llama-2-13b-hf", + "codellama/CodeLlama-7b-hf", + "codellama/CodeLlama-13b-hf", + "codellama/CodeLlama-34b-hf", + "PY007/TinyLlama-1.1B-step-50K-105b", + ], +) +def test_llama_optimized(model_name): """Check that our implementation of LLaMa (with all optimizations enabled) matches the HF implementation: the output of our forward pass in fp16 should be around the same as the HF forward pass in fp16, when compared to the HF forward pass in fp32. @@ -73,17 +103,27 @@ def test_llama_optimized(model_name, checkpoint_format): dtype = torch.float16 device = "cuda" - config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format) - config = llama_config_to_gpt2_config(config) + if "/" in model_name: # Download from HF + config = llama_config_to_gpt2_config( + AutoConfig.from_pretrained(model_name, trust_remote_code=True) + ) + else: + config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format="meta") + config = llama_config_to_gpt2_config(config) config.use_flash_attn = True config.fused_bias_fc = True config.fused_mlp = False # We don't have fused GatedMLP yet config.fused_dropout_add_ln = True config.residual_in_fp32 = True - pretrained_state_dict = _pretrained_state_dict_from_checkpoint( - checkpoint_path, model_name, config, checkpoint_format - ) + if "/" in model_name: # Download from HF + pretrained_state_dict = remap_state_dict_hf_llama( + state_dict_from_pretrained(model_name), config + ) + else: + pretrained_state_dict = _pretrained_state_dict_from_checkpoint( + checkpoint_path, model_name, config, checkpoint_format="meta" + ) model = GPTLMHeadModel(config, device=device, dtype=dtype) model.load_state_dict(pretrained_state_dict) model.eval() @@ -103,7 +143,8 @@ def test_llama_optimized(model_name, checkpoint_format): # Without device_map, the model is loaded on the CPU, which is very slow # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB model_ref = LlamaForCausalLM.from_pretrained( - Path(checkpoint_path) / f"{model_name}-hf", device_map="auto" + model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf", + device_map="auto", ) model_ref.eval() with torch.no_grad(): @@ -112,7 +153,9 @@ def test_llama_optimized(model_name, checkpoint_format): del model_ref model_hf = LlamaForCausalLM.from_pretrained( - Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map={"": device} + model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf", + torch_dtype=dtype, + device_map={"": device}, ) model_hf.eval() with torch.no_grad(): @@ -135,77 +178,12 @@ def test_llama_optimized(model_name, checkpoint_format): ).abs().max().item() - - -@pytest.mark.parametrize("model_name", ["PY007/TinyLlama-1.1B-step-50K-105b"]) -def test_mqa_optimized(model_name): - """Check that our implementation of Llama with MQA/GQA (with all optimizations enabled) matches the - HF implementation: the output of our forward pass in fp16 should be around the same as the HF - forward pass in fp16, when compared to the HF forward pass in fp32. - """ - dtype = torch.float16 - device = "cuda" - config = llama_config_to_gpt2_config(LlamaConfig.from_pretrained(model_name)) - config.use_flash_attn = True # FlashAttention-2 supports headdim 256 - config.fused_bias_fc = True - config.fused_mlp = False - config.fused_dropout_add_ln = True - config.residual_in_fp32 = True - - # Without device_map, the model is loaded on the CPU, which is very slow - model_ref = LlamaForCausalLM.from_pretrained(model_name, device_map={"": device}) - model_ref.eval() - - model = GPTLMHeadModel(config, device=device, dtype=dtype) - model.load_state_dict(remap_state_dict_hf_llama(model_ref.state_dict(), config)) - model.eval() - - torch.manual_seed(0) - batch_size = 2 - max_seqlen = 256 - input_ids = torch.randint( - 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device - ) - with torch.no_grad(): - out = model.transformer(input_ids) - logits = model(input_ids).logits - del model - - with torch.no_grad(): - out_ref = model_ref.model(input_ids).last_hidden_state - logits_ref = model_ref(input_ids).logits - del model_ref - - model_hf = LlamaForCausalLM.from_pretrained( - model_name, torch_dtype=dtype, device_map={"": device} - ) - model_hf.eval() - out_hf = model_hf.model(input_ids).last_hidden_state - logits_hf = model_hf(input_ids).logits - del model_hf - - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}") - print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}") - assert (out - out_ref).abs().max().item() < 3 * ( - out_hf - out_ref - ).abs().max().item() - - print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}") - print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}") - print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}") - print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}") - assert (logits - logits_ref).abs().max().item() < 3 * ( - logits_hf - logits_ref - ).abs().max().item() - - # torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "parallel" @pytest.mark.parametrize("world_size", [2]) -@pytest.mark.parametrize("model_name", ["13B"]) -@pytest.mark.parametrize("checkpoint_format", ["meta", "hf"]) -def test_llama_parallel(model_name, world_size, checkpoint_format): +@pytest.mark.parametrize( + "model_name", ["13B", "meta-llama/Llama-2-13b-hf", "codellama/CodeLlama-34b-hf"] +) +def test_llama_parallel(model_name, world_size): """Check that our implementation of LLaMa (with all optimizations enabled) matches the HF implementation: the output of our forward pass in fp16 should be around the same as the HF forward pass in fp16, when compared to the HF forward pass in fp32. @@ -217,8 +195,13 @@ def test_llama_parallel(model_name, world_size, checkpoint_format): ) dtype = torch.float16 - config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format) - config = llama_config_to_gpt2_config(config) + if "/" in model_name: # Download from HF + config = llama_config_to_gpt2_config( + AutoConfig.from_pretrained(model_name, trust_remote_code=True) + ) + else: + config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format="meta") + config = llama_config_to_gpt2_config(config) config.use_flash_attn = True config.fused_bias_fc = True config.fused_mlp = False # We don't have fused GatedMLP yet @@ -233,9 +216,14 @@ def test_llama_parallel(model_name, world_size, checkpoint_format): rank = parallel_state.get_tensor_model_parallel_rank() process_group = parallel_state.get_tensor_model_parallel_group() - pretrained_state_dict = _pretrained_state_dict_from_checkpoint( - checkpoint_path, model_name, config, checkpoint_format - ) + if "/" in model_name: # Download from HF + pretrained_state_dict = remap_state_dict_hf_llama( + state_dict_from_pretrained(model_name), config + ) + else: + pretrained_state_dict = _pretrained_state_dict_from_checkpoint( + checkpoint_path, model_name, config, checkpoint_format="meta" + ) model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype) model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank)) model.eval() @@ -260,7 +248,8 @@ def test_llama_parallel(model_name, world_size, checkpoint_format): if rank == 0: # Without device_map, the model is loaded on the CPU, which is very slow model_ref = LlamaForCausalLM.from_pretrained( - Path(checkpoint_path) / f"{model_name}-hf", device_map="auto" + model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf", + device_map="auto", ) model_ref.eval() with torch.no_grad(): @@ -269,7 +258,9 @@ def test_llama_parallel(model_name, world_size, checkpoint_format): del model_ref model_hf = LlamaForCausalLM.from_pretrained( - Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map="auto" + model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf", + torch_dtype=dtype, + device_map="auto", ) model_hf.eval() with torch.no_grad(): @@ -405,9 +396,10 @@ def test_llama_generation(model_name, checkpoint_format): # torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "llama_parallel_generation" @pytest.mark.parametrize("world_size", [2]) -@pytest.mark.parametrize("model_name", ["13B"]) -@pytest.mark.parametrize("checkpoint_format", ["meta", "hf"]) -def test_llama_parallel_generation(model_name, world_size, checkpoint_format): +@pytest.mark.parametrize( + "model_name", ["13B", "meta-llama/Llama-2-13b-hf", "codellama/CodeLlama-34b-hf"] +) +def test_llama_parallel_generation(model_name, world_size): """Check that our implementation matches the HF implementation: the scores in fp16 should be around the same as the HF scores in fp16, when compared to the HF scores in fp32. @@ -419,12 +411,17 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format): ) dtype = torch.float16 - config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format) - config = llama_config_to_gpt2_config(config) - config.use_flash_attn = False + if "/" in model_name: # Download from HF + config = llama_config_to_gpt2_config( + AutoConfig.from_pretrained(model_name, trust_remote_code=True) + ) + else: + config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format="meta") + config = llama_config_to_gpt2_config(config) + config.use_flash_attn = True config.fused_bias_fc = True config.fused_mlp = False # We don't have fused GatedMLP yet - config.fused_dropout_add_ln = False + config.fused_dropout_add_ln = True config.residual_in_fp32 = True config.pad_vocab_size_multiple = 8 * world_size config.sequence_parallel = False # Need to set this to False for generation @@ -450,9 +447,14 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format): # GPU0 and GPU1 and things would hang torch.cuda.set_device(device) - pretrained_state_dict = _pretrained_state_dict_from_checkpoint( - checkpoint_path, model_name, config, checkpoint_format - ) + if "/" in model_name: # Download from HF + pretrained_state_dict = remap_state_dict_hf_llama( + state_dict_from_pretrained(model_name), config + ) + else: + pretrained_state_dict = _pretrained_state_dict_from_checkpoint( + checkpoint_path, model_name, config, checkpoint_format="meta" + ) model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype) model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank)) model.eval() @@ -490,7 +492,9 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format): if rank == 0: # Without device_map, the model is loaded on the CPU, which is very slow model_hf = LlamaForCausalLM.from_pretrained( - Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map="auto" + model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf", + torch_dtype=dtype, + device_map="auto", ) model_hf.eval() print("HF fp16") @@ -508,7 +512,8 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format): del model_hf model_ref = LlamaForCausalLM.from_pretrained( - Path(checkpoint_path) / f"{model_name}-hf", device_map="auto" + model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf", + device_map="auto", ) model_ref.eval() with torch.inference_mode(): @@ -594,15 +599,16 @@ def test_llama_parallel_uneven_num_heads(world_size): if rank == 0: model_ref = LlamaForCausalLM.from_pretrained( - Path(checkpoint_path) / f"{model_name}-hf", device_map="auto" + Path(checkpoint_path) / f"{model_name}-hf", device_map={"": device} ) + model_ref = model_ref.to(device=device) model_ref.eval() - out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device) - logits_ref = model_ref(input_ids).logits.to(device=device) + out_ref = model_ref.model(input_ids).last_hidden_state + logits_ref = model_ref(input_ids).logits del model_ref model_hf = LlamaForCausalLM.from_pretrained( - Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map="auto" + Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map={"": device} ) model_hf.eval() out_hf = model_hf.model(input_ids).last_hidden_state.to(device=device) @@ -625,42 +631,3 @@ def test_llama_parallel_uneven_num_heads(world_size): if os.path.exists(checkpoint_path / f"{model_name}-hf"): shutil.rmtree(checkpoint_path / f"{model_name}-hf") - - -@torch.no_grad() -def test_inv_remap_state_dict_hf_llama(): - checkpoint_path = ( - Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama" - ) - model_name = f"teeny" - - llama_config = LlamaConfig( - num_attention_heads=2, - hidden_size=256 * 2, - intermediate_size=256 * 2 * 4, - num_hidden_layers=4, - ) - config = llama_config_to_gpt2_config(llama_config) - config.use_flash_attn = True - config.fused_bias_fc = True - config.fused_mlp = False # We don't have fused GatedMLP yet - config.fused_dropout_add_ln = True - config.residual_in_fp32 = True - - # Set up. - LlamaForCausalLM(config=llama_config).save_pretrained(checkpoint_path / f"{model_name}-hf") - - # inv_remap_state_dict_hf_llama should be the inverse of remap_state_dict_hf_llama - state_dict = state_dict_from_pretrained(checkpoint_path / f"{model_name}-hf") - state_dict = {key: val for key, val in state_dict.items() if "rotary_emb.inv_freq" not in key} - pretrained_state_dict = remap_state_dict_hf_llama(state_dict, config) - state_dict_recover = inv_remap_state_dict_hf_llama(pretrained_state_dict, config) - - assert set(state_dict_recover.keys()) == set(state_dict.keys()) - - for key in state_dict_recover.keys(): - torch.testing.assert_close(state_dict_recover[key], state_dict[key]) - - # Tear down. - if os.path.exists(checkpoint_path / f"{model_name}-hf"): - shutil.rmtree(checkpoint_path / f"{model_name}-hf")