Add BigCode converters (#532)
This commit is contained in:
parent
8a733cbd53
commit
07005806ff
@ -18,11 +18,16 @@ import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from transformers import BertConfig, PretrainedConfig
|
||||
from transformers.models.bert.modeling_bert import (
|
||||
BaseModelOutputWithPoolingAndCrossAttentions, BertForPreTrainingOutput)
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
BertForPreTrainingOutput,
|
||||
)
|
||||
|
||||
from flash_attn.bert_padding import (index_first_axis,
|
||||
index_first_axis_residual, pad_input,
|
||||
unpad_input)
|
||||
from flash_attn.bert_padding import (
|
||||
index_first_axis,
|
||||
index_first_axis_residual,
|
||||
pad_input,
|
||||
unpad_input,
|
||||
)
|
||||
from flash_attn.modules.block import Block
|
||||
from flash_attn.modules.embedding import BertEmbeddings
|
||||
from flash_attn.modules.mha import MHA
|
||||
@ -75,11 +80,15 @@ def create_mlp_cls(config, layer_idx=None, return_residual=False):
|
||||
inner_dim = config.intermediate_size
|
||||
fused_mlp = getattr(config, "fused_mlp", False)
|
||||
if fused_mlp:
|
||||
assert config.hidden_act in ["gelu_new", "gelu_fast"], (
|
||||
assert config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"], (
|
||||
"fused_mlp only " "supports approximate gelu"
|
||||
)
|
||||
if not fused_mlp:
|
||||
approximate = "tanh" if config.hidden_act in ["gelu_new", "gelu_fast"] else "none"
|
||||
approximate = (
|
||||
"tanh"
|
||||
if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
|
||||
else "none"
|
||||
)
|
||||
mlp_cls = partial(
|
||||
Mlp,
|
||||
hidden_features=inner_dim,
|
||||
@ -232,7 +241,11 @@ class BertPredictionHeadTransform(nn.Module):
|
||||
raise ImportError("dropout_add_layer_norm is not installed")
|
||||
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
||||
self.dense = linear_cls(config.hidden_size, config.hidden_size)
|
||||
approximate = "tanh" if config.hidden_act in ["gelu_new", "gelu_fast"] else "none"
|
||||
approximate = (
|
||||
"tanh"
|
||||
if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
|
||||
else "none"
|
||||
)
|
||||
self.transform_act_fn = nn.GELU(approximate=approximate)
|
||||
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
@ -334,7 +347,7 @@ class BertModel(BertPreTrainedModel):
|
||||
self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
|
||||
if self.fused_dropout_add_ln and layer_norm is None:
|
||||
raise ImportError("dropout_add_layer_norm is not installed")
|
||||
assert config.hidden_act in ["gelu", "gelu_new", "gelu_fast"]
|
||||
assert config.hidden_act in ["gelu", "gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
|
||||
|
||||
self.embeddings = BertEmbeddings(
|
||||
config.hidden_size,
|
||||
|
||||
233
flash_attn/models/bigcode.py
Normal file
233
flash_attn/models/bigcode.py
Normal file
@ -0,0 +1,233 @@
|
||||
import math
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import GPT2Config, GPTBigCodeConfig, PretrainedConfig
|
||||
|
||||
|
||||
def remap_state_dict_hf_bigcode(state_dict, config: PretrainedConfig):
|
||||
"""
|
||||
Map the state_dict of a Huggingface BigCode model to be flash_attn compatible.
|
||||
"""
|
||||
|
||||
# Word embedding and position embedding
|
||||
def key_mapping_pos_emb(key):
|
||||
return re.sub(r"^transformer.wpe.", "transformer.embeddings.position_embeddings.", key)
|
||||
|
||||
state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
|
||||
word_embeddings = state_dict.pop("transformer.wte.weight")
|
||||
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
||||
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
||||
vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
||||
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
|
||||
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
||||
)
|
||||
state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
|
||||
|
||||
# LayerNorm
|
||||
def key_mapping_ln(key):
|
||||
key = re.sub(r"^transformer.ln_f.(weight|bias)", r"transformer.ln_f.\1", key)
|
||||
key = re.sub(
|
||||
r"^transformer.h.(\d+).ln_(1|2).(weight|bias)",
|
||||
r"transformer.layers.\1.norm\2.\3",
|
||||
key,
|
||||
)
|
||||
return key
|
||||
|
||||
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
||||
|
||||
def key_mapping_mlp(key):
|
||||
key = re.sub(
|
||||
r"^transformer.h.(\d+).mlp.c_fc.weight",
|
||||
r"transformer.layers.\1.mlp.fc1.weight",
|
||||
key,
|
||||
)
|
||||
key = re.sub(
|
||||
r"^transformer.h.(\d+).mlp.c_proj.weight",
|
||||
r"transformer.layers.\1.mlp.fc2.weight",
|
||||
key,
|
||||
)
|
||||
key = re.sub(
|
||||
r"^transformer.h.(\d+).mlp.c_fc.bias",
|
||||
r"transformer.layers.\1.mlp.fc1.bias",
|
||||
key,
|
||||
)
|
||||
key = re.sub(
|
||||
r"^transformer.h.(\d+).mlp.c_proj.bias",
|
||||
r"transformer.layers.\1.mlp.fc2.bias",
|
||||
key,
|
||||
)
|
||||
return key
|
||||
|
||||
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
||||
|
||||
# TODO: add support for multi-head attention
|
||||
assert config.multi_query, "Only multi-query attention is supported"
|
||||
|
||||
# Attention
|
||||
for d in range(config.num_hidden_layers):
|
||||
embed_dim = config.n_embd
|
||||
head_dim = embed_dim // config.n_head
|
||||
|
||||
c_attn_weight = state_dict.pop(f"transformer.h.{d}.attn.c_attn.weight")
|
||||
# with multi-query attention, the weights have shape (embed_dim, embed_dim + head_dim + head_dim)
|
||||
# see https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py#L112
|
||||
# see also https://github.com/ggerganov/ggml/blob/dd1d575956e54c5bdc07632f25506b3b1884dbd2/examples/starcoder/convert-hf-to-ggml.py#L183
|
||||
# ((n_head + 2) * head_dim, embed_dim) -> (3 * n_heads * head_dim, hidden_dim)
|
||||
q, k, v = torch.split(c_attn_weight, [embed_dim, head_dim, head_dim], dim=0)
|
||||
# duplicate k, v along the first axis (head_dim, hidden_dim) -> (n_heads * head_dim, hidden_dim)
|
||||
k = torch.tile(k, (config.n_head, 1))
|
||||
v = torch.tile(v, (config.n_head, 1))
|
||||
state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = torch.cat((q, k, v), dim=0)
|
||||
|
||||
# same deal with the bias
|
||||
c_attn_bias = state_dict.pop(f"transformer.h.{d}.attn.c_attn.bias")
|
||||
# ((n_head + 2) * head_dim, embed_dim) -> (3 * n_heads * head_dim, hidden_dim)
|
||||
q, k, v = torch.split(c_attn_bias, [embed_dim, head_dim, head_dim], dim=0)
|
||||
# duplicate k, v along the first axis (head_dim, hidden_dim) -> (n_heads * head_dim, hidden_dim)
|
||||
k = torch.tile(k, (config.n_head,))
|
||||
v = torch.tile(v, (config.n_head,))
|
||||
state_dict[f"transformer.layers.{d}.mixer.Wqkv.bias"] = torch.cat((q, k, v), dim=0)
|
||||
|
||||
def key_mapping_attn(key):
|
||||
key = re.sub(
|
||||
r"^transformer.h.(\d+).attn.c_proj.weight",
|
||||
r"transformer.layers.\1.mixer.out_proj.weight",
|
||||
key,
|
||||
)
|
||||
key = re.sub(
|
||||
r"^transformer.h.(\d+).attn.c_proj.bias",
|
||||
r"transformer.layers.\1.mixer.out_proj.bias",
|
||||
key,
|
||||
)
|
||||
return key
|
||||
|
||||
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def inv_remap_state_dict_hf_bigcode(state_dict, config: PretrainedConfig):
|
||||
"""
|
||||
Map the state_dict of a flash_attn model to be Huggingface BigCode compatible.
|
||||
|
||||
This function is meant to be the inverse of remap_state_dict_hf_bigcode.
|
||||
"""
|
||||
|
||||
# Word embedding and position embeddings
|
||||
def inv_key_mapping_pos_emb(key):
|
||||
return re.sub(r"^transformer.embeddings.position_embeddings.", "transformer.wpe.", key)
|
||||
|
||||
state_dict = OrderedDict((inv_key_mapping_pos_emb(k), v) for k, v in state_dict.items())
|
||||
word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
|
||||
|
||||
word_embeddings = word_embeddings[:, : config.vocab_size]
|
||||
state_dict["transformer.wte.weight"] = word_embeddings
|
||||
state_dict["lm_head.weight"] = word_embeddings
|
||||
|
||||
# LayerNorm
|
||||
def inv_key_mapping_ln(key):
|
||||
key = re.sub(r"^transformer.ln_f.(weight|bias)", r"transformer.ln_f.\1", key)
|
||||
key = re.sub(
|
||||
r"^transformer.layers.(\d+).norm(1|2).(weight|bias)",
|
||||
r"transformer.h.\1.ln_\2.\3",
|
||||
key,
|
||||
)
|
||||
return key
|
||||
|
||||
state_dict = OrderedDict((inv_key_mapping_ln(k), v) for k, v in state_dict.items())
|
||||
|
||||
# MLPs
|
||||
def inv_key_mapping_mlp(key):
|
||||
key = re.sub(
|
||||
r"^transformer.layers.(\d+).mlp.fc1.weight",
|
||||
r"transformer.h.\1.mlp.c_fc.weight",
|
||||
key,
|
||||
)
|
||||
key = re.sub(
|
||||
r"^transformer.layers.(\d+).mlp.fc2.weight",
|
||||
r"transformer.h.\1.mlp.c_proj.weight",
|
||||
key,
|
||||
)
|
||||
key = re.sub(
|
||||
r"^transformer.layers.(\d+).mlp.fc1.bias",
|
||||
r"transformer.h.\1.mlp.c_fc.bias",
|
||||
key,
|
||||
)
|
||||
key = re.sub(
|
||||
r"^transformer.layers.(\d+).mlp.fc2.bias",
|
||||
r"transformer.h.\1.mlp.c_proj.bias",
|
||||
key,
|
||||
)
|
||||
return key
|
||||
|
||||
state_dict = OrderedDict((inv_key_mapping_mlp(k), v) for k, v in state_dict.items())
|
||||
|
||||
# Attention
|
||||
for d in range(config.num_hidden_layers):
|
||||
embed_dim = config.n_embd
|
||||
head_dim = embed_dim // config.n_head
|
||||
|
||||
Wqkv_weight = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.weight")
|
||||
q, k, v = torch.split(
|
||||
Wqkv_weight, [embed_dim, head_dim * config.n_head, head_dim * config.n_head], dim=0
|
||||
)
|
||||
c_attn_weight = torch.cat((q, k[:head_dim], v[:head_dim]), dim=0)
|
||||
state_dict[f"transformer.h.{d}.attn.c_attn.weight"] = c_attn_weight
|
||||
|
||||
# Same deal with the bias
|
||||
Wqkv_bias = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.bias")
|
||||
q, k, v = torch.split(
|
||||
Wqkv_bias, [embed_dim, head_dim * config.n_head, head_dim * config.n_head], dim=0
|
||||
)
|
||||
c_attn_bias = torch.cat((q, k[:head_dim], v[:head_dim]), dim=0)
|
||||
state_dict[f"transformer.h.{d}.attn.c_attn.bias"] = c_attn_bias
|
||||
|
||||
def inv_key_mapping_attn(key):
|
||||
key = re.sub(
|
||||
r"^transformer.layers.(\d+).mixer.out_proj.weight",
|
||||
r"transformer.h.\1.attn.c_proj.weight",
|
||||
key,
|
||||
)
|
||||
key = re.sub(
|
||||
r"^transformer.layers.(\d+).mixer.out_proj.bias",
|
||||
r"transformer.h.\1.attn.c_proj.bias",
|
||||
key,
|
||||
)
|
||||
return key
|
||||
|
||||
state_dict = OrderedDict((inv_key_mapping_attn(k), v) for k, v in state_dict.items())
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def bigcode_config_to_gpt2_config(bigcode_config: GPTBigCodeConfig) -> GPT2Config:
|
||||
return GPT2Config(
|
||||
activation_function=bigcode_config.activation_function,
|
||||
attn_pdrop=bigcode_config.attn_pdrop,
|
||||
bos_token_id=bigcode_config.bos_token_id,
|
||||
embd_pdrop=bigcode_config.embd_pdrop,
|
||||
eos_token_id=bigcode_config.eos_token_id,
|
||||
initializer_range=bigcode_config.initializer_range,
|
||||
layer_norm_epsilon=bigcode_config.layer_norm_epsilon,
|
||||
max_batch_size=bigcode_config.max_batch_size,
|
||||
max_sequence_length=bigcode_config.max_sequence_length,
|
||||
model_type=bigcode_config.model_type,
|
||||
multi_query=bigcode_config.multi_query,
|
||||
n_embd=bigcode_config.n_embd,
|
||||
n_head=bigcode_config.n_head,
|
||||
n_inner=bigcode_config.n_inner,
|
||||
n_layer=bigcode_config.n_layer,
|
||||
n_positions=bigcode_config.n_positions,
|
||||
resid_pdrop=bigcode_config.resid_pdrop,
|
||||
scale_attn_weights=bigcode_config.scale_attn_weights,
|
||||
summary_activation=bigcode_config.summary_activation,
|
||||
summary_first_dropout=bigcode_config.summary_first_dropout,
|
||||
summary_proj_to_labels=bigcode_config.summary_proj_to_labels,
|
||||
summary_type=bigcode_config.summary_type,
|
||||
summary_use_proj=bigcode_config.summary_use_proj,
|
||||
use_cache=bigcode_config.use_cache,
|
||||
vocab_size=bigcode_config.vocab_size,
|
||||
)
|
||||
@ -13,6 +13,7 @@ import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from transformers import GPT2Config
|
||||
|
||||
from flash_attn.models.bigcode import remap_state_dict_hf_bigcode
|
||||
from flash_attn.models.falcon import remap_state_dict_hf_falcon
|
||||
from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox
|
||||
from flash_attn.models.gptj import remap_state_dict_hf_gptj
|
||||
@ -21,12 +22,16 @@ from flash_attn.models.opt import remap_state_dict_hf_opt
|
||||
from flash_attn.modules.block import Block, ParallelBlock
|
||||
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
|
||||
from flash_attn.modules.mha import MHA, ParallelMHA
|
||||
from flash_attn.modules.mlp import (FusedMLP, GatedMlp, Mlp, ParallelFusedMLP,
|
||||
ParallelGatedMlp, ParallelMLP)
|
||||
from flash_attn.modules.mlp import (
|
||||
FusedMLP,
|
||||
GatedMlp,
|
||||
Mlp,
|
||||
ParallelFusedMLP,
|
||||
ParallelGatedMlp,
|
||||
ParallelMLP,
|
||||
)
|
||||
from flash_attn.ops.activations import sqrelu_fwd
|
||||
from flash_attn.utils.distributed import (all_gather_raw,
|
||||
get_dim_for_local_rank,
|
||||
sync_shared_params)
|
||||
from flash_attn.utils.distributed import all_gather_raw, get_dim_for_local_rank, sync_shared_params
|
||||
from flash_attn.utils.generation import GenerationMixin
|
||||
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
||||
|
||||
@ -41,8 +46,7 @@ except ImportError:
|
||||
dropout_add_layer_norm = None
|
||||
|
||||
try:
|
||||
from flash_attn.ops.layer_norm import \
|
||||
dropout_add_layer_norm_parallel_residual
|
||||
from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual
|
||||
except ImportError:
|
||||
dropout_add_layer_norm_parallel_residual = None
|
||||
|
||||
@ -129,6 +133,7 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
|
||||
"gelu_new",
|
||||
"gelu_fast",
|
||||
"gelu_approx",
|
||||
"gelu_pytorch_tanh",
|
||||
"relu",
|
||||
"sqrelu",
|
||||
]
|
||||
@ -144,6 +149,7 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
|
||||
"gelu_new",
|
||||
"gelu_fast",
|
||||
"gelu_approx",
|
||||
"gelu_pytorch_tanh",
|
||||
"relu",
|
||||
"sqrelu",
|
||||
"glu",
|
||||
@ -182,7 +188,8 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
|
||||
else:
|
||||
approximate = (
|
||||
"tanh"
|
||||
if config.activation_function in ["gelu_new", "gelu_fast", "gelu_approx"]
|
||||
if config.activation_function
|
||||
in ["gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh"]
|
||||
else "none"
|
||||
)
|
||||
activation = partial(F.gelu, approximate=approximate)
|
||||
@ -215,7 +222,8 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
|
||||
raise ImportError("fused_dense is not installed")
|
||||
activation = (
|
||||
"gelu_approx"
|
||||
if config.activation_function in ["gelu_new", "gelu_fast", "gelu_approx"]
|
||||
if config.activation_function
|
||||
in ["gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh"]
|
||||
else config.activation_function
|
||||
)
|
||||
mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP
|
||||
@ -352,6 +360,8 @@ class GPTPreTrainedModel(nn.Module):
|
||||
state_dict = remap_state_dict_hf_falcon(state_dict, config)
|
||||
elif model_name.startswith("meta-llama/Llama-"):
|
||||
state_dict = remap_state_dict_hf_llama(state_dict, config)
|
||||
elif model_name.startswith("bigcode/") or model_name.startswith("WizardLM/"):
|
||||
state_dict = remap_state_dict_hf_bigcode(state_dict, config)
|
||||
else:
|
||||
raise NotImplementedError(f"Model {model_name} not supported")
|
||||
if world_size > 1:
|
||||
@ -394,6 +404,7 @@ class GPTModel(GPTPreTrainedModel):
|
||||
"gelu_new",
|
||||
"gelu_fast",
|
||||
"gelu_approx",
|
||||
"gelu_pytorch_tanh",
|
||||
"relu",
|
||||
"sqrelu",
|
||||
"glu",
|
||||
@ -628,7 +639,9 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
|
||||
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
|
||||
num_last_tokens: if > 0, only return the logits for the last n tokens
|
||||
"""
|
||||
assert input_ids.ndim == 2, f"Expected `input_ids` to have shape [b, slen], but got shape {input_ids.shape}"
|
||||
assert (
|
||||
input_ids.ndim == 2
|
||||
), f"Expected `input_ids` to have shape [b, slen], but got shape {input_ids.shape}"
|
||||
b, slen = input_ids.shape
|
||||
hidden_states = self.transformer(
|
||||
input_ids, position_ids=position_ids, inference_params=inference_params
|
||||
@ -845,11 +858,11 @@ def combine_state_dicts_tp(state_dicts: list[dict[str, torch.Tensor]], config: G
|
||||
nheadqkv=rank_n_head + 2 * rank_n_head_kv,
|
||||
headdim=headdim,
|
||||
)
|
||||
for s, rank_n_head, rank_n_head_kv in zip(state_dicts, n_head_each_rank, n_head_kv_each_rank)
|
||||
]
|
||||
wq = torch.cat(
|
||||
[x[: n_head_each_rank[rank]] for rank, x in enumerate(xs)], dim=0
|
||||
for s, rank_n_head, rank_n_head_kv in zip(
|
||||
state_dicts, n_head_each_rank, n_head_kv_each_rank
|
||||
)
|
||||
]
|
||||
wq = torch.cat([x[: n_head_each_rank[rank]] for rank, x in enumerate(xs)], dim=0)
|
||||
wk = torch.cat(
|
||||
[
|
||||
x[
|
||||
|
||||
@ -6,12 +6,15 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from transformers import BertConfig
|
||||
from transformers.models.bert.modeling_bert import \
|
||||
BertForPreTraining as BertForPreTrainingHF
|
||||
from transformers.models.bert.modeling_bert import BertForPreTraining as BertForPreTrainingHF
|
||||
from transformers.models.bert.modeling_bert import BertModel as BertModelHF
|
||||
|
||||
from flash_attn.models.bert import (BertForPreTraining, BertModel,
|
||||
inv_remap_state_dict, remap_state_dict)
|
||||
from flash_attn.models.bert import (
|
||||
BertForPreTraining,
|
||||
BertModel,
|
||||
inv_remap_state_dict,
|
||||
remap_state_dict,
|
||||
)
|
||||
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
||||
|
||||
|
||||
@ -102,7 +105,7 @@ def test_bert_optimized(model_name):
|
||||
dtype = torch.float16
|
||||
config = BertConfig.from_pretrained(model_name)
|
||||
# Our implementation of fused_mlp assumes the activation is
|
||||
# nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new" or "gelu_fast".
|
||||
# nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new", "gelu_fast", or "gelu_pytorch_tanh".
|
||||
# If you just want "gelu", disable fused_mlp.
|
||||
config.hidden_act = "gelu_new"
|
||||
config.use_flash_attn = True
|
||||
@ -209,7 +212,7 @@ def test_bert_dense_seq_output(model_name, has_key_padding_mask, last_layer_subs
|
||||
dtype = torch.float16
|
||||
config = BertConfig.from_pretrained(model_name)
|
||||
# Our implementation of fused_mlp assumes the activation is
|
||||
# nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new" or "gelu_fast".
|
||||
# nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new", "gelu_fast", or "gelu_pytorch_tanh".
|
||||
# If you just want "gelu", disable fused_mlp.
|
||||
config.hidden_act = "gelu_new"
|
||||
config.use_flash_attn = True
|
||||
|
||||
206
tests/models/test_bigcode.py
Normal file
206
tests/models/test_bigcode.py
Normal file
@ -0,0 +1,206 @@
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer, GPTBigCodeConfig
|
||||
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeForCausalLM
|
||||
|
||||
from flash_attn.models.bigcode import bigcode_config_to_gpt2_config, inv_remap_state_dict_hf_bigcode
|
||||
from flash_attn.models.gpt import GPTLMHeadModel, remap_state_dict_hf_bigcode
|
||||
from flash_attn.utils.generation import update_graph_cache
|
||||
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["bigcode/starcoderbase-1b", "WizardLM/WizardCoder-1B-V1.0"])
|
||||
def test_bigcode_state_dict(model_name):
|
||||
config = bigcode_config_to_gpt2_config(GPTBigCodeConfig.from_pretrained(model_name))
|
||||
pretrained_state_dict = remap_state_dict_hf_bigcode(
|
||||
state_dict_from_pretrained(model_name), config
|
||||
)
|
||||
model = GPTLMHeadModel(config, device="meta")
|
||||
state_dict = model.state_dict()
|
||||
assert state_dict.keys() == pretrained_state_dict.keys()
|
||||
for k in state_dict.keys():
|
||||
assert state_dict[k].shape == pretrained_state_dict[k].shape
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["bigcode/starcoderbase-1b", "WizardLM/WizardCoder-1B-V1.0"])
|
||||
def test_bigcode_optimized(model_name):
|
||||
"""Check that our implementation of BigCode (with all optimizations enabled) matches the
|
||||
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
|
||||
forward pass in fp16, when compared to the HF forward pass in fp32.
|
||||
"""
|
||||
dtype = torch.float16
|
||||
device = "cuda"
|
||||
config = bigcode_config_to_gpt2_config(GPTBigCodeConfig.from_pretrained(model_name))
|
||||
config.use_flash_attn = True # FlashAttention-2 supports headdim 256
|
||||
config.fused_bias_fc = True
|
||||
config.fused_mlp = True
|
||||
config.fused_dropout_add_ln = True
|
||||
config.residual_in_fp32 = True
|
||||
|
||||
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
|
||||
model.eval()
|
||||
|
||||
torch.manual_seed(0)
|
||||
batch_size = 2
|
||||
max_seqlen = 256
|
||||
input_ids = torch.randint(
|
||||
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
|
||||
)
|
||||
with torch.no_grad():
|
||||
out = model.transformer(input_ids)
|
||||
logits = model(input_ids).logits
|
||||
del model
|
||||
|
||||
# Without device_map, the model is loaded on the CPU, which is very slow
|
||||
model_ref = GPTBigCodeForCausalLM.from_pretrained(model_name, device_map={"": device})
|
||||
model_ref.eval()
|
||||
with torch.no_grad():
|
||||
out_ref = model_ref.transformer(input_ids).last_hidden_state
|
||||
logits_ref = model_ref(input_ids).logits
|
||||
del model_ref
|
||||
|
||||
model_hf = GPTBigCodeForCausalLM.from_pretrained(
|
||||
model_name, torch_dtype=dtype, device_map={"": device}
|
||||
)
|
||||
model_hf.eval()
|
||||
out_hf = model_hf.transformer(input_ids).last_hidden_state
|
||||
logits_hf = model_hf(input_ids).logits
|
||||
del model_hf
|
||||
|
||||
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
||||
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
||||
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
|
||||
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
|
||||
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
|
||||
|
||||
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
|
||||
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
|
||||
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
|
||||
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
|
||||
assert (logits - logits_ref).abs().max().item() < 3 * (
|
||||
logits_hf - logits_ref
|
||||
).abs().max().item()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["bigcode/starcoderbase-1b", "WizardLM/WizardCoder-1B-V1.0"])
|
||||
def test_bigcode_generation(model_name):
|
||||
"""Check that our implementation of BigCode (with all optimizations enabled) matches the
|
||||
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
|
||||
forward pass in fp16, when compared to the HF forward pass in fp32.
|
||||
"""
|
||||
dtype = torch.float16
|
||||
device = "cuda"
|
||||
config = bigcode_config_to_gpt2_config(GPTBigCodeConfig.from_pretrained(model_name))
|
||||
config.use_flash_attn = True # FlashAttention-2 supports headdim 256
|
||||
config.fused_bias_fc = True
|
||||
config.fused_mlp = True
|
||||
config.fused_dropout_add_ln = True
|
||||
# Only prenorm supports residual_in_fp32
|
||||
config.residual_in_fp32 = True
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
torch.manual_seed(0)
|
||||
batch_size = 1
|
||||
seqlen = 100
|
||||
max_length = 150
|
||||
input_ids = torch.randint(
|
||||
0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
|
||||
)
|
||||
|
||||
model_hf = GPTBigCodeForCausalLM.from_pretrained(
|
||||
model_name, torch_dtype=dtype, device_map={"": device}
|
||||
)
|
||||
model_hf.eval()
|
||||
print("HF fp16")
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
out_hf = model_hf.generate(
|
||||
input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
|
||||
del model_hf
|
||||
|
||||
model_ref = GPTBigCodeForCausalLM.from_pretrained(model_name, device_map={"": device})
|
||||
model_ref.eval()
|
||||
with torch.no_grad():
|
||||
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1]
|
||||
del model_ref
|
||||
|
||||
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
|
||||
model.eval()
|
||||
|
||||
print("Without CUDA graph")
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
out = model.generate(
|
||||
input_ids=input_ids,
|
||||
max_length=max_length,
|
||||
eos_token_id=eos_token_id,
|
||||
fused_ft_kernel=True,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
enable_timing=True,
|
||||
teacher_outputs=out_hf.sequences,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
|
||||
|
||||
# Capture graph outside the timing loop
|
||||
batch_size, seqlen_og = input_ids.shape
|
||||
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
|
||||
print("With CUDA graph")
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
out_cg = model.generate(
|
||||
input_ids=input_ids,
|
||||
max_length=max_length,
|
||||
fused_ft_kernel=True,
|
||||
cg=True,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
enable_timing=True,
|
||||
teacher_outputs=out_hf.sequences,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
|
||||
|
||||
with torch.no_grad():
|
||||
logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1]
|
||||
logits_hf = torch.stack(out_hf.scores, dim=1)
|
||||
logits = torch.stack(out.scores, dim=1)
|
||||
logits_cg = torch.stack(out_cg.scores, dim=1)
|
||||
|
||||
del model
|
||||
|
||||
hf_error = (logits_hf - logits_ref).abs().max().item()
|
||||
assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error
|
||||
|
||||
print(f"HF fp16 logits max diff: {hf_error}")
|
||||
print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }")
|
||||
assert (logits - logits_ref).abs().max().item() < 2 * hf_error
|
||||
print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }")
|
||||
assert (logits_cg - logits_ref).abs().max().item() < 2 * hf_error
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["bigcode/starcoderbase-1b", "WizardLM/WizardCoder-1B-V1.0"])
|
||||
def test_inv_remap_state_dict(model_name: str):
|
||||
"""
|
||||
Verify that we can convert a HF BigCode model to flash_attn and back.
|
||||
"""
|
||||
|
||||
state_dict = state_dict_from_pretrained(model_name)
|
||||
config = GPTBigCodeConfig.from_pretrained(model_name)
|
||||
|
||||
flash_state_dict = remap_state_dict_hf_bigcode(state_dict, config)
|
||||
recovered_state_dict = inv_remap_state_dict_hf_bigcode(flash_state_dict, config)
|
||||
|
||||
assert set(state_dict.keys()) == set(recovered_state_dict.keys())
|
||||
|
||||
for k in state_dict.keys():
|
||||
assert state_dict[k].shape == recovered_state_dict[k].shape
|
||||
torch.testing.assert_close(state_dict[k], recovered_state_dict[k], rtol=1e-6, atol=1e-6)
|
||||
Loading…
Reference in New Issue
Block a user