[Llama] Fix some tests, add tests for Llama 2 and CodeLlama

This commit is contained in:
Tri Dao 2023-09-20 23:36:46 -07:00
parent e0fbaa7016
commit 0705d2718d
4 changed files with 124 additions and 161 deletions

View File

@ -13,6 +13,8 @@ import torch.nn.functional as F
from sentencepiece import SentencePieceProcessor
from transformers import GPT2Config, LlamaConfig
from einops import rearrange
def remap_state_dict_meta_llama(
state_dict: dict[str, torch.Tensor], config: GPT2Config
@ -30,9 +32,7 @@ def remap_state_dict_meta_llama(
# Word embedding
def key_mapping_emb(key):
return re.sub(
r"^transformer.tok_embeddings.",
"transformer.embeddings.word_embeddings.",
key,
r"^transformer.tok_embeddings.", "transformer.embeddings.word_embeddings.", key
)
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
@ -113,7 +113,7 @@ def remap_state_dict_meta_llama(
def remap_state_dict_hf_llama(
state_dict: dict[str, torch.Tensor], config: GPT2Config, multi_query: bool = False
state_dict: dict[str, torch.Tensor], config: GPT2Config
) -> dict[str, torch.Tensor]:
"""Convert the state_dict in Hugging Face format to standard GPT format.
@ -186,13 +186,11 @@ def remap_state_dict_hf_llama(
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
def inv_permute(w, first_dim=None):
def inv_permute(w):
# Inverse of permute implemented in:
# https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114
return (
w.reshape(first_dim or config.n_head, 2, -1, config.n_embd)
.transpose(1, 2)
.reshape(-1, config.n_embd)
return rearrange(
w, "(h two d) n -> (h d two) n", d=config.n_embd // config.n_head // 2, two=2
)
# Attention
@ -202,8 +200,7 @@ def remap_state_dict_hf_llama(
Wv = state_dict.pop(f"model.layers.{l}.self_attn.v_proj.weight")
state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat(
(inv_permute(Wq), inv_permute(Wk, getattr(config, "n_head_kv")), Wv),
dim=0,
[inv_permute(Wq), inv_permute(Wk), Wv], dim=0
)
# We don't store these
state_dict.pop(f"model.layers.{l}.self_attn.rotary_emb.inv_freq", None)
@ -220,7 +217,7 @@ def remap_state_dict_hf_llama(
def inv_remap_state_dict_hf_llama(
state_dict: dict[str, torch.Tensor], config: GPT2Config, multi_query: bool = False
state_dict: dict[str, torch.Tensor], config: GPT2Config
) -> dict[str, torch.Tensor]:
"""Convert the state_dict in standard GPT format to Hugging Face format.
@ -293,11 +290,9 @@ def inv_remap_state_dict_hf_llama(
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
def permute(w, first_dim=None):
return (
w.view(first_dim or config.n_head, -1, 2, config.n_embd)
.transpose(1, 2)
.reshape(-1, config.n_embd)
def permute(w):
return rearrange(
w, "(h d two) n -> (h two d) n", d=config.n_embd // config.n_head // 2, two=2
)
n_head = config.n_head
@ -316,7 +311,7 @@ def inv_remap_state_dict_hf_llama(
Wk = Wqkv[q_dim : q_dim + k_dim]
Wv = Wqkv[q_dim + k_dim : q_dim + k_dim + v_dim]
state_dict[f"model.layers.{l}.self_attn.q_proj.weight"] = permute(Wq)
state_dict[f"model.layers.{l}.self_attn.k_proj.weight"] = permute(Wk, n_head_kv)
state_dict[f"model.layers.{l}.self_attn.k_proj.weight"] = permute(Wk)
state_dict[f"model.layers.{l}.self_attn.v_proj.weight"] = Wv
state_dict.pop(f"transformer.layers.{l}.attention.inner_attention.rope.freqs", None)

View File

@ -725,7 +725,7 @@ class ParallelMHA(nn.Module):
process_group,
bias=qkv_proj_bias,
sequence_parallel=sequence_parallel,
multiple_of=self.head_dim * (self.num_heads_per_rank + 2 * self.num_heads_kv_per_rank),
multiple_of=self.head_dim * (self.num_heads // self.num_heads_kv + 2),
**factory_kwargs,
)
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention

View File

@ -160,6 +160,7 @@ def test_tensor_parallel(model_name, rotary, world_size):
assert torch.allclose(
torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), rtol=rtol, atol=atol
)
assert torch.equal(torch.stack(out.scores, dim=1), torch.stack(out_cg.scores, dim=1))
if not rotary:
assert torch.all(out.sequences == out_ref.sequences)
assert torch.all(out.sequences == out_hf.sequences)

View File

@ -1,6 +1,6 @@
# Copyright (c) 2023, Tri Dao.
# To run the huggingface implementation, we first need to convert the weights:
# To run the huggingface implementation of LLaMa (1), we first need to convert the weights:
# https://github.com/huggingface/transformers/pull/21955
# python -m transformers.models.llama.convert_llama_weights_to_hf --input_dir $CHECKPOINT_DIR/llama --model_size 7B --output_dir $CHECKPOINT_DIR/llama/7B-hf
# and repeat for 13B, 30B, 65B
@ -30,6 +30,7 @@ from flash_attn.utils.generation import update_graph_cache
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import LlamaConfig, LlamaTokenizer
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from transformers import AutoConfig
def _pretrained_state_dict_from_checkpoint(checkpoint_path, model_name, config, checkpoint_format):
@ -60,9 +61,38 @@ def test_llama_state_dict(model_name):
assert state_dict[k].shape == pretrained_state_dict[k].shape
@pytest.mark.parametrize("model_name", ["7B", "13B"])
@pytest.mark.parametrize("checkpoint_format", ["meta", "hf"])
def test_llama_optimized(model_name, checkpoint_format):
# TinyLlama-1.1B is to test MQA
@pytest.mark.parametrize(
"model_name", ["meta-llama/Llama-2-7b-hf", "PY007/TinyLlama-1.1B-step-50K-105b"]
)
def test_inv_remap_state_dict_hf_llama(model_name):
config = llama_config_to_gpt2_config(
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
)
state_dict = state_dict_from_pretrained(model_name)
# inv_remap_state_dict_hf_llama should be the inverse of remap_state_dict_hf_llama
state_dict = {key: val for key, val in state_dict.items() if "rotary_emb.inv_freq" not in key}
pretrained_state_dict = remap_state_dict_hf_llama(state_dict, config)
state_dict_recover = inv_remap_state_dict_hf_llama(pretrained_state_dict, config)
assert set(state_dict_recover.keys()) == set(state_dict.keys())
for key in state_dict_recover.keys():
torch.testing.assert_close(state_dict_recover[key], state_dict[key])
# TinyLlama-1.1B is to test MQA
@pytest.mark.parametrize(
"model_name",
[
"7B", # Llama 1
"13B", # Llama 1
"meta-llama/Llama-2-13b-hf",
"codellama/CodeLlama-7b-hf",
"codellama/CodeLlama-13b-hf",
"codellama/CodeLlama-34b-hf",
"PY007/TinyLlama-1.1B-step-50K-105b",
],
)
def test_llama_optimized(model_name):
"""Check that our implementation of LLaMa (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
@ -73,17 +103,27 @@ def test_llama_optimized(model_name, checkpoint_format):
dtype = torch.float16
device = "cuda"
config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format)
config = llama_config_to_gpt2_config(config)
if "/" in model_name: # Download from HF
config = llama_config_to_gpt2_config(
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
)
else:
config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format="meta")
config = llama_config_to_gpt2_config(config)
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused GatedMLP yet
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
checkpoint_path, model_name, config, checkpoint_format
)
if "/" in model_name: # Download from HF
pretrained_state_dict = remap_state_dict_hf_llama(
state_dict_from_pretrained(model_name), config
)
else:
pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
checkpoint_path, model_name, config, checkpoint_format="meta"
)
model = GPTLMHeadModel(config, device=device, dtype=dtype)
model.load_state_dict(pretrained_state_dict)
model.eval()
@ -103,7 +143,8 @@ def test_llama_optimized(model_name, checkpoint_format):
# Without device_map, the model is loaded on the CPU, which is very slow
# Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
model_ref = LlamaForCausalLM.from_pretrained(
Path(checkpoint_path) / f"{model_name}-hf", device_map="auto"
model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf",
device_map="auto",
)
model_ref.eval()
with torch.no_grad():
@ -112,7 +153,9 @@ def test_llama_optimized(model_name, checkpoint_format):
del model_ref
model_hf = LlamaForCausalLM.from_pretrained(
Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map={"": device}
model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf",
torch_dtype=dtype,
device_map={"": device},
)
model_hf.eval()
with torch.no_grad():
@ -135,77 +178,12 @@ def test_llama_optimized(model_name, checkpoint_format):
).abs().max().item()
@pytest.mark.parametrize("model_name", ["PY007/TinyLlama-1.1B-step-50K-105b"])
def test_mqa_optimized(model_name):
"""Check that our implementation of Llama with MQA/GQA (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
device = "cuda"
config = llama_config_to_gpt2_config(LlamaConfig.from_pretrained(model_name))
config.use_flash_attn = True # FlashAttention-2 supports headdim 256
config.fused_bias_fc = True
config.fused_mlp = False
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
# Without device_map, the model is loaded on the CPU, which is very slow
model_ref = LlamaForCausalLM.from_pretrained(model_name, device_map={"": device})
model_ref.eval()
model = GPTLMHeadModel(config, device=device, dtype=dtype)
model.load_state_dict(remap_state_dict_hf_llama(model_ref.state_dict(), config))
model.eval()
torch.manual_seed(0)
batch_size = 2
max_seqlen = 256
input_ids = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
)
with torch.no_grad():
out = model.transformer(input_ids)
logits = model(input_ids).logits
del model
with torch.no_grad():
out_ref = model_ref.model(input_ids).last_hidden_state
logits_ref = model_ref(input_ids).logits
del model_ref
model_hf = LlamaForCausalLM.from_pretrained(
model_name, torch_dtype=dtype, device_map={"": device}
)
model_hf.eval()
out_hf = model_hf.model(input_ids).last_hidden_state
logits_hf = model_hf(input_ids).logits
del model_hf
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 3 * (
out_hf - out_ref
).abs().max().item()
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 3 * (
logits_hf - logits_ref
).abs().max().item()
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "parallel"
@pytest.mark.parametrize("world_size", [2])
@pytest.mark.parametrize("model_name", ["13B"])
@pytest.mark.parametrize("checkpoint_format", ["meta", "hf"])
def test_llama_parallel(model_name, world_size, checkpoint_format):
@pytest.mark.parametrize(
"model_name", ["13B", "meta-llama/Llama-2-13b-hf", "codellama/CodeLlama-34b-hf"]
)
def test_llama_parallel(model_name, world_size):
"""Check that our implementation of LLaMa (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
@ -217,8 +195,13 @@ def test_llama_parallel(model_name, world_size, checkpoint_format):
)
dtype = torch.float16
config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format)
config = llama_config_to_gpt2_config(config)
if "/" in model_name: # Download from HF
config = llama_config_to_gpt2_config(
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
)
else:
config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format="meta")
config = llama_config_to_gpt2_config(config)
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused GatedMLP yet
@ -233,9 +216,14 @@ def test_llama_parallel(model_name, world_size, checkpoint_format):
rank = parallel_state.get_tensor_model_parallel_rank()
process_group = parallel_state.get_tensor_model_parallel_group()
pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
checkpoint_path, model_name, config, checkpoint_format
)
if "/" in model_name: # Download from HF
pretrained_state_dict = remap_state_dict_hf_llama(
state_dict_from_pretrained(model_name), config
)
else:
pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
checkpoint_path, model_name, config, checkpoint_format="meta"
)
model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
model.eval()
@ -260,7 +248,8 @@ def test_llama_parallel(model_name, world_size, checkpoint_format):
if rank == 0:
# Without device_map, the model is loaded on the CPU, which is very slow
model_ref = LlamaForCausalLM.from_pretrained(
Path(checkpoint_path) / f"{model_name}-hf", device_map="auto"
model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf",
device_map="auto",
)
model_ref.eval()
with torch.no_grad():
@ -269,7 +258,9 @@ def test_llama_parallel(model_name, world_size, checkpoint_format):
del model_ref
model_hf = LlamaForCausalLM.from_pretrained(
Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map="auto"
model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf",
torch_dtype=dtype,
device_map="auto",
)
model_hf.eval()
with torch.no_grad():
@ -405,9 +396,10 @@ def test_llama_generation(model_name, checkpoint_format):
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "llama_parallel_generation"
@pytest.mark.parametrize("world_size", [2])
@pytest.mark.parametrize("model_name", ["13B"])
@pytest.mark.parametrize("checkpoint_format", ["meta", "hf"])
def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
@pytest.mark.parametrize(
"model_name", ["13B", "meta-llama/Llama-2-13b-hf", "codellama/CodeLlama-34b-hf"]
)
def test_llama_parallel_generation(model_name, world_size):
"""Check that our implementation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
@ -419,12 +411,17 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
)
dtype = torch.float16
config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format)
config = llama_config_to_gpt2_config(config)
config.use_flash_attn = False
if "/" in model_name: # Download from HF
config = llama_config_to_gpt2_config(
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
)
else:
config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format="meta")
config = llama_config_to_gpt2_config(config)
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused GatedMLP yet
config.fused_dropout_add_ln = False
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
config.pad_vocab_size_multiple = 8 * world_size
config.sequence_parallel = False # Need to set this to False for generation
@ -450,9 +447,14 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
# GPU0 and GPU1 and things would hang
torch.cuda.set_device(device)
pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
checkpoint_path, model_name, config, checkpoint_format
)
if "/" in model_name: # Download from HF
pretrained_state_dict = remap_state_dict_hf_llama(
state_dict_from_pretrained(model_name), config
)
else:
pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
checkpoint_path, model_name, config, checkpoint_format="meta"
)
model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
model.eval()
@ -490,7 +492,9 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
if rank == 0:
# Without device_map, the model is loaded on the CPU, which is very slow
model_hf = LlamaForCausalLM.from_pretrained(
Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map="auto"
model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf",
torch_dtype=dtype,
device_map="auto",
)
model_hf.eval()
print("HF fp16")
@ -508,7 +512,8 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
del model_hf
model_ref = LlamaForCausalLM.from_pretrained(
Path(checkpoint_path) / f"{model_name}-hf", device_map="auto"
model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf",
device_map="auto",
)
model_ref.eval()
with torch.inference_mode():
@ -594,15 +599,16 @@ def test_llama_parallel_uneven_num_heads(world_size):
if rank == 0:
model_ref = LlamaForCausalLM.from_pretrained(
Path(checkpoint_path) / f"{model_name}-hf", device_map="auto"
Path(checkpoint_path) / f"{model_name}-hf", device_map={"": device}
)
model_ref = model_ref.to(device=device)
model_ref.eval()
out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device)
logits_ref = model_ref(input_ids).logits.to(device=device)
out_ref = model_ref.model(input_ids).last_hidden_state
logits_ref = model_ref(input_ids).logits
del model_ref
model_hf = LlamaForCausalLM.from_pretrained(
Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map="auto"
Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map={"": device}
)
model_hf.eval()
out_hf = model_hf.model(input_ids).last_hidden_state.to(device=device)
@ -625,42 +631,3 @@ def test_llama_parallel_uneven_num_heads(world_size):
if os.path.exists(checkpoint_path / f"{model_name}-hf"):
shutil.rmtree(checkpoint_path / f"{model_name}-hf")
@torch.no_grad()
def test_inv_remap_state_dict_hf_llama():
checkpoint_path = (
Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama"
)
model_name = f"teeny"
llama_config = LlamaConfig(
num_attention_heads=2,
hidden_size=256 * 2,
intermediate_size=256 * 2 * 4,
num_hidden_layers=4,
)
config = llama_config_to_gpt2_config(llama_config)
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused GatedMLP yet
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
# Set up.
LlamaForCausalLM(config=llama_config).save_pretrained(checkpoint_path / f"{model_name}-hf")
# inv_remap_state_dict_hf_llama should be the inverse of remap_state_dict_hf_llama
state_dict = state_dict_from_pretrained(checkpoint_path / f"{model_name}-hf")
state_dict = {key: val for key, val in state_dict.items() if "rotary_emb.inv_freq" not in key}
pretrained_state_dict = remap_state_dict_hf_llama(state_dict, config)
state_dict_recover = inv_remap_state_dict_hf_llama(pretrained_state_dict, config)
assert set(state_dict_recover.keys()) == set(state_dict.keys())
for key in state_dict_recover.keys():
torch.testing.assert_close(state_dict_recover[key], state_dict[key])
# Tear down.
if os.path.exists(checkpoint_path / f"{model_name}-hf"):
shutil.rmtree(checkpoint_path / f"{model_name}-hf")