Add test for BTLM init

This commit is contained in:
Tri Dao 2023-12-25 14:22:32 -08:00
parent 7ffba9a501
commit 73df3be7d5
2 changed files with 47 additions and 14 deletions

View File

@ -396,7 +396,9 @@ def _init_weights(
mup_init_scale = math.sqrt(mup_width_scale) mup_init_scale = math.sqrt(mup_width_scale)
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, std=initializer_range * mup_init_scale) nn.init.normal_(module.weight, std=initializer_range * mup_init_scale)
module.weight._optim = {"lr_multiplier": mup_width_scale} optim_cfg = getattr(module.weight, "_optim", {})
optim_cfg.update({"lr_multiplier": mup_width_scale})
setattr(module.weight, "_optim", optim_cfg)
if module.bias is not None: if module.bias is not None:
nn.init.zeros_(module.bias) nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding): elif isinstance(module, nn.Embedding):

View File

@ -1,13 +1,9 @@
# Copyright (c) 2023, Tri Dao. # Copyright (c) 2023, Tri Dao.
import os
import time import time
from pathlib import Path
import torch import torch
import pytest import pytest
from einops import rearrange
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
from flash_attn.models.gpt import GPTLMHeadModel from flash_attn.models.gpt import GPTLMHeadModel
@ -21,9 +17,7 @@ def test_btlm_state_dict(model_name):
config = btlm_config_to_gpt2_config( config = btlm_config_to_gpt2_config(
AutoConfig.from_pretrained(model_name, trust_remote_code=True) AutoConfig.from_pretrained(model_name, trust_remote_code=True)
) )
pretrained_state_dict = remap_state_dict_hf_btlm( pretrained_state_dict = remap_state_dict_hf_btlm(state_dict_from_pretrained(model_name), config)
state_dict_from_pretrained(model_name), config
)
model = GPTLMHeadModel(config, device="meta") # Without device='meta' init is very slow model = GPTLMHeadModel(config, device="meta") # Without device='meta' init is very slow
state_dict = model.state_dict() state_dict = model.state_dict()
assert len(state_dict.keys()) == len(pretrained_state_dict.keys()) assert len(state_dict.keys()) == len(pretrained_state_dict.keys())
@ -47,9 +41,7 @@ def test_btlm_optimized(model_name):
config.fused_dropout_add_ln = True config.fused_dropout_add_ln = True
config.residual_in_fp32 = True config.residual_in_fp32 = True
pretrained_state_dict = remap_state_dict_hf_btlm( pretrained_state_dict = remap_state_dict_hf_btlm(state_dict_from_pretrained(model_name), config)
state_dict_from_pretrained(model_name), config
)
model = GPTLMHeadModel(config, device=device, dtype=dtype) model = GPTLMHeadModel(config, device=device, dtype=dtype)
model.load_state_dict(pretrained_state_dict) model.load_state_dict(pretrained_state_dict)
model.eval() model.eval()
@ -152,9 +144,7 @@ def test_btlm_generation(model_name):
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1].to(device=device) logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1].to(device=device)
del model_ref del model_ref
pretrained_state_dict = remap_state_dict_hf_btlm( pretrained_state_dict = remap_state_dict_hf_btlm(state_dict_from_pretrained(model_name), config)
state_dict_from_pretrained(model_name), config
)
model = GPTLMHeadModel(config, device=device, dtype=dtype) model = GPTLMHeadModel(config, device=device, dtype=dtype)
model.load_state_dict(pretrained_state_dict) model.load_state_dict(pretrained_state_dict)
model.eval() model.eval()
@ -212,3 +202,44 @@ def test_btlm_generation(model_name):
assert torch.equal(logits_cg, logits) assert torch.equal(logits_cg, logits)
@pytest.mark.parametrize("model_name", ["cerebras/btlm-3b-8k-base"])
def test_btlm_init(model_name):
dtype = torch.float32
device = "cuda"
btlm_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
config = btlm_config_to_gpt2_config(btlm_config)
model = GPTLMHeadModel(config, device=device, dtype=dtype)
model_ref = AutoModelForCausalLM.from_config(btlm_config, trust_remote_code=True).to(device)
assert model.transformer.embeddings.word_embeddings.weight.mean().abs() < 1e-4
assert (
model.transformer.embeddings.word_embeddings.weight.std()
- model_ref.transformer.wte.weight.std()
).abs() < 1e-4
assert model.lm_head.weight.mean().abs() < 1e-4
assert (model.lm_head.weight.std() - model_ref.lm_head.weight.std()).abs() < 1e-4
for l in range(config.n_layer):
assert model.transformer.layers[l].mixer.Wqkv.weight.mean().abs() < 1e-4
assert (
model.transformer.layers[l].mixer.Wqkv.weight.std()
- model_ref.transformer.h[l].attn.c_attn.weight.std()
).abs() < 1e-4
assert model.transformer.layers[l].mixer.Wqkv.bias.abs().max() == 0.0
assert model.transformer.layers[l].mixer.out_proj.weight.mean().abs() < 1e-4
assert (
model.transformer.layers[l].mixer.out_proj.weight.std()
- model_ref.transformer.h[l].attn.c_proj.weight.std()
).abs() < 1e-4
assert model.transformer.layers[l].mixer.out_proj.bias.abs().max() == 0.0
assert model.transformer.layers[l].mlp.fc1.weight.mean().abs() < 1e-4
assert (
model.transformer.layers[l].mlp.fc1.weight.std()
- model_ref.transformer.h[l].mlp.c_fc.weight.std()
).abs() < 1e-4
assert model.transformer.layers[l].mlp.fc1.bias.abs().max() == 0.0
assert model.transformer.layers[l].mlp.fc2.weight.mean().abs() < 1e-4
assert (
model.transformer.layers[l].mlp.fc2.weight.std()
- model_ref.transformer.h[l].mlp.c_proj.weight.std()
).abs() < 1e-4
assert model.transformer.layers[l].mlp.fc2.bias.abs().max() == 0.0