diff --git a/flash_attn/models/llama.py b/flash_attn/models/llama.py index 40d4073..993a282 100644 --- a/flash_attn/models/llama.py +++ b/flash_attn/models/llama.py @@ -10,6 +10,7 @@ from typing import Union import torch import torch.nn.functional as F +from sentencepiece import SentencePieceProcessor from transformers import GPT2Config, LlamaConfig @@ -308,7 +309,30 @@ def config_from_meta_checkpoint( num_attention_heads=params["n_heads"], num_hidden_layers=params["n_layers"], rms_norm_eps=params["norm_eps"], + num_key_value_heads=params.get("n_kv_heads", None), ) + multiple_of = params.get("multiple_of", 1) + ffn_dim_multiplier = params.get("ffn_dim_multiplier", None) + + # Compute the hidden dimension of the MLP + # https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/model.py#L224 + intermediate_size = 4 * config.hidden_size + # https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/model.py#L195-L199 + intermediate_size = int(2 * intermediate_size / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + intermediate_size = int(ffn_dim_multiplier * intermediate_size) + intermediate_size = multiple_of * ((intermediate_size + multiple_of - 1) // multiple_of) + + config.intermediate_size = intermediate_size + if "rope_theta" in params: + config.rotary_emb_base = params["rope_theta"] + config.vocab_size = 32000 + # some CodeLLaMa have vocab_size 32000, some 32016 + # Sadly it's not specified in the `params.json` file :( + tokenizer = Path(checkpoint_path) / model_name / "tokenizer.model" + if tokenizer.is_file(): + config.vocab_size = SentencePieceProcessor(str(tokenizer)).vocab_size() return config @@ -364,4 +388,6 @@ def llama_config_to_gpt2_config(llama_config: LlamaConfig) -> GPT2Config: out_proj_bias=False, mlp_fc1_bias=False, mlp_fc2_bias=False, + rotary_emb_base=getattr(llama_config, "rotary_emb_base", 10000.0), + n_head_kv=llama_config.num_key_value_heads, )