Reorder LN in Block, support OPT
This commit is contained in:
parent
f1e01c27ba
commit
ff34123bd4
@ -94,7 +94,8 @@ def create_block(config, layer_idx=None):
|
||||
mlp_cls = create_mlp_cls(config, layer_idx, return_residual=return_residual)
|
||||
norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_eps)
|
||||
block = Block(config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls,
|
||||
prenorm=False, resid_dropout=config.hidden_dropout_prob,
|
||||
prenorm=False, resid_dropout1=config.hidden_dropout_prob,
|
||||
resid_dropout2=config.hidden_dropout_prob,
|
||||
fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False),
|
||||
return_residual=return_residual)
|
||||
return block
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# Copyright (c) 2022, Tri Dao.
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
|
||||
import logging
|
||||
import math
|
||||
@ -23,6 +23,7 @@ from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
|
||||
from flash_attn.utils.distributed import sync_shared_params, all_gather_raw
|
||||
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
||||
from flash_attn.utils.generation import GenerationMixin
|
||||
from flash_attn.models.opt import remap_state_dict_opt
|
||||
|
||||
try:
|
||||
from flash_attn.ops.fused_dense import ColumnParallelLinear
|
||||
@ -88,9 +89,12 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
|
||||
if process_group is not None:
|
||||
assert fused_dense_gelu_dense, 'Tensor Parallel is only implemented for FusedDenseGeluDense'
|
||||
if not fused_dense_gelu_dense and not fused_dense_sqrelu_dense:
|
||||
approximate = 'tanh' if config.activation_function in ['gelu_new', 'gelu_fast'] else 'none'
|
||||
mlp_cls = partial(Mlp, hidden_features=inner_dim,
|
||||
activation=partial(F.gelu, approximate=approximate), **factory_kwargs)
|
||||
if config.activation_function == 'relu':
|
||||
activation = partial(F.relu, inplace=True)
|
||||
else:
|
||||
approximate = 'tanh' if config.activation_function in ['gelu_new', 'gelu_fast'] else 'none'
|
||||
activation=partial(F.gelu, approximate=approximate)
|
||||
mlp_cls = partial(Mlp, hidden_features=inner_dim, activation=activation, **factory_kwargs)
|
||||
else:
|
||||
mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0)
|
||||
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
|
||||
@ -121,9 +125,14 @@ def create_block(config, layer_idx=None, process_group=None, device=None, dtype=
|
||||
mixer_cls = create_mixer_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
|
||||
mlp_cls = create_mlp_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
|
||||
norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_epsilon, **factory_kwargs)
|
||||
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
|
||||
residual_in_fp32 = getattr(config, 'residual_in_fp32', False)
|
||||
resid_dropout1 = config.resid_pdrop if layer_idx is None or layer_idx > 0 else config.embd_pdrop
|
||||
prenorm = getattr(config, 'prenorm', True)
|
||||
block = Block(config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls,
|
||||
prenorm=True, resid_dropout=config.resid_pdrop,
|
||||
prenorm=prenorm, resid_dropout1=resid_dropout1, resid_dropout2=config.resid_pdrop,
|
||||
fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False),
|
||||
residual_in_fp32=residual_in_fp32,
|
||||
sequence_parallel=sequence_parallel and process_group is not None,
|
||||
mark_shared_params=process_group is not None)
|
||||
block.layer_idx = layer_idx
|
||||
@ -154,11 +163,16 @@ class GPTPreTrainedModel(nn.Module):
|
||||
"""
|
||||
# Instantiate model.
|
||||
model = cls(config, *args, device=device, dtype=dtype, **kwargs)
|
||||
state_dict = remap_state_dict_gpt2(
|
||||
# If we're going to shard the model, then don't load fp32 weights to GPU.
|
||||
state_dict_from_pretrained(model_name, device=device if world_size == 1 else None,
|
||||
dtype=dtype), config
|
||||
# If we're going to shard the model, then don't load fp32 weights to GPU.
|
||||
state_dict = state_dict_from_pretrained(
|
||||
model_name, device=device if world_size == 1 else None, dtype=dtype
|
||||
)
|
||||
if model_name.startswith('gpt2'):
|
||||
state_dict = remap_state_dict_gpt2(state_dict, config)
|
||||
elif model_name.startswith('facebook/opt'):
|
||||
state_dict = remap_state_dict_opt(state_dict, config)
|
||||
else:
|
||||
raise NotImplementedError(f'Model {model_name} not supported')
|
||||
if world_size > 1:
|
||||
state_dict = shard_state_dict_tp(state_dict, config, world_size, rank)
|
||||
state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
|
||||
@ -166,6 +180,7 @@ class GPTPreTrainedModel(nn.Module):
|
||||
logger.info(load_return)
|
||||
return model
|
||||
|
||||
|
||||
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
|
||||
def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_residual=True):
|
||||
if isinstance(module, nn.Linear):
|
||||
@ -195,47 +210,53 @@ class GPTModel(GPTPreTrainedModel):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
self.process_group = process_group
|
||||
self.sequence_parallel = getattr(config, 'sequence_parallel', True)
|
||||
assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'sqrelu']
|
||||
assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'relu', 'sqrelu']
|
||||
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple)
|
||||
* pad_vocab_size_multiple)
|
||||
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
|
||||
self.residual_in_fp32 = getattr(config, 'residual_in_fp32', False)
|
||||
# These 2 options are for OPT-350m
|
||||
self.prenorm = getattr(config, 'prenorm', True)
|
||||
word_embed_proj_dim = getattr(config, 'word_embed_proj_dim', None)
|
||||
|
||||
if process_group is None:
|
||||
self.embeddings = GPT2Embeddings(config.hidden_size, vocab_size,
|
||||
config.max_position_embeddings, **factory_kwargs)
|
||||
self.embeddings = GPT2Embeddings(
|
||||
config.hidden_size, vocab_size, config.max_position_embeddings,
|
||||
word_embed_proj_dim=word_embed_proj_dim, **factory_kwargs
|
||||
)
|
||||
else:
|
||||
self.embeddings = ParallelGPT2Embeddings(
|
||||
config.hidden_size, vocab_size, config.max_position_embeddings,
|
||||
process_group=process_group, sequence_parallel=self.sequence_parallel,
|
||||
**factory_kwargs
|
||||
)
|
||||
self.emb_drop = nn.Dropout(config.embd_pdrop)
|
||||
|
||||
# We change the order of residual and layer norm:
|
||||
# We change the order of dropout, residual and layer norm:
|
||||
# Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
|
||||
# Attn / MLP -> Dropout -> Add -> LN, returning both the residual branch (output of Add) and
|
||||
# the main branch (output of LN). The model definition is unchanged, but the mapping of the
|
||||
# nn.LayerNorm weights are changed.
|
||||
# Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and
|
||||
# the main branch (output of MLP). The model definition is unchanged, but the mapping of the
|
||||
# nn.Dropout probabilities are changed.
|
||||
# This is for performance reason: we can fuse dropout + add + layer_norm.
|
||||
self.layers = nn.ModuleList([create_block(config, layer_idx=i, process_group=process_group,
|
||||
**factory_kwargs)
|
||||
for i in range(config.num_hidden_layers)])
|
||||
|
||||
self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
|
||||
if self.fused_dropout_add_ln and dropout_add_layer_norm is None:
|
||||
raise ImportError('dropout_add_layer_norm is not installed')
|
||||
# self.ln_0 is the first layer norm in the model, while self.ln_f (in the pretrained weight)
|
||||
# is the final layer norm.
|
||||
self.ln_0 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon,
|
||||
**factory_kwargs)
|
||||
if self.prenorm:
|
||||
self.drop_f = nn.Dropout(config.resid_pdrop)
|
||||
self.ln_f = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon,
|
||||
**factory_kwargs)
|
||||
if process_group is not None:
|
||||
for p in self.ln_0.parameters():
|
||||
for p in self.ln_f.parameters():
|
||||
# Mark the norm parameters as "shared_params" so that we sync their values at init.
|
||||
p._shared_params = True
|
||||
# Mark the norm params as "sequence_parallel" so we run all-reduce on their grads.
|
||||
if self.sequence_parallel:
|
||||
p._sequence_parallel = True
|
||||
|
||||
self.layers = nn.ModuleList([create_block(config, layer_idx=i, process_group=process_group,
|
||||
**factory_kwargs)
|
||||
for i in range(config.num_hidden_layers)])
|
||||
|
||||
self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
|
||||
initializer_range=config.initializer_range))
|
||||
self.tie_weights()
|
||||
@ -251,23 +272,28 @@ class GPTModel(GPTPreTrainedModel):
|
||||
embedding_kwargs = ({'combine_batch_seqlen_dim': True}
|
||||
if self.process_group is not None and self.sequence_parallel else {})
|
||||
hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs)
|
||||
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
|
||||
if not self.fused_dropout_add_ln:
|
||||
residual = self.emb_drop(hidden_states)
|
||||
hidden_states = self.ln_0(residual.to(dtype=self.ln_0.weight.dtype))
|
||||
residual = residual.float()
|
||||
else:
|
||||
hidden_states, residual = dropout_add_layer_norm(
|
||||
hidden_states, None, self.ln_0.weight, self.ln_0.bias,
|
||||
self.emb_drop.p if self.training else 0.0, self.ln_0.eps, prenorm=True,
|
||||
residual_in_fp32=True
|
||||
)
|
||||
residual = None
|
||||
mixer_kwargs = ({'seqlen': input_ids.shape[1]}
|
||||
if self.process_group is not None and self.sequence_parallel else {})
|
||||
if inference_params is not None:
|
||||
mixer_kwargs['inference_params'] = inference_params
|
||||
for layer in self.layers:
|
||||
hidden_states, residual = layer(hidden_states, residual, mixer_kwargs=mixer_kwargs)
|
||||
if self.prenorm:
|
||||
hidden_states, residual = layer(hidden_states, residual, mixer_kwargs=mixer_kwargs)
|
||||
else:
|
||||
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
||||
if self.prenorm:
|
||||
if not self.fused_dropout_add_ln:
|
||||
dropped = self.drop_f(hidden_states)
|
||||
residual = (dropped + residual) if residual is not None else dropped
|
||||
hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype))
|
||||
else:
|
||||
# Set prenorm=False here since we don't need to the residual
|
||||
hidden_states = dropout_add_layer_norm(
|
||||
hidden_states, residual, self.ln_f.weight, self.ln_f.bias,
|
||||
self.drop_f.p if self.training else 0.0, self.ln_f.eps, prenorm=False,
|
||||
residual_in_fp32=self.residual_in_fp32
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@ -281,13 +307,20 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
|
||||
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple)
|
||||
* pad_vocab_size_multiple)
|
||||
# This option is for OPT-350m
|
||||
word_embed_proj_dim = getattr(config, 'word_embed_proj_dim', None)
|
||||
embed_dim = config.n_embd if word_embed_proj_dim is None else word_embed_proj_dim
|
||||
if word_embed_proj_dim is not None:
|
||||
self.project_out = nn.Linear(config.n_embd, embed_dim, bias=False, **factory_kwargs)
|
||||
else:
|
||||
self.project_out = None
|
||||
if process_group is None:
|
||||
self.lm_head = nn.Linear(config.n_embd, vocab_size, bias=False, **factory_kwargs)
|
||||
self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False, **factory_kwargs)
|
||||
else:
|
||||
if ColumnParallelLinear is None:
|
||||
raise ImportError('fused_dense_lib is not installed')
|
||||
self.lm_head = ColumnParallelLinear(
|
||||
config.n_embd, vocab_size, process_group, bias=False,
|
||||
embed_dim, vocab_size, process_group, bias=False,
|
||||
sequence_parallel=getattr(config, 'sequence_parallel', True), **factory_kwargs
|
||||
)
|
||||
# Initialize weights and apply final processing
|
||||
@ -307,6 +340,8 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
|
||||
"""
|
||||
hidden_states = self.transformer(input_ids, position_ids=position_ids,
|
||||
inference_params=inference_params)
|
||||
if self.project_out is not None:
|
||||
hidden_states = self.project_out(hidden_states)
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
# During inference, we want the full logit for sampling
|
||||
if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None:
|
||||
@ -315,6 +350,32 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
|
||||
CausalLMOutput = namedtuple('CausalLMOutput', ['logits'])
|
||||
return CausalLMOutput(logits=lm_logits)
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
# Remapping from our checkpoints that used a different ordering of layers in the block
|
||||
# Previous: Attn / MLP -> Dropout -> Add -> LN
|
||||
# Current: Dropout -> Add -> LN -> Attn / MLP
|
||||
if 'transformer.ln_0.weight' in state_dict:
|
||||
n_layers = self.config.num_hidden_layers
|
||||
ln_weight = state_dict.pop(f'transformer.layers.{n_layers - 1}.norm2.weight')
|
||||
ln_bias = state_dict.pop(f'transformer.layers.{n_layers - 1}.norm2.bias')
|
||||
state_dict['transformer.ln_f.weight'] = ln_weight
|
||||
state_dict['transformer.ln_f.bias'] = ln_bias
|
||||
for l in reversed(range(n_layers)):
|
||||
ln_weight = state_dict.pop(f'transformer.layers.{l}.norm1.weight')
|
||||
ln_bias = state_dict.pop(f'transformer.layers.{l}.norm1.bias')
|
||||
state_dict[f'transformer.layers.{l}.norm2.weight'] = ln_weight
|
||||
state_dict[f'transformer.layers.{l}.norm2.bias'] = ln_bias
|
||||
if l > 0:
|
||||
ln_weight = state_dict.pop(f'transformer.layers.{l - 1}.norm2.weight')
|
||||
ln_bias = state_dict.pop(f'transformer.layers.{l - 1}.norm2.bias')
|
||||
state_dict[f'transformer.layers.{l}.norm1.weight'] = ln_weight
|
||||
state_dict[f'transformer.layers.{l}.norm1.bias'] = ln_bias
|
||||
ln_weight = state_dict.pop('transformer.ln_0.weight')
|
||||
ln_bias = state_dict.pop('transformer.ln_0.bias')
|
||||
state_dict[f'transformer.layers.0.norm1.weight'] = ln_weight
|
||||
state_dict[f'transformer.layers.0.norm1.bias'] = ln_bias
|
||||
return super().load_state_dict(state_dict, strict=strict)
|
||||
|
||||
|
||||
def remap_state_dict_gpt2(state_dict, config):
|
||||
# Word embedding and position embedding
|
||||
@ -331,22 +392,11 @@ def remap_state_dict_gpt2(state_dict, config):
|
||||
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
|
||||
|
||||
# LayerNorm
|
||||
ln_weight, ln_bias = state_dict.pop('ln_f.weight'), state_dict.pop('ln_f.bias')
|
||||
state_dict[f'transformer.layers.{config.num_hidden_layers - 1}.norm2.weight'] = ln_weight
|
||||
state_dict[f'transformer.layers.{config.num_hidden_layers - 1}.norm2.bias'] = ln_bias
|
||||
ln_weight, ln_bias = state_dict.pop('h.0.ln_1.weight'), state_dict.pop('h.0.ln_1.bias')
|
||||
state_dict['transformer.ln_0.weight'] = ln_weight
|
||||
state_dict['transformer.ln_0.bias'] = ln_bias
|
||||
for d in range(config.num_hidden_layers):
|
||||
ln_weight = state_dict.pop(f'h.{d}.ln_2.weight')
|
||||
ln_bias = state_dict.pop(f'h.{d}.ln_2.bias')
|
||||
state_dict[f'transformer.layers.{d}.norm1.weight'] = ln_weight
|
||||
state_dict[f'transformer.layers.{d}.norm1.bias'] = ln_bias
|
||||
if d > 0:
|
||||
ln_weight = state_dict.pop(f'h.{d}.ln_1.weight')
|
||||
ln_bias = state_dict.pop(f'h.{d}.ln_1.bias')
|
||||
state_dict[f'transformer.layers.{d - 1}.norm2.weight'] = ln_weight
|
||||
state_dict[f'transformer.layers.{d - 1}.norm2.bias'] = ln_bias
|
||||
def key_mapping_ln(key):
|
||||
key = re.sub(r'^ln_f.(weight|bias)', r'transformer.ln_f.\1', key)
|
||||
key = re.sub(r'^h.(\d+).ln_(1|2).(weight|bias)', r'transformer.layers.\1.norm\2.\3', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
||||
|
||||
# MLP
|
||||
for d in range(config.num_hidden_layers):
|
||||
|
||||
104
flash_attn/models/opt.py
Normal file
104
flash_attn/models/opt.py
Normal file
@ -0,0 +1,104 @@
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
|
||||
import math
|
||||
import re
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers import GPT2Config, OPTConfig
|
||||
|
||||
|
||||
def remap_state_dict_opt(state_dict, config):
|
||||
def key_mapping_model(key):
|
||||
key = re.sub(r'^model.decoder.', 'transformer.', key)
|
||||
# The OPT-350m model uses '^decoder' instead of '^model.decoder'
|
||||
key = re.sub(r'^decoder.', 'transformer.', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_model(k), v) for k, v in state_dict.items())
|
||||
# Word embedding and position embedding
|
||||
def key_mapping_emb(key):
|
||||
key = re.sub(r'^transformer.embed_tokens.', 'transformer.embeddings.word_embeddings.', key)
|
||||
# The OPT-350m model uses has project_in and project_out
|
||||
key = re.sub(r'^transformer.project_in.', 'transformer.embeddings.project_in.', key)
|
||||
key = re.sub(r'^transformer.project_out.', 'project_out.', key)
|
||||
key = re.sub(r'^transformer.embed_positions.',
|
||||
'transformer.embeddings.position_embeddings.', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
|
||||
# OPT uses the first 2 indices of pos_emb for padding tokens
|
||||
pos_embeddings = state_dict.pop('transformer.embeddings.position_embeddings.weight')
|
||||
state_dict['transformer.embeddings.position_embeddings.weight'] = pos_embeddings[2:]
|
||||
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
|
||||
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
||||
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
|
||||
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
|
||||
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
||||
)
|
||||
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
|
||||
|
||||
# LayerNorm
|
||||
def key_mapping_ln(key):
|
||||
key = re.sub(r'^transformer.final_layer_norm.', r'transformer.ln_f.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).self_attn_layer_norm.',
|
||||
r'transformer.layers.\1.norm1.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).final_layer_norm.',
|
||||
r'transformer.layers.\1.norm2.', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
||||
|
||||
# MLP
|
||||
def key_mapping_mlp(key):
|
||||
return re.sub(r'^transformer.layers.(\d+).fc(1|2).',
|
||||
r'transformer.layers.\1.mlp.fc\2.', key)
|
||||
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
||||
|
||||
# Attention
|
||||
for l in range(config.n_layer):
|
||||
Wq = state_dict.pop(f'transformer.layers.{l}.self_attn.q_proj.weight')
|
||||
Wk = state_dict.pop(f'transformer.layers.{l}.self_attn.k_proj.weight')
|
||||
Wv = state_dict.pop(f'transformer.layers.{l}.self_attn.v_proj.weight')
|
||||
bq = state_dict.pop(f'transformer.layers.{l}.self_attn.q_proj.bias')
|
||||
bk = state_dict.pop(f'transformer.layers.{l}.self_attn.k_proj.bias')
|
||||
bv = state_dict.pop(f'transformer.layers.{l}.self_attn.v_proj.bias')
|
||||
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat(
|
||||
[Wq, Wk, Wv], dim=0
|
||||
)
|
||||
state_dict[f'transformer.layers.{l}.mixer.Wqkv.bias'] = torch.cat(
|
||||
[bq, bk, bv], dim=0
|
||||
)
|
||||
def key_mapping_attn(key):
|
||||
return re.sub(r'^transformer.layers.(\d+).self_attn.out_proj.',
|
||||
r'transformer.layers.\1.mixer.out_proj.', key)
|
||||
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def opt_config_to_gpt2_config(opt_config: OPTConfig) -> GPT2Config:
|
||||
assert opt_config.layerdrop == 0.0
|
||||
assert opt_config.layer_norm_elementwise_affine
|
||||
word_embed_proj_dim = (None if opt_config.word_embed_proj_dim == opt_config.hidden_size
|
||||
else opt_config.word_embed_proj_dim)
|
||||
return GPT2Config(
|
||||
vocab_size=opt_config.vocab_size,
|
||||
n_positions=opt_config.max_position_embeddings,
|
||||
n_embd=opt_config.hidden_size,
|
||||
n_layer=opt_config.num_hidden_layers,
|
||||
n_head=opt_config.num_attention_heads,
|
||||
n_inner=opt_config.ffn_dim,
|
||||
activation_function=opt_config.activation_function,
|
||||
resid_pdrop=opt_config.dropout,
|
||||
# HF's implementation of OPT doesn't seem to have embedding dropout
|
||||
embd_pdrop=opt_config.dropout,
|
||||
attn_pdrop=opt_config.attention_dropout,
|
||||
initializer_range=opt_config.init_std,
|
||||
bos_token_id=opt_config.bos_token_id,
|
||||
eos_token_id=opt_config.eos_token_id,
|
||||
# These are new arguments not in the original GPT2Config
|
||||
prenorm=opt_config.do_layer_norm_before,
|
||||
word_embed_proj_dim=word_embed_proj_dim
|
||||
)
|
||||
@ -22,10 +22,22 @@ except ImportError:
|
||||
class Block(nn.Module):
|
||||
|
||||
def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm,
|
||||
dropout_cls=nn.Dropout, prenorm=True, resid_dropout=0., drop_path=0.,
|
||||
fused_dropout_add_ln=False, return_residual=False, sequence_parallel=False,
|
||||
mark_shared_params=False):
|
||||
dropout_cls=nn.Dropout, prenorm=True, resid_dropout1=0., resid_dropout2=0.,
|
||||
drop_path1=0., drop_path2=0., fused_dropout_add_ln=False, return_residual=False,
|
||||
residual_in_fp32=False, sequence_parallel=False, mark_shared_params=False):
|
||||
"""
|
||||
For prenorm=True, this Block has a slightly different structure compared to a regular
|
||||
prenorm Transformer block.
|
||||
The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
|
||||
[Ref: https://arxiv.org/abs/2002.04745]
|
||||
Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
|
||||
the hidden_states (output of the MLP) and the residual.
|
||||
This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
|
||||
The residual needs to be provided (except for the very first block).
|
||||
|
||||
For prenorm=False, this Block has the same structure as a regular postnorm Transformer
|
||||
block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.
|
||||
|
||||
return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
|
||||
This is for performance reason: for post-norm architecture, returning the input allows us
|
||||
to fuse the backward of nn.Linear with the residual connection.
|
||||
@ -34,18 +46,21 @@ class Block(nn.Module):
|
||||
self.prenorm = prenorm
|
||||
self.fused_dropout_add_ln = fused_dropout_add_ln
|
||||
self.return_residual = return_residual
|
||||
self.residual_in_fp32 = residual_in_fp32
|
||||
if self.residual_in_fp32:
|
||||
assert self.prenorm, 'residual_in_fp32 is only compatible with prenorm=True'
|
||||
if mixer_cls is None:
|
||||
mixer_cls = partial(MHA, num_heads=dim // 64)
|
||||
if mlp_cls is None:
|
||||
mlp_cls = partial(Mlp, hidden_features=4 * dim)
|
||||
self.mixer = mixer_cls(dim)
|
||||
self.dropout1 = dropout_cls(resid_dropout)
|
||||
self.drop_path1 = StochasticDepth(drop_path, mode='row')
|
||||
self.dropout1 = dropout_cls(resid_dropout1)
|
||||
self.drop_path1 = StochasticDepth(drop_path1, mode='row')
|
||||
self.norm1 = norm_cls(dim)
|
||||
self.mlp = mlp_cls(dim)
|
||||
if not isinstance(self.mlp, nn.Identity):
|
||||
self.dropout2 = dropout_cls(resid_dropout)
|
||||
self.drop_path2 = StochasticDepth(drop_path, mode='row')
|
||||
self.dropout2 = dropout_cls(resid_dropout2)
|
||||
self.drop_path2 = StochasticDepth(drop_path2, mode='row')
|
||||
self.norm2 = norm_cls(dim)
|
||||
|
||||
if self.fused_dropout_add_ln:
|
||||
@ -82,41 +97,48 @@ class Block(nn.Module):
|
||||
residual: if postnorm, residual=None, If prenorm, hidden_states = LayerNorm(residual)
|
||||
"""
|
||||
if self.prenorm:
|
||||
assert residual is not None
|
||||
mixer_out = self.mixer(hidden_states,
|
||||
**(mixer_kwargs if mixer_kwargs is not None else {}))
|
||||
if not self.fused_dropout_add_ln:
|
||||
residual = self.drop_path1(self.dropout1(mixer_out)) + residual
|
||||
dropped = self.drop_path1(self.dropout1(hidden_states))
|
||||
residual = (dropped + residual) if residual is not None else dropped
|
||||
hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
|
||||
if self.residual_in_fp32:
|
||||
residual = residual.to(torch.float32)
|
||||
else:
|
||||
if self.drop_path1.p == 0 or not self.training:
|
||||
rowscale1 = None
|
||||
else:
|
||||
rowscale1 = self.drop_path1(torch.ones(
|
||||
mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype)
|
||||
hidden_states.shape[:-1], device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
)
|
||||
hidden_states, residual = dropout_add_layer_norm(
|
||||
mixer_out, residual, self.norm1.weight, self.norm1.bias,
|
||||
hidden_states, residual, self.norm1.weight, self.norm1.bias,
|
||||
self.dropout1.p if self.training else 0.0, self.norm1.eps,
|
||||
rowscale=rowscale1, prenorm=True
|
||||
rowscale=rowscale1, prenorm=True, residual_in_fp32=self.residual_in_fp32
|
||||
)
|
||||
hidden_states = self.mixer(hidden_states,
|
||||
**(mixer_kwargs if mixer_kwargs is not None else {}))
|
||||
if not isinstance(self.mlp, nn.Identity):
|
||||
mlp_out = self.mlp(hidden_states)
|
||||
if not self.fused_dropout_add_ln:
|
||||
residual = self.drop_path2(self.dropout2(mlp_out)) + residual
|
||||
dropped = self.drop_path2(self.dropout2(hidden_states))
|
||||
residual = (dropped + residual) if residual is not None else dropped
|
||||
hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
|
||||
if self.residual_in_fp32:
|
||||
residual = residual.to(torch.float32)
|
||||
else:
|
||||
if self.drop_path2.p == 0 or not self.training:
|
||||
rowscale2 = None
|
||||
else:
|
||||
rowscale2 = self.drop_path2(torch.ones(
|
||||
mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype)
|
||||
hidden_states.shape[:-1], device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
)
|
||||
hidden_states, residual = dropout_add_layer_norm(
|
||||
mlp_out, residual, self.norm2.weight, self.norm2.bias,
|
||||
hidden_states, residual, self.norm2.weight, self.norm2.bias,
|
||||
self.dropout2.p if self.training else 0.0, self.norm2.eps,
|
||||
rowscale=rowscale2, prenorm=True
|
||||
rowscale=rowscale2, prenorm=True, residual_in_fp32=self.residual_in_fp32
|
||||
)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
else:
|
||||
assert residual is None
|
||||
|
||||
@ -12,14 +12,23 @@ from flash_attn.utils.distributed import reduce_scatter, all_reduce
|
||||
class GPT2Embeddings(nn.Module):
|
||||
|
||||
def __init__(self, embed_dim, vocab_size, max_position_embeddings, padding_idx=None,
|
||||
device=None, dtype=None):
|
||||
word_embed_proj_dim=None, device=None, dtype=None):
|
||||
"""
|
||||
If max_position_embeddings <= 0, there's no position embeddings
|
||||
If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension
|
||||
the project up to embed_dim
|
||||
"""
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx,
|
||||
**factory_kwargs)
|
||||
if word_embed_proj_dim is None:
|
||||
self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx,
|
||||
**factory_kwargs)
|
||||
self.project_in = None
|
||||
else:
|
||||
self.word_embeddings = nn.Embedding(vocab_size, word_embed_proj_dim,
|
||||
padding_idx=padding_idx, **factory_kwargs)
|
||||
self.project_in = nn.Linear(word_embed_proj_dim, embed_dim, bias=False,
|
||||
**factory_kwargs)
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
if self.max_position_embeddings > 0:
|
||||
self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim,
|
||||
@ -32,6 +41,8 @@ class GPT2Embeddings(nn.Module):
|
||||
"""
|
||||
batch_size, seqlen = input_ids.shape
|
||||
embeddings = self.word_embeddings(input_ids)
|
||||
if self.project_in is not None:
|
||||
embeddings = self.project_in(embeddings)
|
||||
if self.max_position_embeddings > 0:
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
|
||||
|
||||
@ -84,6 +84,7 @@ def test_gpt2_optimized(model_name):
|
||||
config.fused_bias_fc = True
|
||||
config.fused_dense_gelu_dense = True
|
||||
config.fused_dropout_add_ln = True
|
||||
config.residual_in_fp32 = True
|
||||
config.pad_vocab_size_multiple = 8
|
||||
|
||||
model = GPTLMHeadModel.from_pretrained(model_name, config)
|
||||
|
||||
77
tests/models/test_opt.py
Normal file
77
tests/models/test_opt.py
Normal file
@ -0,0 +1,77 @@
|
||||
import re
|
||||
|
||||
import torch
|
||||
import pytest
|
||||
|
||||
from transformers import OPTConfig
|
||||
from transformers.models.opt.modeling_opt import OPTForCausalLM
|
||||
|
||||
from flash_attn.models.gpt import GPTLMHeadModel
|
||||
from flash_attn.models.opt import remap_state_dict_opt, opt_config_to_gpt2_config
|
||||
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
||||
|
||||
|
||||
@pytest.mark.parametrize('model_name', ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b"])
|
||||
# @pytest.mark.parametrize('model_name', ["facebook/opt-350m"])
|
||||
def test_opt_state_dict(model_name):
|
||||
config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name))
|
||||
pretrained_state_dict = remap_state_dict_opt(state_dict_from_pretrained(model_name), config)
|
||||
model = GPTLMHeadModel(config)
|
||||
state_dict = model.state_dict()
|
||||
assert state_dict.keys() == pretrained_state_dict.keys()
|
||||
for k in state_dict.keys():
|
||||
assert state_dict[k].shape == pretrained_state_dict[k].shape
|
||||
|
||||
|
||||
@pytest.mark.parametrize('model_name', ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b"])
|
||||
# @pytest.mark.parametrize('model_name', ["facebook/opt-350m"])
|
||||
def test_opt_optimized(model_name):
|
||||
"""Check that our implementation of OPT (without any optimizations enabled) matches the
|
||||
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
|
||||
forward pass in fp16, when compared to the HF forward pass in fp32.
|
||||
"""
|
||||
dtype = torch.float16
|
||||
device = 'cuda'
|
||||
config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name))
|
||||
config.use_flash_attn = True
|
||||
config.fused_bias_fc = True
|
||||
config.fused_dropout_add_ln = True
|
||||
# Only prenorm supports residual_in_fp32
|
||||
config.residual_in_fp32 = getattr(config, 'prenorm', True)
|
||||
config.pad_vocab_size_multiple = 8
|
||||
|
||||
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
|
||||
|
||||
model_ref = OPTForCausalLM.from_pretrained(model_name).to(device=device)
|
||||
model_hf = OPTForCausalLM.from_pretrained(model_name, torch_dtype=dtype).to(device=device)
|
||||
|
||||
model.eval()
|
||||
model_ref.eval()
|
||||
model_hf.eval()
|
||||
|
||||
torch.manual_seed(0)
|
||||
batch_size = 2
|
||||
max_seqlen = 256
|
||||
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda')
|
||||
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
|
||||
device='cuda')
|
||||
if model_name != 'facebook/opt-350m': # The OPT-350m projects the embeddings to dimension 512
|
||||
out = model.transformer(input_ids)
|
||||
out_hf = model_hf.model(input_ids).last_hidden_state
|
||||
out_ref = model_ref.model(input_ids).last_hidden_state
|
||||
|
||||
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
|
||||
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
|
||||
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}')
|
||||
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}')
|
||||
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
|
||||
|
||||
logits = model(input_ids).logits
|
||||
logits_hf = model_hf(input_ids).logits
|
||||
logits_ref = model_ref(input_ids).logits
|
||||
|
||||
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
|
||||
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}')
|
||||
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}')
|
||||
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}')
|
||||
assert (logits - logits_ref).abs().max().item() < 3 * (logits_hf - logits_ref).abs().max().item()
|
||||
Loading…
Reference in New Issue
Block a user