Implement norm head for Baichuan2

This commit is contained in:
Tri Dao 2023-12-22 16:08:08 -08:00
parent 68f178aa4b
commit 2c7d7b7396
3 changed files with 64 additions and 40 deletions

View File

@ -116,6 +116,9 @@ def remap_state_dict_hf_baichuan(state_dict, config):
def baichuan_config_to_gpt2_config(baichuan_config: PretrainedConfig) -> GPT2Config:
# HACK: the config doesn't have say whether it's rotary or alibi.
# So we have to infer from the hidden size (7B -> rotary, 13B -> alibi).
# HACK: the config doesn't have say whether it uses norm head.
# So we have to infer from the vocab size
# (v1, vocab size 64k, no norm head; v2, vocab size 128k, norm head).
use_rotary = baichuan_config.hidden_size < 5000
return GPT2Config(
vocab_size=baichuan_config.vocab_size,
@ -141,6 +144,7 @@ def baichuan_config_to_gpt2_config(baichuan_config: PretrainedConfig) -> GPT2Con
use_alibi=not use_rotary,
use_flash_attn=not use_rotary, # Alibi code path requires flash_attn
tie_word_embeddings=False,
norm_head=baichuan_config.vocab_size > 70000,
qkv_proj_bias=False,
out_proj_bias=False,
mlp_fc1_bias=False,

View File

@ -32,7 +32,12 @@ from flash_attn.modules.mlp import (
ParallelMLP,
)
from flash_attn.ops.activations import sqrelu_fwd
from flash_attn.utils.distributed import all_gather_raw, get_dim_for_local_rank, sync_shared_params
from flash_attn.utils.distributed import (
all_gather,
all_gather_raw,
get_dim_for_local_rank,
sync_shared_params,
)
from flash_attn.utils.generation import GenerationMixin
from flash_attn.utils.pretrained import state_dict_from_pretrained
@ -355,9 +360,8 @@ class GPTPreTrainedModel(nn.Module):
state_dict = remap_state_dict_hf_gpt2(state_dict, config)
elif model_name.startswith("facebook/opt"):
state_dict = remap_state_dict_hf_opt(state_dict, config)
elif (
model_name.startswith("EleutherAI/gpt-j-")
or model_name.startswith("togethercomputer/GPT-JT-")
elif model_name.startswith("EleutherAI/gpt-j-") or model_name.startswith(
"togethercomputer/GPT-JT-"
):
state_dict = remap_state_dict_hf_gptj(state_dict, config)
elif (
@ -621,6 +625,7 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
sequence_parallel=getattr(config, "sequence_parallel", True),
**factory_kwargs,
)
self.norm_head = getattr(config, "norm_head", False)
# Initialize weights and apply final processing
self.apply(
partial(
@ -662,7 +667,13 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
hidden_states = hidden_states[:, -num_last_tokens:]
if self.project_out is not None:
hidden_states = self.project_out(hidden_states)
lm_logits = self.lm_head(hidden_states)
if not self.norm_head:
lm_logits = self.lm_head(hidden_states)
else:
lm_head_weight = F.normalize(self.lm_head.weight)
if isinstance(self.lm_head, ColumnParallelLinear) and self.lm_head.sequence_parallel:
hidden_states = all_gather(hidden_states, self.lm_head.process_group)
lm_logits = F.linear(hidden_states, lm_head_weight, bias=self.lm_head.bias)
# During inference, we want the full logit for sampling
if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None:
lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group)

View File

@ -23,7 +23,15 @@ from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.generation import update_graph_cache
@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B", "baichuan-inc/Baichuan-13B-Base"])
@pytest.mark.parametrize(
"model_name",
[
"baichuan-inc/Baichuan-7B",
"baichuan-inc/Baichuan-13B-Base",
"baichuan-inc/Baichuan2-7B-Base",
"baichuan-inc/Baichuan2-13B-Base",
],
)
def test_baichuan_state_dict(model_name):
config = baichuan_config_to_gpt2_config(
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
@ -39,7 +47,15 @@ def test_baichuan_state_dict(model_name):
assert state_dict[k].shape == pretrained_state_dict[k].shape
@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B", "baichuan-inc/Baichuan-13B-Base"])
@pytest.mark.parametrize(
"model_name",
[
"baichuan-inc/Baichuan-7B",
"baichuan-inc/Baichuan-13B-Base",
"baichuan-inc/Baichuan2-7B-Base",
"baichuan-inc/Baichuan2-13B-Base",
],
)
def test_baichuan_optimized(model_name):
"""Check that our implementation of Baichuan (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
@ -66,9 +82,7 @@ def test_baichuan_optimized(model_name):
torch.manual_seed(0)
batch_size = 2
max_seqlen = 256
seqlens = torch.randint(
max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device
)
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
input_ids = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
)
@ -89,7 +103,10 @@ def test_baichuan_optimized(model_name):
del model_ref
model_hf = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=dtype, device_map={"": device}, trust_remote_code=True,
model_name,
torch_dtype=dtype,
device_map={"": device},
trust_remote_code=True,
)
model_hf.eval()
with torch.no_grad():
@ -101,9 +118,7 @@ def test_baichuan_optimized(model_name):
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 3 * (
out_hf - out_ref
).abs().max().item()
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
@ -116,7 +131,15 @@ def test_baichuan_optimized(model_name):
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_baichuan.py -k "test_baichuan_parallel_forward"
@pytest.mark.parametrize("world_size", [2])
@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B", "baichuan-inc/Baichuan-13B-Base"])
@pytest.mark.parametrize(
"model_name",
[
"baichuan-inc/Baichuan-7B",
"baichuan-inc/Baichuan-13B-Base",
"baichuan-inc/Baichuan2-7B-Base",
"baichuan-inc/Baichuan2-13B-Base",
],
)
def test_baichuan_parallel_forward(model_name, world_size):
"""Check that our implementation of Baichuan (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
@ -146,20 +169,14 @@ def test_baichuan_parallel_forward(model_name, world_size):
state_dict_from_pretrained(model_name), config
)
model = GPTLMHeadModel(
config, process_group=process_group, device=device, dtype=dtype
)
model.load_state_dict(
shard_state_dict_tp(pretrained_state_dict, config, world_size, rank)
)
model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
model.eval()
torch.manual_seed(0)
batch_size = 2
max_seqlen = 256
seqlens = torch.randint(
max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device
)
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
input_ids = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
)
@ -198,9 +215,7 @@ def test_baichuan_parallel_forward(model_name, world_size):
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 2 * (
out_hf - out_ref
).abs().max().item()
assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
@ -211,7 +226,9 @@ def test_baichuan_parallel_forward(model_name, world_size):
).abs().max().item()
@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B", "baichuan-inc/Baichuan-13B-Base"])
@pytest.mark.parametrize(
"model_name", ["baichuan-inc/Baichuan-7B", "baichuan-inc/Baichuan-13B-Base"]
)
def test_baichuan_generation(model_name):
dtype = torch.float16
device = "cuda"
@ -258,9 +275,7 @@ def test_baichuan_generation(model_name):
)
model_ref.eval()
with torch.no_grad():
logits_ref = (
model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1].to(device=device)
)
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1].to(device=device)
del model_ref
pretrained_state_dict = remap_state_dict_hf_baichuan(
@ -370,12 +385,8 @@ def test_baichuan_parallel_generation(model_name, world_size):
state_dict_from_pretrained(model_name), config
)
model = GPTLMHeadModel(
config, process_group=process_group, device=device, dtype=dtype
)
model.load_state_dict(
shard_state_dict_tp(pretrained_state_dict, config, world_size, rank)
)
model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
model.eval()
print("Without CUDA graph")
@ -425,9 +436,7 @@ def test_baichuan_parallel_generation(model_name, world_size):
output_scores=True,
)
torch.cuda.synchronize()
print(
f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms"
)
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
del model_hf
model_ref = AutoModelForCausalLM.from_pretrained(