Add test for BTLM init
This commit is contained in:
parent
7ffba9a501
commit
73df3be7d5
@ -396,7 +396,9 @@ def _init_weights(
|
||||
mup_init_scale = math.sqrt(mup_width_scale)
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.normal_(module.weight, std=initializer_range * mup_init_scale)
|
||||
module.weight._optim = {"lr_multiplier": mup_width_scale}
|
||||
optim_cfg = getattr(module.weight, "_optim", {})
|
||||
optim_cfg.update({"lr_multiplier": mup_width_scale})
|
||||
setattr(module.weight, "_optim", optim_cfg)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
|
||||
@ -1,13 +1,9 @@
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import pytest
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
from flash_attn.models.gpt import GPTLMHeadModel
|
||||
@ -21,9 +17,7 @@ def test_btlm_state_dict(model_name):
|
||||
config = btlm_config_to_gpt2_config(
|
||||
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
)
|
||||
pretrained_state_dict = remap_state_dict_hf_btlm(
|
||||
state_dict_from_pretrained(model_name), config
|
||||
)
|
||||
pretrained_state_dict = remap_state_dict_hf_btlm(state_dict_from_pretrained(model_name), config)
|
||||
model = GPTLMHeadModel(config, device="meta") # Without device='meta' init is very slow
|
||||
state_dict = model.state_dict()
|
||||
assert len(state_dict.keys()) == len(pretrained_state_dict.keys())
|
||||
@ -47,9 +41,7 @@ def test_btlm_optimized(model_name):
|
||||
config.fused_dropout_add_ln = True
|
||||
config.residual_in_fp32 = True
|
||||
|
||||
pretrained_state_dict = remap_state_dict_hf_btlm(
|
||||
state_dict_from_pretrained(model_name), config
|
||||
)
|
||||
pretrained_state_dict = remap_state_dict_hf_btlm(state_dict_from_pretrained(model_name), config)
|
||||
model = GPTLMHeadModel(config, device=device, dtype=dtype)
|
||||
model.load_state_dict(pretrained_state_dict)
|
||||
model.eval()
|
||||
@ -152,9 +144,7 @@ def test_btlm_generation(model_name):
|
||||
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1].to(device=device)
|
||||
del model_ref
|
||||
|
||||
pretrained_state_dict = remap_state_dict_hf_btlm(
|
||||
state_dict_from_pretrained(model_name), config
|
||||
)
|
||||
pretrained_state_dict = remap_state_dict_hf_btlm(state_dict_from_pretrained(model_name), config)
|
||||
model = GPTLMHeadModel(config, device=device, dtype=dtype)
|
||||
model.load_state_dict(pretrained_state_dict)
|
||||
model.eval()
|
||||
@ -212,3 +202,44 @@ def test_btlm_generation(model_name):
|
||||
assert torch.equal(logits_cg, logits)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["cerebras/btlm-3b-8k-base"])
|
||||
def test_btlm_init(model_name):
|
||||
dtype = torch.float32
|
||||
device = "cuda"
|
||||
btlm_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
config = btlm_config_to_gpt2_config(btlm_config)
|
||||
model = GPTLMHeadModel(config, device=device, dtype=dtype)
|
||||
model_ref = AutoModelForCausalLM.from_config(btlm_config, trust_remote_code=True).to(device)
|
||||
|
||||
assert model.transformer.embeddings.word_embeddings.weight.mean().abs() < 1e-4
|
||||
assert (
|
||||
model.transformer.embeddings.word_embeddings.weight.std()
|
||||
- model_ref.transformer.wte.weight.std()
|
||||
).abs() < 1e-4
|
||||
assert model.lm_head.weight.mean().abs() < 1e-4
|
||||
assert (model.lm_head.weight.std() - model_ref.lm_head.weight.std()).abs() < 1e-4
|
||||
for l in range(config.n_layer):
|
||||
assert model.transformer.layers[l].mixer.Wqkv.weight.mean().abs() < 1e-4
|
||||
assert (
|
||||
model.transformer.layers[l].mixer.Wqkv.weight.std()
|
||||
- model_ref.transformer.h[l].attn.c_attn.weight.std()
|
||||
).abs() < 1e-4
|
||||
assert model.transformer.layers[l].mixer.Wqkv.bias.abs().max() == 0.0
|
||||
assert model.transformer.layers[l].mixer.out_proj.weight.mean().abs() < 1e-4
|
||||
assert (
|
||||
model.transformer.layers[l].mixer.out_proj.weight.std()
|
||||
- model_ref.transformer.h[l].attn.c_proj.weight.std()
|
||||
).abs() < 1e-4
|
||||
assert model.transformer.layers[l].mixer.out_proj.bias.abs().max() == 0.0
|
||||
assert model.transformer.layers[l].mlp.fc1.weight.mean().abs() < 1e-4
|
||||
assert (
|
||||
model.transformer.layers[l].mlp.fc1.weight.std()
|
||||
- model_ref.transformer.h[l].mlp.c_fc.weight.std()
|
||||
).abs() < 1e-4
|
||||
assert model.transformer.layers[l].mlp.fc1.bias.abs().max() == 0.0
|
||||
assert model.transformer.layers[l].mlp.fc2.weight.mean().abs() < 1e-4
|
||||
assert (
|
||||
model.transformer.layers[l].mlp.fc2.weight.std()
|
||||
- model_ref.transformer.h[l].mlp.c_proj.weight.std()
|
||||
).abs() < 1e-4
|
||||
assert model.transformer.layers[l].mlp.fc2.bias.abs().max() == 0.0
|
||||
|
||||
Loading…
Reference in New Issue
Block a user