flash-attention/flash_attn/models/llama.py

219 lines
10 KiB
Python
Raw Normal View History

2023-04-19 12:43:37 +08:00
# Copyright (c) 2023, Tri Dao.
import json
import math
import os
2023-04-19 12:43:37 +08:00
import re
from collections import OrderedDict
from pathlib import Path
from typing import Union
2023-04-19 12:43:37 +08:00
import torch
import torch.nn.functional as F
from transformers import GPT2Config, LlamaConfig
def remap_state_dict_meta_llama(state_dict, config):
def key_mapping_layers(key):
return f'transformer.{key}' if not key.startswith('output.') else key
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
# Word embedding
def key_mapping_emb(key):
return re.sub(r'^transformer.tok_embeddings.', 'transformer.embeddings.word_embeddings.', key)
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple)
* pad_vocab_size_multiple)
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
if getattr(config, 'tie_word_embeddings'):
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
else:
output_embeddings = state_dict.pop('output.weight')
# Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings
# differently.
vocab_size = (math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
* pad_vocab_size_multiple)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict['lm_head.weight'] = F.pad(
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
)
# LayerNorm
def key_mapping_ln(key):
key = re.sub(r'^transformer.norm.', r'transformer.ln_f.', key)
key = re.sub(r'^transformer.layers.(\d+).attention_norm.', r'transformer.layers.\1.norm1.', key)
key = re.sub(r'^transformer.layers.(\d+).ffn_norm.', r'transformer.layers.\1.norm2.', key)
return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP
for l in range(config.n_layer):
w1 = state_dict.pop(f'transformer.layers.{l}.feed_forward.w1.weight')
w3 = state_dict.pop(f'transformer.layers.{l}.feed_forward.w3.weight')
# Our ordering is different
state_dict[f'transformer.layers.{l}.mlp.fc1.weight'] = torch.cat([w3, w1], dim=0)
def key_mapping_mlp(key):
return re.sub(r'^transformer.layers.(\d+).feed_forward.w2.',
r'transformer.layers.\1.mlp.fc2.', key)
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# Attention
for l in range(config.n_layer):
Wq = state_dict.pop(f'transformer.layers.{l}.attention.wq.weight')
Wk = state_dict.pop(f'transformer.layers.{l}.attention.wk.weight')
Wv = state_dict.pop(f'transformer.layers.{l}.attention.wv.weight')
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat([Wq, Wk, Wv], dim=0)
# We don't store these
state_dict.pop(f'transformer.layers.{l}.attention.inner_attention.rope.freqs', None)
def key_mapping_attn(key):
return re.sub(r'^transformer.layers.(\d+).attention.wo.',
r'transformer.layers.\1.mixer.out_proj.', key)
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
state_dict.pop("transformer.rope.freqs", None)
2023-04-19 12:43:37 +08:00
return state_dict
def remap_state_dict_hf_llama(state_dict, config):
# Embedding
def key_mapping_emb(key):
return re.sub(r'^model.embed_tokens.', 'transformer.embeddings.word_embeddings.', key)
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple)
* pad_vocab_size_multiple)
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
# LM head
if getattr(config, 'tie_word_embeddings'):
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
else:
output_embeddings = state_dict.pop('lm_head.weight')
# Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings
# differently.
vocab_size = (math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
* pad_vocab_size_multiple)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict['lm_head.weight'] = F.pad(
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
)
# MLP
for l in range(config.n_layer):
# Fusing weights this way based on difference in the following:
# https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/modeling_llama.py#L220
# https://github.com/Dao-AILab/flash-attention/blob/c60851a8253257eb970e06a022c82517a8033e8c/flash_attn/modules/mlp.py#L115
w1 = state_dict.pop(f'model.layers.{l}.mlp.gate_proj.weight')
w3 = state_dict.pop(f'model.layers.{l}.mlp.up_proj.weight')
state_dict[f'transformer.layers.{l}.mlp.fc1.weight'] = torch.cat([w3, w1], dim=0)
def key_mapping_mlp(key):
return re.sub(r'^model.layers.(\d+).mlp.down_proj.',
r'transformer.layers.\1.mlp.fc2.', key)
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# LayerNorm
def key_mapping_ln(key):
key = re.sub(r'^model.norm.', r'transformer.ln_f.', key)
key = re.sub(r'^model.layers.(\d+).input_layernorm.', r'transformer.layers.\1.norm1.', key)
key = re.sub(r'^model.layers.(\d+).post_attention_layernorm.', r'transformer.layers.\1.norm2.', key)
return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
def inv_permute(w):
# Inverse of permute implemented in:
# https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114
return w.reshape(
config.n_head, 2, config.n_embd // config.n_head // 2, config.n_embd
).transpose(1, 2).reshape(config.n_embd, config.n_embd)
# Attention
for l in range(config.n_layer):
Wq = state_dict.pop(f'model.layers.{l}.self_attn.q_proj.weight')
Wk = state_dict.pop(f'model.layers.{l}.self_attn.k_proj.weight')
Wv = state_dict.pop(f'model.layers.{l}.self_attn.v_proj.weight')
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat(
[inv_permute(Wq), inv_permute(Wk), Wv], dim=0
)
# We don't store these
state_dict.pop(f'model.layers.{l}.self_attn.rotary_emb.inv_freq', None)
def key_mapping_attn(key):
return re.sub(r'^model.layers.(\d+).self_attn.o_proj.',
r'transformer.layers.\1.mixer.out_proj.', key)
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
return state_dict
def config_from_meta_checkpoint(checkpoint_path: Union[str, os.PathLike], model_name: str) -> LlamaConfig:
2023-04-19 12:43:37 +08:00
"""Load a LlamaConfig from a checkpoint path."""
with open(Path(checkpoint_path) / model_name / 'params.json') as f:
params = json.load(f)
config = LlamaConfig(hidden_size=params['dim'], intermediate_size=None,
num_attention_heads=params['n_heads'],
num_hidden_layers=params['n_layers'],
rms_norm_eps=params['norm_eps'])
return config
def config_from_hf_checkpoint(checkpoint_path: Union[str, os.PathLike], model_name: str) -> LlamaConfig:
return LlamaConfig.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf' / "config.json")
def config_from_checkpoint(
checkpoint_path: Union[str, os.PathLike], model_name: str, checkpoint_format="meta"
) -> LlamaConfig:
if checkpoint_format == "meta":
return config_from_meta_checkpoint(checkpoint_path, model_name)
else:
return config_from_hf_checkpoint(checkpoint_path, model_name)
def state_dicts_from_checkpoint(checkpoint_path: Union[str, os.PathLike], model_name: str) -> list[dict]:
2023-04-19 12:43:37 +08:00
# Need to sort, otherwise we mess up the ordering and the weights are wrong
return [torch.load(path, map_location='cpu')
for path in sorted((Path(checkpoint_path) / model_name).glob('consolidated.*.pth'))]
def llama_config_to_gpt2_config(llama_config: LlamaConfig) -> GPT2Config:
return GPT2Config(
vocab_size=llama_config.vocab_size,
n_positions=0, # No absolute position embedding
n_embd=llama_config.hidden_size,
n_layer=llama_config.num_hidden_layers,
n_head=llama_config.num_attention_heads,
n_inner=llama_config.intermediate_size,
activation_function='swiglu', # Hardcode since HF calls it 'silu'
# Llama doesn't have dropout, idk if it's because they only release the inference code
resid_pdrop=0.0,
embd_pdrop=0.0,
attn_pdrop=0.0,
layer_norm_epsilon=llama_config.rms_norm_eps,
initializer_range=llama_config.initializer_range,
bos_token_id=llama_config.bos_token_id,
eos_token_id=llama_config.eos_token_id,
# These are new arguments not in the original GPT2Config
pad_token_id=llama_config.pad_token_id, # Idk if this does anything
rms_norm=True,
rotary_emb_fraction=1.0,
rotary_emb_interleaved=True,
tie_word_embeddings=False,
qkv_proj_bias=False,
out_proj_bias=False,
mlp_fc1_bias=False,
mlp_fc2_bias=False,
)