Fix Llama GQA/MQA (#546)

* Fix llama MQA

* Fix permute shape

* Update llama.py
This commit is contained in:
Kevin Hu 2023-09-19 22:15:59 -07:00 committed by GitHub
parent dfe29f5e2b
commit 42832575d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 121 additions and 21 deletions

View File

@ -26,10 +26,13 @@ def remap_state_dict_meta_llama(
return f"transformer.{key}" if not key.startswith("output.") else key
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
# Word embedding
def key_mapping_emb(key):
return re.sub(
r"^transformer.tok_embeddings.", "transformer.embeddings.word_embeddings.", key
r"^transformer.tok_embeddings.",
"transformer.embeddings.word_embeddings.",
key,
)
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
@ -61,7 +64,9 @@ def remap_state_dict_meta_llama(
def key_mapping_ln(key):
key = re.sub(r"^transformer.norm.", r"transformer.ln_f.", key)
key = re.sub(
r"^transformer.layers.(\d+).attention_norm.", r"transformer.layers.\1.norm1.", key
r"^transformer.layers.(\d+).attention_norm.",
r"transformer.layers.\1.norm1.",
key,
)
key = re.sub(r"^transformer.layers.(\d+).ffn_norm.", r"transformer.layers.\1.norm2.", key)
return key
@ -77,7 +82,9 @@ def remap_state_dict_meta_llama(
def key_mapping_mlp(key):
return re.sub(
r"^transformer.layers.(\d+).feed_forward.w2.", r"transformer.layers.\1.mlp.fc2.", key
r"^transformer.layers.(\d+).feed_forward.w2.",
r"transformer.layers.\1.mlp.fc2.",
key,
)
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
@ -106,12 +113,13 @@ def remap_state_dict_meta_llama(
def remap_state_dict_hf_llama(
state_dict: dict[str, torch.Tensor], config: GPT2Config
state_dict: dict[str, torch.Tensor], config: GPT2Config, multi_query: bool = False
) -> dict[str, torch.Tensor]:
"""Convert the state_dict in Hugging Face format to standard GPT format.
This function modifies state_dict in place.
"""
# Embedding
def key_mapping_emb(key):
return re.sub(r"^model.embed_tokens.", "transformer.embeddings.word_embeddings.", key)
@ -153,28 +161,38 @@ def remap_state_dict_hf_llama(
state_dict[f"transformer.layers.{l}.mlp.fc1.weight"] = torch.cat([w3, w1], dim=0)
def key_mapping_mlp(key):
return re.sub(r"^model.layers.(\d+).mlp.down_proj.", r"transformer.layers.\1.mlp.fc2.", key)
return re.sub(
r"^model.layers.(\d+).mlp.down_proj.",
r"transformer.layers.\1.mlp.fc2.",
key,
)
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# LayerNorm
def key_mapping_ln(key):
key = re.sub(r"^model.norm.", r"transformer.ln_f.", key)
key = re.sub(r"^model.layers.(\d+).input_layernorm.", r"transformer.layers.\1.norm1.", key)
key = re.sub(
r"^model.layers.(\d+).post_attention_layernorm.", r"transformer.layers.\1.norm2.", key
r"^model.layers.(\d+).input_layernorm.",
r"transformer.layers.\1.norm1.",
key,
)
key = re.sub(
r"^model.layers.(\d+).post_attention_layernorm.",
r"transformer.layers.\1.norm2.",
key,
)
return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
def inv_permute(w):
def inv_permute(w, first_dim=None):
# Inverse of permute implemented in:
# https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114
return (
w.reshape(config.n_head, 2, config.n_embd // config.n_head // 2, config.n_embd)
w.reshape(first_dim or config.n_head, 2, -1, config.n_embd)
.transpose(1, 2)
.reshape(config.n_embd, config.n_embd)
.reshape(-1, config.n_embd)
)
# Attention
@ -182,15 +200,19 @@ def remap_state_dict_hf_llama(
Wq = state_dict.pop(f"model.layers.{l}.self_attn.q_proj.weight")
Wk = state_dict.pop(f"model.layers.{l}.self_attn.k_proj.weight")
Wv = state_dict.pop(f"model.layers.{l}.self_attn.v_proj.weight")
state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat(
[inv_permute(Wq), inv_permute(Wk), Wv], dim=0
(inv_permute(Wq), inv_permute(Wk, getattr(config, "n_head_kv")), Wv),
dim=0,
)
# We don't store these
state_dict.pop(f"model.layers.{l}.self_attn.rotary_emb.inv_freq", None)
def key_mapping_attn(key):
return re.sub(
r"^model.layers.(\d+).self_attn.o_proj.", r"transformer.layers.\1.mixer.out_proj.", key
r"^model.layers.(\d+).self_attn.o_proj.",
r"transformer.layers.\1.mixer.out_proj.",
key,
)
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
@ -198,7 +220,7 @@ def remap_state_dict_hf_llama(
def inv_remap_state_dict_hf_llama(
state_dict: dict[str, torch.Tensor], config: GPT2Config
state_dict: dict[str, torch.Tensor], config: GPT2Config, multi_query: bool = False
) -> dict[str, torch.Tensor]:
"""Convert the state_dict in standard GPT format to Hugging Face format.
@ -246,26 +268,36 @@ def inv_remap_state_dict_hf_llama(
state_dict[f"model.layers.{l}.mlp.up_proj.weight"] = w3
def key_mapping_mlp(key):
return re.sub(r"^transformer.layers.(\d+).mlp.fc2.", r"model.layers.\1.mlp.down_proj.", key)
return re.sub(
r"^transformer.layers.(\d+).mlp.fc2.",
r"model.layers.\1.mlp.down_proj.",
key,
)
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# LayerNorm
def key_mapping_ln(key):
key = re.sub(r"^transformer.ln_f.", r"model.norm.", key)
key = re.sub(r"^transformer.layers.(\d+).norm1.", r"model.layers.\1.input_layernorm.", key)
key = re.sub(
r"^transformer.layers.(\d+).norm2.", r"model.layers.\1.post_attention_layernorm.", key
r"^transformer.layers.(\d+).norm1.",
r"model.layers.\1.input_layernorm.",
key,
)
key = re.sub(
r"^transformer.layers.(\d+).norm2.",
r"model.layers.\1.post_attention_layernorm.",
key,
)
return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
def permute(w):
def permute(w, first_dim=None):
return (
w.view(config.n_head, config.n_embd // config.n_head // 2, 2, config.n_embd)
w.view(first_dim or config.n_head, -1, 2, config.n_embd)
.transpose(1, 2)
.reshape(config.n_embd, config.n_embd)
.reshape(-1, config.n_embd)
)
n_head = config.n_head
@ -284,13 +316,15 @@ def inv_remap_state_dict_hf_llama(
Wk = Wqkv[q_dim : q_dim + k_dim]
Wv = Wqkv[q_dim + k_dim : q_dim + k_dim + v_dim]
state_dict[f"model.layers.{l}.self_attn.q_proj.weight"] = permute(Wq)
state_dict[f"model.layers.{l}.self_attn.k_proj.weight"] = permute(Wk)
state_dict[f"model.layers.{l}.self_attn.k_proj.weight"] = permute(Wk, n_head_kv)
state_dict[f"model.layers.{l}.self_attn.v_proj.weight"] = Wv
state_dict.pop(f"transformer.layers.{l}.attention.inner_attention.rope.freqs", None)
def key_mapping_attn(key):
return re.sub(
r"^transformer.layers.(\d+).mixer.out_proj.", r"model.layers.\1.self_attn.o_proj.", key
r"^transformer.layers.(\d+).mixer.out_proj.",
r"model.layers.\1.self_attn.o_proj.",
key,
)
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())

View File

@ -135,6 +135,72 @@ def test_llama_optimized(model_name, checkpoint_format):
).abs().max().item()
@pytest.mark.parametrize("model_name", ["PY007/TinyLlama-1.1B-step-50K-105b"])
def test_mqa_optimized(model_name):
"""Check that our implementation of Llama with MQA/GQA (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
device = "cuda"
config = llama_config_to_gpt2_config(LlamaConfig.from_pretrained(model_name))
config.use_flash_attn = True # FlashAttention-2 supports headdim 256
config.fused_bias_fc = True
config.fused_mlp = False
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
# Without device_map, the model is loaded on the CPU, which is very slow
model_ref = LlamaForCausalLM.from_pretrained(model_name, device_map={"": device})
model_ref.eval()
model = GPTLMHeadModel(config, device=device, dtype=dtype)
model.load_state_dict(remap_state_dict_hf_llama(model_ref.state_dict(), config))
model.eval()
torch.manual_seed(0)
batch_size = 2
max_seqlen = 256
input_ids = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
)
with torch.no_grad():
out = model.transformer(input_ids)
logits = model(input_ids).logits
del model
with torch.no_grad():
out_ref = model_ref.model(input_ids).last_hidden_state
logits_ref = model_ref(input_ids).logits
del model_ref
model_hf = LlamaForCausalLM.from_pretrained(
model_name, torch_dtype=dtype, device_map={"": device}
)
model_hf.eval()
out_hf = model_hf.model(input_ids).last_hidden_state
logits_hf = model_hf(input_ids).logits
del model_hf
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 3 * (
out_hf - out_ref
).abs().max().item()
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 3 * (
logits_hf - logits_ref
).abs().max().item()
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "parallel"
@pytest.mark.parametrize("world_size", [2])
@pytest.mark.parametrize("model_name", ["13B"])