map custom model state_dict back to huggingface format (#465)

* fix name.

* set inv function.

* add map back function.

* handle gqa.

* add type annotation to avoid confusion.

* fix docstr.

* test inverse remap logic.
This commit is contained in:
Xuechen Li 2023-08-18 20:51:39 -07:00 committed by GitHub
parent f1a73d0740
commit 7fcd3e6a04
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 342 additions and 131 deletions

View File

@ -785,8 +785,10 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
def combine_state_dicts_tp(state_dicts, config):
"""Convert the state_dict of a standard GPT model to the state_dict of a GPT model
with tensor parallel.
"""Convert the state_dict of a GPT model with tensor parallel to the state_dict of a
standard GPT model.
This function is meant to be the "reverse" of shard_state_dict_tp.
"""
world_size = len(state_dicts)
keys = state_dicts[0].keys()

View File

@ -13,7 +13,14 @@ import torch.nn.functional as F
from transformers import GPT2Config, LlamaConfig
def remap_state_dict_meta_llama(state_dict, config):
def remap_state_dict_meta_llama(
state_dict: dict[str, torch.Tensor], config: GPT2Config
) -> dict[str, torch.Tensor]:
"""Convert the state_dict in Meta format to standard GPT format.
This function modifies state_dict in place.
"""
def key_mapping_layers(key):
return f"transformer.{key}" if not key.startswith("output.") else key
@ -97,7 +104,13 @@ def remap_state_dict_meta_llama(state_dict, config):
return state_dict
def remap_state_dict_hf_llama(state_dict, config):
def remap_state_dict_hf_llama(
state_dict: dict[str, torch.Tensor], config: GPT2Config
) -> dict[str, torch.Tensor]:
"""Convert the state_dict in Hugging Face format to standard GPT format.
This function modifies state_dict in place.
"""
# Embedding
def key_mapping_emb(key):
return re.sub(r"^model.embed_tokens.", "transformer.embeddings.word_embeddings.", key)
@ -183,6 +196,106 @@ def remap_state_dict_hf_llama(state_dict, config):
return state_dict
def inv_remap_state_dict_hf_llama(
state_dict: dict[str, torch.Tensor], config: GPT2Config
) -> dict[str, torch.Tensor]:
"""Convert the state_dict in standard GPT format to Hugging Face format.
This function is meant to be the inverse of remap_state_dict_hf_llama, up to a
multiplier pad in the embedding and lm_head. That is if the original embedding
isn't a multiple of pad_vocab_size_multiple, then
inv_remap_state_dict_hf_llama(remap_state_dict_hf_llama(state_dict)) != state_dict.
This function modifies state_dict in place.
"""
# Embedding
def key_mapping_emb(key):
return re.sub(r"^transformer.embeddings.word_embeddings.", "model.embed_tokens.", key)
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop("model.embed_tokens.weight")
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
vocab_size = (
math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple
)
state_dict["model.embed_tokens.weight"] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
# LM head
if getattr(config, "tie_word_embeddings"):
state_dict["lm_head.weight"] = state_dict["model.embed_tokens.weight"]
else:
output_embeddings = state_dict.pop("lm_head.weight")
vocab_size = (
math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
* pad_vocab_size_multiple
)
state_dict["lm_head.weight"] = F.pad(
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
)
# MLP
for l in range(config.n_layer):
w3, w1 = torch.chunk(
state_dict.pop(f"transformer.layers.{l}.mlp.fc1.weight"), chunks=2, dim=0
)
state_dict[f"model.layers.{l}.mlp.gate_proj.weight"] = w1
state_dict[f"model.layers.{l}.mlp.up_proj.weight"] = w3
def key_mapping_mlp(key):
return re.sub(r"^transformer.layers.(\d+).mlp.fc2.", r"model.layers.\1.mlp.down_proj.", key)
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# LayerNorm
def key_mapping_ln(key):
key = re.sub(r"^transformer.ln_f.", r"model.norm.", key)
key = re.sub(r"^transformer.layers.(\d+).norm1.", r"model.layers.\1.input_layernorm.", key)
key = re.sub(
r"^transformer.layers.(\d+).norm2.", r"model.layers.\1.post_attention_layernorm.", key
)
return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
def permute(w):
return (
w.view(config.n_head, config.n_embd // config.n_head // 2, 2, config.n_embd)
.transpose(1, 2)
.reshape(config.n_embd, config.n_embd)
)
n_head = config.n_head
n_head_kv = getattr(config, "n_head_kv", n_head)
embed_dim = config.hidden_size
head_dim = embed_dim // n_head
q_dim = n_head * head_dim
k_dim = v_dim = n_head_kv * head_dim
# Attention
for l in range(config.n_layer):
Wqkv = state_dict.pop(f"transformer.layers.{l}.mixer.Wqkv.weight")
Wq = Wqkv[:q_dim]
Wk = Wqkv[q_dim : q_dim + k_dim]
Wv = Wqkv[q_dim + k_dim : q_dim + k_dim + v_dim]
state_dict[f"model.layers.{l}.self_attn.q_proj.weight"] = permute(Wq)
state_dict[f"model.layers.{l}.self_attn.k_proj.weight"] = permute(Wk)
state_dict[f"model.layers.{l}.self_attn.v_proj.weight"] = Wv
state_dict.pop(f"transformer.layers.{l}.attention.inner_attention.rope.freqs", None)
def key_mapping_attn(key):
return re.sub(
r"^transformer.layers.(\d+).mixer.out_proj.", r"model.layers.\1.self_attn.o_proj.", key
)
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
return state_dict
def config_from_meta_checkpoint(
checkpoint_path: Union[str, os.PathLike], model_name: str
) -> LlamaConfig:

View File

@ -13,6 +13,7 @@ current_dir = Path(__file__).parent.absolute()
import torch
import pytest
import shutil
from einops import rearrange
@ -20,7 +21,12 @@ from transformers import LlamaTokenizer, LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from flash_attn.models.gpt import GPTLMHeadModel, combine_state_dicts_tp, shard_state_dict_tp
from flash_attn.models.llama import remap_state_dict_meta_llama, llama_config_to_gpt2_config, remap_state_dict_hf_llama
from flash_attn.models.llama import (
remap_state_dict_meta_llama,
llama_config_to_gpt2_config,
remap_state_dict_hf_llama,
inv_remap_state_dict_hf_llama,
)
from flash_attn.models.llama import config_from_checkpoint, state_dicts_from_checkpoint
from flash_attn.utils.distributed import all_gather_raw
from flash_attn.utils.pretrained import state_dict_from_pretrained
@ -33,37 +39,41 @@ def _pretrained_state_dict_from_checkpoint(checkpoint_path, model_name, config,
pretrained_state_dicts = [remap_state_dict_meta_llama(s, config) for s in ckpt_state_dicts]
pretrained_state_dict = combine_state_dicts_tp(pretrained_state_dicts, config)
else:
pretrained_state_dict = state_dict_from_pretrained(Path(checkpoint_path) / f'{model_name}-hf')
pretrained_state_dict = state_dict_from_pretrained(
Path(checkpoint_path) / f"{model_name}-hf"
)
pretrained_state_dict = remap_state_dict_hf_llama(pretrained_state_dict, config)
return pretrained_state_dict
@pytest.mark.parametrize('model_name', ["7B"])
@pytest.mark.parametrize("model_name", ["7B"])
def test_llama_state_dict(model_name):
checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR',
current_dir.parent.parent / 'checkpoints')) / 'llama'
checkpoint_path = (
Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama"
)
config = llama_config_to_gpt2_config(config_from_checkpoint(checkpoint_path, model_name))
ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name)
pretrained_state_dict = remap_state_dict_meta_llama(ckpt_state_dicts[0], config)
model = GPTLMHeadModel(config, device='meta') # Without device='meta' init is very slow
model = GPTLMHeadModel(config, device="meta") # Without device='meta' init is very slow
state_dict = model.state_dict()
assert state_dict.keys() == pretrained_state_dict.keys()
for k in state_dict.keys():
assert state_dict[k].shape == pretrained_state_dict[k].shape
@pytest.mark.parametrize('model_name', ["7B", "13B"])
@pytest.mark.parametrize('checkpoint_format', ["meta", "hf"])
@pytest.mark.parametrize("model_name", ["7B", "13B"])
@pytest.mark.parametrize("checkpoint_format", ["meta", "hf"])
def test_llama_optimized(model_name, checkpoint_format):
"""Check that our implementation of LLaMa (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR',
current_dir.parent.parent / 'checkpoints')) / 'llama'
checkpoint_path = (
Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama"
)
dtype = torch.float16
device = 'cuda'
device = "cuda"
config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format)
config = llama_config_to_gpt2_config(config)
config.use_flash_attn = True
@ -83,8 +93,9 @@ def test_llama_optimized(model_name, checkpoint_format):
batch_size = 2
max_seqlen = 256
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device=device)
input_ids = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
)
with torch.no_grad():
out = model.transformer(input_ids)
logits = model(input_ids).logits
@ -92,39 +103,43 @@ def test_llama_optimized(model_name, checkpoint_format):
# Without device_map, the model is loaded on the CPU, which is very slow
# Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
model_ref = LlamaForCausalLM.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf',
device_map='auto')
model_ref = LlamaForCausalLM.from_pretrained(
Path(checkpoint_path) / f"{model_name}-hf", device_map="auto"
)
model_ref.eval()
with torch.no_grad():
out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device)
logits_ref = model_ref(input_ids).logits.to(device=device)
del model_ref
model_hf = LlamaForCausalLM.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf',
torch_dtype=dtype, device_map={"": device})
model_hf = LlamaForCausalLM.from_pretrained(
Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map={"": device}
)
model_hf.eval()
with torch.no_grad():
out_hf = model_hf.model(input_ids).last_hidden_state
logits_hf = model_hf(input_ids).logits
del model_hf
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}')
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}')
assert (logits - logits_ref).abs().max().item() < 3 * (logits_hf - logits_ref).abs().max().item()
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 3 * (
logits_hf - logits_ref
).abs().max().item()
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "parallel"
@pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('model_name', ["13B"])
@pytest.mark.parametrize('checkpoint_format', ["meta", "hf"])
@pytest.mark.parametrize("world_size", [2])
@pytest.mark.parametrize("model_name", ["13B"])
@pytest.mark.parametrize("checkpoint_format", ["meta", "hf"])
def test_llama_parallel(model_name, world_size, checkpoint_format):
"""Check that our implementation of LLaMa (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
@ -132,8 +147,9 @@ def test_llama_parallel(model_name, world_size, checkpoint_format):
"""
from apex.transformer import parallel_state
checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR',
current_dir.parent.parent / 'checkpoints')) / 'llama'
checkpoint_path = (
Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama"
)
dtype = torch.float16
config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format)
@ -145,8 +161,8 @@ def test_llama_parallel(model_name, world_size, checkpoint_format):
config.residual_in_fp32 = True
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://')
device = f'cuda:{torch.distributed.get_rank()}'
torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
@ -163,8 +179,9 @@ def test_llama_parallel(model_name, world_size, checkpoint_format):
batch_size = 2
max_seqlen = 256
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device=device)
input_ids = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
)
with torch.no_grad():
out = model.transformer(input_ids)
out, _ = all_gather_raw(out, process_group=process_group)
@ -172,13 +189,13 @@ def test_llama_parallel(model_name, world_size, checkpoint_format):
logits = model(input_ids).logits
logits = rearrange(logits, "(b s) d -> b s d", b=batch_size)
logits, _ = all_gather_raw(logits, process_group)
logits = rearrange(logits, '(n b) ... d -> b ... (n d)', b=batch_size)
logits = rearrange(logits, "(n b) ... d -> b ... (n d)", b=batch_size)
del model
if rank == 0:
# Without device_map, the model is loaded on the CPU, which is very slow
model_ref = LlamaForCausalLM.from_pretrained(
Path(checkpoint_path) / f'{model_name}-hf', device_map="auto"
Path(checkpoint_path) / f"{model_name}-hf", device_map="auto"
)
model_ref.eval()
with torch.no_grad():
@ -187,7 +204,7 @@ def test_llama_parallel(model_name, world_size, checkpoint_format):
del model_ref
model_hf = LlamaForCausalLM.from_pretrained(
Path(checkpoint_path) / f'{model_name}-hf', torch_dtype=dtype, device_map="auto"
Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map="auto"
)
model_hf.eval()
with torch.no_grad():
@ -195,28 +212,31 @@ def test_llama_parallel(model_name, world_size, checkpoint_format):
logits_hf = model_hf(input_ids).logits.to(device=device)
del model_hf
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}')
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}')
assert (logits - logits_ref).abs().max().item() < 2 * (logits_hf - logits_ref).abs().max().item()
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 2 * (
logits_hf - logits_ref
).abs().max().item()
# @pytest.mark.parametrize('model_name', ["7B", "13B"])
@pytest.mark.parametrize('model_name', ["7B"])
@pytest.mark.parametrize('checkpoint_format', ["meta", "hf"])
@pytest.mark.parametrize("model_name", ["7B"])
@pytest.mark.parametrize("checkpoint_format", ["meta", "hf"])
def test_llama_generation(model_name, checkpoint_format):
checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR',
current_dir.parent.parent / 'checkpoints')) / 'llama'
checkpoint_path = (
Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama"
)
dtype = torch.float16
device = 'cuda'
device = "cuda"
config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format)
config = llama_config_to_gpt2_config(config)
config.use_flash_attn = True
@ -225,34 +245,38 @@ def test_llama_generation(model_name, checkpoint_format):
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
tokenizer = LlamaTokenizer.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf')
tokenizer = LlamaTokenizer.from_pretrained(Path(checkpoint_path) / f"{model_name}-hf")
eos_token_id = tokenizer.eos_token_id
torch.manual_seed(0)
batch_size = 1
seqlen = 100
max_length = 150
input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long,
device=device)
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
)
model_hf = LlamaForCausalLM.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf',
torch_dtype=dtype, device_map={"": device})
model_hf = LlamaForCausalLM.from_pretrained(
Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map={"": device}
)
model_hf.eval()
print("HF fp16")
torch.cuda.synchronize()
start = time.time()
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length,
return_dict_in_generate=True, output_scores=True)
out_hf = model_hf.generate(
input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
)
torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
del model_hf
# Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
model_ref = LlamaForCausalLM.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf',
device_map='auto')
model_ref = LlamaForCausalLM.from_pretrained(
Path(checkpoint_path) / f"{model_name}-hf", device_map="auto"
)
model_ref.eval()
with torch.no_grad():
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1):-1].to(device=device)
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1].to(device=device)
del model_ref
pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
@ -262,31 +286,43 @@ def test_llama_generation(model_name, checkpoint_format):
model.load_state_dict(pretrained_state_dict)
model.eval()
print('Without CUDA graph')
print("Without CUDA graph")
torch.cuda.synchronize()
start = time.time()
out = model.generate(input_ids=input_ids, max_length=max_length,
eos_token_id=eos_token_id, fused_ft_kernel=True,
return_dict_in_generate=True, output_scores=True, timing=True,
teacher_outputs=out_hf.sequences)
out = model.generate(
input_ids=input_ids,
max_length=max_length,
eos_token_id=eos_token_id,
fused_ft_kernel=True,
return_dict_in_generate=True,
output_scores=True,
timing=True,
teacher_outputs=out_hf.sequences,
)
torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
# Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
print('With CUDA graph')
print("With CUDA graph")
torch.cuda.synchronize()
start = time.time()
out_cg = model.generate(input_ids=input_ids, max_length=max_length,
fused_ft_kernel=True, cg=True,
return_dict_in_generate=True, output_scores=True, timing=True,
teacher_outputs=out_hf.sequences)
out_cg = model.generate(
input_ids=input_ids,
max_length=max_length,
fused_ft_kernel=True,
cg=True,
return_dict_in_generate=True,
output_scores=True,
timing=True,
teacher_outputs=out_hf.sequences,
)
torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
with torch.no_grad():
logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1):-1]
logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1]
logits_hf = torch.stack(out_hf.scores, dim=1)
logits = torch.stack(out.scores, dim=1)
logits_cg = torch.stack(out_cg.scores, dim=1)
@ -295,9 +331,9 @@ def test_llama_generation(model_name, checkpoint_format):
hf_error = (logits_hf - logits_ref).abs().max().item()
print(f'HF fp16 logits max diff: {hf_error}')
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item()}')
print(f"HF fp16 logits max diff: {hf_error}")
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item()}")
assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error
assert (logits - logits_ref).abs().max().item() < 2 * hf_error
@ -305,9 +341,9 @@ def test_llama_generation(model_name, checkpoint_format):
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "llama_parallel_generation"
@pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('model_name', ["13B"])
@pytest.mark.parametrize('checkpoint_format', ["meta", "hf"])
@pytest.mark.parametrize("world_size", [2])
@pytest.mark.parametrize("model_name", ["13B"])
@pytest.mark.parametrize("checkpoint_format", ["meta", "hf"])
def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
"""Check that our implementation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
@ -315,8 +351,9 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
"""
from apex.transformer import parallel_state
checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR',
current_dir.parent.parent / 'checkpoints')) / 'llama'
checkpoint_path = (
Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama"
)
dtype = torch.float16
config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format)
@ -331,8 +368,8 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://')
device = f'cuda:{torch.distributed.get_rank()}'
torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
@ -342,8 +379,9 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
batch_size = 1
seqlen = 100
max_length = 150
input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long,
device=device)
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
)
# Need this, otherwise when we capture the graph the process for GPU 1 would run on both
# GPU0 and GPU1 and things would hang
@ -356,23 +394,34 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
model.eval()
print('Without CUDA graph')
print("Without CUDA graph")
out = model.generate(
input_ids=input_ids, max_length=max_length, tensor_parallel=world_size,
vocab_size=config.vocab_size, fused_ft_kernel=True,
input_ids=input_ids,
max_length=max_length,
tensor_parallel=world_size,
vocab_size=config.vocab_size,
fused_ft_kernel=True,
# teacher_outputs=out_hf.sequences,
return_dict_in_generate=True, output_scores=True, timing=True
return_dict_in_generate=True,
output_scores=True,
timing=True,
)
# Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
print('With CUDA graph')
print("With CUDA graph")
out_cg = model.generate(
input_ids=input_ids, max_length=max_length, tensor_parallel=world_size,
vocab_size=config.vocab_size, fused_ft_kernel=True, cg=True,
input_ids=input_ids,
max_length=max_length,
tensor_parallel=world_size,
vocab_size=config.vocab_size,
fused_ft_kernel=True,
cg=True,
# teacher_outputs=out_hf.sequences,
return_dict_in_generate=True, output_scores=True, timing=True
return_dict_in_generate=True,
output_scores=True,
timing=True,
)
del model
parallel_state.destroy_model_parallel()
@ -380,7 +429,7 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
if rank == 0:
# Without device_map, the model is loaded on the CPU, which is very slow
model_hf = LlamaForCausalLM.from_pretrained(
Path(checkpoint_path) / f'{model_name}-hf', torch_dtype=dtype, device_map="auto"
Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map="auto"
)
model_hf.eval()
print("HF fp16")
@ -388,19 +437,21 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
start = time.time()
with torch.inference_mode():
out_hf = model_hf.generate(
input_ids=input_ids, max_length=max_length, return_dict_in_generate=True,
output_scores=True
input_ids=input_ids,
max_length=max_length,
return_dict_in_generate=True,
output_scores=True,
)
torch.cuda.synchronize()
print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
del model_hf
model_ref = LlamaForCausalLM.from_pretrained(
Path(checkpoint_path) / f'{model_name}-hf', device_map="auto"
Path(checkpoint_path) / f"{model_name}-hf", device_map="auto"
)
model_ref.eval()
with torch.inference_mode():
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1):-1]
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1]
del model_ref
logits_hf = torch.stack(out_hf.scores, dim=1)
@ -408,25 +459,27 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
logits_cg = torch.stack(out_cg.scores, dim=1)
hf_error = (logits_hf - logits_ref).abs().max().item()
print(f'HF fp16 logits max diff: {hf_error}')
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
print(f"HF fp16 logits max diff: {hf_error}")
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
assert (logits - logits_ref).abs().max().item() < 2 * hf_error
print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item()}')
print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item()}")
assert torch.equal(logits_cg, logits)
@torch.no_grad()
@pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize("world_size", [2])
def test_llama_parallel_uneven_num_heads(world_size):
from apex.transformer import parallel_state
checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR', current_dir.parent.parent / 'checkpoints')) / 'llama'
checkpoint_path = (
Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama"
)
num_attention_heads = world_size + 1
model_name = f'teeny-{num_attention_heads}-heads'
model_name = f"teeny-{num_attention_heads}-heads"
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://')
device = f'cuda:{torch.distributed.get_rank()}'
torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
@ -434,7 +487,8 @@ def test_llama_parallel_uneven_num_heads(world_size):
dtype = torch.float16
llama_config = LlamaConfig(
hidden_size=256 * num_attention_heads, # ParallelGatedMlp hidden_features must be divisible by 256
hidden_size=256
* num_attention_heads, # ParallelGatedMlp hidden_features must be divisible by 256
intermediate_size=256 * num_attention_heads * 4,
num_hidden_layers=4,
num_attention_heads=num_attention_heads,
@ -451,8 +505,9 @@ def test_llama_parallel_uneven_num_heads(world_size):
batch_size = 2
max_seqlen = 256
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device=device)
input_ids = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
)
# Create a shared test model.
if rank == 0:
@ -474,11 +529,11 @@ def test_llama_parallel_uneven_num_heads(world_size):
logits = model(input_ids).logits
logits = rearrange(logits, "(b s) d -> b s d", b=batch_size)
logits, _ = all_gather_raw(logits, process_group)
logits = rearrange(logits, '(n b) ... d -> b ... (n d)', b=batch_size)
logits = rearrange(logits, "(n b) ... d -> b ... (n d)", b=batch_size)
if rank == 0:
model_ref = LlamaForCausalLM.from_pretrained(
Path(checkpoint_path) / f'{model_name}-hf', device_map="auto"
Path(checkpoint_path) / f"{model_name}-hf", device_map="auto"
)
model_ref.eval()
out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device)
@ -486,24 +541,65 @@ def test_llama_parallel_uneven_num_heads(world_size):
del model_ref
model_hf = LlamaForCausalLM.from_pretrained(
Path(checkpoint_path) / f'{model_name}-hf', torch_dtype=dtype, device_map="auto"
Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map="auto"
)
model_hf.eval()
out_hf = model_hf.model(input_ids).last_hidden_state.to(device=device)
logits_hf = model_hf(input_ids).logits.to(device=device)
del model_hf
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}')
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}')
assert (logits - logits_ref).abs().max().item() < 2 * (logits_hf - logits_ref).abs().max().item()
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 2 * (
logits_hf - logits_ref
).abs().max().item()
import shutil
shutil.rmtree(checkpoint_path / f'{model_name}-hf')
if os.path.exists(checkpoint_path / f"{model_name}-hf"):
shutil.rmtree(checkpoint_path / f"{model_name}-hf")
@torch.no_grad()
def test_inv_remap_state_dict_hf_llama():
checkpoint_path = (
Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama"
)
model_name = f"teeny"
llama_config = LlamaConfig(
num_attention_heads=2,
hidden_size=256 * 2,
intermediate_size=256 * 2 * 4,
num_hidden_layers=4,
)
config = llama_config_to_gpt2_config(llama_config)
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused GatedMLP yet
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
# Set up.
LlamaForCausalLM(config=llama_config).save_pretrained(checkpoint_path / f"{model_name}-hf")
# inv_remap_state_dict_hf_llama should be the inverse of remap_state_dict_hf_llama
state_dict = state_dict_from_pretrained(checkpoint_path / f"{model_name}-hf")
state_dict = {key: val for key, val in state_dict.items() if "rotary_emb.inv_freq" not in key}
pretrained_state_dict = remap_state_dict_hf_llama(state_dict, config)
state_dict_recover = inv_remap_state_dict_hf_llama(pretrained_state_dict, config)
assert set(state_dict_recover.keys()) == set(state_dict.keys())
for key in state_dict_recover.keys():
torch.testing.assert_close(state_dict_recover[key], state_dict[key])
# Tear down.
if os.path.exists(checkpoint_path / f"{model_name}-hf"):
shutil.rmtree(checkpoint_path / f"{model_name}-hf")