diff --git a/flash_attn/models/baichuan.py b/flash_attn/models/baichuan.py index 2ca9ac1..be7320b 100644 --- a/flash_attn/models/baichuan.py +++ b/flash_attn/models/baichuan.py @@ -116,6 +116,9 @@ def remap_state_dict_hf_baichuan(state_dict, config): def baichuan_config_to_gpt2_config(baichuan_config: PretrainedConfig) -> GPT2Config: # HACK: the config doesn't have say whether it's rotary or alibi. # So we have to infer from the hidden size (7B -> rotary, 13B -> alibi). + # HACK: the config doesn't have say whether it uses norm head. + # So we have to infer from the vocab size + # (v1, vocab size 64k, no norm head; v2, vocab size 128k, norm head). use_rotary = baichuan_config.hidden_size < 5000 return GPT2Config( vocab_size=baichuan_config.vocab_size, @@ -141,6 +144,7 @@ def baichuan_config_to_gpt2_config(baichuan_config: PretrainedConfig) -> GPT2Con use_alibi=not use_rotary, use_flash_attn=not use_rotary, # Alibi code path requires flash_attn tie_word_embeddings=False, + norm_head=baichuan_config.vocab_size > 70000, qkv_proj_bias=False, out_proj_bias=False, mlp_fc1_bias=False, diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index 97d555d..04c7674 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -32,7 +32,12 @@ from flash_attn.modules.mlp import ( ParallelMLP, ) from flash_attn.ops.activations import sqrelu_fwd -from flash_attn.utils.distributed import all_gather_raw, get_dim_for_local_rank, sync_shared_params +from flash_attn.utils.distributed import ( + all_gather, + all_gather_raw, + get_dim_for_local_rank, + sync_shared_params, +) from flash_attn.utils.generation import GenerationMixin from flash_attn.utils.pretrained import state_dict_from_pretrained @@ -355,9 +360,8 @@ class GPTPreTrainedModel(nn.Module): state_dict = remap_state_dict_hf_gpt2(state_dict, config) elif model_name.startswith("facebook/opt"): state_dict = remap_state_dict_hf_opt(state_dict, config) - elif ( - model_name.startswith("EleutherAI/gpt-j-") - or model_name.startswith("togethercomputer/GPT-JT-") + elif model_name.startswith("EleutherAI/gpt-j-") or model_name.startswith( + "togethercomputer/GPT-JT-" ): state_dict = remap_state_dict_hf_gptj(state_dict, config) elif ( @@ -621,6 +625,7 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): sequence_parallel=getattr(config, "sequence_parallel", True), **factory_kwargs, ) + self.norm_head = getattr(config, "norm_head", False) # Initialize weights and apply final processing self.apply( partial( @@ -662,7 +667,13 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): hidden_states = hidden_states[:, -num_last_tokens:] if self.project_out is not None: hidden_states = self.project_out(hidden_states) - lm_logits = self.lm_head(hidden_states) + if not self.norm_head: + lm_logits = self.lm_head(hidden_states) + else: + lm_head_weight = F.normalize(self.lm_head.weight) + if isinstance(self.lm_head, ColumnParallelLinear) and self.lm_head.sequence_parallel: + hidden_states = all_gather(hidden_states, self.lm_head.process_group) + lm_logits = F.linear(hidden_states, lm_head_weight, bias=self.lm_head.bias) # During inference, we want the full logit for sampling if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None: lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group) diff --git a/tests/models/test_baichuan.py b/tests/models/test_baichuan.py index 4f04c2c..1fc550a 100644 --- a/tests/models/test_baichuan.py +++ b/tests/models/test_baichuan.py @@ -23,7 +23,15 @@ from flash_attn.utils.pretrained import state_dict_from_pretrained from flash_attn.utils.generation import update_graph_cache -@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B", "baichuan-inc/Baichuan-13B-Base"]) +@pytest.mark.parametrize( + "model_name", + [ + "baichuan-inc/Baichuan-7B", + "baichuan-inc/Baichuan-13B-Base", + "baichuan-inc/Baichuan2-7B-Base", + "baichuan-inc/Baichuan2-13B-Base", + ], +) def test_baichuan_state_dict(model_name): config = baichuan_config_to_gpt2_config( AutoConfig.from_pretrained(model_name, trust_remote_code=True) @@ -39,7 +47,15 @@ def test_baichuan_state_dict(model_name): assert state_dict[k].shape == pretrained_state_dict[k].shape -@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B", "baichuan-inc/Baichuan-13B-Base"]) +@pytest.mark.parametrize( + "model_name", + [ + "baichuan-inc/Baichuan-7B", + "baichuan-inc/Baichuan-13B-Base", + "baichuan-inc/Baichuan2-7B-Base", + "baichuan-inc/Baichuan2-13B-Base", + ], +) def test_baichuan_optimized(model_name): """Check that our implementation of Baichuan (with all optimizations enabled) matches the HF implementation: the output of our forward pass in fp16 should be around the same as the HF @@ -66,9 +82,7 @@ def test_baichuan_optimized(model_name): torch.manual_seed(0) batch_size = 2 max_seqlen = 256 - seqlens = torch.randint( - max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device - ) + seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device) input_ids = torch.randint( 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device ) @@ -89,7 +103,10 @@ def test_baichuan_optimized(model_name): del model_ref model_hf = AutoModelForCausalLM.from_pretrained( - model_name, torch_dtype=dtype, device_map={"": device}, trust_remote_code=True, + model_name, + torch_dtype=dtype, + device_map={"": device}, + trust_remote_code=True, ) model_hf.eval() with torch.no_grad(): @@ -101,9 +118,7 @@ def test_baichuan_optimized(model_name): print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}") print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}") - assert (out - out_ref).abs().max().item() < 3 * ( - out_hf - out_ref - ).abs().max().item() + assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item() print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}") print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}") @@ -116,7 +131,15 @@ def test_baichuan_optimized(model_name): # torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_baichuan.py -k "test_baichuan_parallel_forward" @pytest.mark.parametrize("world_size", [2]) -@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B", "baichuan-inc/Baichuan-13B-Base"]) +@pytest.mark.parametrize( + "model_name", + [ + "baichuan-inc/Baichuan-7B", + "baichuan-inc/Baichuan-13B-Base", + "baichuan-inc/Baichuan2-7B-Base", + "baichuan-inc/Baichuan2-13B-Base", + ], +) def test_baichuan_parallel_forward(model_name, world_size): """Check that our implementation of Baichuan (with all optimizations enabled) matches the HF implementation: the output of our forward pass in fp16 should be around the same as the HF @@ -146,20 +169,14 @@ def test_baichuan_parallel_forward(model_name, world_size): state_dict_from_pretrained(model_name), config ) - model = GPTLMHeadModel( - config, process_group=process_group, device=device, dtype=dtype - ) - model.load_state_dict( - shard_state_dict_tp(pretrained_state_dict, config, world_size, rank) - ) + model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype) + model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank)) model.eval() torch.manual_seed(0) batch_size = 2 max_seqlen = 256 - seqlens = torch.randint( - max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device - ) + seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device) input_ids = torch.randint( 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device ) @@ -198,9 +215,7 @@ def test_baichuan_parallel_forward(model_name, world_size): print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}") print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}") - assert (out - out_ref).abs().max().item() < 2 * ( - out_hf - out_ref - ).abs().max().item() + assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item() print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}") print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}") @@ -211,7 +226,9 @@ def test_baichuan_parallel_forward(model_name, world_size): ).abs().max().item() -@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B", "baichuan-inc/Baichuan-13B-Base"]) +@pytest.mark.parametrize( + "model_name", ["baichuan-inc/Baichuan-7B", "baichuan-inc/Baichuan-13B-Base"] +) def test_baichuan_generation(model_name): dtype = torch.float16 device = "cuda" @@ -258,9 +275,7 @@ def test_baichuan_generation(model_name): ) model_ref.eval() with torch.no_grad(): - logits_ref = ( - model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1].to(device=device) - ) + logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1].to(device=device) del model_ref pretrained_state_dict = remap_state_dict_hf_baichuan( @@ -370,12 +385,8 @@ def test_baichuan_parallel_generation(model_name, world_size): state_dict_from_pretrained(model_name), config ) - model = GPTLMHeadModel( - config, process_group=process_group, device=device, dtype=dtype - ) - model.load_state_dict( - shard_state_dict_tp(pretrained_state_dict, config, world_size, rank) - ) + model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype) + model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank)) model.eval() print("Without CUDA graph") @@ -425,9 +436,7 @@ def test_baichuan_parallel_generation(model_name, world_size): output_scores=True, ) torch.cuda.synchronize() - print( - f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms" - ) + print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") del model_hf model_ref = AutoModelForCausalLM.from_pretrained(