Implement BTLM model
This commit is contained in:
parent
2e29dacf0c
commit
7ffba9a501
102
flash_attn/models/btlm.py
Normal file
102
flash_attn/models/btlm.py
Normal file
@ -0,0 +1,102 @@
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
|
||||
import math
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange
|
||||
from transformers import GPT2Config, AutoConfig, PretrainedConfig
|
||||
|
||||
|
||||
def remap_state_dict_hf_btlm(state_dict, config):
|
||||
# Word embedding and position embedding
|
||||
def key_mapping_pos_emb(key):
|
||||
return re.sub(r"^transformer.wpe.", "transformer.embeddings.position_embeddings.", key)
|
||||
|
||||
if "transformer.wpe.weight" in state_dict:
|
||||
state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
|
||||
word_embeddings = state_dict.pop("transformer.wte.weight")
|
||||
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
||||
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
||||
vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
||||
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
|
||||
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
||||
)
|
||||
state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
|
||||
|
||||
# LayerNorm
|
||||
def key_mapping_ln(key):
|
||||
key = re.sub(r"^transformer.ln_f.(weight|bias)", r"transformer.ln_f.\1", key)
|
||||
key = re.sub(r"^transformer.h.(\d+).ln_(1|2).(weight|bias)", r"transformer.layers.\1.norm\2.\3", key)
|
||||
return key
|
||||
|
||||
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
||||
|
||||
# MLP
|
||||
for d in range(config.num_hidden_layers):
|
||||
W1 = state_dict.pop(f"transformer.h.{d}.mlp.c_fc.weight")
|
||||
W3 = state_dict.pop(f"transformer.h.{d}.mlp.c_fc2.weight")
|
||||
state_dict[f"transformer.layers.{d}.mlp.fc1.weight"] = torch.cat([W1.t(), W3.t()], dim=0)
|
||||
b1 = state_dict.pop(f"transformer.h.{d}.mlp.c_fc.bias")
|
||||
b3 = state_dict.pop(f"transformer.h.{d}.mlp.c_fc2.bias")
|
||||
state_dict[f"transformer.layers.{d}.mlp.fc1.bias"] = torch.cat([b1, b3], dim=0)
|
||||
W2 = state_dict.pop(f"transformer.h.{d}.mlp.c_proj.weight")
|
||||
state_dict[f"transformer.layers.{d}.mlp.fc2.weight"] = W2.t()
|
||||
|
||||
def key_mapping_mlp(key):
|
||||
key = re.sub(r"^transformer.h.(\d+).mlp.c_proj.bias", r"transformer.layers.\1.mlp.fc2.bias", key)
|
||||
return key
|
||||
|
||||
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
||||
|
||||
# Attention
|
||||
for d in range(config.num_hidden_layers):
|
||||
Wqkv = state_dict.pop(f"transformer.h.{d}.attn.c_attn.weight")
|
||||
state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = Wqkv.t()
|
||||
Wout = state_dict.pop(f"transformer.h.{d}.attn.c_proj.weight")
|
||||
state_dict[f"transformer.layers.{d}.mixer.out_proj.weight"] = Wout.t()
|
||||
state_dict.pop(f"transformer.relative_pe.slopes") # We don't store the Alibi slopes
|
||||
|
||||
def key_mapping_attn(key):
|
||||
key = re.sub(r"^transformer.h.(\d+).attn.c_attn.bias", r"transformer.layers.\1.mixer.Wqkv.bias", key)
|
||||
key = re.sub(
|
||||
r"^transformer.h.(\d+).attn.c_proj.bias", r"transformer.layers.\1.mixer.out_proj.bias", key
|
||||
)
|
||||
return key
|
||||
|
||||
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def btlm_config_to_gpt2_config(btlm_config: PretrainedConfig) -> GPT2Config:
|
||||
return GPT2Config(
|
||||
vocab_size=btlm_config.vocab_size,
|
||||
n_positions=0 if btlm_config.position_embedding_type == "alibi" else btlm_config.n_positions,
|
||||
n_embd=btlm_config.hidden_size,
|
||||
n_layer=btlm_config.num_hidden_layers,
|
||||
n_head=btlm_config.num_attention_heads,
|
||||
n_inner=btlm_config.n_inner,
|
||||
activation_function=btlm_config.activation_function,
|
||||
resid_pdrop=btlm_config.resid_pdrop,
|
||||
embd_pdrop=btlm_config.embd_pdrop,
|
||||
attn_pdrop=btlm_config.attn_pdrop,
|
||||
layer_norm_epsilon=btlm_config.layer_norm_epsilon,
|
||||
initializer_range=btlm_config.initializer_range,
|
||||
bos_token_id=btlm_config.bos_token_id,
|
||||
eos_token_id=btlm_config.eos_token_id,
|
||||
# These are new arguments not in the original GPT2Config
|
||||
use_alibi=btlm_config.position_embedding_type == "alibi",
|
||||
use_flash_attn=btlm_config.position_embedding_type == "alibi", # Alibi code path requires flash_attn
|
||||
mup_width_scale=btlm_config.mup_width_scale,
|
||||
mup_embeddings_multiplier=btlm_config.mup_embeddings_scale,
|
||||
mup_output_multiplier=btlm_config.mup_output_alpha,
|
||||
mup_scale_qk_dot_by_d=btlm_config.mup_scale_qk_dot_by_d,
|
||||
mlp_multiple_of=1,
|
||||
)
|
||||
214
tests/models/test_btlm.py
Normal file
214
tests/models/test_btlm.py
Normal file
@ -0,0 +1,214 @@
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import pytest
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
from flash_attn.models.gpt import GPTLMHeadModel
|
||||
from flash_attn.models.btlm import btlm_config_to_gpt2_config, remap_state_dict_hf_btlm
|
||||
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
||||
from flash_attn.utils.generation import update_graph_cache
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["cerebras/btlm-3b-8k-base"])
|
||||
def test_btlm_state_dict(model_name):
|
||||
config = btlm_config_to_gpt2_config(
|
||||
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
)
|
||||
pretrained_state_dict = remap_state_dict_hf_btlm(
|
||||
state_dict_from_pretrained(model_name), config
|
||||
)
|
||||
model = GPTLMHeadModel(config, device="meta") # Without device='meta' init is very slow
|
||||
state_dict = model.state_dict()
|
||||
assert len(state_dict.keys()) == len(pretrained_state_dict.keys())
|
||||
assert state_dict.keys() == pretrained_state_dict.keys()
|
||||
for k in state_dict.keys():
|
||||
assert state_dict[k].shape == pretrained_state_dict[k].shape
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["cerebras/btlm-3b-8k-base"])
|
||||
def test_btlm_optimized(model_name):
|
||||
"""Check that our implementation of Btlm (with all optimizations enabled) matches the
|
||||
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
|
||||
forward pass in fp16, when compared to the HF forward pass in fp32.
|
||||
"""
|
||||
dtype = torch.float16
|
||||
device = "cuda"
|
||||
config = btlm_config_to_gpt2_config(
|
||||
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
)
|
||||
config.fused_bias_fc = True
|
||||
config.fused_dropout_add_ln = True
|
||||
config.residual_in_fp32 = True
|
||||
|
||||
pretrained_state_dict = remap_state_dict_hf_btlm(
|
||||
state_dict_from_pretrained(model_name), config
|
||||
)
|
||||
model = GPTLMHeadModel(config, device=device, dtype=dtype)
|
||||
model.load_state_dict(pretrained_state_dict)
|
||||
model.eval()
|
||||
|
||||
torch.manual_seed(0)
|
||||
batch_size = 2
|
||||
max_seqlen = 256
|
||||
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
|
||||
input_ids = torch.randint(
|
||||
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
|
||||
)
|
||||
with torch.no_grad():
|
||||
out = model.transformer(input_ids)
|
||||
logits = model(input_ids).logits
|
||||
del model
|
||||
|
||||
# Without device_map, the model is loaded on the CPU, which is very slow
|
||||
# Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
|
||||
model_ref = AutoModelForCausalLM.from_pretrained(
|
||||
model_name, device_map="auto", trust_remote_code=True
|
||||
)
|
||||
model_ref.eval()
|
||||
with torch.no_grad():
|
||||
out_ref = model_ref.transformer(input_ids).last_hidden_state.to(device=device)
|
||||
logits_ref = model_ref(input_ids).logits.to(device=device)
|
||||
del model_ref
|
||||
|
||||
model_hf = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=dtype,
|
||||
device_map={"": device},
|
||||
trust_remote_code=True,
|
||||
)
|
||||
model_hf.eval()
|
||||
with torch.no_grad():
|
||||
out_hf = model_hf.transformer(input_ids).last_hidden_state
|
||||
logits_hf = model_hf(input_ids).logits
|
||||
del model_hf
|
||||
|
||||
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
||||
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
||||
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
|
||||
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
|
||||
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
|
||||
|
||||
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
|
||||
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
|
||||
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
|
||||
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
|
||||
assert (logits - logits_ref).abs().max().item() < 3 * (
|
||||
logits_hf - logits_ref
|
||||
).abs().max().item()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["cerebras/btlm-3b-8k-base"])
|
||||
def test_btlm_generation(model_name):
|
||||
dtype = torch.float16
|
||||
device = "cuda"
|
||||
config = btlm_config_to_gpt2_config(
|
||||
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
)
|
||||
config.fused_bias_fc = True
|
||||
config.fused_dropout_add_ln = True
|
||||
config.residual_in_fp32 = True
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
torch.manual_seed(0)
|
||||
batch_size = 1
|
||||
seqlen = 2048
|
||||
max_length = 2048 + 150
|
||||
input_ids = torch.randint(
|
||||
0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
|
||||
)
|
||||
|
||||
model_hf = AutoModelForCausalLM.from_pretrained(
|
||||
model_name, torch_dtype=dtype, device_map={"": device}, trust_remote_code=True
|
||||
)
|
||||
model_hf.eval()
|
||||
print("HF fp16")
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
out_hf = model_hf.generate(
|
||||
input_ids=input_ids,
|
||||
max_length=max_length,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
|
||||
del model_hf
|
||||
|
||||
# Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
|
||||
model_ref = AutoModelForCausalLM.from_pretrained(
|
||||
model_name, device_map="auto", trust_remote_code=True
|
||||
)
|
||||
model_ref.eval()
|
||||
with torch.no_grad():
|
||||
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1].to(device=device)
|
||||
del model_ref
|
||||
|
||||
pretrained_state_dict = remap_state_dict_hf_btlm(
|
||||
state_dict_from_pretrained(model_name), config
|
||||
)
|
||||
model = GPTLMHeadModel(config, device=device, dtype=dtype)
|
||||
model.load_state_dict(pretrained_state_dict)
|
||||
model.eval()
|
||||
|
||||
model(input_ids) # Warm up
|
||||
print("Without CUDA graph")
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
out = model.generate(
|
||||
input_ids=input_ids,
|
||||
max_length=max_length,
|
||||
eos_token_id=eos_token_id,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
enable_timing=True,
|
||||
teacher_outputs=out_hf.sequences,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
|
||||
|
||||
# Capture graph outside the timing loop
|
||||
batch_size, seqlen_og = input_ids.shape
|
||||
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
|
||||
print("With CUDA graph")
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
out_cg = model.generate(
|
||||
input_ids=input_ids,
|
||||
max_length=max_length,
|
||||
cg=True,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
enable_timing=True,
|
||||
teacher_outputs=out_hf.sequences,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
|
||||
|
||||
with torch.no_grad():
|
||||
logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1]
|
||||
logits_hf = torch.stack(out_hf.scores, dim=1)
|
||||
logits = torch.stack(out.scores, dim=1)
|
||||
logits_cg = torch.stack(out_cg.scores, dim=1)
|
||||
|
||||
del model
|
||||
|
||||
hf_error = (logits_hf - logits_ref).abs().max().item()
|
||||
|
||||
print(f"HF fp16 logits max diff: {hf_error}")
|
||||
print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }")
|
||||
print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }")
|
||||
|
||||
assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error
|
||||
assert (logits - logits_ref).abs().max().item() < 2 * hf_error
|
||||
assert torch.equal(logits_cg, logits)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user