From 7fcd3e6a04fa6810cf6f87310d89955f01f9b786 Mon Sep 17 00:00:00 2001 From: Xuechen Li <12689993+lxuechen@users.noreply.github.com> Date: Fri, 18 Aug 2023 20:51:39 -0700 Subject: [PATCH] map custom model state_dict back to huggingface format (#465) * fix name. * set inv function. * add map back function. * handle gqa. * add type annotation to avoid confusion. * fix docstr. * test inverse remap logic. --- flash_attn/models/gpt.py | 6 +- flash_attn/models/llama.py | 117 ++++++++++++- tests/models/test_llama.py | 350 +++++++++++++++++++++++-------------- 3 files changed, 342 insertions(+), 131 deletions(-) diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index ca85319..8955f31 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -785,8 +785,10 @@ def shard_state_dict_tp(state_dict, config, world_size, rank): def combine_state_dicts_tp(state_dicts, config): - """Convert the state_dict of a standard GPT model to the state_dict of a GPT model - with tensor parallel. + """Convert the state_dict of a GPT model with tensor parallel to the state_dict of a + standard GPT model. + + This function is meant to be the "reverse" of shard_state_dict_tp. """ world_size = len(state_dicts) keys = state_dicts[0].keys() diff --git a/flash_attn/models/llama.py b/flash_attn/models/llama.py index 766f5bd..40d4073 100644 --- a/flash_attn/models/llama.py +++ b/flash_attn/models/llama.py @@ -13,7 +13,14 @@ import torch.nn.functional as F from transformers import GPT2Config, LlamaConfig -def remap_state_dict_meta_llama(state_dict, config): +def remap_state_dict_meta_llama( + state_dict: dict[str, torch.Tensor], config: GPT2Config +) -> dict[str, torch.Tensor]: + """Convert the state_dict in Meta format to standard GPT format. + + This function modifies state_dict in place. + """ + def key_mapping_layers(key): return f"transformer.{key}" if not key.startswith("output.") else key @@ -97,7 +104,13 @@ def remap_state_dict_meta_llama(state_dict, config): return state_dict -def remap_state_dict_hf_llama(state_dict, config): +def remap_state_dict_hf_llama( + state_dict: dict[str, torch.Tensor], config: GPT2Config +) -> dict[str, torch.Tensor]: + """Convert the state_dict in Hugging Face format to standard GPT format. + + This function modifies state_dict in place. + """ # Embedding def key_mapping_emb(key): return re.sub(r"^model.embed_tokens.", "transformer.embeddings.word_embeddings.", key) @@ -183,6 +196,106 @@ def remap_state_dict_hf_llama(state_dict, config): return state_dict +def inv_remap_state_dict_hf_llama( + state_dict: dict[str, torch.Tensor], config: GPT2Config +) -> dict[str, torch.Tensor]: + """Convert the state_dict in standard GPT format to Hugging Face format. + + This function is meant to be the inverse of remap_state_dict_hf_llama, up to a + multiplier pad in the embedding and lm_head. That is if the original embedding + isn't a multiple of pad_vocab_size_multiple, then + inv_remap_state_dict_hf_llama(remap_state_dict_hf_llama(state_dict)) != state_dict. + + This function modifies state_dict in place. + """ + + # Embedding + def key_mapping_emb(key): + return re.sub(r"^transformer.embeddings.word_embeddings.", "model.embed_tokens.", key) + + state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) + word_embeddings = state_dict.pop("model.embed_tokens.weight") + pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) + vocab_size = ( + math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple + ) + state_dict["model.embed_tokens.weight"] = F.pad( + word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) + ) + + # LM head + if getattr(config, "tie_word_embeddings"): + state_dict["lm_head.weight"] = state_dict["model.embed_tokens.weight"] + else: + output_embeddings = state_dict.pop("lm_head.weight") + vocab_size = ( + math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple) + * pad_vocab_size_multiple + ) + state_dict["lm_head.weight"] = F.pad( + output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0]) + ) + + # MLP + for l in range(config.n_layer): + w3, w1 = torch.chunk( + state_dict.pop(f"transformer.layers.{l}.mlp.fc1.weight"), chunks=2, dim=0 + ) + state_dict[f"model.layers.{l}.mlp.gate_proj.weight"] = w1 + state_dict[f"model.layers.{l}.mlp.up_proj.weight"] = w3 + + def key_mapping_mlp(key): + return re.sub(r"^transformer.layers.(\d+).mlp.fc2.", r"model.layers.\1.mlp.down_proj.", key) + + state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) + + # LayerNorm + def key_mapping_ln(key): + key = re.sub(r"^transformer.ln_f.", r"model.norm.", key) + key = re.sub(r"^transformer.layers.(\d+).norm1.", r"model.layers.\1.input_layernorm.", key) + key = re.sub( + r"^transformer.layers.(\d+).norm2.", r"model.layers.\1.post_attention_layernorm.", key + ) + return key + + state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) + + def permute(w): + return ( + w.view(config.n_head, config.n_embd // config.n_head // 2, 2, config.n_embd) + .transpose(1, 2) + .reshape(config.n_embd, config.n_embd) + ) + + n_head = config.n_head + n_head_kv = getattr(config, "n_head_kv", n_head) + + embed_dim = config.hidden_size + head_dim = embed_dim // n_head + + q_dim = n_head * head_dim + k_dim = v_dim = n_head_kv * head_dim + + # Attention + for l in range(config.n_layer): + Wqkv = state_dict.pop(f"transformer.layers.{l}.mixer.Wqkv.weight") + Wq = Wqkv[:q_dim] + Wk = Wqkv[q_dim : q_dim + k_dim] + Wv = Wqkv[q_dim + k_dim : q_dim + k_dim + v_dim] + state_dict[f"model.layers.{l}.self_attn.q_proj.weight"] = permute(Wq) + state_dict[f"model.layers.{l}.self_attn.k_proj.weight"] = permute(Wk) + state_dict[f"model.layers.{l}.self_attn.v_proj.weight"] = Wv + state_dict.pop(f"transformer.layers.{l}.attention.inner_attention.rope.freqs", None) + + def key_mapping_attn(key): + return re.sub( + r"^transformer.layers.(\d+).mixer.out_proj.", r"model.layers.\1.self_attn.o_proj.", key + ) + + state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) + return state_dict + + def config_from_meta_checkpoint( checkpoint_path: Union[str, os.PathLike], model_name: str ) -> LlamaConfig: diff --git a/tests/models/test_llama.py b/tests/models/test_llama.py index 09b00d1..b54a5a8 100644 --- a/tests/models/test_llama.py +++ b/tests/models/test_llama.py @@ -13,6 +13,7 @@ current_dir = Path(__file__).parent.absolute() import torch import pytest +import shutil from einops import rearrange @@ -20,7 +21,12 @@ from transformers import LlamaTokenizer, LlamaConfig from transformers.models.llama.modeling_llama import LlamaForCausalLM from flash_attn.models.gpt import GPTLMHeadModel, combine_state_dicts_tp, shard_state_dict_tp -from flash_attn.models.llama import remap_state_dict_meta_llama, llama_config_to_gpt2_config, remap_state_dict_hf_llama +from flash_attn.models.llama import ( + remap_state_dict_meta_llama, + llama_config_to_gpt2_config, + remap_state_dict_hf_llama, + inv_remap_state_dict_hf_llama, +) from flash_attn.models.llama import config_from_checkpoint, state_dicts_from_checkpoint from flash_attn.utils.distributed import all_gather_raw from flash_attn.utils.pretrained import state_dict_from_pretrained @@ -33,37 +39,41 @@ def _pretrained_state_dict_from_checkpoint(checkpoint_path, model_name, config, pretrained_state_dicts = [remap_state_dict_meta_llama(s, config) for s in ckpt_state_dicts] pretrained_state_dict = combine_state_dicts_tp(pretrained_state_dicts, config) else: - pretrained_state_dict = state_dict_from_pretrained(Path(checkpoint_path) / f'{model_name}-hf') + pretrained_state_dict = state_dict_from_pretrained( + Path(checkpoint_path) / f"{model_name}-hf" + ) pretrained_state_dict = remap_state_dict_hf_llama(pretrained_state_dict, config) return pretrained_state_dict -@pytest.mark.parametrize('model_name', ["7B"]) +@pytest.mark.parametrize("model_name", ["7B"]) def test_llama_state_dict(model_name): - checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR', - current_dir.parent.parent / 'checkpoints')) / 'llama' + checkpoint_path = ( + Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama" + ) config = llama_config_to_gpt2_config(config_from_checkpoint(checkpoint_path, model_name)) ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name) pretrained_state_dict = remap_state_dict_meta_llama(ckpt_state_dicts[0], config) - model = GPTLMHeadModel(config, device='meta') # Without device='meta' init is very slow + model = GPTLMHeadModel(config, device="meta") # Without device='meta' init is very slow state_dict = model.state_dict() assert state_dict.keys() == pretrained_state_dict.keys() for k in state_dict.keys(): assert state_dict[k].shape == pretrained_state_dict[k].shape -@pytest.mark.parametrize('model_name', ["7B", "13B"]) -@pytest.mark.parametrize('checkpoint_format', ["meta", "hf"]) +@pytest.mark.parametrize("model_name", ["7B", "13B"]) +@pytest.mark.parametrize("checkpoint_format", ["meta", "hf"]) def test_llama_optimized(model_name, checkpoint_format): """Check that our implementation of LLaMa (with all optimizations enabled) matches the HF implementation: the output of our forward pass in fp16 should be around the same as the HF forward pass in fp16, when compared to the HF forward pass in fp32. """ - checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR', - current_dir.parent.parent / 'checkpoints')) / 'llama' + checkpoint_path = ( + Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama" + ) dtype = torch.float16 - device = 'cuda' + device = "cuda" config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format) config = llama_config_to_gpt2_config(config) config.use_flash_attn = True @@ -83,8 +93,9 @@ def test_llama_optimized(model_name, checkpoint_format): batch_size = 2 max_seqlen = 256 seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device) - input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, - device=device) + input_ids = torch.randint( + 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device + ) with torch.no_grad(): out = model.transformer(input_ids) logits = model(input_ids).logits @@ -92,39 +103,43 @@ def test_llama_optimized(model_name, checkpoint_format): # Without device_map, the model is loaded on the CPU, which is very slow # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB - model_ref = LlamaForCausalLM.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf', - device_map='auto') + model_ref = LlamaForCausalLM.from_pretrained( + Path(checkpoint_path) / f"{model_name}-hf", device_map="auto" + ) model_ref.eval() with torch.no_grad(): out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device) logits_ref = model_ref(input_ids).logits.to(device=device) del model_ref - model_hf = LlamaForCausalLM.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf', - torch_dtype=dtype, device_map={"": device}) + model_hf = LlamaForCausalLM.from_pretrained( + Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map={"": device} + ) model_hf.eval() with torch.no_grad(): out_hf = model_hf.model(input_ids).last_hidden_state logits_hf = model_hf(input_ids).logits del model_hf - print(f'Output max diff: {(out - out_ref).abs().max().item()}') - print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') - print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}') - print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}') + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}") + print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}") assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item() - print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}') - print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}') - print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}') - print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}') - assert (logits - logits_ref).abs().max().item() < 3 * (logits_hf - logits_ref).abs().max().item() + print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}") + print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}") + print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}") + print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}") + assert (logits - logits_ref).abs().max().item() < 3 * ( + logits_hf - logits_ref + ).abs().max().item() # torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "parallel" -@pytest.mark.parametrize('world_size', [2]) -@pytest.mark.parametrize('model_name', ["13B"]) -@pytest.mark.parametrize('checkpoint_format', ["meta", "hf"]) +@pytest.mark.parametrize("world_size", [2]) +@pytest.mark.parametrize("model_name", ["13B"]) +@pytest.mark.parametrize("checkpoint_format", ["meta", "hf"]) def test_llama_parallel(model_name, world_size, checkpoint_format): """Check that our implementation of LLaMa (with all optimizations enabled) matches the HF implementation: the output of our forward pass in fp16 should be around the same as the HF @@ -132,8 +147,9 @@ def test_llama_parallel(model_name, world_size, checkpoint_format): """ from apex.transformer import parallel_state - checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR', - current_dir.parent.parent / 'checkpoints')) / 'llama' + checkpoint_path = ( + Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama" + ) dtype = torch.float16 config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format) @@ -145,8 +161,8 @@ def test_llama_parallel(model_name, world_size, checkpoint_format): config.residual_in_fp32 = True if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend='nccl', init_method='env://') - device = f'cuda:{torch.distributed.get_rank()}' + torch.distributed.init_process_group(backend="nccl", init_method="env://") + device = f"cuda:{torch.distributed.get_rank()}" assert world_size <= torch.distributed.get_world_size() parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) rank = parallel_state.get_tensor_model_parallel_rank() @@ -163,8 +179,9 @@ def test_llama_parallel(model_name, world_size, checkpoint_format): batch_size = 2 max_seqlen = 256 seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device) - input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, - device=device) + input_ids = torch.randint( + 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device + ) with torch.no_grad(): out = model.transformer(input_ids) out, _ = all_gather_raw(out, process_group=process_group) @@ -172,13 +189,13 @@ def test_llama_parallel(model_name, world_size, checkpoint_format): logits = model(input_ids).logits logits = rearrange(logits, "(b s) d -> b s d", b=batch_size) logits, _ = all_gather_raw(logits, process_group) - logits = rearrange(logits, '(n b) ... d -> b ... (n d)', b=batch_size) + logits = rearrange(logits, "(n b) ... d -> b ... (n d)", b=batch_size) del model if rank == 0: # Without device_map, the model is loaded on the CPU, which is very slow model_ref = LlamaForCausalLM.from_pretrained( - Path(checkpoint_path) / f'{model_name}-hf', device_map="auto" + Path(checkpoint_path) / f"{model_name}-hf", device_map="auto" ) model_ref.eval() with torch.no_grad(): @@ -187,7 +204,7 @@ def test_llama_parallel(model_name, world_size, checkpoint_format): del model_ref model_hf = LlamaForCausalLM.from_pretrained( - Path(checkpoint_path) / f'{model_name}-hf', torch_dtype=dtype, device_map="auto" + Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map="auto" ) model_hf.eval() with torch.no_grad(): @@ -195,28 +212,31 @@ def test_llama_parallel(model_name, world_size, checkpoint_format): logits_hf = model_hf(input_ids).logits.to(device=device) del model_hf - print(f'Output max diff: {(out - out_ref).abs().max().item()}') - print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') - print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}') - print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}') + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}") + print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}") assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item() - print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}') - print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}') - print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}') - print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}') - assert (logits - logits_ref).abs().max().item() < 2 * (logits_hf - logits_ref).abs().max().item() + print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}") + print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}") + print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}") + print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}") + assert (logits - logits_ref).abs().max().item() < 2 * ( + logits_hf - logits_ref + ).abs().max().item() # @pytest.mark.parametrize('model_name', ["7B", "13B"]) -@pytest.mark.parametrize('model_name', ["7B"]) -@pytest.mark.parametrize('checkpoint_format', ["meta", "hf"]) +@pytest.mark.parametrize("model_name", ["7B"]) +@pytest.mark.parametrize("checkpoint_format", ["meta", "hf"]) def test_llama_generation(model_name, checkpoint_format): - checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR', - current_dir.parent.parent / 'checkpoints')) / 'llama' + checkpoint_path = ( + Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama" + ) dtype = torch.float16 - device = 'cuda' + device = "cuda" config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format) config = llama_config_to_gpt2_config(config) config.use_flash_attn = True @@ -225,34 +245,38 @@ def test_llama_generation(model_name, checkpoint_format): config.fused_dropout_add_ln = True config.residual_in_fp32 = True - tokenizer = LlamaTokenizer.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf') + tokenizer = LlamaTokenizer.from_pretrained(Path(checkpoint_path) / f"{model_name}-hf") eos_token_id = tokenizer.eos_token_id torch.manual_seed(0) batch_size = 1 seqlen = 100 max_length = 150 - input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, - device=device) + input_ids = torch.randint( + 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device + ) - model_hf = LlamaForCausalLM.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf', - torch_dtype=dtype, device_map={"": device}) + model_hf = LlamaForCausalLM.from_pretrained( + Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map={"": device} + ) model_hf.eval() print("HF fp16") torch.cuda.synchronize() start = time.time() - out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length, - return_dict_in_generate=True, output_scores=True) + out_hf = model_hf.generate( + input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True + ) torch.cuda.synchronize() - print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms') + print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") del model_hf # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB - model_ref = LlamaForCausalLM.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf', - device_map='auto') + model_ref = LlamaForCausalLM.from_pretrained( + Path(checkpoint_path) / f"{model_name}-hf", device_map="auto" + ) model_ref.eval() with torch.no_grad(): - logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1):-1].to(device=device) + logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1].to(device=device) del model_ref pretrained_state_dict = _pretrained_state_dict_from_checkpoint( @@ -262,31 +286,43 @@ def test_llama_generation(model_name, checkpoint_format): model.load_state_dict(pretrained_state_dict) model.eval() - print('Without CUDA graph') + print("Without CUDA graph") torch.cuda.synchronize() start = time.time() - out = model.generate(input_ids=input_ids, max_length=max_length, - eos_token_id=eos_token_id, fused_ft_kernel=True, - return_dict_in_generate=True, output_scores=True, timing=True, - teacher_outputs=out_hf.sequences) + out = model.generate( + input_ids=input_ids, + max_length=max_length, + eos_token_id=eos_token_id, + fused_ft_kernel=True, + return_dict_in_generate=True, + output_scores=True, + timing=True, + teacher_outputs=out_hf.sequences, + ) torch.cuda.synchronize() - print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms') + print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") # Capture graph outside the timing loop batch_size, seqlen_og = input_ids.shape model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length) - print('With CUDA graph') + print("With CUDA graph") torch.cuda.synchronize() start = time.time() - out_cg = model.generate(input_ids=input_ids, max_length=max_length, - fused_ft_kernel=True, cg=True, - return_dict_in_generate=True, output_scores=True, timing=True, - teacher_outputs=out_hf.sequences) + out_cg = model.generate( + input_ids=input_ids, + max_length=max_length, + fused_ft_kernel=True, + cg=True, + return_dict_in_generate=True, + output_scores=True, + timing=True, + teacher_outputs=out_hf.sequences, + ) torch.cuda.synchronize() - print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms') + print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") with torch.no_grad(): - logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1):-1] + logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1] logits_hf = torch.stack(out_hf.scores, dim=1) logits = torch.stack(out.scores, dim=1) logits_cg = torch.stack(out_cg.scores, dim=1) @@ -295,9 +331,9 @@ def test_llama_generation(model_name, checkpoint_format): hf_error = (logits_hf - logits_ref).abs().max().item() - print(f'HF fp16 logits max diff: {hf_error}') - print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}') - print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item()}') + print(f"HF fp16 logits max diff: {hf_error}") + print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}") + print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item()}") assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error assert (logits - logits_ref).abs().max().item() < 2 * hf_error @@ -305,9 +341,9 @@ def test_llama_generation(model_name, checkpoint_format): # torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "llama_parallel_generation" -@pytest.mark.parametrize('world_size', [2]) -@pytest.mark.parametrize('model_name', ["13B"]) -@pytest.mark.parametrize('checkpoint_format', ["meta", "hf"]) +@pytest.mark.parametrize("world_size", [2]) +@pytest.mark.parametrize("model_name", ["13B"]) +@pytest.mark.parametrize("checkpoint_format", ["meta", "hf"]) def test_llama_parallel_generation(model_name, world_size, checkpoint_format): """Check that our implementation matches the HF implementation: the scores in fp16 should be around the same as the HF scores in fp16, when compared to @@ -315,8 +351,9 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format): """ from apex.transformer import parallel_state - checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR', - current_dir.parent.parent / 'checkpoints')) / 'llama' + checkpoint_path = ( + Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama" + ) dtype = torch.float16 config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format) @@ -331,8 +368,8 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format): os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0" if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend='nccl', init_method='env://') - device = f'cuda:{torch.distributed.get_rank()}' + torch.distributed.init_process_group(backend="nccl", init_method="env://") + device = f"cuda:{torch.distributed.get_rank()}" assert world_size <= torch.distributed.get_world_size() parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) rank = parallel_state.get_tensor_model_parallel_rank() @@ -342,8 +379,9 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format): batch_size = 1 seqlen = 100 max_length = 150 - input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, - device=device) + input_ids = torch.randint( + 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device + ) # Need this, otherwise when we capture the graph the process for GPU 1 would run on both # GPU0 and GPU1 and things would hang @@ -356,23 +394,34 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format): model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank)) model.eval() - print('Without CUDA graph') + print("Without CUDA graph") out = model.generate( - input_ids=input_ids, max_length=max_length, tensor_parallel=world_size, - vocab_size=config.vocab_size, fused_ft_kernel=True, + input_ids=input_ids, + max_length=max_length, + tensor_parallel=world_size, + vocab_size=config.vocab_size, + fused_ft_kernel=True, # teacher_outputs=out_hf.sequences, - return_dict_in_generate=True, output_scores=True, timing=True + return_dict_in_generate=True, + output_scores=True, + timing=True, ) # Capture graph outside the timing loop batch_size, seqlen_og = input_ids.shape model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length) - print('With CUDA graph') + print("With CUDA graph") out_cg = model.generate( - input_ids=input_ids, max_length=max_length, tensor_parallel=world_size, - vocab_size=config.vocab_size, fused_ft_kernel=True, cg=True, + input_ids=input_ids, + max_length=max_length, + tensor_parallel=world_size, + vocab_size=config.vocab_size, + fused_ft_kernel=True, + cg=True, # teacher_outputs=out_hf.sequences, - return_dict_in_generate=True, output_scores=True, timing=True + return_dict_in_generate=True, + output_scores=True, + timing=True, ) del model parallel_state.destroy_model_parallel() @@ -380,7 +429,7 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format): if rank == 0: # Without device_map, the model is loaded on the CPU, which is very slow model_hf = LlamaForCausalLM.from_pretrained( - Path(checkpoint_path) / f'{model_name}-hf', torch_dtype=dtype, device_map="auto" + Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map="auto" ) model_hf.eval() print("HF fp16") @@ -388,19 +437,21 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format): start = time.time() with torch.inference_mode(): out_hf = model_hf.generate( - input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, - output_scores=True + input_ids=input_ids, + max_length=max_length, + return_dict_in_generate=True, + output_scores=True, ) torch.cuda.synchronize() - print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms') + print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") del model_hf model_ref = LlamaForCausalLM.from_pretrained( - Path(checkpoint_path) / f'{model_name}-hf', device_map="auto" + Path(checkpoint_path) / f"{model_name}-hf", device_map="auto" ) model_ref.eval() with torch.inference_mode(): - logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1):-1] + logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1] del model_ref logits_hf = torch.stack(out_hf.scores, dim=1) @@ -408,25 +459,27 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format): logits_cg = torch.stack(out_cg.scores, dim=1) hf_error = (logits_hf - logits_ref).abs().max().item() - print(f'HF fp16 logits max diff: {hf_error}') - print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}') + print(f"HF fp16 logits max diff: {hf_error}") + print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}") assert (logits - logits_ref).abs().max().item() < 2 * hf_error - print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item()}') + print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item()}") assert torch.equal(logits_cg, logits) @torch.no_grad() -@pytest.mark.parametrize('world_size', [2]) +@pytest.mark.parametrize("world_size", [2]) def test_llama_parallel_uneven_num_heads(world_size): from apex.transformer import parallel_state - checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR', current_dir.parent.parent / 'checkpoints')) / 'llama' + checkpoint_path = ( + Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama" + ) num_attention_heads = world_size + 1 - model_name = f'teeny-{num_attention_heads}-heads' + model_name = f"teeny-{num_attention_heads}-heads" if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend='nccl', init_method='env://') - device = f'cuda:{torch.distributed.get_rank()}' + torch.distributed.init_process_group(backend="nccl", init_method="env://") + device = f"cuda:{torch.distributed.get_rank()}" assert world_size <= torch.distributed.get_world_size() parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) rank = parallel_state.get_tensor_model_parallel_rank() @@ -434,7 +487,8 @@ def test_llama_parallel_uneven_num_heads(world_size): dtype = torch.float16 llama_config = LlamaConfig( - hidden_size=256 * num_attention_heads, # ParallelGatedMlp hidden_features must be divisible by 256 + hidden_size=256 + * num_attention_heads, # ParallelGatedMlp hidden_features must be divisible by 256 intermediate_size=256 * num_attention_heads * 4, num_hidden_layers=4, num_attention_heads=num_attention_heads, @@ -451,8 +505,9 @@ def test_llama_parallel_uneven_num_heads(world_size): batch_size = 2 max_seqlen = 256 seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device) - input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, - device=device) + input_ids = torch.randint( + 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device + ) # Create a shared test model. if rank == 0: @@ -474,11 +529,11 @@ def test_llama_parallel_uneven_num_heads(world_size): logits = model(input_ids).logits logits = rearrange(logits, "(b s) d -> b s d", b=batch_size) logits, _ = all_gather_raw(logits, process_group) - logits = rearrange(logits, '(n b) ... d -> b ... (n d)', b=batch_size) + logits = rearrange(logits, "(n b) ... d -> b ... (n d)", b=batch_size) if rank == 0: model_ref = LlamaForCausalLM.from_pretrained( - Path(checkpoint_path) / f'{model_name}-hf', device_map="auto" + Path(checkpoint_path) / f"{model_name}-hf", device_map="auto" ) model_ref.eval() out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device) @@ -486,24 +541,65 @@ def test_llama_parallel_uneven_num_heads(world_size): del model_ref model_hf = LlamaForCausalLM.from_pretrained( - Path(checkpoint_path) / f'{model_name}-hf', torch_dtype=dtype, device_map="auto" + Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map="auto" ) model_hf.eval() out_hf = model_hf.model(input_ids).last_hidden_state.to(device=device) logits_hf = model_hf(input_ids).logits.to(device=device) del model_hf - print(f'Output max diff: {(out - out_ref).abs().max().item()}') - print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') - print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}') - print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}') + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}") + print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}") assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item() - print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}') - print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}') - print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}') - print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}') - assert (logits - logits_ref).abs().max().item() < 2 * (logits_hf - logits_ref).abs().max().item() + print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}") + print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}") + print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}") + print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}") + assert (logits - logits_ref).abs().max().item() < 2 * ( + logits_hf - logits_ref + ).abs().max().item() - import shutil - shutil.rmtree(checkpoint_path / f'{model_name}-hf') + if os.path.exists(checkpoint_path / f"{model_name}-hf"): + shutil.rmtree(checkpoint_path / f"{model_name}-hf") + + +@torch.no_grad() +def test_inv_remap_state_dict_hf_llama(): + checkpoint_path = ( + Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama" + ) + model_name = f"teeny" + + llama_config = LlamaConfig( + num_attention_heads=2, + hidden_size=256 * 2, + intermediate_size=256 * 2 * 4, + num_hidden_layers=4, + ) + config = llama_config_to_gpt2_config(llama_config) + config.use_flash_attn = True + config.fused_bias_fc = True + config.fused_mlp = False # We don't have fused GatedMLP yet + config.fused_dropout_add_ln = True + config.residual_in_fp32 = True + + # Set up. + LlamaForCausalLM(config=llama_config).save_pretrained(checkpoint_path / f"{model_name}-hf") + + # inv_remap_state_dict_hf_llama should be the inverse of remap_state_dict_hf_llama + state_dict = state_dict_from_pretrained(checkpoint_path / f"{model_name}-hf") + state_dict = {key: val for key, val in state_dict.items() if "rotary_emb.inv_freq" not in key} + pretrained_state_dict = remap_state_dict_hf_llama(state_dict, config) + state_dict_recover = inv_remap_state_dict_hf_llama(pretrained_state_dict, config) + + assert set(state_dict_recover.keys()) == set(state_dict.keys()) + + for key in state_dict_recover.keys(): + torch.testing.assert_close(state_dict_recover[key], state_dict[key]) + + # Tear down. + if os.path.exists(checkpoint_path / f"{model_name}-hf"): + shutil.rmtree(checkpoint_path / f"{model_name}-hf")