map custom model state_dict back to huggingface format (#465)
* fix name. * set inv function. * add map back function. * handle gqa. * add type annotation to avoid confusion. * fix docstr. * test inverse remap logic.
This commit is contained in:
parent
f1a73d0740
commit
7fcd3e6a04
@ -785,8 +785,10 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
|
||||
|
||||
|
||||
def combine_state_dicts_tp(state_dicts, config):
|
||||
"""Convert the state_dict of a standard GPT model to the state_dict of a GPT model
|
||||
with tensor parallel.
|
||||
"""Convert the state_dict of a GPT model with tensor parallel to the state_dict of a
|
||||
standard GPT model.
|
||||
|
||||
This function is meant to be the "reverse" of shard_state_dict_tp.
|
||||
"""
|
||||
world_size = len(state_dicts)
|
||||
keys = state_dicts[0].keys()
|
||||
|
||||
@ -13,7 +13,14 @@ import torch.nn.functional as F
|
||||
from transformers import GPT2Config, LlamaConfig
|
||||
|
||||
|
||||
def remap_state_dict_meta_llama(state_dict, config):
|
||||
def remap_state_dict_meta_llama(
|
||||
state_dict: dict[str, torch.Tensor], config: GPT2Config
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Convert the state_dict in Meta format to standard GPT format.
|
||||
|
||||
This function modifies state_dict in place.
|
||||
"""
|
||||
|
||||
def key_mapping_layers(key):
|
||||
return f"transformer.{key}" if not key.startswith("output.") else key
|
||||
|
||||
@ -97,7 +104,13 @@ def remap_state_dict_meta_llama(state_dict, config):
|
||||
return state_dict
|
||||
|
||||
|
||||
def remap_state_dict_hf_llama(state_dict, config):
|
||||
def remap_state_dict_hf_llama(
|
||||
state_dict: dict[str, torch.Tensor], config: GPT2Config
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Convert the state_dict in Hugging Face format to standard GPT format.
|
||||
|
||||
This function modifies state_dict in place.
|
||||
"""
|
||||
# Embedding
|
||||
def key_mapping_emb(key):
|
||||
return re.sub(r"^model.embed_tokens.", "transformer.embeddings.word_embeddings.", key)
|
||||
@ -183,6 +196,106 @@ def remap_state_dict_hf_llama(state_dict, config):
|
||||
return state_dict
|
||||
|
||||
|
||||
def inv_remap_state_dict_hf_llama(
|
||||
state_dict: dict[str, torch.Tensor], config: GPT2Config
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Convert the state_dict in standard GPT format to Hugging Face format.
|
||||
|
||||
This function is meant to be the inverse of remap_state_dict_hf_llama, up to a
|
||||
multiplier pad in the embedding and lm_head. That is if the original embedding
|
||||
isn't a multiple of pad_vocab_size_multiple, then
|
||||
inv_remap_state_dict_hf_llama(remap_state_dict_hf_llama(state_dict)) != state_dict.
|
||||
|
||||
This function modifies state_dict in place.
|
||||
"""
|
||||
|
||||
# Embedding
|
||||
def key_mapping_emb(key):
|
||||
return re.sub(r"^transformer.embeddings.word_embeddings.", "model.embed_tokens.", key)
|
||||
|
||||
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
|
||||
word_embeddings = state_dict.pop("model.embed_tokens.weight")
|
||||
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
||||
vocab_size = (
|
||||
math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
||||
)
|
||||
state_dict["model.embed_tokens.weight"] = F.pad(
|
||||
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
||||
)
|
||||
|
||||
# LM head
|
||||
if getattr(config, "tie_word_embeddings"):
|
||||
state_dict["lm_head.weight"] = state_dict["model.embed_tokens.weight"]
|
||||
else:
|
||||
output_embeddings = state_dict.pop("lm_head.weight")
|
||||
vocab_size = (
|
||||
math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
|
||||
* pad_vocab_size_multiple
|
||||
)
|
||||
state_dict["lm_head.weight"] = F.pad(
|
||||
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
|
||||
)
|
||||
|
||||
# MLP
|
||||
for l in range(config.n_layer):
|
||||
w3, w1 = torch.chunk(
|
||||
state_dict.pop(f"transformer.layers.{l}.mlp.fc1.weight"), chunks=2, dim=0
|
||||
)
|
||||
state_dict[f"model.layers.{l}.mlp.gate_proj.weight"] = w1
|
||||
state_dict[f"model.layers.{l}.mlp.up_proj.weight"] = w3
|
||||
|
||||
def key_mapping_mlp(key):
|
||||
return re.sub(r"^transformer.layers.(\d+).mlp.fc2.", r"model.layers.\1.mlp.down_proj.", key)
|
||||
|
||||
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
||||
|
||||
# LayerNorm
|
||||
def key_mapping_ln(key):
|
||||
key = re.sub(r"^transformer.ln_f.", r"model.norm.", key)
|
||||
key = re.sub(r"^transformer.layers.(\d+).norm1.", r"model.layers.\1.input_layernorm.", key)
|
||||
key = re.sub(
|
||||
r"^transformer.layers.(\d+).norm2.", r"model.layers.\1.post_attention_layernorm.", key
|
||||
)
|
||||
return key
|
||||
|
||||
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
||||
|
||||
def permute(w):
|
||||
return (
|
||||
w.view(config.n_head, config.n_embd // config.n_head // 2, 2, config.n_embd)
|
||||
.transpose(1, 2)
|
||||
.reshape(config.n_embd, config.n_embd)
|
||||
)
|
||||
|
||||
n_head = config.n_head
|
||||
n_head_kv = getattr(config, "n_head_kv", n_head)
|
||||
|
||||
embed_dim = config.hidden_size
|
||||
head_dim = embed_dim // n_head
|
||||
|
||||
q_dim = n_head * head_dim
|
||||
k_dim = v_dim = n_head_kv * head_dim
|
||||
|
||||
# Attention
|
||||
for l in range(config.n_layer):
|
||||
Wqkv = state_dict.pop(f"transformer.layers.{l}.mixer.Wqkv.weight")
|
||||
Wq = Wqkv[:q_dim]
|
||||
Wk = Wqkv[q_dim : q_dim + k_dim]
|
||||
Wv = Wqkv[q_dim + k_dim : q_dim + k_dim + v_dim]
|
||||
state_dict[f"model.layers.{l}.self_attn.q_proj.weight"] = permute(Wq)
|
||||
state_dict[f"model.layers.{l}.self_attn.k_proj.weight"] = permute(Wk)
|
||||
state_dict[f"model.layers.{l}.self_attn.v_proj.weight"] = Wv
|
||||
state_dict.pop(f"transformer.layers.{l}.attention.inner_attention.rope.freqs", None)
|
||||
|
||||
def key_mapping_attn(key):
|
||||
return re.sub(
|
||||
r"^transformer.layers.(\d+).mixer.out_proj.", r"model.layers.\1.self_attn.o_proj.", key
|
||||
)
|
||||
|
||||
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
||||
return state_dict
|
||||
|
||||
|
||||
def config_from_meta_checkpoint(
|
||||
checkpoint_path: Union[str, os.PathLike], model_name: str
|
||||
) -> LlamaConfig:
|
||||
|
||||
@ -13,6 +13,7 @@ current_dir = Path(__file__).parent.absolute()
|
||||
|
||||
import torch
|
||||
import pytest
|
||||
import shutil
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
@ -20,7 +21,12 @@ from transformers import LlamaTokenizer, LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||
|
||||
from flash_attn.models.gpt import GPTLMHeadModel, combine_state_dicts_tp, shard_state_dict_tp
|
||||
from flash_attn.models.llama import remap_state_dict_meta_llama, llama_config_to_gpt2_config, remap_state_dict_hf_llama
|
||||
from flash_attn.models.llama import (
|
||||
remap_state_dict_meta_llama,
|
||||
llama_config_to_gpt2_config,
|
||||
remap_state_dict_hf_llama,
|
||||
inv_remap_state_dict_hf_llama,
|
||||
)
|
||||
from flash_attn.models.llama import config_from_checkpoint, state_dicts_from_checkpoint
|
||||
from flash_attn.utils.distributed import all_gather_raw
|
||||
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
||||
@ -33,37 +39,41 @@ def _pretrained_state_dict_from_checkpoint(checkpoint_path, model_name, config,
|
||||
pretrained_state_dicts = [remap_state_dict_meta_llama(s, config) for s in ckpt_state_dicts]
|
||||
pretrained_state_dict = combine_state_dicts_tp(pretrained_state_dicts, config)
|
||||
else:
|
||||
pretrained_state_dict = state_dict_from_pretrained(Path(checkpoint_path) / f'{model_name}-hf')
|
||||
pretrained_state_dict = state_dict_from_pretrained(
|
||||
Path(checkpoint_path) / f"{model_name}-hf"
|
||||
)
|
||||
pretrained_state_dict = remap_state_dict_hf_llama(pretrained_state_dict, config)
|
||||
return pretrained_state_dict
|
||||
|
||||
|
||||
@pytest.mark.parametrize('model_name', ["7B"])
|
||||
@pytest.mark.parametrize("model_name", ["7B"])
|
||||
def test_llama_state_dict(model_name):
|
||||
checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR',
|
||||
current_dir.parent.parent / 'checkpoints')) / 'llama'
|
||||
checkpoint_path = (
|
||||
Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama"
|
||||
)
|
||||
config = llama_config_to_gpt2_config(config_from_checkpoint(checkpoint_path, model_name))
|
||||
ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name)
|
||||
pretrained_state_dict = remap_state_dict_meta_llama(ckpt_state_dicts[0], config)
|
||||
model = GPTLMHeadModel(config, device='meta') # Without device='meta' init is very slow
|
||||
model = GPTLMHeadModel(config, device="meta") # Without device='meta' init is very slow
|
||||
state_dict = model.state_dict()
|
||||
assert state_dict.keys() == pretrained_state_dict.keys()
|
||||
for k in state_dict.keys():
|
||||
assert state_dict[k].shape == pretrained_state_dict[k].shape
|
||||
|
||||
|
||||
@pytest.mark.parametrize('model_name', ["7B", "13B"])
|
||||
@pytest.mark.parametrize('checkpoint_format', ["meta", "hf"])
|
||||
@pytest.mark.parametrize("model_name", ["7B", "13B"])
|
||||
@pytest.mark.parametrize("checkpoint_format", ["meta", "hf"])
|
||||
def test_llama_optimized(model_name, checkpoint_format):
|
||||
"""Check that our implementation of LLaMa (with all optimizations enabled) matches the
|
||||
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
|
||||
forward pass in fp16, when compared to the HF forward pass in fp32.
|
||||
"""
|
||||
checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR',
|
||||
current_dir.parent.parent / 'checkpoints')) / 'llama'
|
||||
checkpoint_path = (
|
||||
Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama"
|
||||
)
|
||||
|
||||
dtype = torch.float16
|
||||
device = 'cuda'
|
||||
device = "cuda"
|
||||
config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format)
|
||||
config = llama_config_to_gpt2_config(config)
|
||||
config.use_flash_attn = True
|
||||
@ -83,8 +93,9 @@ def test_llama_optimized(model_name, checkpoint_format):
|
||||
batch_size = 2
|
||||
max_seqlen = 256
|
||||
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
|
||||
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
|
||||
device=device)
|
||||
input_ids = torch.randint(
|
||||
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
|
||||
)
|
||||
with torch.no_grad():
|
||||
out = model.transformer(input_ids)
|
||||
logits = model(input_ids).logits
|
||||
@ -92,39 +103,43 @@ def test_llama_optimized(model_name, checkpoint_format):
|
||||
|
||||
# Without device_map, the model is loaded on the CPU, which is very slow
|
||||
# Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
|
||||
model_ref = LlamaForCausalLM.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf',
|
||||
device_map='auto')
|
||||
model_ref = LlamaForCausalLM.from_pretrained(
|
||||
Path(checkpoint_path) / f"{model_name}-hf", device_map="auto"
|
||||
)
|
||||
model_ref.eval()
|
||||
with torch.no_grad():
|
||||
out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device)
|
||||
logits_ref = model_ref(input_ids).logits.to(device=device)
|
||||
del model_ref
|
||||
|
||||
model_hf = LlamaForCausalLM.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf',
|
||||
torch_dtype=dtype, device_map={"": device})
|
||||
model_hf = LlamaForCausalLM.from_pretrained(
|
||||
Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map={"": device}
|
||||
)
|
||||
model_hf.eval()
|
||||
with torch.no_grad():
|
||||
out_hf = model_hf.model(input_ids).last_hidden_state
|
||||
logits_hf = model_hf(input_ids).logits
|
||||
del model_hf
|
||||
|
||||
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
|
||||
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
|
||||
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}')
|
||||
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}')
|
||||
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
||||
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
||||
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
|
||||
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
|
||||
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
|
||||
|
||||
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
|
||||
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}')
|
||||
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}')
|
||||
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}')
|
||||
assert (logits - logits_ref).abs().max().item() < 3 * (logits_hf - logits_ref).abs().max().item()
|
||||
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
|
||||
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
|
||||
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
|
||||
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
|
||||
assert (logits - logits_ref).abs().max().item() < 3 * (
|
||||
logits_hf - logits_ref
|
||||
).abs().max().item()
|
||||
|
||||
|
||||
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "parallel"
|
||||
@pytest.mark.parametrize('world_size', [2])
|
||||
@pytest.mark.parametrize('model_name', ["13B"])
|
||||
@pytest.mark.parametrize('checkpoint_format', ["meta", "hf"])
|
||||
@pytest.mark.parametrize("world_size", [2])
|
||||
@pytest.mark.parametrize("model_name", ["13B"])
|
||||
@pytest.mark.parametrize("checkpoint_format", ["meta", "hf"])
|
||||
def test_llama_parallel(model_name, world_size, checkpoint_format):
|
||||
"""Check that our implementation of LLaMa (with all optimizations enabled) matches the
|
||||
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
|
||||
@ -132,8 +147,9 @@ def test_llama_parallel(model_name, world_size, checkpoint_format):
|
||||
"""
|
||||
from apex.transformer import parallel_state
|
||||
|
||||
checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR',
|
||||
current_dir.parent.parent / 'checkpoints')) / 'llama'
|
||||
checkpoint_path = (
|
||||
Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama"
|
||||
)
|
||||
|
||||
dtype = torch.float16
|
||||
config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format)
|
||||
@ -145,8 +161,8 @@ def test_llama_parallel(model_name, world_size, checkpoint_format):
|
||||
config.residual_in_fp32 = True
|
||||
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group(backend='nccl', init_method='env://')
|
||||
device = f'cuda:{torch.distributed.get_rank()}'
|
||||
torch.distributed.init_process_group(backend="nccl", init_method="env://")
|
||||
device = f"cuda:{torch.distributed.get_rank()}"
|
||||
assert world_size <= torch.distributed.get_world_size()
|
||||
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
|
||||
rank = parallel_state.get_tensor_model_parallel_rank()
|
||||
@ -163,8 +179,9 @@ def test_llama_parallel(model_name, world_size, checkpoint_format):
|
||||
batch_size = 2
|
||||
max_seqlen = 256
|
||||
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
|
||||
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
|
||||
device=device)
|
||||
input_ids = torch.randint(
|
||||
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
|
||||
)
|
||||
with torch.no_grad():
|
||||
out = model.transformer(input_ids)
|
||||
out, _ = all_gather_raw(out, process_group=process_group)
|
||||
@ -172,13 +189,13 @@ def test_llama_parallel(model_name, world_size, checkpoint_format):
|
||||
logits = model(input_ids).logits
|
||||
logits = rearrange(logits, "(b s) d -> b s d", b=batch_size)
|
||||
logits, _ = all_gather_raw(logits, process_group)
|
||||
logits = rearrange(logits, '(n b) ... d -> b ... (n d)', b=batch_size)
|
||||
logits = rearrange(logits, "(n b) ... d -> b ... (n d)", b=batch_size)
|
||||
del model
|
||||
|
||||
if rank == 0:
|
||||
# Without device_map, the model is loaded on the CPU, which is very slow
|
||||
model_ref = LlamaForCausalLM.from_pretrained(
|
||||
Path(checkpoint_path) / f'{model_name}-hf', device_map="auto"
|
||||
Path(checkpoint_path) / f"{model_name}-hf", device_map="auto"
|
||||
)
|
||||
model_ref.eval()
|
||||
with torch.no_grad():
|
||||
@ -187,7 +204,7 @@ def test_llama_parallel(model_name, world_size, checkpoint_format):
|
||||
del model_ref
|
||||
|
||||
model_hf = LlamaForCausalLM.from_pretrained(
|
||||
Path(checkpoint_path) / f'{model_name}-hf', torch_dtype=dtype, device_map="auto"
|
||||
Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map="auto"
|
||||
)
|
||||
model_hf.eval()
|
||||
with torch.no_grad():
|
||||
@ -195,28 +212,31 @@ def test_llama_parallel(model_name, world_size, checkpoint_format):
|
||||
logits_hf = model_hf(input_ids).logits.to(device=device)
|
||||
del model_hf
|
||||
|
||||
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
|
||||
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
|
||||
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}')
|
||||
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}')
|
||||
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
||||
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
||||
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
|
||||
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
|
||||
assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()
|
||||
|
||||
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
|
||||
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}')
|
||||
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}')
|
||||
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}')
|
||||
assert (logits - logits_ref).abs().max().item() < 2 * (logits_hf - logits_ref).abs().max().item()
|
||||
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
|
||||
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
|
||||
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
|
||||
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
|
||||
assert (logits - logits_ref).abs().max().item() < 2 * (
|
||||
logits_hf - logits_ref
|
||||
).abs().max().item()
|
||||
|
||||
|
||||
# @pytest.mark.parametrize('model_name', ["7B", "13B"])
|
||||
@pytest.mark.parametrize('model_name', ["7B"])
|
||||
@pytest.mark.parametrize('checkpoint_format', ["meta", "hf"])
|
||||
@pytest.mark.parametrize("model_name", ["7B"])
|
||||
@pytest.mark.parametrize("checkpoint_format", ["meta", "hf"])
|
||||
def test_llama_generation(model_name, checkpoint_format):
|
||||
checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR',
|
||||
current_dir.parent.parent / 'checkpoints')) / 'llama'
|
||||
checkpoint_path = (
|
||||
Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama"
|
||||
)
|
||||
|
||||
dtype = torch.float16
|
||||
device = 'cuda'
|
||||
device = "cuda"
|
||||
config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format)
|
||||
config = llama_config_to_gpt2_config(config)
|
||||
config.use_flash_attn = True
|
||||
@ -225,34 +245,38 @@ def test_llama_generation(model_name, checkpoint_format):
|
||||
config.fused_dropout_add_ln = True
|
||||
config.residual_in_fp32 = True
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf')
|
||||
tokenizer = LlamaTokenizer.from_pretrained(Path(checkpoint_path) / f"{model_name}-hf")
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
torch.manual_seed(0)
|
||||
batch_size = 1
|
||||
seqlen = 100
|
||||
max_length = 150
|
||||
input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long,
|
||||
device=device)
|
||||
input_ids = torch.randint(
|
||||
0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
|
||||
)
|
||||
|
||||
model_hf = LlamaForCausalLM.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf',
|
||||
torch_dtype=dtype, device_map={"": device})
|
||||
model_hf = LlamaForCausalLM.from_pretrained(
|
||||
Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map={"": device}
|
||||
)
|
||||
model_hf.eval()
|
||||
print("HF fp16")
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length,
|
||||
return_dict_in_generate=True, output_scores=True)
|
||||
out_hf = model_hf.generate(
|
||||
input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
|
||||
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
|
||||
del model_hf
|
||||
|
||||
# Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
|
||||
model_ref = LlamaForCausalLM.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf',
|
||||
device_map='auto')
|
||||
model_ref = LlamaForCausalLM.from_pretrained(
|
||||
Path(checkpoint_path) / f"{model_name}-hf", device_map="auto"
|
||||
)
|
||||
model_ref.eval()
|
||||
with torch.no_grad():
|
||||
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1):-1].to(device=device)
|
||||
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1].to(device=device)
|
||||
del model_ref
|
||||
|
||||
pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
|
||||
@ -262,31 +286,43 @@ def test_llama_generation(model_name, checkpoint_format):
|
||||
model.load_state_dict(pretrained_state_dict)
|
||||
model.eval()
|
||||
|
||||
print('Without CUDA graph')
|
||||
print("Without CUDA graph")
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
out = model.generate(input_ids=input_ids, max_length=max_length,
|
||||
eos_token_id=eos_token_id, fused_ft_kernel=True,
|
||||
return_dict_in_generate=True, output_scores=True, timing=True,
|
||||
teacher_outputs=out_hf.sequences)
|
||||
out = model.generate(
|
||||
input_ids=input_ids,
|
||||
max_length=max_length,
|
||||
eos_token_id=eos_token_id,
|
||||
fused_ft_kernel=True,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
timing=True,
|
||||
teacher_outputs=out_hf.sequences,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
|
||||
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
|
||||
|
||||
# Capture graph outside the timing loop
|
||||
batch_size, seqlen_og = input_ids.shape
|
||||
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
|
||||
print('With CUDA graph')
|
||||
print("With CUDA graph")
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
out_cg = model.generate(input_ids=input_ids, max_length=max_length,
|
||||
fused_ft_kernel=True, cg=True,
|
||||
return_dict_in_generate=True, output_scores=True, timing=True,
|
||||
teacher_outputs=out_hf.sequences)
|
||||
out_cg = model.generate(
|
||||
input_ids=input_ids,
|
||||
max_length=max_length,
|
||||
fused_ft_kernel=True,
|
||||
cg=True,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
timing=True,
|
||||
teacher_outputs=out_hf.sequences,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
|
||||
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
|
||||
|
||||
with torch.no_grad():
|
||||
logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1):-1]
|
||||
logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1]
|
||||
logits_hf = torch.stack(out_hf.scores, dim=1)
|
||||
logits = torch.stack(out.scores, dim=1)
|
||||
logits_cg = torch.stack(out_cg.scores, dim=1)
|
||||
@ -295,9 +331,9 @@ def test_llama_generation(model_name, checkpoint_format):
|
||||
|
||||
hf_error = (logits_hf - logits_ref).abs().max().item()
|
||||
|
||||
print(f'HF fp16 logits max diff: {hf_error}')
|
||||
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
|
||||
print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item()}')
|
||||
print(f"HF fp16 logits max diff: {hf_error}")
|
||||
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
|
||||
print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item()}")
|
||||
|
||||
assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error
|
||||
assert (logits - logits_ref).abs().max().item() < 2 * hf_error
|
||||
@ -305,9 +341,9 @@ def test_llama_generation(model_name, checkpoint_format):
|
||||
|
||||
|
||||
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "llama_parallel_generation"
|
||||
@pytest.mark.parametrize('world_size', [2])
|
||||
@pytest.mark.parametrize('model_name', ["13B"])
|
||||
@pytest.mark.parametrize('checkpoint_format', ["meta", "hf"])
|
||||
@pytest.mark.parametrize("world_size", [2])
|
||||
@pytest.mark.parametrize("model_name", ["13B"])
|
||||
@pytest.mark.parametrize("checkpoint_format", ["meta", "hf"])
|
||||
def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
|
||||
"""Check that our implementation matches the HF implementation:
|
||||
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
|
||||
@ -315,8 +351,9 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
|
||||
"""
|
||||
from apex.transformer import parallel_state
|
||||
|
||||
checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR',
|
||||
current_dir.parent.parent / 'checkpoints')) / 'llama'
|
||||
checkpoint_path = (
|
||||
Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama"
|
||||
)
|
||||
|
||||
dtype = torch.float16
|
||||
config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format)
|
||||
@ -331,8 +368,8 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
|
||||
|
||||
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group(backend='nccl', init_method='env://')
|
||||
device = f'cuda:{torch.distributed.get_rank()}'
|
||||
torch.distributed.init_process_group(backend="nccl", init_method="env://")
|
||||
device = f"cuda:{torch.distributed.get_rank()}"
|
||||
assert world_size <= torch.distributed.get_world_size()
|
||||
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
|
||||
rank = parallel_state.get_tensor_model_parallel_rank()
|
||||
@ -342,8 +379,9 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
|
||||
batch_size = 1
|
||||
seqlen = 100
|
||||
max_length = 150
|
||||
input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long,
|
||||
device=device)
|
||||
input_ids = torch.randint(
|
||||
0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
|
||||
)
|
||||
|
||||
# Need this, otherwise when we capture the graph the process for GPU 1 would run on both
|
||||
# GPU0 and GPU1 and things would hang
|
||||
@ -356,23 +394,34 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
|
||||
model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
|
||||
model.eval()
|
||||
|
||||
print('Without CUDA graph')
|
||||
print("Without CUDA graph")
|
||||
out = model.generate(
|
||||
input_ids=input_ids, max_length=max_length, tensor_parallel=world_size,
|
||||
vocab_size=config.vocab_size, fused_ft_kernel=True,
|
||||
input_ids=input_ids,
|
||||
max_length=max_length,
|
||||
tensor_parallel=world_size,
|
||||
vocab_size=config.vocab_size,
|
||||
fused_ft_kernel=True,
|
||||
# teacher_outputs=out_hf.sequences,
|
||||
return_dict_in_generate=True, output_scores=True, timing=True
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
timing=True,
|
||||
)
|
||||
|
||||
# Capture graph outside the timing loop
|
||||
batch_size, seqlen_og = input_ids.shape
|
||||
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
|
||||
print('With CUDA graph')
|
||||
print("With CUDA graph")
|
||||
out_cg = model.generate(
|
||||
input_ids=input_ids, max_length=max_length, tensor_parallel=world_size,
|
||||
vocab_size=config.vocab_size, fused_ft_kernel=True, cg=True,
|
||||
input_ids=input_ids,
|
||||
max_length=max_length,
|
||||
tensor_parallel=world_size,
|
||||
vocab_size=config.vocab_size,
|
||||
fused_ft_kernel=True,
|
||||
cg=True,
|
||||
# teacher_outputs=out_hf.sequences,
|
||||
return_dict_in_generate=True, output_scores=True, timing=True
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
timing=True,
|
||||
)
|
||||
del model
|
||||
parallel_state.destroy_model_parallel()
|
||||
@ -380,7 +429,7 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
|
||||
if rank == 0:
|
||||
# Without device_map, the model is loaded on the CPU, which is very slow
|
||||
model_hf = LlamaForCausalLM.from_pretrained(
|
||||
Path(checkpoint_path) / f'{model_name}-hf', torch_dtype=dtype, device_map="auto"
|
||||
Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map="auto"
|
||||
)
|
||||
model_hf.eval()
|
||||
print("HF fp16")
|
||||
@ -388,19 +437,21 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
|
||||
start = time.time()
|
||||
with torch.inference_mode():
|
||||
out_hf = model_hf.generate(
|
||||
input_ids=input_ids, max_length=max_length, return_dict_in_generate=True,
|
||||
output_scores=True
|
||||
input_ids=input_ids,
|
||||
max_length=max_length,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
|
||||
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
|
||||
del model_hf
|
||||
|
||||
model_ref = LlamaForCausalLM.from_pretrained(
|
||||
Path(checkpoint_path) / f'{model_name}-hf', device_map="auto"
|
||||
Path(checkpoint_path) / f"{model_name}-hf", device_map="auto"
|
||||
)
|
||||
model_ref.eval()
|
||||
with torch.inference_mode():
|
||||
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1):-1]
|
||||
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1]
|
||||
del model_ref
|
||||
logits_hf = torch.stack(out_hf.scores, dim=1)
|
||||
|
||||
@ -408,25 +459,27 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
|
||||
logits_cg = torch.stack(out_cg.scores, dim=1)
|
||||
|
||||
hf_error = (logits_hf - logits_ref).abs().max().item()
|
||||
print(f'HF fp16 logits max diff: {hf_error}')
|
||||
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
|
||||
print(f"HF fp16 logits max diff: {hf_error}")
|
||||
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
|
||||
assert (logits - logits_ref).abs().max().item() < 2 * hf_error
|
||||
print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item()}')
|
||||
print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item()}")
|
||||
assert torch.equal(logits_cg, logits)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@pytest.mark.parametrize('world_size', [2])
|
||||
@pytest.mark.parametrize("world_size", [2])
|
||||
def test_llama_parallel_uneven_num_heads(world_size):
|
||||
from apex.transformer import parallel_state
|
||||
|
||||
checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR', current_dir.parent.parent / 'checkpoints')) / 'llama'
|
||||
checkpoint_path = (
|
||||
Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama"
|
||||
)
|
||||
num_attention_heads = world_size + 1
|
||||
model_name = f'teeny-{num_attention_heads}-heads'
|
||||
model_name = f"teeny-{num_attention_heads}-heads"
|
||||
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group(backend='nccl', init_method='env://')
|
||||
device = f'cuda:{torch.distributed.get_rank()}'
|
||||
torch.distributed.init_process_group(backend="nccl", init_method="env://")
|
||||
device = f"cuda:{torch.distributed.get_rank()}"
|
||||
assert world_size <= torch.distributed.get_world_size()
|
||||
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
|
||||
rank = parallel_state.get_tensor_model_parallel_rank()
|
||||
@ -434,7 +487,8 @@ def test_llama_parallel_uneven_num_heads(world_size):
|
||||
|
||||
dtype = torch.float16
|
||||
llama_config = LlamaConfig(
|
||||
hidden_size=256 * num_attention_heads, # ParallelGatedMlp hidden_features must be divisible by 256
|
||||
hidden_size=256
|
||||
* num_attention_heads, # ParallelGatedMlp hidden_features must be divisible by 256
|
||||
intermediate_size=256 * num_attention_heads * 4,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=num_attention_heads,
|
||||
@ -451,8 +505,9 @@ def test_llama_parallel_uneven_num_heads(world_size):
|
||||
batch_size = 2
|
||||
max_seqlen = 256
|
||||
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
|
||||
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
|
||||
device=device)
|
||||
input_ids = torch.randint(
|
||||
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
|
||||
)
|
||||
|
||||
# Create a shared test model.
|
||||
if rank == 0:
|
||||
@ -474,11 +529,11 @@ def test_llama_parallel_uneven_num_heads(world_size):
|
||||
logits = model(input_ids).logits
|
||||
logits = rearrange(logits, "(b s) d -> b s d", b=batch_size)
|
||||
logits, _ = all_gather_raw(logits, process_group)
|
||||
logits = rearrange(logits, '(n b) ... d -> b ... (n d)', b=batch_size)
|
||||
logits = rearrange(logits, "(n b) ... d -> b ... (n d)", b=batch_size)
|
||||
|
||||
if rank == 0:
|
||||
model_ref = LlamaForCausalLM.from_pretrained(
|
||||
Path(checkpoint_path) / f'{model_name}-hf', device_map="auto"
|
||||
Path(checkpoint_path) / f"{model_name}-hf", device_map="auto"
|
||||
)
|
||||
model_ref.eval()
|
||||
out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device)
|
||||
@ -486,24 +541,65 @@ def test_llama_parallel_uneven_num_heads(world_size):
|
||||
del model_ref
|
||||
|
||||
model_hf = LlamaForCausalLM.from_pretrained(
|
||||
Path(checkpoint_path) / f'{model_name}-hf', torch_dtype=dtype, device_map="auto"
|
||||
Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map="auto"
|
||||
)
|
||||
model_hf.eval()
|
||||
out_hf = model_hf.model(input_ids).last_hidden_state.to(device=device)
|
||||
logits_hf = model_hf(input_ids).logits.to(device=device)
|
||||
del model_hf
|
||||
|
||||
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
|
||||
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
|
||||
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}')
|
||||
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}')
|
||||
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
||||
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
||||
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
|
||||
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
|
||||
assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()
|
||||
|
||||
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
|
||||
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}')
|
||||
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}')
|
||||
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}')
|
||||
assert (logits - logits_ref).abs().max().item() < 2 * (logits_hf - logits_ref).abs().max().item()
|
||||
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
|
||||
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
|
||||
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
|
||||
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
|
||||
assert (logits - logits_ref).abs().max().item() < 2 * (
|
||||
logits_hf - logits_ref
|
||||
).abs().max().item()
|
||||
|
||||
import shutil
|
||||
shutil.rmtree(checkpoint_path / f'{model_name}-hf')
|
||||
if os.path.exists(checkpoint_path / f"{model_name}-hf"):
|
||||
shutil.rmtree(checkpoint_path / f"{model_name}-hf")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_inv_remap_state_dict_hf_llama():
|
||||
checkpoint_path = (
|
||||
Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama"
|
||||
)
|
||||
model_name = f"teeny"
|
||||
|
||||
llama_config = LlamaConfig(
|
||||
num_attention_heads=2,
|
||||
hidden_size=256 * 2,
|
||||
intermediate_size=256 * 2 * 4,
|
||||
num_hidden_layers=4,
|
||||
)
|
||||
config = llama_config_to_gpt2_config(llama_config)
|
||||
config.use_flash_attn = True
|
||||
config.fused_bias_fc = True
|
||||
config.fused_mlp = False # We don't have fused GatedMLP yet
|
||||
config.fused_dropout_add_ln = True
|
||||
config.residual_in_fp32 = True
|
||||
|
||||
# Set up.
|
||||
LlamaForCausalLM(config=llama_config).save_pretrained(checkpoint_path / f"{model_name}-hf")
|
||||
|
||||
# inv_remap_state_dict_hf_llama should be the inverse of remap_state_dict_hf_llama
|
||||
state_dict = state_dict_from_pretrained(checkpoint_path / f"{model_name}-hf")
|
||||
state_dict = {key: val for key, val in state_dict.items() if "rotary_emb.inv_freq" not in key}
|
||||
pretrained_state_dict = remap_state_dict_hf_llama(state_dict, config)
|
||||
state_dict_recover = inv_remap_state_dict_hf_llama(pretrained_state_dict, config)
|
||||
|
||||
assert set(state_dict_recover.keys()) == set(state_dict.keys())
|
||||
|
||||
for key in state_dict_recover.keys():
|
||||
torch.testing.assert_close(state_dict_recover[key], state_dict[key])
|
||||
|
||||
# Tear down.
|
||||
if os.path.exists(checkpoint_path / f"{model_name}-hf"):
|
||||
shutil.rmtree(checkpoint_path / f"{model_name}-hf")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user