diff --git a/flash_attn/models/btlm.py b/flash_attn/models/btlm.py new file mode 100644 index 0000000..295e120 --- /dev/null +++ b/flash_attn/models/btlm.py @@ -0,0 +1,102 @@ +# Copyright (c) 2023, Tri Dao. + +import math +import json +import re +from pathlib import Path + +from collections import OrderedDict + +import torch +import torch.nn.functional as F + +from einops import rearrange +from transformers import GPT2Config, AutoConfig, PretrainedConfig + + +def remap_state_dict_hf_btlm(state_dict, config): + # Word embedding and position embedding + def key_mapping_pos_emb(key): + return re.sub(r"^transformer.wpe.", "transformer.embeddings.position_embeddings.", key) + + if "transformer.wpe.weight" in state_dict: + state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items()) + word_embeddings = state_dict.pop("transformer.wte.weight") + # It's possible that vocab_size is padded to be a multiple of 8, for example. + pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) + vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple + state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad( + word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) + ) + state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"] + + # LayerNorm + def key_mapping_ln(key): + key = re.sub(r"^transformer.ln_f.(weight|bias)", r"transformer.ln_f.\1", key) + key = re.sub(r"^transformer.h.(\d+).ln_(1|2).(weight|bias)", r"transformer.layers.\1.norm\2.\3", key) + return key + + state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) + + # MLP + for d in range(config.num_hidden_layers): + W1 = state_dict.pop(f"transformer.h.{d}.mlp.c_fc.weight") + W3 = state_dict.pop(f"transformer.h.{d}.mlp.c_fc2.weight") + state_dict[f"transformer.layers.{d}.mlp.fc1.weight"] = torch.cat([W1.t(), W3.t()], dim=0) + b1 = state_dict.pop(f"transformer.h.{d}.mlp.c_fc.bias") + b3 = state_dict.pop(f"transformer.h.{d}.mlp.c_fc2.bias") + state_dict[f"transformer.layers.{d}.mlp.fc1.bias"] = torch.cat([b1, b3], dim=0) + W2 = state_dict.pop(f"transformer.h.{d}.mlp.c_proj.weight") + state_dict[f"transformer.layers.{d}.mlp.fc2.weight"] = W2.t() + + def key_mapping_mlp(key): + key = re.sub(r"^transformer.h.(\d+).mlp.c_proj.bias", r"transformer.layers.\1.mlp.fc2.bias", key) + return key + + state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) + + # Attention + for d in range(config.num_hidden_layers): + Wqkv = state_dict.pop(f"transformer.h.{d}.attn.c_attn.weight") + state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = Wqkv.t() + Wout = state_dict.pop(f"transformer.h.{d}.attn.c_proj.weight") + state_dict[f"transformer.layers.{d}.mixer.out_proj.weight"] = Wout.t() + state_dict.pop(f"transformer.relative_pe.slopes") # We don't store the Alibi slopes + + def key_mapping_attn(key): + key = re.sub(r"^transformer.h.(\d+).attn.c_attn.bias", r"transformer.layers.\1.mixer.Wqkv.bias", key) + key = re.sub( + r"^transformer.h.(\d+).attn.c_proj.bias", r"transformer.layers.\1.mixer.out_proj.bias", key + ) + return key + + state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) + + return state_dict + + +def btlm_config_to_gpt2_config(btlm_config: PretrainedConfig) -> GPT2Config: + return GPT2Config( + vocab_size=btlm_config.vocab_size, + n_positions=0 if btlm_config.position_embedding_type == "alibi" else btlm_config.n_positions, + n_embd=btlm_config.hidden_size, + n_layer=btlm_config.num_hidden_layers, + n_head=btlm_config.num_attention_heads, + n_inner=btlm_config.n_inner, + activation_function=btlm_config.activation_function, + resid_pdrop=btlm_config.resid_pdrop, + embd_pdrop=btlm_config.embd_pdrop, + attn_pdrop=btlm_config.attn_pdrop, + layer_norm_epsilon=btlm_config.layer_norm_epsilon, + initializer_range=btlm_config.initializer_range, + bos_token_id=btlm_config.bos_token_id, + eos_token_id=btlm_config.eos_token_id, + # These are new arguments not in the original GPT2Config + use_alibi=btlm_config.position_embedding_type == "alibi", + use_flash_attn=btlm_config.position_embedding_type == "alibi", # Alibi code path requires flash_attn + mup_width_scale=btlm_config.mup_width_scale, + mup_embeddings_multiplier=btlm_config.mup_embeddings_scale, + mup_output_multiplier=btlm_config.mup_output_alpha, + mup_scale_qk_dot_by_d=btlm_config.mup_scale_qk_dot_by_d, + mlp_multiple_of=1, + ) diff --git a/tests/models/test_btlm.py b/tests/models/test_btlm.py new file mode 100644 index 0000000..e82e842 --- /dev/null +++ b/tests/models/test_btlm.py @@ -0,0 +1,214 @@ +# Copyright (c) 2023, Tri Dao. +import os +import time +from pathlib import Path + +import torch +import pytest + +from einops import rearrange + +from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM + +from flash_attn.models.gpt import GPTLMHeadModel +from flash_attn.models.btlm import btlm_config_to_gpt2_config, remap_state_dict_hf_btlm +from flash_attn.utils.pretrained import state_dict_from_pretrained +from flash_attn.utils.generation import update_graph_cache + + +@pytest.mark.parametrize("model_name", ["cerebras/btlm-3b-8k-base"]) +def test_btlm_state_dict(model_name): + config = btlm_config_to_gpt2_config( + AutoConfig.from_pretrained(model_name, trust_remote_code=True) + ) + pretrained_state_dict = remap_state_dict_hf_btlm( + state_dict_from_pretrained(model_name), config + ) + model = GPTLMHeadModel(config, device="meta") # Without device='meta' init is very slow + state_dict = model.state_dict() + assert len(state_dict.keys()) == len(pretrained_state_dict.keys()) + assert state_dict.keys() == pretrained_state_dict.keys() + for k in state_dict.keys(): + assert state_dict[k].shape == pretrained_state_dict[k].shape + + +@pytest.mark.parametrize("model_name", ["cerebras/btlm-3b-8k-base"]) +def test_btlm_optimized(model_name): + """Check that our implementation of Btlm (with all optimizations enabled) matches the + HF implementation: the output of our forward pass in fp16 should be around the same as the HF + forward pass in fp16, when compared to the HF forward pass in fp32. + """ + dtype = torch.float16 + device = "cuda" + config = btlm_config_to_gpt2_config( + AutoConfig.from_pretrained(model_name, trust_remote_code=True) + ) + config.fused_bias_fc = True + config.fused_dropout_add_ln = True + config.residual_in_fp32 = True + + pretrained_state_dict = remap_state_dict_hf_btlm( + state_dict_from_pretrained(model_name), config + ) + model = GPTLMHeadModel(config, device=device, dtype=dtype) + model.load_state_dict(pretrained_state_dict) + model.eval() + + torch.manual_seed(0) + batch_size = 2 + max_seqlen = 256 + seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device) + input_ids = torch.randint( + 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device + ) + with torch.no_grad(): + out = model.transformer(input_ids) + logits = model(input_ids).logits + del model + + # Without device_map, the model is loaded on the CPU, which is very slow + # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB + model_ref = AutoModelForCausalLM.from_pretrained( + model_name, device_map="auto", trust_remote_code=True + ) + model_ref.eval() + with torch.no_grad(): + out_ref = model_ref.transformer(input_ids).last_hidden_state.to(device=device) + logits_ref = model_ref(input_ids).logits.to(device=device) + del model_ref + + model_hf = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=dtype, + device_map={"": device}, + trust_remote_code=True, + ) + model_hf.eval() + with torch.no_grad(): + out_hf = model_hf.transformer(input_ids).last_hidden_state + logits_hf = model_hf(input_ids).logits + del model_hf + + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}") + print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}") + assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item() + + print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}") + print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}") + print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}") + print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}") + assert (logits - logits_ref).abs().max().item() < 3 * ( + logits_hf - logits_ref + ).abs().max().item() + + +@pytest.mark.parametrize("model_name", ["cerebras/btlm-3b-8k-base"]) +def test_btlm_generation(model_name): + dtype = torch.float16 + device = "cuda" + config = btlm_config_to_gpt2_config( + AutoConfig.from_pretrained(model_name, trust_remote_code=True) + ) + config.fused_bias_fc = True + config.fused_dropout_add_ln = True + config.residual_in_fp32 = True + + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + eos_token_id = tokenizer.eos_token_id + + torch.manual_seed(0) + batch_size = 1 + seqlen = 2048 + max_length = 2048 + 150 + input_ids = torch.randint( + 0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device + ) + + model_hf = AutoModelForCausalLM.from_pretrained( + model_name, torch_dtype=dtype, device_map={"": device}, trust_remote_code=True + ) + model_hf.eval() + print("HF fp16") + torch.cuda.synchronize() + start = time.time() + out_hf = model_hf.generate( + input_ids=input_ids, + max_length=max_length, + return_dict_in_generate=True, + output_scores=True, + ) + torch.cuda.synchronize() + print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") + del model_hf + + # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB + model_ref = AutoModelForCausalLM.from_pretrained( + model_name, device_map="auto", trust_remote_code=True + ) + model_ref.eval() + with torch.no_grad(): + logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1].to(device=device) + del model_ref + + pretrained_state_dict = remap_state_dict_hf_btlm( + state_dict_from_pretrained(model_name), config + ) + model = GPTLMHeadModel(config, device=device, dtype=dtype) + model.load_state_dict(pretrained_state_dict) + model.eval() + + model(input_ids) # Warm up + print("Without CUDA graph") + torch.cuda.synchronize() + start = time.time() + out = model.generate( + input_ids=input_ids, + max_length=max_length, + eos_token_id=eos_token_id, + return_dict_in_generate=True, + output_scores=True, + enable_timing=True, + teacher_outputs=out_hf.sequences, + ) + torch.cuda.synchronize() + print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") + + # Capture graph outside the timing loop + batch_size, seqlen_og = input_ids.shape + model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length) + print("With CUDA graph") + torch.cuda.synchronize() + start = time.time() + out_cg = model.generate( + input_ids=input_ids, + max_length=max_length, + cg=True, + return_dict_in_generate=True, + output_scores=True, + enable_timing=True, + teacher_outputs=out_hf.sequences, + ) + torch.cuda.synchronize() + print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") + + with torch.no_grad(): + logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1] + logits_hf = torch.stack(out_hf.scores, dim=1) + logits = torch.stack(out.scores, dim=1) + logits_cg = torch.stack(out_cg.scores, dim=1) + + del model + + hf_error = (logits_hf - logits_ref).abs().max().item() + + print(f"HF fp16 logits max diff: {hf_error}") + print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }") + print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }") + + assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error + assert (logits - logits_ref).abs().max().item() < 2 * hf_error + assert torch.equal(logits_cg, logits) + +