diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index 6939344..135213f 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -396,7 +396,9 @@ def _init_weights( mup_init_scale = math.sqrt(mup_width_scale) if isinstance(module, nn.Linear): nn.init.normal_(module.weight, std=initializer_range * mup_init_scale) - module.weight._optim = {"lr_multiplier": mup_width_scale} + optim_cfg = getattr(module.weight, "_optim", {}) + optim_cfg.update({"lr_multiplier": mup_width_scale}) + setattr(module.weight, "_optim", optim_cfg) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): diff --git a/tests/models/test_btlm.py b/tests/models/test_btlm.py index e82e842..eb5316b 100644 --- a/tests/models/test_btlm.py +++ b/tests/models/test_btlm.py @@ -1,13 +1,9 @@ # Copyright (c) 2023, Tri Dao. -import os import time -from pathlib import Path import torch import pytest -from einops import rearrange - from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM from flash_attn.models.gpt import GPTLMHeadModel @@ -21,9 +17,7 @@ def test_btlm_state_dict(model_name): config = btlm_config_to_gpt2_config( AutoConfig.from_pretrained(model_name, trust_remote_code=True) ) - pretrained_state_dict = remap_state_dict_hf_btlm( - state_dict_from_pretrained(model_name), config - ) + pretrained_state_dict = remap_state_dict_hf_btlm(state_dict_from_pretrained(model_name), config) model = GPTLMHeadModel(config, device="meta") # Without device='meta' init is very slow state_dict = model.state_dict() assert len(state_dict.keys()) == len(pretrained_state_dict.keys()) @@ -47,9 +41,7 @@ def test_btlm_optimized(model_name): config.fused_dropout_add_ln = True config.residual_in_fp32 = True - pretrained_state_dict = remap_state_dict_hf_btlm( - state_dict_from_pretrained(model_name), config - ) + pretrained_state_dict = remap_state_dict_hf_btlm(state_dict_from_pretrained(model_name), config) model = GPTLMHeadModel(config, device=device, dtype=dtype) model.load_state_dict(pretrained_state_dict) model.eval() @@ -152,9 +144,7 @@ def test_btlm_generation(model_name): logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1].to(device=device) del model_ref - pretrained_state_dict = remap_state_dict_hf_btlm( - state_dict_from_pretrained(model_name), config - ) + pretrained_state_dict = remap_state_dict_hf_btlm(state_dict_from_pretrained(model_name), config) model = GPTLMHeadModel(config, device=device, dtype=dtype) model.load_state_dict(pretrained_state_dict) model.eval() @@ -212,3 +202,44 @@ def test_btlm_generation(model_name): assert torch.equal(logits_cg, logits) +@pytest.mark.parametrize("model_name", ["cerebras/btlm-3b-8k-base"]) +def test_btlm_init(model_name): + dtype = torch.float32 + device = "cuda" + btlm_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + config = btlm_config_to_gpt2_config(btlm_config) + model = GPTLMHeadModel(config, device=device, dtype=dtype) + model_ref = AutoModelForCausalLM.from_config(btlm_config, trust_remote_code=True).to(device) + + assert model.transformer.embeddings.word_embeddings.weight.mean().abs() < 1e-4 + assert ( + model.transformer.embeddings.word_embeddings.weight.std() + - model_ref.transformer.wte.weight.std() + ).abs() < 1e-4 + assert model.lm_head.weight.mean().abs() < 1e-4 + assert (model.lm_head.weight.std() - model_ref.lm_head.weight.std()).abs() < 1e-4 + for l in range(config.n_layer): + assert model.transformer.layers[l].mixer.Wqkv.weight.mean().abs() < 1e-4 + assert ( + model.transformer.layers[l].mixer.Wqkv.weight.std() + - model_ref.transformer.h[l].attn.c_attn.weight.std() + ).abs() < 1e-4 + assert model.transformer.layers[l].mixer.Wqkv.bias.abs().max() == 0.0 + assert model.transformer.layers[l].mixer.out_proj.weight.mean().abs() < 1e-4 + assert ( + model.transformer.layers[l].mixer.out_proj.weight.std() + - model_ref.transformer.h[l].attn.c_proj.weight.std() + ).abs() < 1e-4 + assert model.transformer.layers[l].mixer.out_proj.bias.abs().max() == 0.0 + assert model.transformer.layers[l].mlp.fc1.weight.mean().abs() < 1e-4 + assert ( + model.transformer.layers[l].mlp.fc1.weight.std() + - model_ref.transformer.h[l].mlp.c_fc.weight.std() + ).abs() < 1e-4 + assert model.transformer.layers[l].mlp.fc1.bias.abs().max() == 0.0 + assert model.transformer.layers[l].mlp.fc2.weight.mean().abs() < 1e-4 + assert ( + model.transformer.layers[l].mlp.fc2.weight.std() + - model_ref.transformer.h[l].mlp.c_proj.weight.std() + ).abs() < 1e-4 + assert model.transformer.layers[l].mlp.fc2.bias.abs().max() == 0.0