Implement norm head for Baichuan2
This commit is contained in:
parent
68f178aa4b
commit
2c7d7b7396
@ -116,6 +116,9 @@ def remap_state_dict_hf_baichuan(state_dict, config):
|
||||
def baichuan_config_to_gpt2_config(baichuan_config: PretrainedConfig) -> GPT2Config:
|
||||
# HACK: the config doesn't have say whether it's rotary or alibi.
|
||||
# So we have to infer from the hidden size (7B -> rotary, 13B -> alibi).
|
||||
# HACK: the config doesn't have say whether it uses norm head.
|
||||
# So we have to infer from the vocab size
|
||||
# (v1, vocab size 64k, no norm head; v2, vocab size 128k, norm head).
|
||||
use_rotary = baichuan_config.hidden_size < 5000
|
||||
return GPT2Config(
|
||||
vocab_size=baichuan_config.vocab_size,
|
||||
@ -141,6 +144,7 @@ def baichuan_config_to_gpt2_config(baichuan_config: PretrainedConfig) -> GPT2Con
|
||||
use_alibi=not use_rotary,
|
||||
use_flash_attn=not use_rotary, # Alibi code path requires flash_attn
|
||||
tie_word_embeddings=False,
|
||||
norm_head=baichuan_config.vocab_size > 70000,
|
||||
qkv_proj_bias=False,
|
||||
out_proj_bias=False,
|
||||
mlp_fc1_bias=False,
|
||||
|
||||
@ -32,7 +32,12 @@ from flash_attn.modules.mlp import (
|
||||
ParallelMLP,
|
||||
)
|
||||
from flash_attn.ops.activations import sqrelu_fwd
|
||||
from flash_attn.utils.distributed import all_gather_raw, get_dim_for_local_rank, sync_shared_params
|
||||
from flash_attn.utils.distributed import (
|
||||
all_gather,
|
||||
all_gather_raw,
|
||||
get_dim_for_local_rank,
|
||||
sync_shared_params,
|
||||
)
|
||||
from flash_attn.utils.generation import GenerationMixin
|
||||
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
||||
|
||||
@ -355,9 +360,8 @@ class GPTPreTrainedModel(nn.Module):
|
||||
state_dict = remap_state_dict_hf_gpt2(state_dict, config)
|
||||
elif model_name.startswith("facebook/opt"):
|
||||
state_dict = remap_state_dict_hf_opt(state_dict, config)
|
||||
elif (
|
||||
model_name.startswith("EleutherAI/gpt-j-")
|
||||
or model_name.startswith("togethercomputer/GPT-JT-")
|
||||
elif model_name.startswith("EleutherAI/gpt-j-") or model_name.startswith(
|
||||
"togethercomputer/GPT-JT-"
|
||||
):
|
||||
state_dict = remap_state_dict_hf_gptj(state_dict, config)
|
||||
elif (
|
||||
@ -621,6 +625,7 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
|
||||
sequence_parallel=getattr(config, "sequence_parallel", True),
|
||||
**factory_kwargs,
|
||||
)
|
||||
self.norm_head = getattr(config, "norm_head", False)
|
||||
# Initialize weights and apply final processing
|
||||
self.apply(
|
||||
partial(
|
||||
@ -662,7 +667,13 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
|
||||
hidden_states = hidden_states[:, -num_last_tokens:]
|
||||
if self.project_out is not None:
|
||||
hidden_states = self.project_out(hidden_states)
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
if not self.norm_head:
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
else:
|
||||
lm_head_weight = F.normalize(self.lm_head.weight)
|
||||
if isinstance(self.lm_head, ColumnParallelLinear) and self.lm_head.sequence_parallel:
|
||||
hidden_states = all_gather(hidden_states, self.lm_head.process_group)
|
||||
lm_logits = F.linear(hidden_states, lm_head_weight, bias=self.lm_head.bias)
|
||||
# During inference, we want the full logit for sampling
|
||||
if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None:
|
||||
lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group)
|
||||
|
||||
@ -23,7 +23,15 @@ from flash_attn.utils.pretrained import state_dict_from_pretrained
|
||||
from flash_attn.utils.generation import update_graph_cache
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B", "baichuan-inc/Baichuan-13B-Base"])
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[
|
||||
"baichuan-inc/Baichuan-7B",
|
||||
"baichuan-inc/Baichuan-13B-Base",
|
||||
"baichuan-inc/Baichuan2-7B-Base",
|
||||
"baichuan-inc/Baichuan2-13B-Base",
|
||||
],
|
||||
)
|
||||
def test_baichuan_state_dict(model_name):
|
||||
config = baichuan_config_to_gpt2_config(
|
||||
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
@ -39,7 +47,15 @@ def test_baichuan_state_dict(model_name):
|
||||
assert state_dict[k].shape == pretrained_state_dict[k].shape
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B", "baichuan-inc/Baichuan-13B-Base"])
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[
|
||||
"baichuan-inc/Baichuan-7B",
|
||||
"baichuan-inc/Baichuan-13B-Base",
|
||||
"baichuan-inc/Baichuan2-7B-Base",
|
||||
"baichuan-inc/Baichuan2-13B-Base",
|
||||
],
|
||||
)
|
||||
def test_baichuan_optimized(model_name):
|
||||
"""Check that our implementation of Baichuan (with all optimizations enabled) matches the
|
||||
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
|
||||
@ -66,9 +82,7 @@ def test_baichuan_optimized(model_name):
|
||||
torch.manual_seed(0)
|
||||
batch_size = 2
|
||||
max_seqlen = 256
|
||||
seqlens = torch.randint(
|
||||
max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device
|
||||
)
|
||||
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
|
||||
input_ids = torch.randint(
|
||||
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
|
||||
)
|
||||
@ -89,7 +103,10 @@ def test_baichuan_optimized(model_name):
|
||||
del model_ref
|
||||
|
||||
model_hf = AutoModelForCausalLM.from_pretrained(
|
||||
model_name, torch_dtype=dtype, device_map={"": device}, trust_remote_code=True,
|
||||
model_name,
|
||||
torch_dtype=dtype,
|
||||
device_map={"": device},
|
||||
trust_remote_code=True,
|
||||
)
|
||||
model_hf.eval()
|
||||
with torch.no_grad():
|
||||
@ -101,9 +118,7 @@ def test_baichuan_optimized(model_name):
|
||||
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
||||
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
|
||||
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
|
||||
assert (out - out_ref).abs().max().item() < 3 * (
|
||||
out_hf - out_ref
|
||||
).abs().max().item()
|
||||
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
|
||||
|
||||
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
|
||||
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
|
||||
@ -116,7 +131,15 @@ def test_baichuan_optimized(model_name):
|
||||
|
||||
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_baichuan.py -k "test_baichuan_parallel_forward"
|
||||
@pytest.mark.parametrize("world_size", [2])
|
||||
@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B", "baichuan-inc/Baichuan-13B-Base"])
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[
|
||||
"baichuan-inc/Baichuan-7B",
|
||||
"baichuan-inc/Baichuan-13B-Base",
|
||||
"baichuan-inc/Baichuan2-7B-Base",
|
||||
"baichuan-inc/Baichuan2-13B-Base",
|
||||
],
|
||||
)
|
||||
def test_baichuan_parallel_forward(model_name, world_size):
|
||||
"""Check that our implementation of Baichuan (with all optimizations enabled) matches the
|
||||
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
|
||||
@ -146,20 +169,14 @@ def test_baichuan_parallel_forward(model_name, world_size):
|
||||
state_dict_from_pretrained(model_name), config
|
||||
)
|
||||
|
||||
model = GPTLMHeadModel(
|
||||
config, process_group=process_group, device=device, dtype=dtype
|
||||
)
|
||||
model.load_state_dict(
|
||||
shard_state_dict_tp(pretrained_state_dict, config, world_size, rank)
|
||||
)
|
||||
model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
|
||||
model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
|
||||
model.eval()
|
||||
|
||||
torch.manual_seed(0)
|
||||
batch_size = 2
|
||||
max_seqlen = 256
|
||||
seqlens = torch.randint(
|
||||
max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device
|
||||
)
|
||||
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
|
||||
input_ids = torch.randint(
|
||||
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
|
||||
)
|
||||
@ -198,9 +215,7 @@ def test_baichuan_parallel_forward(model_name, world_size):
|
||||
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
||||
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
|
||||
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
|
||||
assert (out - out_ref).abs().max().item() < 2 * (
|
||||
out_hf - out_ref
|
||||
).abs().max().item()
|
||||
assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()
|
||||
|
||||
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
|
||||
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
|
||||
@ -211,7 +226,9 @@ def test_baichuan_parallel_forward(model_name, world_size):
|
||||
).abs().max().item()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B", "baichuan-inc/Baichuan-13B-Base"])
|
||||
@pytest.mark.parametrize(
|
||||
"model_name", ["baichuan-inc/Baichuan-7B", "baichuan-inc/Baichuan-13B-Base"]
|
||||
)
|
||||
def test_baichuan_generation(model_name):
|
||||
dtype = torch.float16
|
||||
device = "cuda"
|
||||
@ -258,9 +275,7 @@ def test_baichuan_generation(model_name):
|
||||
)
|
||||
model_ref.eval()
|
||||
with torch.no_grad():
|
||||
logits_ref = (
|
||||
model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1].to(device=device)
|
||||
)
|
||||
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1].to(device=device)
|
||||
del model_ref
|
||||
|
||||
pretrained_state_dict = remap_state_dict_hf_baichuan(
|
||||
@ -370,12 +385,8 @@ def test_baichuan_parallel_generation(model_name, world_size):
|
||||
state_dict_from_pretrained(model_name), config
|
||||
)
|
||||
|
||||
model = GPTLMHeadModel(
|
||||
config, process_group=process_group, device=device, dtype=dtype
|
||||
)
|
||||
model.load_state_dict(
|
||||
shard_state_dict_tp(pretrained_state_dict, config, world_size, rank)
|
||||
)
|
||||
model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
|
||||
model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
|
||||
model.eval()
|
||||
|
||||
print("Without CUDA graph")
|
||||
@ -425,9 +436,7 @@ def test_baichuan_parallel_generation(model_name, world_size):
|
||||
output_scores=True,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
print(
|
||||
f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms"
|
||||
)
|
||||
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
|
||||
del model_hf
|
||||
|
||||
model_ref = AutoModelForCausalLM.from_pretrained(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user