Support LLaMa2 and CodeLLaMa (#491)
Co-authored-by: danthe3rd <danthe3rd>
This commit is contained in:
parent
011ec323d6
commit
c9d4a816fa
@ -10,6 +10,7 @@ from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
from transformers import GPT2Config, LlamaConfig
|
||||
|
||||
|
||||
@ -308,7 +309,30 @@ def config_from_meta_checkpoint(
|
||||
num_attention_heads=params["n_heads"],
|
||||
num_hidden_layers=params["n_layers"],
|
||||
rms_norm_eps=params["norm_eps"],
|
||||
num_key_value_heads=params.get("n_kv_heads", None),
|
||||
)
|
||||
multiple_of = params.get("multiple_of", 1)
|
||||
ffn_dim_multiplier = params.get("ffn_dim_multiplier", None)
|
||||
|
||||
# Compute the hidden dimension of the MLP
|
||||
# https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/model.py#L224
|
||||
intermediate_size = 4 * config.hidden_size
|
||||
# https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/model.py#L195-L199
|
||||
intermediate_size = int(2 * intermediate_size / 3)
|
||||
# custom dim factor multiplier
|
||||
if ffn_dim_multiplier is not None:
|
||||
intermediate_size = int(ffn_dim_multiplier * intermediate_size)
|
||||
intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1) // multiple_of)
|
||||
|
||||
config.intermediate_size = intermediate_size
|
||||
if "rope_theta" in params:
|
||||
config.rotary_emb_base = params["rope_theta"]
|
||||
config.vocab_size = 32000
|
||||
# some CodeLLaMa have vocab_size 32000, some 32016
|
||||
# Sadly it's not specified in the `params.json` file :(
|
||||
tokenizer = Path(checkpoint_path) / model_name / "tokenizer.model"
|
||||
if tokenizer.is_file():
|
||||
config.vocab_size = SentencePieceProcessor(str(tokenizer)).vocab_size()
|
||||
return config
|
||||
|
||||
|
||||
@ -364,4 +388,6 @@ def llama_config_to_gpt2_config(llama_config: LlamaConfig) -> GPT2Config:
|
||||
out_proj_bias=False,
|
||||
mlp_fc1_bias=False,
|
||||
mlp_fc2_bias=False,
|
||||
rotary_emb_base=getattr(llama_config, "rotary_emb_base", 10000.0),
|
||||
n_head_kv=llama_config.num_key_value_heads,
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user