Support LLaMa2 and CodeLLaMa (#491)

Co-authored-by: danthe3rd <danthe3rd>
This commit is contained in:
dan_the_3rd 2023-08-30 19:31:14 +02:00 committed by GitHub
parent 011ec323d6
commit c9d4a816fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -10,6 +10,7 @@ from typing import Union
import torch
import torch.nn.functional as F
from sentencepiece import SentencePieceProcessor
from transformers import GPT2Config, LlamaConfig
@ -308,7 +309,30 @@ def config_from_meta_checkpoint(
num_attention_heads=params["n_heads"],
num_hidden_layers=params["n_layers"],
rms_norm_eps=params["norm_eps"],
num_key_value_heads=params.get("n_kv_heads", None),
)
multiple_of = params.get("multiple_of", 1)
ffn_dim_multiplier = params.get("ffn_dim_multiplier", None)
# Compute the hidden dimension of the MLP
# https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/model.py#L224
intermediate_size = 4 * config.hidden_size
# https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/model.py#L195-L199
intermediate_size = int(2 * intermediate_size / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
intermediate_size = int(ffn_dim_multiplier * intermediate_size)
intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1) // multiple_of)
config.intermediate_size = intermediate_size
if "rope_theta" in params:
config.rotary_emb_base = params["rope_theta"]
config.vocab_size = 32000
# some CodeLLaMa have vocab_size 32000, some 32016
# Sadly it's not specified in the `params.json` file :(
tokenizer = Path(checkpoint_path) / model_name / "tokenizer.model"
if tokenizer.is_file():
config.vocab_size = SentencePieceProcessor(str(tokenizer)).vocab_size()
return config
@ -364,4 +388,6 @@ def llama_config_to_gpt2_config(llama_config: LlamaConfig) -> GPT2Config:
out_proj_bias=False,
mlp_fc1_bias=False,
mlp_fc2_bias=False,
rotary_emb_base=getattr(llama_config, "rotary_emb_base", 10000.0),
n_head_kv=llama_config.num_key_value_heads,
)