2023-04-19 12:43:37 +08:00
|
|
|
# Copyright (c) 2023, Tri Dao.
|
|
|
|
|
|
|
|
|
|
import json
|
2023-08-15 23:33:15 +08:00
|
|
|
import math
|
|
|
|
|
import os
|
2023-04-19 12:43:37 +08:00
|
|
|
import re
|
|
|
|
|
from collections import OrderedDict
|
2023-08-15 23:33:15 +08:00
|
|
|
from pathlib import Path
|
|
|
|
|
from typing import Union
|
2023-04-19 12:43:37 +08:00
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
from transformers import GPT2Config, LlamaConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def remap_state_dict_meta_llama(state_dict, config):
|
|
|
|
|
def key_mapping_layers(key):
|
2023-08-19 05:22:11 +08:00
|
|
|
return f"transformer.{key}" if not key.startswith("output.") else key
|
|
|
|
|
|
2023-04-19 12:43:37 +08:00
|
|
|
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
|
|
|
|
|
# Word embedding
|
|
|
|
|
def key_mapping_emb(key):
|
2023-08-19 05:22:11 +08:00
|
|
|
return re.sub(
|
|
|
|
|
r"^transformer.tok_embeddings.", "transformer.embeddings.word_embeddings.", key
|
|
|
|
|
)
|
|
|
|
|
|
2023-04-19 12:43:37 +08:00
|
|
|
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
|
2023-08-19 05:22:11 +08:00
|
|
|
word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
|
2023-04-19 12:43:37 +08:00
|
|
|
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
2023-08-19 05:22:11 +08:00
|
|
|
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
|
|
|
|
vocab_size = (
|
|
|
|
|
math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
|
|
|
|
)
|
|
|
|
|
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
|
2023-04-19 12:43:37 +08:00
|
|
|
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
|
|
|
|
)
|
2023-08-19 05:22:11 +08:00
|
|
|
if getattr(config, "tie_word_embeddings"):
|
|
|
|
|
state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
|
2023-04-19 12:43:37 +08:00
|
|
|
else:
|
2023-08-19 05:22:11 +08:00
|
|
|
output_embeddings = state_dict.pop("output.weight")
|
2023-04-19 12:43:37 +08:00
|
|
|
# Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings
|
|
|
|
|
# differently.
|
2023-08-19 05:22:11 +08:00
|
|
|
vocab_size = (
|
|
|
|
|
math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
|
|
|
|
|
* pad_vocab_size_multiple
|
|
|
|
|
)
|
2023-04-19 12:43:37 +08:00
|
|
|
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
2023-08-19 05:22:11 +08:00
|
|
|
state_dict["lm_head.weight"] = F.pad(
|
2023-04-19 12:43:37 +08:00
|
|
|
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# LayerNorm
|
|
|
|
|
def key_mapping_ln(key):
|
2023-08-19 05:22:11 +08:00
|
|
|
key = re.sub(r"^transformer.norm.", r"transformer.ln_f.", key)
|
|
|
|
|
key = re.sub(
|
|
|
|
|
r"^transformer.layers.(\d+).attention_norm.", r"transformer.layers.\1.norm1.", key
|
|
|
|
|
)
|
|
|
|
|
key = re.sub(r"^transformer.layers.(\d+).ffn_norm.", r"transformer.layers.\1.norm2.", key)
|
2023-04-19 12:43:37 +08:00
|
|
|
return key
|
2023-08-19 05:22:11 +08:00
|
|
|
|
2023-04-19 12:43:37 +08:00
|
|
|
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
|
|
|
|
|
|
|
|
|
# MLP
|
|
|
|
|
for l in range(config.n_layer):
|
2023-08-19 05:22:11 +08:00
|
|
|
w1 = state_dict.pop(f"transformer.layers.{l}.feed_forward.w1.weight")
|
|
|
|
|
w3 = state_dict.pop(f"transformer.layers.{l}.feed_forward.w3.weight")
|
2023-04-19 12:43:37 +08:00
|
|
|
# Our ordering is different
|
2023-08-19 05:22:11 +08:00
|
|
|
state_dict[f"transformer.layers.{l}.mlp.fc1.weight"] = torch.cat([w3, w1], dim=0)
|
|
|
|
|
|
2023-04-19 12:43:37 +08:00
|
|
|
def key_mapping_mlp(key):
|
2023-08-19 05:22:11 +08:00
|
|
|
return re.sub(
|
|
|
|
|
r"^transformer.layers.(\d+).feed_forward.w2.", r"transformer.layers.\1.mlp.fc2.", key
|
|
|
|
|
)
|
|
|
|
|
|
2023-04-19 12:43:37 +08:00
|
|
|
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
|
|
|
|
|
|
|
|
|
# Attention
|
|
|
|
|
for l in range(config.n_layer):
|
2023-08-19 05:22:11 +08:00
|
|
|
Wq = state_dict.pop(f"transformer.layers.{l}.attention.wq.weight")
|
|
|
|
|
Wk = state_dict.pop(f"transformer.layers.{l}.attention.wk.weight")
|
|
|
|
|
Wv = state_dict.pop(f"transformer.layers.{l}.attention.wv.weight")
|
|
|
|
|
state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
|
2023-04-19 12:43:37 +08:00
|
|
|
# We don't store these
|
2023-08-19 05:22:11 +08:00
|
|
|
state_dict.pop(f"transformer.layers.{l}.attention.inner_attention.rope.freqs", None)
|
|
|
|
|
|
2023-04-19 12:43:37 +08:00
|
|
|
def key_mapping_attn(key):
|
2023-08-19 05:22:11 +08:00
|
|
|
return re.sub(
|
|
|
|
|
r"^transformer.layers.(\d+).attention.wo.",
|
|
|
|
|
r"transformer.layers.\1.mixer.out_proj.",
|
|
|
|
|
key,
|
|
|
|
|
)
|
|
|
|
|
|
2023-04-19 12:43:37 +08:00
|
|
|
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
|
|
|
|
|
2023-08-15 23:33:15 +08:00
|
|
|
state_dict.pop("transformer.rope.freqs", None)
|
|
|
|
|
|
2023-04-19 12:43:37 +08:00
|
|
|
return state_dict
|
|
|
|
|
|
|
|
|
|
|
2023-08-15 23:33:15 +08:00
|
|
|
def remap_state_dict_hf_llama(state_dict, config):
|
|
|
|
|
# Embedding
|
|
|
|
|
def key_mapping_emb(key):
|
2023-08-19 05:22:11 +08:00
|
|
|
return re.sub(r"^model.embed_tokens.", "transformer.embeddings.word_embeddings.", key)
|
2023-08-15 23:33:15 +08:00
|
|
|
|
|
|
|
|
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
|
2023-08-19 05:22:11 +08:00
|
|
|
word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
|
2023-08-15 23:33:15 +08:00
|
|
|
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
2023-08-19 05:22:11 +08:00
|
|
|
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
|
|
|
|
vocab_size = (
|
|
|
|
|
math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
|
|
|
|
)
|
|
|
|
|
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
|
2023-08-15 23:33:15 +08:00
|
|
|
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# LM head
|
2023-08-19 05:22:11 +08:00
|
|
|
if getattr(config, "tie_word_embeddings"):
|
|
|
|
|
state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
|
2023-08-15 23:33:15 +08:00
|
|
|
else:
|
2023-08-19 05:22:11 +08:00
|
|
|
output_embeddings = state_dict.pop("lm_head.weight")
|
2023-08-15 23:33:15 +08:00
|
|
|
# Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings
|
|
|
|
|
# differently.
|
2023-08-19 05:22:11 +08:00
|
|
|
vocab_size = (
|
|
|
|
|
math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
|
|
|
|
|
* pad_vocab_size_multiple
|
|
|
|
|
)
|
2023-08-15 23:33:15 +08:00
|
|
|
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
2023-08-19 05:22:11 +08:00
|
|
|
state_dict["lm_head.weight"] = F.pad(
|
2023-08-15 23:33:15 +08:00
|
|
|
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# MLP
|
|
|
|
|
for l in range(config.n_layer):
|
|
|
|
|
# Fusing weights this way based on difference in the following:
|
|
|
|
|
# https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/modeling_llama.py#L220
|
|
|
|
|
# https://github.com/Dao-AILab/flash-attention/blob/c60851a8253257eb970e06a022c82517a8033e8c/flash_attn/modules/mlp.py#L115
|
2023-08-19 05:22:11 +08:00
|
|
|
w1 = state_dict.pop(f"model.layers.{l}.mlp.gate_proj.weight")
|
|
|
|
|
w3 = state_dict.pop(f"model.layers.{l}.mlp.up_proj.weight")
|
|
|
|
|
state_dict[f"transformer.layers.{l}.mlp.fc1.weight"] = torch.cat([w3, w1], dim=0)
|
2023-08-15 23:33:15 +08:00
|
|
|
|
|
|
|
|
def key_mapping_mlp(key):
|
2023-08-19 05:22:11 +08:00
|
|
|
return re.sub(r"^model.layers.(\d+).mlp.down_proj.", r"transformer.layers.\1.mlp.fc2.", key)
|
2023-08-15 23:33:15 +08:00
|
|
|
|
|
|
|
|
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
|
|
|
|
|
|
|
|
|
# LayerNorm
|
|
|
|
|
def key_mapping_ln(key):
|
2023-08-19 05:22:11 +08:00
|
|
|
key = re.sub(r"^model.norm.", r"transformer.ln_f.", key)
|
|
|
|
|
key = re.sub(r"^model.layers.(\d+).input_layernorm.", r"transformer.layers.\1.norm1.", key)
|
|
|
|
|
key = re.sub(
|
|
|
|
|
r"^model.layers.(\d+).post_attention_layernorm.", r"transformer.layers.\1.norm2.", key
|
|
|
|
|
)
|
2023-08-15 23:33:15 +08:00
|
|
|
return key
|
|
|
|
|
|
|
|
|
|
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
|
|
|
|
|
|
|
|
|
def inv_permute(w):
|
|
|
|
|
# Inverse of permute implemented in:
|
|
|
|
|
# https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114
|
2023-08-19 05:22:11 +08:00
|
|
|
return (
|
|
|
|
|
w.reshape(config.n_head, 2, config.n_embd // config.n_head // 2, config.n_embd)
|
|
|
|
|
.transpose(1, 2)
|
|
|
|
|
.reshape(config.n_embd, config.n_embd)
|
|
|
|
|
)
|
2023-08-15 23:33:15 +08:00
|
|
|
|
|
|
|
|
# Attention
|
|
|
|
|
for l in range(config.n_layer):
|
2023-08-19 05:22:11 +08:00
|
|
|
Wq = state_dict.pop(f"model.layers.{l}.self_attn.q_proj.weight")
|
|
|
|
|
Wk = state_dict.pop(f"model.layers.{l}.self_attn.k_proj.weight")
|
|
|
|
|
Wv = state_dict.pop(f"model.layers.{l}.self_attn.v_proj.weight")
|
|
|
|
|
state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat(
|
2023-08-15 23:33:15 +08:00
|
|
|
[inv_permute(Wq), inv_permute(Wk), Wv], dim=0
|
|
|
|
|
)
|
|
|
|
|
# We don't store these
|
2023-08-19 05:22:11 +08:00
|
|
|
state_dict.pop(f"model.layers.{l}.self_attn.rotary_emb.inv_freq", None)
|
2023-08-15 23:33:15 +08:00
|
|
|
|
|
|
|
|
def key_mapping_attn(key):
|
2023-08-19 05:22:11 +08:00
|
|
|
return re.sub(
|
|
|
|
|
r"^model.layers.(\d+).self_attn.o_proj.", r"transformer.layers.\1.mixer.out_proj.", key
|
|
|
|
|
)
|
2023-08-15 23:33:15 +08:00
|
|
|
|
|
|
|
|
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
|
|
|
|
return state_dict
|
|
|
|
|
|
|
|
|
|
|
2023-08-19 05:22:11 +08:00
|
|
|
def config_from_meta_checkpoint(
|
|
|
|
|
checkpoint_path: Union[str, os.PathLike], model_name: str
|
|
|
|
|
) -> LlamaConfig:
|
2023-04-19 12:43:37 +08:00
|
|
|
"""Load a LlamaConfig from a checkpoint path."""
|
2023-08-19 05:22:11 +08:00
|
|
|
with open(Path(checkpoint_path) / model_name / "params.json") as f:
|
2023-04-19 12:43:37 +08:00
|
|
|
params = json.load(f)
|
2023-08-19 05:22:11 +08:00
|
|
|
config = LlamaConfig(
|
|
|
|
|
hidden_size=params["dim"],
|
|
|
|
|
intermediate_size=None,
|
|
|
|
|
num_attention_heads=params["n_heads"],
|
|
|
|
|
num_hidden_layers=params["n_layers"],
|
|
|
|
|
rms_norm_eps=params["norm_eps"],
|
|
|
|
|
)
|
2023-04-19 12:43:37 +08:00
|
|
|
return config
|
|
|
|
|
|
|
|
|
|
|
2023-08-19 05:22:11 +08:00
|
|
|
def config_from_hf_checkpoint(
|
|
|
|
|
checkpoint_path: Union[str, os.PathLike], model_name: str
|
|
|
|
|
) -> LlamaConfig:
|
|
|
|
|
return LlamaConfig.from_pretrained(Path(checkpoint_path) / f"{model_name}-hf" / "config.json")
|
2023-08-15 23:33:15 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def config_from_checkpoint(
|
|
|
|
|
checkpoint_path: Union[str, os.PathLike], model_name: str, checkpoint_format="meta"
|
|
|
|
|
) -> LlamaConfig:
|
|
|
|
|
if checkpoint_format == "meta":
|
|
|
|
|
return config_from_meta_checkpoint(checkpoint_path, model_name)
|
|
|
|
|
else:
|
|
|
|
|
return config_from_hf_checkpoint(checkpoint_path, model_name)
|
|
|
|
|
|
|
|
|
|
|
2023-08-19 05:22:11 +08:00
|
|
|
def state_dicts_from_checkpoint(
|
|
|
|
|
checkpoint_path: Union[str, os.PathLike], model_name: str
|
|
|
|
|
) -> list[dict]:
|
2023-04-19 12:43:37 +08:00
|
|
|
# Need to sort, otherwise we mess up the ordering and the weights are wrong
|
2023-08-19 05:22:11 +08:00
|
|
|
return [
|
|
|
|
|
torch.load(path, map_location="cpu")
|
|
|
|
|
for path in sorted((Path(checkpoint_path) / model_name).glob("consolidated.*.pth"))
|
|
|
|
|
]
|
2023-04-19 12:43:37 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def llama_config_to_gpt2_config(llama_config: LlamaConfig) -> GPT2Config:
|
|
|
|
|
return GPT2Config(
|
|
|
|
|
vocab_size=llama_config.vocab_size,
|
|
|
|
|
n_positions=0, # No absolute position embedding
|
|
|
|
|
n_embd=llama_config.hidden_size,
|
|
|
|
|
n_layer=llama_config.num_hidden_layers,
|
|
|
|
|
n_head=llama_config.num_attention_heads,
|
|
|
|
|
n_inner=llama_config.intermediate_size,
|
2023-08-19 05:22:11 +08:00
|
|
|
activation_function="swiglu", # Hardcode since HF calls it 'silu'
|
2023-04-19 12:43:37 +08:00
|
|
|
# Llama doesn't have dropout, idk if it's because they only release the inference code
|
|
|
|
|
resid_pdrop=0.0,
|
|
|
|
|
embd_pdrop=0.0,
|
|
|
|
|
attn_pdrop=0.0,
|
|
|
|
|
layer_norm_epsilon=llama_config.rms_norm_eps,
|
|
|
|
|
initializer_range=llama_config.initializer_range,
|
|
|
|
|
bos_token_id=llama_config.bos_token_id,
|
|
|
|
|
eos_token_id=llama_config.eos_token_id,
|
|
|
|
|
# These are new arguments not in the original GPT2Config
|
|
|
|
|
pad_token_id=llama_config.pad_token_id, # Idk if this does anything
|
|
|
|
|
rms_norm=True,
|
|
|
|
|
rotary_emb_fraction=1.0,
|
|
|
|
|
rotary_emb_interleaved=True,
|
|
|
|
|
tie_word_embeddings=False,
|
|
|
|
|
qkv_proj_bias=False,
|
|
|
|
|
out_proj_bias=False,
|
|
|
|
|
mlp_fc1_bias=False,
|
|
|
|
|
mlp_fc2_bias=False,
|
|
|
|
|
)
|