Fix Llama GQA/MQA (#546)
* Fix llama MQA * Fix permute shape * Update llama.py
This commit is contained in:
parent
dfe29f5e2b
commit
42832575d4
@ -26,10 +26,13 @@ def remap_state_dict_meta_llama(
|
||||
return f"transformer.{key}" if not key.startswith("output.") else key
|
||||
|
||||
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
|
||||
|
||||
# Word embedding
|
||||
def key_mapping_emb(key):
|
||||
return re.sub(
|
||||
r"^transformer.tok_embeddings.", "transformer.embeddings.word_embeddings.", key
|
||||
r"^transformer.tok_embeddings.",
|
||||
"transformer.embeddings.word_embeddings.",
|
||||
key,
|
||||
)
|
||||
|
||||
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
|
||||
@ -61,7 +64,9 @@ def remap_state_dict_meta_llama(
|
||||
def key_mapping_ln(key):
|
||||
key = re.sub(r"^transformer.norm.", r"transformer.ln_f.", key)
|
||||
key = re.sub(
|
||||
r"^transformer.layers.(\d+).attention_norm.", r"transformer.layers.\1.norm1.", key
|
||||
r"^transformer.layers.(\d+).attention_norm.",
|
||||
r"transformer.layers.\1.norm1.",
|
||||
key,
|
||||
)
|
||||
key = re.sub(r"^transformer.layers.(\d+).ffn_norm.", r"transformer.layers.\1.norm2.", key)
|
||||
return key
|
||||
@ -77,7 +82,9 @@ def remap_state_dict_meta_llama(
|
||||
|
||||
def key_mapping_mlp(key):
|
||||
return re.sub(
|
||||
r"^transformer.layers.(\d+).feed_forward.w2.", r"transformer.layers.\1.mlp.fc2.", key
|
||||
r"^transformer.layers.(\d+).feed_forward.w2.",
|
||||
r"transformer.layers.\1.mlp.fc2.",
|
||||
key,
|
||||
)
|
||||
|
||||
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
||||
@ -106,12 +113,13 @@ def remap_state_dict_meta_llama(
|
||||
|
||||
|
||||
def remap_state_dict_hf_llama(
|
||||
state_dict: dict[str, torch.Tensor], config: GPT2Config
|
||||
state_dict: dict[str, torch.Tensor], config: GPT2Config, multi_query: bool = False
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Convert the state_dict in Hugging Face format to standard GPT format.
|
||||
|
||||
This function modifies state_dict in place.
|
||||
"""
|
||||
|
||||
# Embedding
|
||||
def key_mapping_emb(key):
|
||||
return re.sub(r"^model.embed_tokens.", "transformer.embeddings.word_embeddings.", key)
|
||||
@ -153,28 +161,38 @@ def remap_state_dict_hf_llama(
|
||||
state_dict[f"transformer.layers.{l}.mlp.fc1.weight"] = torch.cat([w3, w1], dim=0)
|
||||
|
||||
def key_mapping_mlp(key):
|
||||
return re.sub(r"^model.layers.(\d+).mlp.down_proj.", r"transformer.layers.\1.mlp.fc2.", key)
|
||||
return re.sub(
|
||||
r"^model.layers.(\d+).mlp.down_proj.",
|
||||
r"transformer.layers.\1.mlp.fc2.",
|
||||
key,
|
||||
)
|
||||
|
||||
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
||||
|
||||
# LayerNorm
|
||||
def key_mapping_ln(key):
|
||||
key = re.sub(r"^model.norm.", r"transformer.ln_f.", key)
|
||||
key = re.sub(r"^model.layers.(\d+).input_layernorm.", r"transformer.layers.\1.norm1.", key)
|
||||
key = re.sub(
|
||||
r"^model.layers.(\d+).post_attention_layernorm.", r"transformer.layers.\1.norm2.", key
|
||||
r"^model.layers.(\d+).input_layernorm.",
|
||||
r"transformer.layers.\1.norm1.",
|
||||
key,
|
||||
)
|
||||
key = re.sub(
|
||||
r"^model.layers.(\d+).post_attention_layernorm.",
|
||||
r"transformer.layers.\1.norm2.",
|
||||
key,
|
||||
)
|
||||
return key
|
||||
|
||||
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
||||
|
||||
def inv_permute(w):
|
||||
def inv_permute(w, first_dim=None):
|
||||
# Inverse of permute implemented in:
|
||||
# https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114
|
||||
return (
|
||||
w.reshape(config.n_head, 2, config.n_embd // config.n_head // 2, config.n_embd)
|
||||
w.reshape(first_dim or config.n_head, 2, -1, config.n_embd)
|
||||
.transpose(1, 2)
|
||||
.reshape(config.n_embd, config.n_embd)
|
||||
.reshape(-1, config.n_embd)
|
||||
)
|
||||
|
||||
# Attention
|
||||
@ -182,15 +200,19 @@ def remap_state_dict_hf_llama(
|
||||
Wq = state_dict.pop(f"model.layers.{l}.self_attn.q_proj.weight")
|
||||
Wk = state_dict.pop(f"model.layers.{l}.self_attn.k_proj.weight")
|
||||
Wv = state_dict.pop(f"model.layers.{l}.self_attn.v_proj.weight")
|
||||
|
||||
state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat(
|
||||
[inv_permute(Wq), inv_permute(Wk), Wv], dim=0
|
||||
(inv_permute(Wq), inv_permute(Wk, getattr(config, "n_head_kv")), Wv),
|
||||
dim=0,
|
||||
)
|
||||
# We don't store these
|
||||
state_dict.pop(f"model.layers.{l}.self_attn.rotary_emb.inv_freq", None)
|
||||
|
||||
def key_mapping_attn(key):
|
||||
return re.sub(
|
||||
r"^model.layers.(\d+).self_attn.o_proj.", r"transformer.layers.\1.mixer.out_proj.", key
|
||||
r"^model.layers.(\d+).self_attn.o_proj.",
|
||||
r"transformer.layers.\1.mixer.out_proj.",
|
||||
key,
|
||||
)
|
||||
|
||||
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
||||
@ -198,7 +220,7 @@ def remap_state_dict_hf_llama(
|
||||
|
||||
|
||||
def inv_remap_state_dict_hf_llama(
|
||||
state_dict: dict[str, torch.Tensor], config: GPT2Config
|
||||
state_dict: dict[str, torch.Tensor], config: GPT2Config, multi_query: bool = False
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Convert the state_dict in standard GPT format to Hugging Face format.
|
||||
|
||||
@ -246,26 +268,36 @@ def inv_remap_state_dict_hf_llama(
|
||||
state_dict[f"model.layers.{l}.mlp.up_proj.weight"] = w3
|
||||
|
||||
def key_mapping_mlp(key):
|
||||
return re.sub(r"^transformer.layers.(\d+).mlp.fc2.", r"model.layers.\1.mlp.down_proj.", key)
|
||||
return re.sub(
|
||||
r"^transformer.layers.(\d+).mlp.fc2.",
|
||||
r"model.layers.\1.mlp.down_proj.",
|
||||
key,
|
||||
)
|
||||
|
||||
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
||||
|
||||
# LayerNorm
|
||||
def key_mapping_ln(key):
|
||||
key = re.sub(r"^transformer.ln_f.", r"model.norm.", key)
|
||||
key = re.sub(r"^transformer.layers.(\d+).norm1.", r"model.layers.\1.input_layernorm.", key)
|
||||
key = re.sub(
|
||||
r"^transformer.layers.(\d+).norm2.", r"model.layers.\1.post_attention_layernorm.", key
|
||||
r"^transformer.layers.(\d+).norm1.",
|
||||
r"model.layers.\1.input_layernorm.",
|
||||
key,
|
||||
)
|
||||
key = re.sub(
|
||||
r"^transformer.layers.(\d+).norm2.",
|
||||
r"model.layers.\1.post_attention_layernorm.",
|
||||
key,
|
||||
)
|
||||
return key
|
||||
|
||||
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
||||
|
||||
def permute(w):
|
||||
def permute(w, first_dim=None):
|
||||
return (
|
||||
w.view(config.n_head, config.n_embd // config.n_head // 2, 2, config.n_embd)
|
||||
w.view(first_dim or config.n_head, -1, 2, config.n_embd)
|
||||
.transpose(1, 2)
|
||||
.reshape(config.n_embd, config.n_embd)
|
||||
.reshape(-1, config.n_embd)
|
||||
)
|
||||
|
||||
n_head = config.n_head
|
||||
@ -284,13 +316,15 @@ def inv_remap_state_dict_hf_llama(
|
||||
Wk = Wqkv[q_dim : q_dim + k_dim]
|
||||
Wv = Wqkv[q_dim + k_dim : q_dim + k_dim + v_dim]
|
||||
state_dict[f"model.layers.{l}.self_attn.q_proj.weight"] = permute(Wq)
|
||||
state_dict[f"model.layers.{l}.self_attn.k_proj.weight"] = permute(Wk)
|
||||
state_dict[f"model.layers.{l}.self_attn.k_proj.weight"] = permute(Wk, n_head_kv)
|
||||
state_dict[f"model.layers.{l}.self_attn.v_proj.weight"] = Wv
|
||||
state_dict.pop(f"transformer.layers.{l}.attention.inner_attention.rope.freqs", None)
|
||||
|
||||
def key_mapping_attn(key):
|
||||
return re.sub(
|
||||
r"^transformer.layers.(\d+).mixer.out_proj.", r"model.layers.\1.self_attn.o_proj.", key
|
||||
r"^transformer.layers.(\d+).mixer.out_proj.",
|
||||
r"model.layers.\1.self_attn.o_proj.",
|
||||
key,
|
||||
)
|
||||
|
||||
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
||||
|
||||
@ -135,6 +135,72 @@ def test_llama_optimized(model_name, checkpoint_format):
|
||||
).abs().max().item()
|
||||
|
||||
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["PY007/TinyLlama-1.1B-step-50K-105b"])
|
||||
def test_mqa_optimized(model_name):
|
||||
"""Check that our implementation of Llama with MQA/GQA (with all optimizations enabled) matches the
|
||||
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
|
||||
forward pass in fp16, when compared to the HF forward pass in fp32.
|
||||
"""
|
||||
dtype = torch.float16
|
||||
device = "cuda"
|
||||
config = llama_config_to_gpt2_config(LlamaConfig.from_pretrained(model_name))
|
||||
config.use_flash_attn = True # FlashAttention-2 supports headdim 256
|
||||
config.fused_bias_fc = True
|
||||
config.fused_mlp = False
|
||||
config.fused_dropout_add_ln = True
|
||||
config.residual_in_fp32 = True
|
||||
|
||||
# Without device_map, the model is loaded on the CPU, which is very slow
|
||||
model_ref = LlamaForCausalLM.from_pretrained(model_name, device_map={"": device})
|
||||
model_ref.eval()
|
||||
|
||||
model = GPTLMHeadModel(config, device=device, dtype=dtype)
|
||||
model.load_state_dict(remap_state_dict_hf_llama(model_ref.state_dict(), config))
|
||||
model.eval()
|
||||
|
||||
torch.manual_seed(0)
|
||||
batch_size = 2
|
||||
max_seqlen = 256
|
||||
input_ids = torch.randint(
|
||||
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
|
||||
)
|
||||
with torch.no_grad():
|
||||
out = model.transformer(input_ids)
|
||||
logits = model(input_ids).logits
|
||||
del model
|
||||
|
||||
with torch.no_grad():
|
||||
out_ref = model_ref.model(input_ids).last_hidden_state
|
||||
logits_ref = model_ref(input_ids).logits
|
||||
del model_ref
|
||||
|
||||
model_hf = LlamaForCausalLM.from_pretrained(
|
||||
model_name, torch_dtype=dtype, device_map={"": device}
|
||||
)
|
||||
model_hf.eval()
|
||||
out_hf = model_hf.model(input_ids).last_hidden_state
|
||||
logits_hf = model_hf(input_ids).logits
|
||||
del model_hf
|
||||
|
||||
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
||||
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
||||
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
|
||||
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
|
||||
assert (out - out_ref).abs().max().item() < 3 * (
|
||||
out_hf - out_ref
|
||||
).abs().max().item()
|
||||
|
||||
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
|
||||
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
|
||||
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
|
||||
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
|
||||
assert (logits - logits_ref).abs().max().item() < 3 * (
|
||||
logits_hf - logits_ref
|
||||
).abs().max().item()
|
||||
|
||||
|
||||
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "parallel"
|
||||
@pytest.mark.parametrize("world_size", [2])
|
||||
@pytest.mark.parametrize("model_name", ["13B"])
|
||||
|
||||
Loading…
Reference in New Issue
Block a user