[Llama] Fix some tests, add tests for Llama 2 and CodeLlama
This commit is contained in:
parent
e0fbaa7016
commit
0705d2718d
@ -13,6 +13,8 @@ import torch.nn.functional as F
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
from transformers import GPT2Config, LlamaConfig
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
def remap_state_dict_meta_llama(
|
||||
state_dict: dict[str, torch.Tensor], config: GPT2Config
|
||||
@ -30,9 +32,7 @@ def remap_state_dict_meta_llama(
|
||||
# Word embedding
|
||||
def key_mapping_emb(key):
|
||||
return re.sub(
|
||||
r"^transformer.tok_embeddings.",
|
||||
"transformer.embeddings.word_embeddings.",
|
||||
key,
|
||||
r"^transformer.tok_embeddings.", "transformer.embeddings.word_embeddings.", key
|
||||
)
|
||||
|
||||
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
|
||||
@ -113,7 +113,7 @@ def remap_state_dict_meta_llama(
|
||||
|
||||
|
||||
def remap_state_dict_hf_llama(
|
||||
state_dict: dict[str, torch.Tensor], config: GPT2Config, multi_query: bool = False
|
||||
state_dict: dict[str, torch.Tensor], config: GPT2Config
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Convert the state_dict in Hugging Face format to standard GPT format.
|
||||
|
||||
@ -186,13 +186,11 @@ def remap_state_dict_hf_llama(
|
||||
|
||||
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
||||
|
||||
def inv_permute(w, first_dim=None):
|
||||
def inv_permute(w):
|
||||
# Inverse of permute implemented in:
|
||||
# https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114
|
||||
return (
|
||||
w.reshape(first_dim or config.n_head, 2, -1, config.n_embd)
|
||||
.transpose(1, 2)
|
||||
.reshape(-1, config.n_embd)
|
||||
return rearrange(
|
||||
w, "(h two d) n -> (h d two) n", d=config.n_embd // config.n_head // 2, two=2
|
||||
)
|
||||
|
||||
# Attention
|
||||
@ -202,8 +200,7 @@ def remap_state_dict_hf_llama(
|
||||
Wv = state_dict.pop(f"model.layers.{l}.self_attn.v_proj.weight")
|
||||
|
||||
state_dict[f"transformer.layers.{l}.mixer.Wqkv.weight"] = torch.cat(
|
||||
(inv_permute(Wq), inv_permute(Wk, getattr(config, "n_head_kv")), Wv),
|
||||
dim=0,
|
||||
[inv_permute(Wq), inv_permute(Wk), Wv], dim=0
|
||||
)
|
||||
# We don't store these
|
||||
state_dict.pop(f"model.layers.{l}.self_attn.rotary_emb.inv_freq", None)
|
||||
@ -220,7 +217,7 @@ def remap_state_dict_hf_llama(
|
||||
|
||||
|
||||
def inv_remap_state_dict_hf_llama(
|
||||
state_dict: dict[str, torch.Tensor], config: GPT2Config, multi_query: bool = False
|
||||
state_dict: dict[str, torch.Tensor], config: GPT2Config
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Convert the state_dict in standard GPT format to Hugging Face format.
|
||||
|
||||
@ -293,11 +290,9 @@ def inv_remap_state_dict_hf_llama(
|
||||
|
||||
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
||||
|
||||
def permute(w, first_dim=None):
|
||||
return (
|
||||
w.view(first_dim or config.n_head, -1, 2, config.n_embd)
|
||||
.transpose(1, 2)
|
||||
.reshape(-1, config.n_embd)
|
||||
def permute(w):
|
||||
return rearrange(
|
||||
w, "(h d two) n -> (h two d) n", d=config.n_embd // config.n_head // 2, two=2
|
||||
)
|
||||
|
||||
n_head = config.n_head
|
||||
@ -316,7 +311,7 @@ def inv_remap_state_dict_hf_llama(
|
||||
Wk = Wqkv[q_dim : q_dim + k_dim]
|
||||
Wv = Wqkv[q_dim + k_dim : q_dim + k_dim + v_dim]
|
||||
state_dict[f"model.layers.{l}.self_attn.q_proj.weight"] = permute(Wq)
|
||||
state_dict[f"model.layers.{l}.self_attn.k_proj.weight"] = permute(Wk, n_head_kv)
|
||||
state_dict[f"model.layers.{l}.self_attn.k_proj.weight"] = permute(Wk)
|
||||
state_dict[f"model.layers.{l}.self_attn.v_proj.weight"] = Wv
|
||||
state_dict.pop(f"transformer.layers.{l}.attention.inner_attention.rope.freqs", None)
|
||||
|
||||
|
||||
@ -725,7 +725,7 @@ class ParallelMHA(nn.Module):
|
||||
process_group,
|
||||
bias=qkv_proj_bias,
|
||||
sequence_parallel=sequence_parallel,
|
||||
multiple_of=self.head_dim * (self.num_heads_per_rank + 2 * self.num_heads_kv_per_rank),
|
||||
multiple_of=self.head_dim * (self.num_heads // self.num_heads_kv + 2),
|
||||
**factory_kwargs,
|
||||
)
|
||||
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
|
||||
|
||||
@ -160,6 +160,7 @@ def test_tensor_parallel(model_name, rotary, world_size):
|
||||
assert torch.allclose(
|
||||
torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), rtol=rtol, atol=atol
|
||||
)
|
||||
assert torch.equal(torch.stack(out.scores, dim=1), torch.stack(out_cg.scores, dim=1))
|
||||
if not rotary:
|
||||
assert torch.all(out.sequences == out_ref.sequences)
|
||||
assert torch.all(out.sequences == out_hf.sequences)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
|
||||
# To run the huggingface implementation, we first need to convert the weights:
|
||||
# To run the huggingface implementation of LLaMa (1), we first need to convert the weights:
|
||||
# https://github.com/huggingface/transformers/pull/21955
|
||||
# python -m transformers.models.llama.convert_llama_weights_to_hf --input_dir $CHECKPOINT_DIR/llama --model_size 7B --output_dir $CHECKPOINT_DIR/llama/7B-hf
|
||||
# and repeat for 13B, 30B, 65B
|
||||
@ -30,6 +30,7 @@ from flash_attn.utils.generation import update_graph_cache
|
||||
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
||||
from transformers import LlamaConfig, LlamaTokenizer
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||
from transformers import AutoConfig
|
||||
|
||||
|
||||
def _pretrained_state_dict_from_checkpoint(checkpoint_path, model_name, config, checkpoint_format):
|
||||
@ -60,9 +61,38 @@ def test_llama_state_dict(model_name):
|
||||
assert state_dict[k].shape == pretrained_state_dict[k].shape
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["7B", "13B"])
|
||||
@pytest.mark.parametrize("checkpoint_format", ["meta", "hf"])
|
||||
def test_llama_optimized(model_name, checkpoint_format):
|
||||
# TinyLlama-1.1B is to test MQA
|
||||
@pytest.mark.parametrize(
|
||||
"model_name", ["meta-llama/Llama-2-7b-hf", "PY007/TinyLlama-1.1B-step-50K-105b"]
|
||||
)
|
||||
def test_inv_remap_state_dict_hf_llama(model_name):
|
||||
config = llama_config_to_gpt2_config(
|
||||
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
)
|
||||
state_dict = state_dict_from_pretrained(model_name)
|
||||
# inv_remap_state_dict_hf_llama should be the inverse of remap_state_dict_hf_llama
|
||||
state_dict = {key: val for key, val in state_dict.items() if "rotary_emb.inv_freq" not in key}
|
||||
pretrained_state_dict = remap_state_dict_hf_llama(state_dict, config)
|
||||
state_dict_recover = inv_remap_state_dict_hf_llama(pretrained_state_dict, config)
|
||||
assert set(state_dict_recover.keys()) == set(state_dict.keys())
|
||||
for key in state_dict_recover.keys():
|
||||
torch.testing.assert_close(state_dict_recover[key], state_dict[key])
|
||||
|
||||
|
||||
# TinyLlama-1.1B is to test MQA
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[
|
||||
"7B", # Llama 1
|
||||
"13B", # Llama 1
|
||||
"meta-llama/Llama-2-13b-hf",
|
||||
"codellama/CodeLlama-7b-hf",
|
||||
"codellama/CodeLlama-13b-hf",
|
||||
"codellama/CodeLlama-34b-hf",
|
||||
"PY007/TinyLlama-1.1B-step-50K-105b",
|
||||
],
|
||||
)
|
||||
def test_llama_optimized(model_name):
|
||||
"""Check that our implementation of LLaMa (with all optimizations enabled) matches the
|
||||
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
|
||||
forward pass in fp16, when compared to the HF forward pass in fp32.
|
||||
@ -73,17 +103,27 @@ def test_llama_optimized(model_name, checkpoint_format):
|
||||
|
||||
dtype = torch.float16
|
||||
device = "cuda"
|
||||
config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format)
|
||||
config = llama_config_to_gpt2_config(config)
|
||||
if "/" in model_name: # Download from HF
|
||||
config = llama_config_to_gpt2_config(
|
||||
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
)
|
||||
else:
|
||||
config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format="meta")
|
||||
config = llama_config_to_gpt2_config(config)
|
||||
config.use_flash_attn = True
|
||||
config.fused_bias_fc = True
|
||||
config.fused_mlp = False # We don't have fused GatedMLP yet
|
||||
config.fused_dropout_add_ln = True
|
||||
config.residual_in_fp32 = True
|
||||
|
||||
pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
|
||||
checkpoint_path, model_name, config, checkpoint_format
|
||||
)
|
||||
if "/" in model_name: # Download from HF
|
||||
pretrained_state_dict = remap_state_dict_hf_llama(
|
||||
state_dict_from_pretrained(model_name), config
|
||||
)
|
||||
else:
|
||||
pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
|
||||
checkpoint_path, model_name, config, checkpoint_format="meta"
|
||||
)
|
||||
model = GPTLMHeadModel(config, device=device, dtype=dtype)
|
||||
model.load_state_dict(pretrained_state_dict)
|
||||
model.eval()
|
||||
@ -103,7 +143,8 @@ def test_llama_optimized(model_name, checkpoint_format):
|
||||
# Without device_map, the model is loaded on the CPU, which is very slow
|
||||
# Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
|
||||
model_ref = LlamaForCausalLM.from_pretrained(
|
||||
Path(checkpoint_path) / f"{model_name}-hf", device_map="auto"
|
||||
model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf",
|
||||
device_map="auto",
|
||||
)
|
||||
model_ref.eval()
|
||||
with torch.no_grad():
|
||||
@ -112,7 +153,9 @@ def test_llama_optimized(model_name, checkpoint_format):
|
||||
del model_ref
|
||||
|
||||
model_hf = LlamaForCausalLM.from_pretrained(
|
||||
Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map={"": device}
|
||||
model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf",
|
||||
torch_dtype=dtype,
|
||||
device_map={"": device},
|
||||
)
|
||||
model_hf.eval()
|
||||
with torch.no_grad():
|
||||
@ -135,77 +178,12 @@ def test_llama_optimized(model_name, checkpoint_format):
|
||||
).abs().max().item()
|
||||
|
||||
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["PY007/TinyLlama-1.1B-step-50K-105b"])
|
||||
def test_mqa_optimized(model_name):
|
||||
"""Check that our implementation of Llama with MQA/GQA (with all optimizations enabled) matches the
|
||||
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
|
||||
forward pass in fp16, when compared to the HF forward pass in fp32.
|
||||
"""
|
||||
dtype = torch.float16
|
||||
device = "cuda"
|
||||
config = llama_config_to_gpt2_config(LlamaConfig.from_pretrained(model_name))
|
||||
config.use_flash_attn = True # FlashAttention-2 supports headdim 256
|
||||
config.fused_bias_fc = True
|
||||
config.fused_mlp = False
|
||||
config.fused_dropout_add_ln = True
|
||||
config.residual_in_fp32 = True
|
||||
|
||||
# Without device_map, the model is loaded on the CPU, which is very slow
|
||||
model_ref = LlamaForCausalLM.from_pretrained(model_name, device_map={"": device})
|
||||
model_ref.eval()
|
||||
|
||||
model = GPTLMHeadModel(config, device=device, dtype=dtype)
|
||||
model.load_state_dict(remap_state_dict_hf_llama(model_ref.state_dict(), config))
|
||||
model.eval()
|
||||
|
||||
torch.manual_seed(0)
|
||||
batch_size = 2
|
||||
max_seqlen = 256
|
||||
input_ids = torch.randint(
|
||||
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
|
||||
)
|
||||
with torch.no_grad():
|
||||
out = model.transformer(input_ids)
|
||||
logits = model(input_ids).logits
|
||||
del model
|
||||
|
||||
with torch.no_grad():
|
||||
out_ref = model_ref.model(input_ids).last_hidden_state
|
||||
logits_ref = model_ref(input_ids).logits
|
||||
del model_ref
|
||||
|
||||
model_hf = LlamaForCausalLM.from_pretrained(
|
||||
model_name, torch_dtype=dtype, device_map={"": device}
|
||||
)
|
||||
model_hf.eval()
|
||||
out_hf = model_hf.model(input_ids).last_hidden_state
|
||||
logits_hf = model_hf(input_ids).logits
|
||||
del model_hf
|
||||
|
||||
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
||||
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
||||
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
|
||||
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
|
||||
assert (out - out_ref).abs().max().item() < 3 * (
|
||||
out_hf - out_ref
|
||||
).abs().max().item()
|
||||
|
||||
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
|
||||
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
|
||||
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
|
||||
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
|
||||
assert (logits - logits_ref).abs().max().item() < 3 * (
|
||||
logits_hf - logits_ref
|
||||
).abs().max().item()
|
||||
|
||||
|
||||
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "parallel"
|
||||
@pytest.mark.parametrize("world_size", [2])
|
||||
@pytest.mark.parametrize("model_name", ["13B"])
|
||||
@pytest.mark.parametrize("checkpoint_format", ["meta", "hf"])
|
||||
def test_llama_parallel(model_name, world_size, checkpoint_format):
|
||||
@pytest.mark.parametrize(
|
||||
"model_name", ["13B", "meta-llama/Llama-2-13b-hf", "codellama/CodeLlama-34b-hf"]
|
||||
)
|
||||
def test_llama_parallel(model_name, world_size):
|
||||
"""Check that our implementation of LLaMa (with all optimizations enabled) matches the
|
||||
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
|
||||
forward pass in fp16, when compared to the HF forward pass in fp32.
|
||||
@ -217,8 +195,13 @@ def test_llama_parallel(model_name, world_size, checkpoint_format):
|
||||
)
|
||||
|
||||
dtype = torch.float16
|
||||
config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format)
|
||||
config = llama_config_to_gpt2_config(config)
|
||||
if "/" in model_name: # Download from HF
|
||||
config = llama_config_to_gpt2_config(
|
||||
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
)
|
||||
else:
|
||||
config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format="meta")
|
||||
config = llama_config_to_gpt2_config(config)
|
||||
config.use_flash_attn = True
|
||||
config.fused_bias_fc = True
|
||||
config.fused_mlp = False # We don't have fused GatedMLP yet
|
||||
@ -233,9 +216,14 @@ def test_llama_parallel(model_name, world_size, checkpoint_format):
|
||||
rank = parallel_state.get_tensor_model_parallel_rank()
|
||||
process_group = parallel_state.get_tensor_model_parallel_group()
|
||||
|
||||
pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
|
||||
checkpoint_path, model_name, config, checkpoint_format
|
||||
)
|
||||
if "/" in model_name: # Download from HF
|
||||
pretrained_state_dict = remap_state_dict_hf_llama(
|
||||
state_dict_from_pretrained(model_name), config
|
||||
)
|
||||
else:
|
||||
pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
|
||||
checkpoint_path, model_name, config, checkpoint_format="meta"
|
||||
)
|
||||
model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
|
||||
model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
|
||||
model.eval()
|
||||
@ -260,7 +248,8 @@ def test_llama_parallel(model_name, world_size, checkpoint_format):
|
||||
if rank == 0:
|
||||
# Without device_map, the model is loaded on the CPU, which is very slow
|
||||
model_ref = LlamaForCausalLM.from_pretrained(
|
||||
Path(checkpoint_path) / f"{model_name}-hf", device_map="auto"
|
||||
model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf",
|
||||
device_map="auto",
|
||||
)
|
||||
model_ref.eval()
|
||||
with torch.no_grad():
|
||||
@ -269,7 +258,9 @@ def test_llama_parallel(model_name, world_size, checkpoint_format):
|
||||
del model_ref
|
||||
|
||||
model_hf = LlamaForCausalLM.from_pretrained(
|
||||
Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map="auto"
|
||||
model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf",
|
||||
torch_dtype=dtype,
|
||||
device_map="auto",
|
||||
)
|
||||
model_hf.eval()
|
||||
with torch.no_grad():
|
||||
@ -405,9 +396,10 @@ def test_llama_generation(model_name, checkpoint_format):
|
||||
|
||||
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "llama_parallel_generation"
|
||||
@pytest.mark.parametrize("world_size", [2])
|
||||
@pytest.mark.parametrize("model_name", ["13B"])
|
||||
@pytest.mark.parametrize("checkpoint_format", ["meta", "hf"])
|
||||
def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
|
||||
@pytest.mark.parametrize(
|
||||
"model_name", ["13B", "meta-llama/Llama-2-13b-hf", "codellama/CodeLlama-34b-hf"]
|
||||
)
|
||||
def test_llama_parallel_generation(model_name, world_size):
|
||||
"""Check that our implementation matches the HF implementation:
|
||||
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
|
||||
the HF scores in fp32.
|
||||
@ -419,12 +411,17 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
|
||||
)
|
||||
|
||||
dtype = torch.float16
|
||||
config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format)
|
||||
config = llama_config_to_gpt2_config(config)
|
||||
config.use_flash_attn = False
|
||||
if "/" in model_name: # Download from HF
|
||||
config = llama_config_to_gpt2_config(
|
||||
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
)
|
||||
else:
|
||||
config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format="meta")
|
||||
config = llama_config_to_gpt2_config(config)
|
||||
config.use_flash_attn = True
|
||||
config.fused_bias_fc = True
|
||||
config.fused_mlp = False # We don't have fused GatedMLP yet
|
||||
config.fused_dropout_add_ln = False
|
||||
config.fused_dropout_add_ln = True
|
||||
config.residual_in_fp32 = True
|
||||
config.pad_vocab_size_multiple = 8 * world_size
|
||||
config.sequence_parallel = False # Need to set this to False for generation
|
||||
@ -450,9 +447,14 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
|
||||
# GPU0 and GPU1 and things would hang
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
|
||||
checkpoint_path, model_name, config, checkpoint_format
|
||||
)
|
||||
if "/" in model_name: # Download from HF
|
||||
pretrained_state_dict = remap_state_dict_hf_llama(
|
||||
state_dict_from_pretrained(model_name), config
|
||||
)
|
||||
else:
|
||||
pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
|
||||
checkpoint_path, model_name, config, checkpoint_format="meta"
|
||||
)
|
||||
model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
|
||||
model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
|
||||
model.eval()
|
||||
@ -490,7 +492,9 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
|
||||
if rank == 0:
|
||||
# Without device_map, the model is loaded on the CPU, which is very slow
|
||||
model_hf = LlamaForCausalLM.from_pretrained(
|
||||
Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map="auto"
|
||||
model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf",
|
||||
torch_dtype=dtype,
|
||||
device_map="auto",
|
||||
)
|
||||
model_hf.eval()
|
||||
print("HF fp16")
|
||||
@ -508,7 +512,8 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
|
||||
del model_hf
|
||||
|
||||
model_ref = LlamaForCausalLM.from_pretrained(
|
||||
Path(checkpoint_path) / f"{model_name}-hf", device_map="auto"
|
||||
model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf",
|
||||
device_map="auto",
|
||||
)
|
||||
model_ref.eval()
|
||||
with torch.inference_mode():
|
||||
@ -594,15 +599,16 @@ def test_llama_parallel_uneven_num_heads(world_size):
|
||||
|
||||
if rank == 0:
|
||||
model_ref = LlamaForCausalLM.from_pretrained(
|
||||
Path(checkpoint_path) / f"{model_name}-hf", device_map="auto"
|
||||
Path(checkpoint_path) / f"{model_name}-hf", device_map={"": device}
|
||||
)
|
||||
model_ref = model_ref.to(device=device)
|
||||
model_ref.eval()
|
||||
out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device)
|
||||
logits_ref = model_ref(input_ids).logits.to(device=device)
|
||||
out_ref = model_ref.model(input_ids).last_hidden_state
|
||||
logits_ref = model_ref(input_ids).logits
|
||||
del model_ref
|
||||
|
||||
model_hf = LlamaForCausalLM.from_pretrained(
|
||||
Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map="auto"
|
||||
Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map={"": device}
|
||||
)
|
||||
model_hf.eval()
|
||||
out_hf = model_hf.model(input_ids).last_hidden_state.to(device=device)
|
||||
@ -625,42 +631,3 @@ def test_llama_parallel_uneven_num_heads(world_size):
|
||||
|
||||
if os.path.exists(checkpoint_path / f"{model_name}-hf"):
|
||||
shutil.rmtree(checkpoint_path / f"{model_name}-hf")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_inv_remap_state_dict_hf_llama():
|
||||
checkpoint_path = (
|
||||
Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama"
|
||||
)
|
||||
model_name = f"teeny"
|
||||
|
||||
llama_config = LlamaConfig(
|
||||
num_attention_heads=2,
|
||||
hidden_size=256 * 2,
|
||||
intermediate_size=256 * 2 * 4,
|
||||
num_hidden_layers=4,
|
||||
)
|
||||
config = llama_config_to_gpt2_config(llama_config)
|
||||
config.use_flash_attn = True
|
||||
config.fused_bias_fc = True
|
||||
config.fused_mlp = False # We don't have fused GatedMLP yet
|
||||
config.fused_dropout_add_ln = True
|
||||
config.residual_in_fp32 = True
|
||||
|
||||
# Set up.
|
||||
LlamaForCausalLM(config=llama_config).save_pretrained(checkpoint_path / f"{model_name}-hf")
|
||||
|
||||
# inv_remap_state_dict_hf_llama should be the inverse of remap_state_dict_hf_llama
|
||||
state_dict = state_dict_from_pretrained(checkpoint_path / f"{model_name}-hf")
|
||||
state_dict = {key: val for key, val in state_dict.items() if "rotary_emb.inv_freq" not in key}
|
||||
pretrained_state_dict = remap_state_dict_hf_llama(state_dict, config)
|
||||
state_dict_recover = inv_remap_state_dict_hf_llama(pretrained_state_dict, config)
|
||||
|
||||
assert set(state_dict_recover.keys()) == set(state_dict.keys())
|
||||
|
||||
for key in state_dict_recover.keys():
|
||||
torch.testing.assert_close(state_dict_recover[key], state_dict[key])
|
||||
|
||||
# Tear down.
|
||||
if os.path.exists(checkpoint_path / f"{model_name}-hf"):
|
||||
shutil.rmtree(checkpoint_path / f"{model_name}-hf")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user