Add BigCode converters (#532)
This commit is contained in:
parent
8a733cbd53
commit
07005806ff
@ -18,11 +18,16 @@ import torch.nn.functional as F
|
|||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from transformers import BertConfig, PretrainedConfig
|
from transformers import BertConfig, PretrainedConfig
|
||||||
from transformers.models.bert.modeling_bert import (
|
from transformers.models.bert.modeling_bert import (
|
||||||
BaseModelOutputWithPoolingAndCrossAttentions, BertForPreTrainingOutput)
|
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||||
|
BertForPreTrainingOutput,
|
||||||
|
)
|
||||||
|
|
||||||
from flash_attn.bert_padding import (index_first_axis,
|
from flash_attn.bert_padding import (
|
||||||
index_first_axis_residual, pad_input,
|
index_first_axis,
|
||||||
unpad_input)
|
index_first_axis_residual,
|
||||||
|
pad_input,
|
||||||
|
unpad_input,
|
||||||
|
)
|
||||||
from flash_attn.modules.block import Block
|
from flash_attn.modules.block import Block
|
||||||
from flash_attn.modules.embedding import BertEmbeddings
|
from flash_attn.modules.embedding import BertEmbeddings
|
||||||
from flash_attn.modules.mha import MHA
|
from flash_attn.modules.mha import MHA
|
||||||
@ -75,11 +80,15 @@ def create_mlp_cls(config, layer_idx=None, return_residual=False):
|
|||||||
inner_dim = config.intermediate_size
|
inner_dim = config.intermediate_size
|
||||||
fused_mlp = getattr(config, "fused_mlp", False)
|
fused_mlp = getattr(config, "fused_mlp", False)
|
||||||
if fused_mlp:
|
if fused_mlp:
|
||||||
assert config.hidden_act in ["gelu_new", "gelu_fast"], (
|
assert config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"], (
|
||||||
"fused_mlp only " "supports approximate gelu"
|
"fused_mlp only " "supports approximate gelu"
|
||||||
)
|
)
|
||||||
if not fused_mlp:
|
if not fused_mlp:
|
||||||
approximate = "tanh" if config.hidden_act in ["gelu_new", "gelu_fast"] else "none"
|
approximate = (
|
||||||
|
"tanh"
|
||||||
|
if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
|
||||||
|
else "none"
|
||||||
|
)
|
||||||
mlp_cls = partial(
|
mlp_cls = partial(
|
||||||
Mlp,
|
Mlp,
|
||||||
hidden_features=inner_dim,
|
hidden_features=inner_dim,
|
||||||
@ -232,7 +241,11 @@ class BertPredictionHeadTransform(nn.Module):
|
|||||||
raise ImportError("dropout_add_layer_norm is not installed")
|
raise ImportError("dropout_add_layer_norm is not installed")
|
||||||
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
||||||
self.dense = linear_cls(config.hidden_size, config.hidden_size)
|
self.dense = linear_cls(config.hidden_size, config.hidden_size)
|
||||||
approximate = "tanh" if config.hidden_act in ["gelu_new", "gelu_fast"] else "none"
|
approximate = (
|
||||||
|
"tanh"
|
||||||
|
if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
|
||||||
|
else "none"
|
||||||
|
)
|
||||||
self.transform_act_fn = nn.GELU(approximate=approximate)
|
self.transform_act_fn = nn.GELU(approximate=approximate)
|
||||||
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
@ -334,7 +347,7 @@ class BertModel(BertPreTrainedModel):
|
|||||||
self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
|
self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
|
||||||
if self.fused_dropout_add_ln and layer_norm is None:
|
if self.fused_dropout_add_ln and layer_norm is None:
|
||||||
raise ImportError("dropout_add_layer_norm is not installed")
|
raise ImportError("dropout_add_layer_norm is not installed")
|
||||||
assert config.hidden_act in ["gelu", "gelu_new", "gelu_fast"]
|
assert config.hidden_act in ["gelu", "gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
|
||||||
|
|
||||||
self.embeddings = BertEmbeddings(
|
self.embeddings = BertEmbeddings(
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
|
|||||||
233
flash_attn/models/bigcode.py
Normal file
233
flash_attn/models/bigcode.py
Normal file
@ -0,0 +1,233 @@
|
|||||||
|
import math
|
||||||
|
import re
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers import GPT2Config, GPTBigCodeConfig, PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
def remap_state_dict_hf_bigcode(state_dict, config: PretrainedConfig):
|
||||||
|
"""
|
||||||
|
Map the state_dict of a Huggingface BigCode model to be flash_attn compatible.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Word embedding and position embedding
|
||||||
|
def key_mapping_pos_emb(key):
|
||||||
|
return re.sub(r"^transformer.wpe.", "transformer.embeddings.position_embeddings.", key)
|
||||||
|
|
||||||
|
state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
|
||||||
|
word_embeddings = state_dict.pop("transformer.wte.weight")
|
||||||
|
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
||||||
|
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
||||||
|
vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
||||||
|
state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad(
|
||||||
|
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
||||||
|
)
|
||||||
|
state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"]
|
||||||
|
|
||||||
|
# LayerNorm
|
||||||
|
def key_mapping_ln(key):
|
||||||
|
key = re.sub(r"^transformer.ln_f.(weight|bias)", r"transformer.ln_f.\1", key)
|
||||||
|
key = re.sub(
|
||||||
|
r"^transformer.h.(\d+).ln_(1|2).(weight|bias)",
|
||||||
|
r"transformer.layers.\1.norm\2.\3",
|
||||||
|
key,
|
||||||
|
)
|
||||||
|
return key
|
||||||
|
|
||||||
|
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
||||||
|
|
||||||
|
def key_mapping_mlp(key):
|
||||||
|
key = re.sub(
|
||||||
|
r"^transformer.h.(\d+).mlp.c_fc.weight",
|
||||||
|
r"transformer.layers.\1.mlp.fc1.weight",
|
||||||
|
key,
|
||||||
|
)
|
||||||
|
key = re.sub(
|
||||||
|
r"^transformer.h.(\d+).mlp.c_proj.weight",
|
||||||
|
r"transformer.layers.\1.mlp.fc2.weight",
|
||||||
|
key,
|
||||||
|
)
|
||||||
|
key = re.sub(
|
||||||
|
r"^transformer.h.(\d+).mlp.c_fc.bias",
|
||||||
|
r"transformer.layers.\1.mlp.fc1.bias",
|
||||||
|
key,
|
||||||
|
)
|
||||||
|
key = re.sub(
|
||||||
|
r"^transformer.h.(\d+).mlp.c_proj.bias",
|
||||||
|
r"transformer.layers.\1.mlp.fc2.bias",
|
||||||
|
key,
|
||||||
|
)
|
||||||
|
return key
|
||||||
|
|
||||||
|
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
||||||
|
|
||||||
|
# TODO: add support for multi-head attention
|
||||||
|
assert config.multi_query, "Only multi-query attention is supported"
|
||||||
|
|
||||||
|
# Attention
|
||||||
|
for d in range(config.num_hidden_layers):
|
||||||
|
embed_dim = config.n_embd
|
||||||
|
head_dim = embed_dim // config.n_head
|
||||||
|
|
||||||
|
c_attn_weight = state_dict.pop(f"transformer.h.{d}.attn.c_attn.weight")
|
||||||
|
# with multi-query attention, the weights have shape (embed_dim, embed_dim + head_dim + head_dim)
|
||||||
|
# see https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py#L112
|
||||||
|
# see also https://github.com/ggerganov/ggml/blob/dd1d575956e54c5bdc07632f25506b3b1884dbd2/examples/starcoder/convert-hf-to-ggml.py#L183
|
||||||
|
# ((n_head + 2) * head_dim, embed_dim) -> (3 * n_heads * head_dim, hidden_dim)
|
||||||
|
q, k, v = torch.split(c_attn_weight, [embed_dim, head_dim, head_dim], dim=0)
|
||||||
|
# duplicate k, v along the first axis (head_dim, hidden_dim) -> (n_heads * head_dim, hidden_dim)
|
||||||
|
k = torch.tile(k, (config.n_head, 1))
|
||||||
|
v = torch.tile(v, (config.n_head, 1))
|
||||||
|
state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = torch.cat((q, k, v), dim=0)
|
||||||
|
|
||||||
|
# same deal with the bias
|
||||||
|
c_attn_bias = state_dict.pop(f"transformer.h.{d}.attn.c_attn.bias")
|
||||||
|
# ((n_head + 2) * head_dim, embed_dim) -> (3 * n_heads * head_dim, hidden_dim)
|
||||||
|
q, k, v = torch.split(c_attn_bias, [embed_dim, head_dim, head_dim], dim=0)
|
||||||
|
# duplicate k, v along the first axis (head_dim, hidden_dim) -> (n_heads * head_dim, hidden_dim)
|
||||||
|
k = torch.tile(k, (config.n_head,))
|
||||||
|
v = torch.tile(v, (config.n_head,))
|
||||||
|
state_dict[f"transformer.layers.{d}.mixer.Wqkv.bias"] = torch.cat((q, k, v), dim=0)
|
||||||
|
|
||||||
|
def key_mapping_attn(key):
|
||||||
|
key = re.sub(
|
||||||
|
r"^transformer.h.(\d+).attn.c_proj.weight",
|
||||||
|
r"transformer.layers.\1.mixer.out_proj.weight",
|
||||||
|
key,
|
||||||
|
)
|
||||||
|
key = re.sub(
|
||||||
|
r"^transformer.h.(\d+).attn.c_proj.bias",
|
||||||
|
r"transformer.layers.\1.mixer.out_proj.bias",
|
||||||
|
key,
|
||||||
|
)
|
||||||
|
return key
|
||||||
|
|
||||||
|
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
||||||
|
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def inv_remap_state_dict_hf_bigcode(state_dict, config: PretrainedConfig):
|
||||||
|
"""
|
||||||
|
Map the state_dict of a flash_attn model to be Huggingface BigCode compatible.
|
||||||
|
|
||||||
|
This function is meant to be the inverse of remap_state_dict_hf_bigcode.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Word embedding and position embeddings
|
||||||
|
def inv_key_mapping_pos_emb(key):
|
||||||
|
return re.sub(r"^transformer.embeddings.position_embeddings.", "transformer.wpe.", key)
|
||||||
|
|
||||||
|
state_dict = OrderedDict((inv_key_mapping_pos_emb(k), v) for k, v in state_dict.items())
|
||||||
|
word_embeddings = state_dict.pop("transformer.embeddings.word_embeddings.weight")
|
||||||
|
|
||||||
|
word_embeddings = word_embeddings[:, : config.vocab_size]
|
||||||
|
state_dict["transformer.wte.weight"] = word_embeddings
|
||||||
|
state_dict["lm_head.weight"] = word_embeddings
|
||||||
|
|
||||||
|
# LayerNorm
|
||||||
|
def inv_key_mapping_ln(key):
|
||||||
|
key = re.sub(r"^transformer.ln_f.(weight|bias)", r"transformer.ln_f.\1", key)
|
||||||
|
key = re.sub(
|
||||||
|
r"^transformer.layers.(\d+).norm(1|2).(weight|bias)",
|
||||||
|
r"transformer.h.\1.ln_\2.\3",
|
||||||
|
key,
|
||||||
|
)
|
||||||
|
return key
|
||||||
|
|
||||||
|
state_dict = OrderedDict((inv_key_mapping_ln(k), v) for k, v in state_dict.items())
|
||||||
|
|
||||||
|
# MLPs
|
||||||
|
def inv_key_mapping_mlp(key):
|
||||||
|
key = re.sub(
|
||||||
|
r"^transformer.layers.(\d+).mlp.fc1.weight",
|
||||||
|
r"transformer.h.\1.mlp.c_fc.weight",
|
||||||
|
key,
|
||||||
|
)
|
||||||
|
key = re.sub(
|
||||||
|
r"^transformer.layers.(\d+).mlp.fc2.weight",
|
||||||
|
r"transformer.h.\1.mlp.c_proj.weight",
|
||||||
|
key,
|
||||||
|
)
|
||||||
|
key = re.sub(
|
||||||
|
r"^transformer.layers.(\d+).mlp.fc1.bias",
|
||||||
|
r"transformer.h.\1.mlp.c_fc.bias",
|
||||||
|
key,
|
||||||
|
)
|
||||||
|
key = re.sub(
|
||||||
|
r"^transformer.layers.(\d+).mlp.fc2.bias",
|
||||||
|
r"transformer.h.\1.mlp.c_proj.bias",
|
||||||
|
key,
|
||||||
|
)
|
||||||
|
return key
|
||||||
|
|
||||||
|
state_dict = OrderedDict((inv_key_mapping_mlp(k), v) for k, v in state_dict.items())
|
||||||
|
|
||||||
|
# Attention
|
||||||
|
for d in range(config.num_hidden_layers):
|
||||||
|
embed_dim = config.n_embd
|
||||||
|
head_dim = embed_dim // config.n_head
|
||||||
|
|
||||||
|
Wqkv_weight = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.weight")
|
||||||
|
q, k, v = torch.split(
|
||||||
|
Wqkv_weight, [embed_dim, head_dim * config.n_head, head_dim * config.n_head], dim=0
|
||||||
|
)
|
||||||
|
c_attn_weight = torch.cat((q, k[:head_dim], v[:head_dim]), dim=0)
|
||||||
|
state_dict[f"transformer.h.{d}.attn.c_attn.weight"] = c_attn_weight
|
||||||
|
|
||||||
|
# Same deal with the bias
|
||||||
|
Wqkv_bias = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.bias")
|
||||||
|
q, k, v = torch.split(
|
||||||
|
Wqkv_bias, [embed_dim, head_dim * config.n_head, head_dim * config.n_head], dim=0
|
||||||
|
)
|
||||||
|
c_attn_bias = torch.cat((q, k[:head_dim], v[:head_dim]), dim=0)
|
||||||
|
state_dict[f"transformer.h.{d}.attn.c_attn.bias"] = c_attn_bias
|
||||||
|
|
||||||
|
def inv_key_mapping_attn(key):
|
||||||
|
key = re.sub(
|
||||||
|
r"^transformer.layers.(\d+).mixer.out_proj.weight",
|
||||||
|
r"transformer.h.\1.attn.c_proj.weight",
|
||||||
|
key,
|
||||||
|
)
|
||||||
|
key = re.sub(
|
||||||
|
r"^transformer.layers.(\d+).mixer.out_proj.bias",
|
||||||
|
r"transformer.h.\1.attn.c_proj.bias",
|
||||||
|
key,
|
||||||
|
)
|
||||||
|
return key
|
||||||
|
|
||||||
|
state_dict = OrderedDict((inv_key_mapping_attn(k), v) for k, v in state_dict.items())
|
||||||
|
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def bigcode_config_to_gpt2_config(bigcode_config: GPTBigCodeConfig) -> GPT2Config:
|
||||||
|
return GPT2Config(
|
||||||
|
activation_function=bigcode_config.activation_function,
|
||||||
|
attn_pdrop=bigcode_config.attn_pdrop,
|
||||||
|
bos_token_id=bigcode_config.bos_token_id,
|
||||||
|
embd_pdrop=bigcode_config.embd_pdrop,
|
||||||
|
eos_token_id=bigcode_config.eos_token_id,
|
||||||
|
initializer_range=bigcode_config.initializer_range,
|
||||||
|
layer_norm_epsilon=bigcode_config.layer_norm_epsilon,
|
||||||
|
max_batch_size=bigcode_config.max_batch_size,
|
||||||
|
max_sequence_length=bigcode_config.max_sequence_length,
|
||||||
|
model_type=bigcode_config.model_type,
|
||||||
|
multi_query=bigcode_config.multi_query,
|
||||||
|
n_embd=bigcode_config.n_embd,
|
||||||
|
n_head=bigcode_config.n_head,
|
||||||
|
n_inner=bigcode_config.n_inner,
|
||||||
|
n_layer=bigcode_config.n_layer,
|
||||||
|
n_positions=bigcode_config.n_positions,
|
||||||
|
resid_pdrop=bigcode_config.resid_pdrop,
|
||||||
|
scale_attn_weights=bigcode_config.scale_attn_weights,
|
||||||
|
summary_activation=bigcode_config.summary_activation,
|
||||||
|
summary_first_dropout=bigcode_config.summary_first_dropout,
|
||||||
|
summary_proj_to_labels=bigcode_config.summary_proj_to_labels,
|
||||||
|
summary_type=bigcode_config.summary_type,
|
||||||
|
summary_use_proj=bigcode_config.summary_use_proj,
|
||||||
|
use_cache=bigcode_config.use_cache,
|
||||||
|
vocab_size=bigcode_config.vocab_size,
|
||||||
|
)
|
||||||
@ -13,6 +13,7 @@ import torch.nn.functional as F
|
|||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from transformers import GPT2Config
|
from transformers import GPT2Config
|
||||||
|
|
||||||
|
from flash_attn.models.bigcode import remap_state_dict_hf_bigcode
|
||||||
from flash_attn.models.falcon import remap_state_dict_hf_falcon
|
from flash_attn.models.falcon import remap_state_dict_hf_falcon
|
||||||
from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox
|
from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox
|
||||||
from flash_attn.models.gptj import remap_state_dict_hf_gptj
|
from flash_attn.models.gptj import remap_state_dict_hf_gptj
|
||||||
@ -21,12 +22,16 @@ from flash_attn.models.opt import remap_state_dict_hf_opt
|
|||||||
from flash_attn.modules.block import Block, ParallelBlock
|
from flash_attn.modules.block import Block, ParallelBlock
|
||||||
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
|
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
|
||||||
from flash_attn.modules.mha import MHA, ParallelMHA
|
from flash_attn.modules.mha import MHA, ParallelMHA
|
||||||
from flash_attn.modules.mlp import (FusedMLP, GatedMlp, Mlp, ParallelFusedMLP,
|
from flash_attn.modules.mlp import (
|
||||||
ParallelGatedMlp, ParallelMLP)
|
FusedMLP,
|
||||||
|
GatedMlp,
|
||||||
|
Mlp,
|
||||||
|
ParallelFusedMLP,
|
||||||
|
ParallelGatedMlp,
|
||||||
|
ParallelMLP,
|
||||||
|
)
|
||||||
from flash_attn.ops.activations import sqrelu_fwd
|
from flash_attn.ops.activations import sqrelu_fwd
|
||||||
from flash_attn.utils.distributed import (all_gather_raw,
|
from flash_attn.utils.distributed import all_gather_raw, get_dim_for_local_rank, sync_shared_params
|
||||||
get_dim_for_local_rank,
|
|
||||||
sync_shared_params)
|
|
||||||
from flash_attn.utils.generation import GenerationMixin
|
from flash_attn.utils.generation import GenerationMixin
|
||||||
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
||||||
|
|
||||||
@ -41,8 +46,7 @@ except ImportError:
|
|||||||
dropout_add_layer_norm = None
|
dropout_add_layer_norm = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from flash_attn.ops.layer_norm import \
|
from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual
|
||||||
dropout_add_layer_norm_parallel_residual
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
dropout_add_layer_norm_parallel_residual = None
|
dropout_add_layer_norm_parallel_residual = None
|
||||||
|
|
||||||
@ -129,6 +133,7 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
|
|||||||
"gelu_new",
|
"gelu_new",
|
||||||
"gelu_fast",
|
"gelu_fast",
|
||||||
"gelu_approx",
|
"gelu_approx",
|
||||||
|
"gelu_pytorch_tanh",
|
||||||
"relu",
|
"relu",
|
||||||
"sqrelu",
|
"sqrelu",
|
||||||
]
|
]
|
||||||
@ -144,6 +149,7 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
|
|||||||
"gelu_new",
|
"gelu_new",
|
||||||
"gelu_fast",
|
"gelu_fast",
|
||||||
"gelu_approx",
|
"gelu_approx",
|
||||||
|
"gelu_pytorch_tanh",
|
||||||
"relu",
|
"relu",
|
||||||
"sqrelu",
|
"sqrelu",
|
||||||
"glu",
|
"glu",
|
||||||
@ -182,7 +188,8 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
|
|||||||
else:
|
else:
|
||||||
approximate = (
|
approximate = (
|
||||||
"tanh"
|
"tanh"
|
||||||
if config.activation_function in ["gelu_new", "gelu_fast", "gelu_approx"]
|
if config.activation_function
|
||||||
|
in ["gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh"]
|
||||||
else "none"
|
else "none"
|
||||||
)
|
)
|
||||||
activation = partial(F.gelu, approximate=approximate)
|
activation = partial(F.gelu, approximate=approximate)
|
||||||
@ -215,7 +222,8 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
|
|||||||
raise ImportError("fused_dense is not installed")
|
raise ImportError("fused_dense is not installed")
|
||||||
activation = (
|
activation = (
|
||||||
"gelu_approx"
|
"gelu_approx"
|
||||||
if config.activation_function in ["gelu_new", "gelu_fast", "gelu_approx"]
|
if config.activation_function
|
||||||
|
in ["gelu_new", "gelu_fast", "gelu_approx", "gelu_pytorch_tanh"]
|
||||||
else config.activation_function
|
else config.activation_function
|
||||||
)
|
)
|
||||||
mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP
|
mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP
|
||||||
@ -352,6 +360,8 @@ class GPTPreTrainedModel(nn.Module):
|
|||||||
state_dict = remap_state_dict_hf_falcon(state_dict, config)
|
state_dict = remap_state_dict_hf_falcon(state_dict, config)
|
||||||
elif model_name.startswith("meta-llama/Llama-"):
|
elif model_name.startswith("meta-llama/Llama-"):
|
||||||
state_dict = remap_state_dict_hf_llama(state_dict, config)
|
state_dict = remap_state_dict_hf_llama(state_dict, config)
|
||||||
|
elif model_name.startswith("bigcode/") or model_name.startswith("WizardLM/"):
|
||||||
|
state_dict = remap_state_dict_hf_bigcode(state_dict, config)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Model {model_name} not supported")
|
raise NotImplementedError(f"Model {model_name} not supported")
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
@ -394,6 +404,7 @@ class GPTModel(GPTPreTrainedModel):
|
|||||||
"gelu_new",
|
"gelu_new",
|
||||||
"gelu_fast",
|
"gelu_fast",
|
||||||
"gelu_approx",
|
"gelu_approx",
|
||||||
|
"gelu_pytorch_tanh",
|
||||||
"relu",
|
"relu",
|
||||||
"sqrelu",
|
"sqrelu",
|
||||||
"glu",
|
"glu",
|
||||||
@ -628,7 +639,9 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
|
|||||||
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
|
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
|
||||||
num_last_tokens: if > 0, only return the logits for the last n tokens
|
num_last_tokens: if > 0, only return the logits for the last n tokens
|
||||||
"""
|
"""
|
||||||
assert input_ids.ndim == 2, f"Expected `input_ids` to have shape [b, slen], but got shape {input_ids.shape}"
|
assert (
|
||||||
|
input_ids.ndim == 2
|
||||||
|
), f"Expected `input_ids` to have shape [b, slen], but got shape {input_ids.shape}"
|
||||||
b, slen = input_ids.shape
|
b, slen = input_ids.shape
|
||||||
hidden_states = self.transformer(
|
hidden_states = self.transformer(
|
||||||
input_ids, position_ids=position_ids, inference_params=inference_params
|
input_ids, position_ids=position_ids, inference_params=inference_params
|
||||||
@ -845,11 +858,11 @@ def combine_state_dicts_tp(state_dicts: list[dict[str, torch.Tensor]], config: G
|
|||||||
nheadqkv=rank_n_head + 2 * rank_n_head_kv,
|
nheadqkv=rank_n_head + 2 * rank_n_head_kv,
|
||||||
headdim=headdim,
|
headdim=headdim,
|
||||||
)
|
)
|
||||||
for s, rank_n_head, rank_n_head_kv in zip(state_dicts, n_head_each_rank, n_head_kv_each_rank)
|
for s, rank_n_head, rank_n_head_kv in zip(
|
||||||
|
state_dicts, n_head_each_rank, n_head_kv_each_rank
|
||||||
|
)
|
||||||
]
|
]
|
||||||
wq = torch.cat(
|
wq = torch.cat([x[: n_head_each_rank[rank]] for rank, x in enumerate(xs)], dim=0)
|
||||||
[x[: n_head_each_rank[rank]] for rank, x in enumerate(xs)], dim=0
|
|
||||||
)
|
|
||||||
wk = torch.cat(
|
wk = torch.cat(
|
||||||
[
|
[
|
||||||
x[
|
x[
|
||||||
|
|||||||
@ -6,12 +6,15 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from transformers import BertConfig
|
from transformers import BertConfig
|
||||||
from transformers.models.bert.modeling_bert import \
|
from transformers.models.bert.modeling_bert import BertForPreTraining as BertForPreTrainingHF
|
||||||
BertForPreTraining as BertForPreTrainingHF
|
|
||||||
from transformers.models.bert.modeling_bert import BertModel as BertModelHF
|
from transformers.models.bert.modeling_bert import BertModel as BertModelHF
|
||||||
|
|
||||||
from flash_attn.models.bert import (BertForPreTraining, BertModel,
|
from flash_attn.models.bert import (
|
||||||
inv_remap_state_dict, remap_state_dict)
|
BertForPreTraining,
|
||||||
|
BertModel,
|
||||||
|
inv_remap_state_dict,
|
||||||
|
remap_state_dict,
|
||||||
|
)
|
||||||
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
||||||
|
|
||||||
|
|
||||||
@ -102,7 +105,7 @@ def test_bert_optimized(model_name):
|
|||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
config = BertConfig.from_pretrained(model_name)
|
config = BertConfig.from_pretrained(model_name)
|
||||||
# Our implementation of fused_mlp assumes the activation is
|
# Our implementation of fused_mlp assumes the activation is
|
||||||
# nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new" or "gelu_fast".
|
# nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new", "gelu_fast", or "gelu_pytorch_tanh".
|
||||||
# If you just want "gelu", disable fused_mlp.
|
# If you just want "gelu", disable fused_mlp.
|
||||||
config.hidden_act = "gelu_new"
|
config.hidden_act = "gelu_new"
|
||||||
config.use_flash_attn = True
|
config.use_flash_attn = True
|
||||||
@ -209,7 +212,7 @@ def test_bert_dense_seq_output(model_name, has_key_padding_mask, last_layer_subs
|
|||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
config = BertConfig.from_pretrained(model_name)
|
config = BertConfig.from_pretrained(model_name)
|
||||||
# Our implementation of fused_mlp assumes the activation is
|
# Our implementation of fused_mlp assumes the activation is
|
||||||
# nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new" or "gelu_fast".
|
# nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new", "gelu_fast", or "gelu_pytorch_tanh".
|
||||||
# If you just want "gelu", disable fused_mlp.
|
# If you just want "gelu", disable fused_mlp.
|
||||||
config.hidden_act = "gelu_new"
|
config.hidden_act = "gelu_new"
|
||||||
config.use_flash_attn = True
|
config.use_flash_attn = True
|
||||||
|
|||||||
206
tests/models/test_bigcode.py
Normal file
206
tests/models/test_bigcode.py
Normal file
@ -0,0 +1,206 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from transformers import AutoTokenizer, GPTBigCodeConfig
|
||||||
|
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeForCausalLM
|
||||||
|
|
||||||
|
from flash_attn.models.bigcode import bigcode_config_to_gpt2_config, inv_remap_state_dict_hf_bigcode
|
||||||
|
from flash_attn.models.gpt import GPTLMHeadModel, remap_state_dict_hf_bigcode
|
||||||
|
from flash_attn.utils.generation import update_graph_cache
|
||||||
|
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_name", ["bigcode/starcoderbase-1b", "WizardLM/WizardCoder-1B-V1.0"])
|
||||||
|
def test_bigcode_state_dict(model_name):
|
||||||
|
config = bigcode_config_to_gpt2_config(GPTBigCodeConfig.from_pretrained(model_name))
|
||||||
|
pretrained_state_dict = remap_state_dict_hf_bigcode(
|
||||||
|
state_dict_from_pretrained(model_name), config
|
||||||
|
)
|
||||||
|
model = GPTLMHeadModel(config, device="meta")
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
assert state_dict.keys() == pretrained_state_dict.keys()
|
||||||
|
for k in state_dict.keys():
|
||||||
|
assert state_dict[k].shape == pretrained_state_dict[k].shape
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_name", ["bigcode/starcoderbase-1b", "WizardLM/WizardCoder-1B-V1.0"])
|
||||||
|
def test_bigcode_optimized(model_name):
|
||||||
|
"""Check that our implementation of BigCode (with all optimizations enabled) matches the
|
||||||
|
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
|
||||||
|
forward pass in fp16, when compared to the HF forward pass in fp32.
|
||||||
|
"""
|
||||||
|
dtype = torch.float16
|
||||||
|
device = "cuda"
|
||||||
|
config = bigcode_config_to_gpt2_config(GPTBigCodeConfig.from_pretrained(model_name))
|
||||||
|
config.use_flash_attn = True # FlashAttention-2 supports headdim 256
|
||||||
|
config.fused_bias_fc = True
|
||||||
|
config.fused_mlp = True
|
||||||
|
config.fused_dropout_add_ln = True
|
||||||
|
config.residual_in_fp32 = True
|
||||||
|
|
||||||
|
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
batch_size = 2
|
||||||
|
max_seqlen = 256
|
||||||
|
input_ids = torch.randint(
|
||||||
|
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
|
||||||
|
)
|
||||||
|
with torch.no_grad():
|
||||||
|
out = model.transformer(input_ids)
|
||||||
|
logits = model(input_ids).logits
|
||||||
|
del model
|
||||||
|
|
||||||
|
# Without device_map, the model is loaded on the CPU, which is very slow
|
||||||
|
model_ref = GPTBigCodeForCausalLM.from_pretrained(model_name, device_map={"": device})
|
||||||
|
model_ref.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
out_ref = model_ref.transformer(input_ids).last_hidden_state
|
||||||
|
logits_ref = model_ref(input_ids).logits
|
||||||
|
del model_ref
|
||||||
|
|
||||||
|
model_hf = GPTBigCodeForCausalLM.from_pretrained(
|
||||||
|
model_name, torch_dtype=dtype, device_map={"": device}
|
||||||
|
)
|
||||||
|
model_hf.eval()
|
||||||
|
out_hf = model_hf.transformer(input_ids).last_hidden_state
|
||||||
|
logits_hf = model_hf(input_ids).logits
|
||||||
|
del model_hf
|
||||||
|
|
||||||
|
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
||||||
|
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
||||||
|
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
|
||||||
|
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
|
||||||
|
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
|
||||||
|
|
||||||
|
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
|
||||||
|
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
|
||||||
|
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
|
||||||
|
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
|
||||||
|
assert (logits - logits_ref).abs().max().item() < 3 * (
|
||||||
|
logits_hf - logits_ref
|
||||||
|
).abs().max().item()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_name", ["bigcode/starcoderbase-1b", "WizardLM/WizardCoder-1B-V1.0"])
|
||||||
|
def test_bigcode_generation(model_name):
|
||||||
|
"""Check that our implementation of BigCode (with all optimizations enabled) matches the
|
||||||
|
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
|
||||||
|
forward pass in fp16, when compared to the HF forward pass in fp32.
|
||||||
|
"""
|
||||||
|
dtype = torch.float16
|
||||||
|
device = "cuda"
|
||||||
|
config = bigcode_config_to_gpt2_config(GPTBigCodeConfig.from_pretrained(model_name))
|
||||||
|
config.use_flash_attn = True # FlashAttention-2 supports headdim 256
|
||||||
|
config.fused_bias_fc = True
|
||||||
|
config.fused_mlp = True
|
||||||
|
config.fused_dropout_add_ln = True
|
||||||
|
# Only prenorm supports residual_in_fp32
|
||||||
|
config.residual_in_fp32 = True
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
eos_token_id = tokenizer.eos_token_id
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
batch_size = 1
|
||||||
|
seqlen = 100
|
||||||
|
max_length = 150
|
||||||
|
input_ids = torch.randint(
|
||||||
|
0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
model_hf = GPTBigCodeForCausalLM.from_pretrained(
|
||||||
|
model_name, torch_dtype=dtype, device_map={"": device}
|
||||||
|
)
|
||||||
|
model_hf.eval()
|
||||||
|
print("HF fp16")
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
start = time.time()
|
||||||
|
out_hf = model_hf.generate(
|
||||||
|
input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
|
||||||
|
del model_hf
|
||||||
|
|
||||||
|
model_ref = GPTBigCodeForCausalLM.from_pretrained(model_name, device_map={"": device})
|
||||||
|
model_ref.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1]
|
||||||
|
del model_ref
|
||||||
|
|
||||||
|
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
print("Without CUDA graph")
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
start = time.time()
|
||||||
|
out = model.generate(
|
||||||
|
input_ids=input_ids,
|
||||||
|
max_length=max_length,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
fused_ft_kernel=True,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
output_scores=True,
|
||||||
|
enable_timing=True,
|
||||||
|
teacher_outputs=out_hf.sequences,
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
|
||||||
|
|
||||||
|
# Capture graph outside the timing loop
|
||||||
|
batch_size, seqlen_og = input_ids.shape
|
||||||
|
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
|
||||||
|
print("With CUDA graph")
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
start = time.time()
|
||||||
|
out_cg = model.generate(
|
||||||
|
input_ids=input_ids,
|
||||||
|
max_length=max_length,
|
||||||
|
fused_ft_kernel=True,
|
||||||
|
cg=True,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
output_scores=True,
|
||||||
|
enable_timing=True,
|
||||||
|
teacher_outputs=out_hf.sequences,
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1]
|
||||||
|
logits_hf = torch.stack(out_hf.scores, dim=1)
|
||||||
|
logits = torch.stack(out.scores, dim=1)
|
||||||
|
logits_cg = torch.stack(out_cg.scores, dim=1)
|
||||||
|
|
||||||
|
del model
|
||||||
|
|
||||||
|
hf_error = (logits_hf - logits_ref).abs().max().item()
|
||||||
|
assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error
|
||||||
|
|
||||||
|
print(f"HF fp16 logits max diff: {hf_error}")
|
||||||
|
print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }")
|
||||||
|
assert (logits - logits_ref).abs().max().item() < 2 * hf_error
|
||||||
|
print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }")
|
||||||
|
assert (logits_cg - logits_ref).abs().max().item() < 2 * hf_error
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_name", ["bigcode/starcoderbase-1b", "WizardLM/WizardCoder-1B-V1.0"])
|
||||||
|
def test_inv_remap_state_dict(model_name: str):
|
||||||
|
"""
|
||||||
|
Verify that we can convert a HF BigCode model to flash_attn and back.
|
||||||
|
"""
|
||||||
|
|
||||||
|
state_dict = state_dict_from_pretrained(model_name)
|
||||||
|
config = GPTBigCodeConfig.from_pretrained(model_name)
|
||||||
|
|
||||||
|
flash_state_dict = remap_state_dict_hf_bigcode(state_dict, config)
|
||||||
|
recovered_state_dict = inv_remap_state_dict_hf_bigcode(flash_state_dict, config)
|
||||||
|
|
||||||
|
assert set(state_dict.keys()) == set(recovered_state_dict.keys())
|
||||||
|
|
||||||
|
for k in state_dict.keys():
|
||||||
|
assert state_dict[k].shape == recovered_state_dict[k].shape
|
||||||
|
torch.testing.assert_close(state_dict[k], recovered_state_dict[k], rtol=1e-6, atol=1e-6)
|
||||||
Loading…
Reference in New Issue
Block a user