From 42832575d40d1b21c43ee570549cfa395e1f6e51 Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Tue, 19 Sep 2023 22:15:59 -0700 Subject: [PATCH] Fix Llama GQA/MQA (#546) * Fix llama MQA * Fix permute shape * Update llama.py --- flash_attn/models/llama.py | 76 +++++++++++++++++++++++++++----------- tests/models/test_llama.py | 66 +++++++++++++++++++++++++++++++++ 2 files changed, 121 insertions(+), 21 deletions(-) diff --git a/flash_attn/models/llama.py b/flash_attn/models/llama.py index 993a282..7bea141 100644 --- a/flash_attn/models/llama.py +++ b/flash_attn/models/llama.py @@ -26,10 +26,13 @@ def remap_state_dict_meta_llama( return f"transformer.{key}" if not key.startswith("output.") else key state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items()) + # Word embedding def key_mapping_emb(key): return re.sub( - r"^transformer.tok_embeddings.", "transformer.embeddings.word_embeddings.", key + r"^transformer.tok_embeddings.", + "transformer.embeddings.word_embeddings.", + key, ) state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) @@ -61,7 +64,9 @@ def remap_state_dict_meta_llama( def key_mapping_ln(key): key = re.sub(r"^transformer.norm.", r"transformer.ln_f.", key) key = re.sub( - r"^transformer.layers.(\d+).attention_norm.", r"transformer.layers.\1.norm1.", key + r"^transformer.layers.(\d+).attention_norm.", + r"transformer.layers.\1.norm1.", + key, ) key = re.sub(r"^transformer.layers.(\d+).ffn_norm.", r"transformer.layers.\1.norm2.", key) return key @@ -77,7 +82,9 @@ def remap_state_dict_meta_llama( def key_mapping_mlp(key): return re.sub( - r"^transformer.layers.(\d+).feed_forward.w2.", r"transformer.layers.\1.mlp.fc2.", key + r"^transformer.layers.(\d+).feed_forward.w2.", + r"transformer.layers.\1.mlp.fc2.", + key, ) state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) @@ -106,12 +113,13 @@ def remap_state_dict_meta_llama( def remap_state_dict_hf_llama( - state_dict: dict[str, torch.Tensor], config: GPT2Config + state_dict: dict[str, torch.Tensor], config: GPT2Config, multi_query: bool = False ) -> dict[str, torch.Tensor]: """Convert the state_dict in Hugging Face format to standard GPT format. This function modifies state_dict in place. """ + # Embedding def key_mapping_emb(key): return re.sub(r"^model.embed_tokens.", "transformer.embeddings.word_embeddings.", key) @@ -153,28 +161,38 @@ def remap_state_dict_hf_llama( state_dict[f"transformer.layers.{l}.mlp.fc1.weight"] = torch.cat([w3, w1], dim=0) def key_mapping_mlp(key): - return re.sub(r"^model.layers.(\d+).mlp.down_proj.", r"transformer.layers.\1.mlp.fc2.", key) + return re.sub( + r"^model.layers.(\d+).mlp.down_proj.", + r"transformer.layers.\1.mlp.fc2.", + key, + ) state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) # LayerNorm def key_mapping_ln(key): key = re.sub(r"^model.norm.", r"transformer.ln_f.", key) - key = re.sub(r"^model.layers.(\d+).input_layernorm.", r"transformer.layers.\1.norm1.", key) key = re.sub( - r"^model.layers.(\d+).post_attention_layernorm.", r"transformer.layers.\1.norm2.", key + r"^model.layers.(\d+).input_layernorm.", + r"transformer.layers.\1.norm1.", + key, + ) + key = re.sub( + r"^model.layers.(\d+).post_attention_layernorm.", + r"transformer.layers.\1.norm2.", + key, ) return key state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) - def inv_permute(w): + def inv_permute(w, first_dim=None): # Inverse of permute implemented in: # https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114 return ( - w.reshape(config.n_head, 2, config.n_embd // config.n_head // 2, config.n_embd) + w.reshape(first_dim or config.n_head, 2, -1, config.n_embd) .transpose(1, 2) - .reshape(config.n_embd, config.n_embd) + .reshape(-1, config.n_embd) ) # Attention @@ -182,15 +200,19 @@ def remap_state_dict_hf_llama( Wq = state_dict.pop(f"model.layers.{l}.self_attn.q_proj.weight") Wk = state_dict.pop(f"model.layers.{l}.self_attn.k_proj.weight") Wv = state_dict.pop(f"model.layers.{l}.self_attn.v_proj.weight") + state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat( - [inv_permute(Wq), inv_permute(Wk), Wv], dim=0 + (inv_permute(Wq), inv_permute(Wk, getattr(config, "n_head_kv")), Wv), + dim=0, ) # We don't store these state_dict.pop(f"model.layers.{l}.self_attn.rotary_emb.inv_freq", None) def key_mapping_attn(key): return re.sub( - r"^model.layers.(\d+).self_attn.o_proj.", r"transformer.layers.\1.mixer.out_proj.", key + r"^model.layers.(\d+).self_attn.o_proj.", + r"transformer.layers.\1.mixer.out_proj.", + key, ) state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) @@ -198,7 +220,7 @@ def remap_state_dict_hf_llama( def inv_remap_state_dict_hf_llama( - state_dict: dict[str, torch.Tensor], config: GPT2Config + state_dict: dict[str, torch.Tensor], config: GPT2Config, multi_query: bool = False ) -> dict[str, torch.Tensor]: """Convert the state_dict in standard GPT format to Hugging Face format. @@ -246,26 +268,36 @@ def inv_remap_state_dict_hf_llama( state_dict[f"model.layers.{l}.mlp.up_proj.weight"] = w3 def key_mapping_mlp(key): - return re.sub(r"^transformer.layers.(\d+).mlp.fc2.", r"model.layers.\1.mlp.down_proj.", key) + return re.sub( + r"^transformer.layers.(\d+).mlp.fc2.", + r"model.layers.\1.mlp.down_proj.", + key, + ) state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) # LayerNorm def key_mapping_ln(key): key = re.sub(r"^transformer.ln_f.", r"model.norm.", key) - key = re.sub(r"^transformer.layers.(\d+).norm1.", r"model.layers.\1.input_layernorm.", key) key = re.sub( - r"^transformer.layers.(\d+).norm2.", r"model.layers.\1.post_attention_layernorm.", key + r"^transformer.layers.(\d+).norm1.", + r"model.layers.\1.input_layernorm.", + key, + ) + key = re.sub( + r"^transformer.layers.(\d+).norm2.", + r"model.layers.\1.post_attention_layernorm.", + key, ) return key state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) - def permute(w): + def permute(w, first_dim=None): return ( - w.view(config.n_head, config.n_embd // config.n_head // 2, 2, config.n_embd) + w.view(first_dim or config.n_head, -1, 2, config.n_embd) .transpose(1, 2) - .reshape(config.n_embd, config.n_embd) + .reshape(-1, config.n_embd) ) n_head = config.n_head @@ -284,13 +316,15 @@ def inv_remap_state_dict_hf_llama( Wk = Wqkv[q_dim : q_dim + k_dim] Wv = Wqkv[q_dim + k_dim : q_dim + k_dim + v_dim] state_dict[f"model.layers.{l}.self_attn.q_proj.weight"] = permute(Wq) - state_dict[f"model.layers.{l}.self_attn.k_proj.weight"] = permute(Wk) + state_dict[f"model.layers.{l}.self_attn.k_proj.weight"] = permute(Wk, n_head_kv) state_dict[f"model.layers.{l}.self_attn.v_proj.weight"] = Wv state_dict.pop(f"transformer.layers.{l}.attention.inner_attention.rope.freqs", None) def key_mapping_attn(key): return re.sub( - r"^transformer.layers.(\d+).mixer.out_proj.", r"model.layers.\1.self_attn.o_proj.", key + r"^transformer.layers.(\d+).mixer.out_proj.", + r"model.layers.\1.self_attn.o_proj.", + key, ) state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) diff --git a/tests/models/test_llama.py b/tests/models/test_llama.py index 6456e10..e0b8c30 100644 --- a/tests/models/test_llama.py +++ b/tests/models/test_llama.py @@ -135,6 +135,72 @@ def test_llama_optimized(model_name, checkpoint_format): ).abs().max().item() + + +@pytest.mark.parametrize("model_name", ["PY007/TinyLlama-1.1B-step-50K-105b"]) +def test_mqa_optimized(model_name): + """Check that our implementation of Llama with MQA/GQA (with all optimizations enabled) matches the + HF implementation: the output of our forward pass in fp16 should be around the same as the HF + forward pass in fp16, when compared to the HF forward pass in fp32. + """ + dtype = torch.float16 + device = "cuda" + config = llama_config_to_gpt2_config(LlamaConfig.from_pretrained(model_name)) + config.use_flash_attn = True # FlashAttention-2 supports headdim 256 + config.fused_bias_fc = True + config.fused_mlp = False + config.fused_dropout_add_ln = True + config.residual_in_fp32 = True + + # Without device_map, the model is loaded on the CPU, which is very slow + model_ref = LlamaForCausalLM.from_pretrained(model_name, device_map={"": device}) + model_ref.eval() + + model = GPTLMHeadModel(config, device=device, dtype=dtype) + model.load_state_dict(remap_state_dict_hf_llama(model_ref.state_dict(), config)) + model.eval() + + torch.manual_seed(0) + batch_size = 2 + max_seqlen = 256 + input_ids = torch.randint( + 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device + ) + with torch.no_grad(): + out = model.transformer(input_ids) + logits = model(input_ids).logits + del model + + with torch.no_grad(): + out_ref = model_ref.model(input_ids).last_hidden_state + logits_ref = model_ref(input_ids).logits + del model_ref + + model_hf = LlamaForCausalLM.from_pretrained( + model_name, torch_dtype=dtype, device_map={"": device} + ) + model_hf.eval() + out_hf = model_hf.model(input_ids).last_hidden_state + logits_hf = model_hf(input_ids).logits + del model_hf + + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}") + print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}") + assert (out - out_ref).abs().max().item() < 3 * ( + out_hf - out_ref + ).abs().max().item() + + print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}") + print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}") + print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}") + print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}") + assert (logits - logits_ref).abs().max().item() < 3 * ( + logits_hf - logits_ref + ).abs().max().item() + + # torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "parallel" @pytest.mark.parametrize("world_size", [2]) @pytest.mark.parametrize("model_name", ["13B"])