Add BigCode converters (#532)

This commit is contained in:
Kevin Hu 2023-09-10 17:24:50 -07:00 committed by GitHub
parent 8a733cbd53
commit 07005806ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 496 additions and 28 deletions

View File

@ -18,11 +18,16 @@ import torch.nn.functional as F
from einops import rearrange
from transformers import BertConfig, PretrainedConfig
from transformers.models.bert.modeling_bert import (
BaseModelOutputWithPoolingAndCrossAttentions, BertForPreTrainingOutput)
BaseModelOutputWithPoolingAndCrossAttentions,
BertForPreTrainingOutput,
)
from flash_attn.bert_padding import (index_first_axis,
index_first_axis_residual, pad_input,
unpad_input)
from flash_attn.bert_padding import (
index_first_axis,
index_first_axis_residual,
pad_input,
unpad_input,
)
from flash_attn.modules.block import Block
from flash_attn.modules.embedding import BertEmbeddings
from flash_attn.modules.mha import MHA
@ -75,11 +80,15 @@ def create_mlp_cls(config, layer_idx=None, return_residual=False):
inner_dim = config.intermediate_size
fused_mlp = getattr(config, "fused_mlp", False)
if fused_mlp:
assert config.hidden_act in ["gelu_new", "gelu_fast"], (
assert config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"], (
"fused_mlp only " "supports approximate gelu"
)
if not fused_mlp:
approximate = "tanh" if config.hidden_act in ["gelu_new", "gelu_fast"] else "none"
approximate = (
"tanh"
if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
else "none"
)
mlp_cls = partial(
Mlp,
hidden_features=inner_dim,
@ -232,7 +241,11 @@ class BertPredictionHeadTransform(nn.Module):
raise ImportError("dropout_add_layer_norm is not installed")
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
self.dense = linear_cls(config.hidden_size, config.hidden_size)
approximate = "tanh" if config.hidden_act in ["gelu_new", "gelu_fast"] else "none"
approximate = (
"tanh"
if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
else "none"
)
self.transform_act_fn = nn.GELU(approximate=approximate)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@ -334,7 +347,7 @@ class BertModel(BertPreTrainedModel):
self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
if self.fused_dropout_add_ln and layer_norm is None:
raise ImportError("dropout_add_layer_norm is not installed")
assert config.hidden_act in ["gelu", "gelu_new", "gelu_fast"]
assert config.hidden_act in ["gelu", "gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
self.embeddings = BertEmbeddings(
config.hidden_size,

View File

@ -0,0 +1,233 @@
import math
import re
from collections import OrderedDict
import torch
import torch.nn.functional as F
from transformers import GPT2Config, GPTBigCodeConfig, PretrainedConfig
def remap_state_dict_hf_bigcode(state_dict, config: PretrainedConfig):
"""
Map the state_dict of a Huggingface BigCode model to be flash_attn compatible.
"""
# Word embedding and position embedding
def key_mapping_pos_emb(key):
return re.sub(r"^transformer.wpe.", "transformer.embeddings.position_embeddings.", key)
state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop("transformer.wte.weight")
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
# LayerNorm
def key_mapping_ln(key):
key = re.sub(r"^transformer.ln_f.(weight|bias)", r"transformer.ln_f.\1", key)
key = re.sub(
r"^transformer.h.(\d+).ln_(1|2).(weight|bias)",
r"transformer.layers.\1.norm\2.\3",
key,
)
return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
def key_mapping_mlp(key):
key = re.sub(
r"^transformer.h.(\d+).mlp.c_fc.weight",
r"transformer.layers.\1.mlp.fc1.weight",
key,
)
key = re.sub(
r"^transformer.h.(\d+).mlp.c_proj.weight",
r"transformer.layers.\1.mlp.fc2.weight",
key,
)
key = re.sub(
r"^transformer.h.(\d+).mlp.c_fc.bias",
r"transformer.layers.\1.mlp.fc1.bias",
key,
)
key = re.sub(
r"^transformer.h.(\d+).mlp.c_proj.bias",
r"transformer.layers.\1.mlp.fc2.bias",
key,
)
return key
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# TODO: add support for multi-head attention
assert config.multi_query, "Only multi-query attention is supported"
# Attention
for d in range(config.num_hidden_layers):
embed_dim = config.n_embd
head_dim = embed_dim // config.n_head
c_attn_weight = state_dict.pop(f"transformer.h.{d}.attn.c_attn.weight")
# with multi-query attention, the weights have shape (embed_dim, embed_dim + head_dim + head_dim)
# see https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py#L112
# see also https://github.com/ggerganov/ggml/blob/dd1d575956e54c5bdc07632f25506b3b1884dbd2/examples/starcoder/convert-hf-to-ggml.py#L183
# ((n_head + 2) * head_dim, embed_dim) -> (3 * n_heads * head_dim, hidden_dim)
q, k, v = torch.split(c_attn_weight, [embed_dim, head_dim, head_dim], dim=0)
# duplicate k, v along the first axis (head_dim, hidden_dim) -> (n_heads * head_dim, hidden_dim)
k = torch.tile(k, (config.n_head, 1))
v = torch.tile(v, (config.n_head, 1))
state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = torch.cat((q, k, v), dim=0)
# same deal with the bias
c_attn_bias = state_dict.pop(f"transformer.h.{d}.attn.c_attn.bias")
# ((n_head + 2) * head_dim, embed_dim) -> (3 * n_heads * head_dim, hidden_dim)
q, k, v = torch.split(c_attn_bias, [embed_dim, head_dim, head_dim], dim=0)
# duplicate k, v along the first axis (head_dim, hidden_dim) -> (n_heads * head_dim, hidden_dim)
k = torch.tile(k, (config.n_head,))
v = torch.tile(v, (config.n_head,))
state_dict[f"transformer.layers.{d}.mixer.Wqkv.bias"] = torch.cat((q, k, v), dim=0)
def key_mapping_attn(key):
key = re.sub(
r"^transformer.h.(\d+).attn.c_proj.weight",
r"transformer.layers.\1.mixer.out_proj.weight",
key,
)
key = re.sub(
r"^transformer.h.(\d+).attn.c_proj.bias",
r"transformer.layers.\1.mixer.out_proj.bias",
key,
)
return key
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
return state_dict
def inv_remap_state_dict_hf_bigcode(state_dict, config: PretrainedConfig):
"""
Map the state_dict of a flash_attn model to be Huggingface BigCode compatible.
This function is meant to be the inverse of remap_state_dict_hf_bigcode.
"""
# Word embedding and position embeddings
def inv_key_mapping_pos_emb(key):
return re.sub(r"^transformer.embeddings.position_embeddings.", "transformer.wpe.", key)
state_dict = OrderedDict((inv_key_mapping_pos_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
word_embeddings = word_embeddings[:, : config.vocab_size]
state_dict["transformer.wte.weight"] = word_embeddings
state_dict["lm_head.weight"] = word_embeddings
# LayerNorm
def inv_key_mapping_ln(key):
key = re.sub(r"^transformer.ln_f.(weight|bias)", r"transformer.ln_f.\1", key)
key = re.sub(
r"^transformer.layers.(\d+).norm(1|2).(weight|bias)",
r"transformer.h.\1.ln_\2.\3",
key,
)
return key
state_dict = OrderedDict((inv_key_mapping_ln(k), v) for k, v in state_dict.items())
# MLPs
def inv_key_mapping_mlp(key):
key = re.sub(
r"^transformer.layers.(\d+).mlp.fc1.weight",
r"transformer.h.\1.mlp.c_fc.weight",
key,
)
key = re.sub(
r"^transformer.layers.(\d+).mlp.fc2.weight",
r"transformer.h.\1.mlp.c_proj.weight",
key,
)
key = re.sub(
r"^transformer.layers.(\d+).mlp.fc1.bias",
r"transformer.h.\1.mlp.c_fc.bias",
key,
)
key = re.sub(
r"^transformer.layers.(\d+).mlp.fc2.bias",
r"transformer.h.\1.mlp.c_proj.bias",
key,
)
return key
state_dict = OrderedDict((inv_key_mapping_mlp(k), v) for k, v in state_dict.items())
# Attention
for d in range(config.num_hidden_layers):
embed_dim = config.n_embd
head_dim = embed_dim // config.n_head
Wqkv_weight = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.weight")
q, k, v = torch.split(
Wqkv_weight, [embed_dim, head_dim * config.n_head, head_dim * config.n_head], dim=0
)
c_attn_weight = torch.cat((q, k[:head_dim], v[:head_dim]), dim=0)
state_dict[f"transformer.h.{d}.attn.c_attn.weight"] = c_attn_weight
# Same deal with the bias
Wqkv_bias = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.bias")
q, k, v = torch.split(
Wqkv_bias, [embed_dim, head_dim * config.n_head, head_dim * config.n_head], dim=0
)
c_attn_bias = torch.cat((q, k[:head_dim], v[:head_dim]), dim=0)
state_dict[f"transformer.h.{d}.attn.c_attn.bias"] = c_attn_bias
def inv_key_mapping_attn(key):
key = re.sub(
r"^transformer.layers.(\d+).mixer.out_proj.weight",
r"transformer.h.\1.attn.c_proj.weight",
key,
)
key = re.sub(
r"^transformer.layers.(\d+).mixer.out_proj.bias",
r"transformer.h.\1.attn.c_proj.bias",
key,
)
return key
state_dict = OrderedDict((inv_key_mapping_attn(k), v) for k, v in state_dict.items())
return state_dict
def bigcode_config_to_gpt2_config(bigcode_config: GPTBigCodeConfig) -> GPT2Config:
return GPT2Config(
activation_function=bigcode_config.activation_function,
attn_pdrop=bigcode_config.attn_pdrop,
bos_token_id=bigcode_config.bos_token_id,
embd_pdrop=bigcode_config.embd_pdrop,
eos_token_id=bigcode_config.eos_token_id,
initializer_range=bigcode_config.initializer_range,
layer_norm_epsilon=bigcode_config.layer_norm_epsilon,
max_batch_size=bigcode_config.max_batch_size,
max_sequence_length=bigcode_config.max_sequence_length,
model_type=bigcode_config.model_type,
multi_query=bigcode_config.multi_query,
n_embd=bigcode_config.n_embd,
n_head=bigcode_config.n_head,
n_inner=bigcode_config.n_inner,
n_layer=bigcode_config.n_layer,
n_positions=bigcode_config.n_positions,
resid_pdrop=bigcode_config.resid_pdrop,
scale_attn_weights=bigcode_config.scale_attn_weights,
summary_activation=bigcode_config.summary_activation,
summary_first_dropout=bigcode_config.summary_first_dropout,
summary_proj_to_labels=bigcode_config.summary_proj_to_labels,
summary_type=bigcode_config.summary_type,
summary_use_proj=bigcode_config.summary_use_proj,
use_cache=bigcode_config.use_cache,
vocab_size=bigcode_config.vocab_size,
)

View File

@ -13,6 +13,7 @@ import torch.nn.functional as F
from einops import rearrange
from transformers import GPT2Config
from flash_attn.models.bigcode import remap_state_dict_hf_bigcode
from flash_attn.models.falcon import remap_state_dict_hf_falcon
from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox
from flash_attn.models.gptj import remap_state_dict_hf_gptj
@ -21,12 +22,16 @@ from flash_attn.models.opt import remap_state_dict_hf_opt
from flash_attn.modules.block import Block, ParallelBlock
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
from flash_attn.modules.mha import MHA, ParallelMHA
from flash_attn.modules.mlp import (FusedMLP, GatedMlp, Mlp, ParallelFusedMLP,
ParallelGatedMlp, ParallelMLP)
from flash_attn.modules.mlp import (
FusedMLP,
GatedMlp,
Mlp,
ParallelFusedMLP,
ParallelGatedMlp,
ParallelMLP,
)
from flash_attn.ops.activations import sqrelu_fwd
from flash_attn.utils.distributed import (all_gather_raw,
get_dim_for_local_rank,
sync_shared_params)
from flash_attn.utils.distributed import all_gather_raw, get_dim_for_local_rank, sync_shared_params
from flash_attn.utils.generation import GenerationMixin
from flash_attn.utils.pretrained import state_dict_from_pretrained
@ -41,8 +46,7 @@ except ImportError:
dropout_add_layer_norm = None
try:
from flash_attn.ops.layer_norm import \
dropout_add_layer_norm_parallel_residual
from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual
except ImportError:
dropout_add_layer_norm_parallel_residual = None
@ -129,6 +133,7 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
"gelu_new",
"gelu_fast",
"gelu_approx",
"gelu_pytorch_tanh",
"relu",
"sqrelu",
]
@ -144,6 +149,7 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
"gelu_new",
"gelu_fast",
"gelu_approx",
"gelu_pytorch_tanh",
"relu",
"sqrelu",
"glu",
@ -182,7 +188,8 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
else:
approximate = (
"tanh"
if config.activation_function in ["gelu_new", "gelu_fast", "gelu_approx"]
if config.activation_function
in ["gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh"]
else "none"
)
activation = partial(F.gelu, approximate=approximate)
@ -215,7 +222,8 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
raise ImportError("fused_dense is not installed")
activation = (
"gelu_approx"
if config.activation_function in ["gelu_new", "gelu_fast", "gelu_approx"]
if config.activation_function
in ["gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh"]
else config.activation_function
)
mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP
@ -352,6 +360,8 @@ class GPTPreTrainedModel(nn.Module):
state_dict = remap_state_dict_hf_falcon(state_dict, config)
elif model_name.startswith("meta-llama/Llama-"):
state_dict = remap_state_dict_hf_llama(state_dict, config)
elif model_name.startswith("bigcode/") or model_name.startswith("WizardLM/"):
state_dict = remap_state_dict_hf_bigcode(state_dict, config)
else:
raise NotImplementedError(f"Model {model_name} not supported")
if world_size > 1:
@ -394,6 +404,7 @@ class GPTModel(GPTPreTrainedModel):
"gelu_new",
"gelu_fast",
"gelu_approx",
"gelu_pytorch_tanh",
"relu",
"sqrelu",
"glu",
@ -628,7 +639,9 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
num_last_tokens: if > 0, only return the logits for the last n tokens
"""
assert input_ids.ndim == 2, f"Expected `input_ids` to have shape [b, slen], but got shape {input_ids.shape}"
assert (
input_ids.ndim == 2
), f"Expected `input_ids` to have shape [b, slen], but got shape {input_ids.shape}"
b, slen = input_ids.shape
hidden_states = self.transformer(
input_ids, position_ids=position_ids, inference_params=inference_params
@ -845,11 +858,11 @@ def combine_state_dicts_tp(state_dicts: list[dict[str, torch.Tensor]], config: G
nheadqkv=rank_n_head + 2 * rank_n_head_kv,
headdim=headdim,
)
for s, rank_n_head, rank_n_head_kv in zip(state_dicts, n_head_each_rank, n_head_kv_each_rank)
]
wq = torch.cat(
[x[: n_head_each_rank[rank]] for rank, x in enumerate(xs)], dim=0
for s, rank_n_head, rank_n_head_kv in zip(
state_dicts, n_head_each_rank, n_head_kv_each_rank
)
]
wq = torch.cat([x[: n_head_each_rank[rank]] for rank, x in enumerate(xs)], dim=0)
wk = torch.cat(
[
x[

View File

@ -6,12 +6,15 @@ import torch
import torch.nn.functional as F
from einops import rearrange
from transformers import BertConfig
from transformers.models.bert.modeling_bert import \
BertForPreTraining as BertForPreTrainingHF
from transformers.models.bert.modeling_bert import BertForPreTraining as BertForPreTrainingHF
from transformers.models.bert.modeling_bert import BertModel as BertModelHF
from flash_attn.models.bert import (BertForPreTraining, BertModel,
inv_remap_state_dict, remap_state_dict)
from flash_attn.models.bert import (
BertForPreTraining,
BertModel,
inv_remap_state_dict,
remap_state_dict,
)
from flash_attn.utils.pretrained import state_dict_from_pretrained
@ -102,7 +105,7 @@ def test_bert_optimized(model_name):
dtype = torch.float16
config = BertConfig.from_pretrained(model_name)
# Our implementation of fused_mlp assumes the activation is
# nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new" or "gelu_fast".
# nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new", "gelu_fast", or "gelu_pytorch_tanh".
# If you just want "gelu", disable fused_mlp.
config.hidden_act = "gelu_new"
config.use_flash_attn = True
@ -209,7 +212,7 @@ def test_bert_dense_seq_output(model_name, has_key_padding_mask, last_layer_subs
dtype = torch.float16
config = BertConfig.from_pretrained(model_name)
# Our implementation of fused_mlp assumes the activation is
# nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new" or "gelu_fast".
# nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new", "gelu_fast", or "gelu_pytorch_tanh".
# If you just want "gelu", disable fused_mlp.
config.hidden_act = "gelu_new"
config.use_flash_attn = True

View File

@ -0,0 +1,206 @@
import time
import pytest
import torch
from transformers import AutoTokenizer, GPTBigCodeConfig
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeForCausalLM
from flash_attn.models.bigcode import bigcode_config_to_gpt2_config, inv_remap_state_dict_hf_bigcode
from flash_attn.models.gpt import GPTLMHeadModel, remap_state_dict_hf_bigcode
from flash_attn.utils.generation import update_graph_cache
from flash_attn.utils.pretrained import state_dict_from_pretrained
@pytest.mark.parametrize("model_name", ["bigcode/starcoderbase-1b", "WizardLM/WizardCoder-1B-V1.0"])
def test_bigcode_state_dict(model_name):
config = bigcode_config_to_gpt2_config(GPTBigCodeConfig.from_pretrained(model_name))
pretrained_state_dict = remap_state_dict_hf_bigcode(
state_dict_from_pretrained(model_name), config
)
model = GPTLMHeadModel(config, device="meta")
state_dict = model.state_dict()
assert state_dict.keys() == pretrained_state_dict.keys()
for k in state_dict.keys():
assert state_dict[k].shape == pretrained_state_dict[k].shape
@pytest.mark.parametrize("model_name", ["bigcode/starcoderbase-1b", "WizardLM/WizardCoder-1B-V1.0"])
def test_bigcode_optimized(model_name):
"""Check that our implementation of BigCode (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
device = "cuda"
config = bigcode_config_to_gpt2_config(GPTBigCodeConfig.from_pretrained(model_name))
config.use_flash_attn = True # FlashAttention-2 supports headdim 256
config.fused_bias_fc = True
config.fused_mlp = True
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
model.eval()
torch.manual_seed(0)
batch_size = 2
max_seqlen = 256
input_ids = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
)
with torch.no_grad():
out = model.transformer(input_ids)
logits = model(input_ids).logits
del model
# Without device_map, the model is loaded on the CPU, which is very slow
model_ref = GPTBigCodeForCausalLM.from_pretrained(model_name, device_map={"": device})
model_ref.eval()
with torch.no_grad():
out_ref = model_ref.transformer(input_ids).last_hidden_state
logits_ref = model_ref(input_ids).logits
del model_ref
model_hf = GPTBigCodeForCausalLM.from_pretrained(
model_name, torch_dtype=dtype, device_map={"": device}
)
model_hf.eval()
out_hf = model_hf.transformer(input_ids).last_hidden_state
logits_hf = model_hf(input_ids).logits
del model_hf
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 3 * (
logits_hf - logits_ref
).abs().max().item()
@pytest.mark.parametrize("model_name", ["bigcode/starcoderbase-1b", "WizardLM/WizardCoder-1B-V1.0"])
def test_bigcode_generation(model_name):
"""Check that our implementation of BigCode (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
device = "cuda"
config = bigcode_config_to_gpt2_config(GPTBigCodeConfig.from_pretrained(model_name))
config.use_flash_attn = True # FlashAttention-2 supports headdim 256
config.fused_bias_fc = True
config.fused_mlp = True
config.fused_dropout_add_ln = True
# Only prenorm supports residual_in_fp32
config.residual_in_fp32 = True
tokenizer = AutoTokenizer.from_pretrained(model_name)
eos_token_id = tokenizer.eos_token_id
torch.manual_seed(0)
batch_size = 1
seqlen = 100
max_length = 150
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
)
model_hf = GPTBigCodeForCausalLM.from_pretrained(
model_name, torch_dtype=dtype, device_map={"": device}
)
model_hf.eval()
print("HF fp16")
torch.cuda.synchronize()
start = time.time()
out_hf = model_hf.generate(
input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
)
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
del model_hf
model_ref = GPTBigCodeForCausalLM.from_pretrained(model_name, device_map={"": device})
model_ref.eval()
with torch.no_grad():
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1]
del model_ref
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
model.eval()
print("Without CUDA graph")
torch.cuda.synchronize()
start = time.time()
out = model.generate(
input_ids=input_ids,
max_length=max_length,
eos_token_id=eos_token_id,
fused_ft_kernel=True,
return_dict_in_generate=True,
output_scores=True,
enable_timing=True,
teacher_outputs=out_hf.sequences,
)
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
# Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
print("With CUDA graph")
torch.cuda.synchronize()
start = time.time()
out_cg = model.generate(
input_ids=input_ids,
max_length=max_length,
fused_ft_kernel=True,
cg=True,
return_dict_in_generate=True,
output_scores=True,
enable_timing=True,
teacher_outputs=out_hf.sequences,
)
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
with torch.no_grad():
logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1]
logits_hf = torch.stack(out_hf.scores, dim=1)
logits = torch.stack(out.scores, dim=1)
logits_cg = torch.stack(out_cg.scores, dim=1)
del model
hf_error = (logits_hf - logits_ref).abs().max().item()
assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error
print(f"HF fp16 logits max diff: {hf_error}")
print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }")
assert (logits - logits_ref).abs().max().item() < 2 * hf_error
print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }")
assert (logits_cg - logits_ref).abs().max().item() < 2 * hf_error
@pytest.mark.parametrize("model_name", ["bigcode/starcoderbase-1b", "WizardLM/WizardCoder-1B-V1.0"])
def test_inv_remap_state_dict(model_name: str):
"""
Verify that we can convert a HF BigCode model to flash_attn and back.
"""
state_dict = state_dict_from_pretrained(model_name)
config = GPTBigCodeConfig.from_pretrained(model_name)
flash_state_dict = remap_state_dict_hf_bigcode(state_dict, config)
recovered_state_dict = inv_remap_state_dict_hf_bigcode(flash_state_dict, config)
assert set(state_dict.keys()) == set(recovered_state_dict.keys())
for k in state_dict.keys():
assert state_dict[k].shape == recovered_state_dict[k].shape
torch.testing.assert_close(state_dict[k], recovered_state_dict[k], rtol=1e-6, atol=1e-6)