Add test for BTLM init

This commit is contained in:
Tri Dao 2023-12-25 14:22:32 -08:00
parent 7ffba9a501
commit 73df3be7d5
2 changed files with 47 additions and 14 deletions

View File

@ -396,7 +396,9 @@ def _init_weights(
mup_init_scale = math.sqrt(mup_width_scale)
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, std=initializer_range * mup_init_scale)
module.weight._optim = {"lr_multiplier": mup_width_scale}
optim_cfg = getattr(module.weight, "_optim", {})
optim_cfg.update({"lr_multiplier": mup_width_scale})
setattr(module.weight, "_optim", optim_cfg)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):

View File

@ -1,13 +1,9 @@
# Copyright (c) 2023, Tri Dao.
import os
import time
from pathlib import Path
import torch
import pytest
from einops import rearrange
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
from flash_attn.models.gpt import GPTLMHeadModel
@ -21,9 +17,7 @@ def test_btlm_state_dict(model_name):
config = btlm_config_to_gpt2_config(
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
)
pretrained_state_dict = remap_state_dict_hf_btlm(
state_dict_from_pretrained(model_name), config
)
pretrained_state_dict = remap_state_dict_hf_btlm(state_dict_from_pretrained(model_name), config)
model = GPTLMHeadModel(config, device="meta") # Without device='meta' init is very slow
state_dict = model.state_dict()
assert len(state_dict.keys()) == len(pretrained_state_dict.keys())
@ -47,9 +41,7 @@ def test_btlm_optimized(model_name):
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
pretrained_state_dict = remap_state_dict_hf_btlm(
state_dict_from_pretrained(model_name), config
)
pretrained_state_dict = remap_state_dict_hf_btlm(state_dict_from_pretrained(model_name), config)
model = GPTLMHeadModel(config, device=device, dtype=dtype)
model.load_state_dict(pretrained_state_dict)
model.eval()
@ -152,9 +144,7 @@ def test_btlm_generation(model_name):
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1].to(device=device)
del model_ref
pretrained_state_dict = remap_state_dict_hf_btlm(
state_dict_from_pretrained(model_name), config
)
pretrained_state_dict = remap_state_dict_hf_btlm(state_dict_from_pretrained(model_name), config)
model = GPTLMHeadModel(config, device=device, dtype=dtype)
model.load_state_dict(pretrained_state_dict)
model.eval()
@ -212,3 +202,44 @@ def test_btlm_generation(model_name):
assert torch.equal(logits_cg, logits)
@pytest.mark.parametrize("model_name", ["cerebras/btlm-3b-8k-base"])
def test_btlm_init(model_name):
dtype = torch.float32
device = "cuda"
btlm_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
config = btlm_config_to_gpt2_config(btlm_config)
model = GPTLMHeadModel(config, device=device, dtype=dtype)
model_ref = AutoModelForCausalLM.from_config(btlm_config, trust_remote_code=True).to(device)
assert model.transformer.embeddings.word_embeddings.weight.mean().abs() < 1e-4
assert (
model.transformer.embeddings.word_embeddings.weight.std()
- model_ref.transformer.wte.weight.std()
).abs() < 1e-4
assert model.lm_head.weight.mean().abs() < 1e-4
assert (model.lm_head.weight.std() - model_ref.lm_head.weight.std()).abs() < 1e-4
for l in range(config.n_layer):
assert model.transformer.layers[l].mixer.Wqkv.weight.mean().abs() < 1e-4
assert (
model.transformer.layers[l].mixer.Wqkv.weight.std()
- model_ref.transformer.h[l].attn.c_attn.weight.std()
).abs() < 1e-4
assert model.transformer.layers[l].mixer.Wqkv.bias.abs().max() == 0.0
assert model.transformer.layers[l].mixer.out_proj.weight.mean().abs() < 1e-4
assert (
model.transformer.layers[l].mixer.out_proj.weight.std()
- model_ref.transformer.h[l].attn.c_proj.weight.std()
).abs() < 1e-4
assert model.transformer.layers[l].mixer.out_proj.bias.abs().max() == 0.0
assert model.transformer.layers[l].mlp.fc1.weight.mean().abs() < 1e-4
assert (
model.transformer.layers[l].mlp.fc1.weight.std()
- model_ref.transformer.h[l].mlp.c_fc.weight.std()
).abs() < 1e-4
assert model.transformer.layers[l].mlp.fc1.bias.abs().max() == 0.0
assert model.transformer.layers[l].mlp.fc2.weight.mean().abs() < 1e-4
assert (
model.transformer.layers[l].mlp.fc2.weight.std()
- model_ref.transformer.h[l].mlp.c_proj.weight.std()
).abs() < 1e-4
assert model.transformer.layers[l].mlp.fc2.bias.abs().max() == 0.0