Reorder LN in Block, support OPT

This commit is contained in:
Tri Dao 2023-01-15 22:14:31 -08:00
parent f1e01c27ba
commit ff34123bd4
7 changed files with 345 additions and 79 deletions

View File

@ -94,7 +94,8 @@ def create_block(config, layer_idx=None):
mlp_cls = create_mlp_cls(config, layer_idx, return_residual=return_residual)
norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_eps)
block = Block(config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls,
prenorm=False, resid_dropout=config.hidden_dropout_prob,
prenorm=False, resid_dropout1=config.hidden_dropout_prob,
resid_dropout2=config.hidden_dropout_prob,
fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False),
return_residual=return_residual)
return block

View File

@ -1,4 +1,4 @@
# Copyright (c) 2022, Tri Dao.
# Copyright (c) 2023, Tri Dao.
import logging
import math
@ -23,6 +23,7 @@ from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
from flash_attn.utils.distributed import sync_shared_params, all_gather_raw
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.generation import GenerationMixin
from flash_attn.models.opt import remap_state_dict_opt
try:
from flash_attn.ops.fused_dense import ColumnParallelLinear
@ -88,9 +89,12 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
if process_group is not None:
assert fused_dense_gelu_dense, 'Tensor Parallel is only implemented for FusedDenseGeluDense'
if not fused_dense_gelu_dense and not fused_dense_sqrelu_dense:
approximate = 'tanh' if config.activation_function in ['gelu_new', 'gelu_fast'] else 'none'
mlp_cls = partial(Mlp, hidden_features=inner_dim,
activation=partial(F.gelu, approximate=approximate), **factory_kwargs)
if config.activation_function == 'relu':
activation = partial(F.relu, inplace=True)
else:
approximate = 'tanh' if config.activation_function in ['gelu_new', 'gelu_fast'] else 'none'
activation=partial(F.gelu, approximate=approximate)
mlp_cls = partial(Mlp, hidden_features=inner_dim, activation=activation, **factory_kwargs)
else:
mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0)
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
@ -121,9 +125,14 @@ def create_block(config, layer_idx=None, process_group=None, device=None, dtype=
mixer_cls = create_mixer_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
mlp_cls = create_mlp_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_epsilon, **factory_kwargs)
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
residual_in_fp32 = getattr(config, 'residual_in_fp32', False)
resid_dropout1 = config.resid_pdrop if layer_idx is None or layer_idx > 0 else config.embd_pdrop
prenorm = getattr(config, 'prenorm', True)
block = Block(config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls,
prenorm=True, resid_dropout=config.resid_pdrop,
prenorm=prenorm, resid_dropout1=resid_dropout1, resid_dropout2=config.resid_pdrop,
fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False),
residual_in_fp32=residual_in_fp32,
sequence_parallel=sequence_parallel and process_group is not None,
mark_shared_params=process_group is not None)
block.layer_idx = layer_idx
@ -154,11 +163,16 @@ class GPTPreTrainedModel(nn.Module):
"""
# Instantiate model.
model = cls(config, *args, device=device, dtype=dtype, **kwargs)
state_dict = remap_state_dict_gpt2(
# If we're going to shard the model, then don't load fp32 weights to GPU.
state_dict_from_pretrained(model_name, device=device if world_size == 1 else None,
dtype=dtype), config
# If we're going to shard the model, then don't load fp32 weights to GPU.
state_dict = state_dict_from_pretrained(
model_name, device=device if world_size == 1 else None, dtype=dtype
)
if model_name.startswith('gpt2'):
state_dict = remap_state_dict_gpt2(state_dict, config)
elif model_name.startswith('facebook/opt'):
state_dict = remap_state_dict_opt(state_dict, config)
else:
raise NotImplementedError(f'Model {model_name} not supported')
if world_size > 1:
state_dict = shard_state_dict_tp(state_dict, config, world_size, rank)
state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
@ -166,6 +180,7 @@ class GPTPreTrainedModel(nn.Module):
logger.info(load_return)
return model
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_residual=True):
if isinstance(module, nn.Linear):
@ -195,47 +210,53 @@ class GPTModel(GPTPreTrainedModel):
factory_kwargs = {'device': device, 'dtype': dtype}
self.process_group = process_group
self.sequence_parallel = getattr(config, 'sequence_parallel', True)
assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'sqrelu']
assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'relu', 'sqrelu']
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple)
* pad_vocab_size_multiple)
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
self.residual_in_fp32 = getattr(config, 'residual_in_fp32', False)
# These 2 options are for OPT-350m
self.prenorm = getattr(config, 'prenorm', True)
word_embed_proj_dim = getattr(config, 'word_embed_proj_dim', None)
if process_group is None:
self.embeddings = GPT2Embeddings(config.hidden_size, vocab_size,
config.max_position_embeddings, **factory_kwargs)
self.embeddings = GPT2Embeddings(
config.hidden_size, vocab_size, config.max_position_embeddings,
word_embed_proj_dim=word_embed_proj_dim, **factory_kwargs
)
else:
self.embeddings = ParallelGPT2Embeddings(
config.hidden_size, vocab_size, config.max_position_embeddings,
process_group=process_group, sequence_parallel=self.sequence_parallel,
**factory_kwargs
)
self.emb_drop = nn.Dropout(config.embd_pdrop)
# We change the order of residual and layer norm:
# We change the order of dropout, residual and layer norm:
# Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
# Attn / MLP -> Dropout -> Add -> LN, returning both the residual branch (output of Add) and
# the main branch (output of LN). The model definition is unchanged, but the mapping of the
# nn.LayerNorm weights are changed.
# Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and
# the main branch (output of MLP). The model definition is unchanged, but the mapping of the
# nn.Dropout probabilities are changed.
# This is for performance reason: we can fuse dropout + add + layer_norm.
self.layers = nn.ModuleList([create_block(config, layer_idx=i, process_group=process_group,
**factory_kwargs)
for i in range(config.num_hidden_layers)])
self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
if self.fused_dropout_add_ln and dropout_add_layer_norm is None:
raise ImportError('dropout_add_layer_norm is not installed')
# self.ln_0 is the first layer norm in the model, while self.ln_f (in the pretrained weight)
# is the final layer norm.
self.ln_0 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon,
**factory_kwargs)
if self.prenorm:
self.drop_f = nn.Dropout(config.resid_pdrop)
self.ln_f = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon,
**factory_kwargs)
if process_group is not None:
for p in self.ln_0.parameters():
for p in self.ln_f.parameters():
# Mark the norm parameters as "shared_params" so that we sync their values at init.
p._shared_params = True
# Mark the norm params as "sequence_parallel" so we run all-reduce on their grads.
if self.sequence_parallel:
p._sequence_parallel = True
self.layers = nn.ModuleList([create_block(config, layer_idx=i, process_group=process_group,
**factory_kwargs)
for i in range(config.num_hidden_layers)])
self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
initializer_range=config.initializer_range))
self.tie_weights()
@ -251,23 +272,28 @@ class GPTModel(GPTPreTrainedModel):
embedding_kwargs = ({'combine_batch_seqlen_dim': True}
if self.process_group is not None and self.sequence_parallel else {})
hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs)
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
if not self.fused_dropout_add_ln:
residual = self.emb_drop(hidden_states)
hidden_states = self.ln_0(residual.to(dtype=self.ln_0.weight.dtype))
residual = residual.float()
else:
hidden_states, residual = dropout_add_layer_norm(
hidden_states, None, self.ln_0.weight, self.ln_0.bias,
self.emb_drop.p if self.training else 0.0, self.ln_0.eps, prenorm=True,
residual_in_fp32=True
)
residual = None
mixer_kwargs = ({'seqlen': input_ids.shape[1]}
if self.process_group is not None and self.sequence_parallel else {})
if inference_params is not None:
mixer_kwargs['inference_params'] = inference_params
for layer in self.layers:
hidden_states, residual = layer(hidden_states, residual, mixer_kwargs=mixer_kwargs)
if self.prenorm:
hidden_states, residual = layer(hidden_states, residual, mixer_kwargs=mixer_kwargs)
else:
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
if self.prenorm:
if not self.fused_dropout_add_ln:
dropped = self.drop_f(hidden_states)
residual = (dropped + residual) if residual is not None else dropped
hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype))
else:
# Set prenorm=False here since we don't need to the residual
hidden_states = dropout_add_layer_norm(
hidden_states, residual, self.ln_f.weight, self.ln_f.bias,
self.drop_f.p if self.training else 0.0, self.ln_f.eps, prenorm=False,
residual_in_fp32=self.residual_in_fp32
)
return hidden_states
@ -281,13 +307,20 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple)
* pad_vocab_size_multiple)
# This option is for OPT-350m
word_embed_proj_dim = getattr(config, 'word_embed_proj_dim', None)
embed_dim = config.n_embd if word_embed_proj_dim is None else word_embed_proj_dim
if word_embed_proj_dim is not None:
self.project_out = nn.Linear(config.n_embd, embed_dim, bias=False, **factory_kwargs)
else:
self.project_out = None
if process_group is None:
self.lm_head = nn.Linear(config.n_embd, vocab_size, bias=False, **factory_kwargs)
self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False, **factory_kwargs)
else:
if ColumnParallelLinear is None:
raise ImportError('fused_dense_lib is not installed')
self.lm_head = ColumnParallelLinear(
config.n_embd, vocab_size, process_group, bias=False,
embed_dim, vocab_size, process_group, bias=False,
sequence_parallel=getattr(config, 'sequence_parallel', True), **factory_kwargs
)
# Initialize weights and apply final processing
@ -307,6 +340,8 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
"""
hidden_states = self.transformer(input_ids, position_ids=position_ids,
inference_params=inference_params)
if self.project_out is not None:
hidden_states = self.project_out(hidden_states)
lm_logits = self.lm_head(hidden_states)
# During inference, we want the full logit for sampling
if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None:
@ -315,6 +350,32 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
CausalLMOutput = namedtuple('CausalLMOutput', ['logits'])
return CausalLMOutput(logits=lm_logits)
def load_state_dict(self, state_dict, strict=True):
# Remapping from our checkpoints that used a different ordering of layers in the block
# Previous: Attn / MLP -> Dropout -> Add -> LN
# Current: Dropout -> Add -> LN -> Attn / MLP
if 'transformer.ln_0.weight' in state_dict:
n_layers = self.config.num_hidden_layers
ln_weight = state_dict.pop(f'transformer.layers.{n_layers - 1}.norm2.weight')
ln_bias = state_dict.pop(f'transformer.layers.{n_layers - 1}.norm2.bias')
state_dict['transformer.ln_f.weight'] = ln_weight
state_dict['transformer.ln_f.bias'] = ln_bias
for l in reversed(range(n_layers)):
ln_weight = state_dict.pop(f'transformer.layers.{l}.norm1.weight')
ln_bias = state_dict.pop(f'transformer.layers.{l}.norm1.bias')
state_dict[f'transformer.layers.{l}.norm2.weight'] = ln_weight
state_dict[f'transformer.layers.{l}.norm2.bias'] = ln_bias
if l > 0:
ln_weight = state_dict.pop(f'transformer.layers.{l - 1}.norm2.weight')
ln_bias = state_dict.pop(f'transformer.layers.{l - 1}.norm2.bias')
state_dict[f'transformer.layers.{l}.norm1.weight'] = ln_weight
state_dict[f'transformer.layers.{l}.norm1.bias'] = ln_bias
ln_weight = state_dict.pop('transformer.ln_0.weight')
ln_bias = state_dict.pop('transformer.ln_0.bias')
state_dict[f'transformer.layers.0.norm1.weight'] = ln_weight
state_dict[f'transformer.layers.0.norm1.bias'] = ln_bias
return super().load_state_dict(state_dict, strict=strict)
def remap_state_dict_gpt2(state_dict, config):
# Word embedding and position embedding
@ -331,22 +392,11 @@ def remap_state_dict_gpt2(state_dict, config):
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
# LayerNorm
ln_weight, ln_bias = state_dict.pop('ln_f.weight'), state_dict.pop('ln_f.bias')
state_dict[f'transformer.layers.{config.num_hidden_layers - 1}.norm2.weight'] = ln_weight
state_dict[f'transformer.layers.{config.num_hidden_layers - 1}.norm2.bias'] = ln_bias
ln_weight, ln_bias = state_dict.pop('h.0.ln_1.weight'), state_dict.pop('h.0.ln_1.bias')
state_dict['transformer.ln_0.weight'] = ln_weight
state_dict['transformer.ln_0.bias'] = ln_bias
for d in range(config.num_hidden_layers):
ln_weight = state_dict.pop(f'h.{d}.ln_2.weight')
ln_bias = state_dict.pop(f'h.{d}.ln_2.bias')
state_dict[f'transformer.layers.{d}.norm1.weight'] = ln_weight
state_dict[f'transformer.layers.{d}.norm1.bias'] = ln_bias
if d > 0:
ln_weight = state_dict.pop(f'h.{d}.ln_1.weight')
ln_bias = state_dict.pop(f'h.{d}.ln_1.bias')
state_dict[f'transformer.layers.{d - 1}.norm2.weight'] = ln_weight
state_dict[f'transformer.layers.{d - 1}.norm2.bias'] = ln_bias
def key_mapping_ln(key):
key = re.sub(r'^ln_f.(weight|bias)', r'transformer.ln_f.\1', key)
key = re.sub(r'^h.(\d+).ln_(1|2).(weight|bias)', r'transformer.layers.\1.norm\2.\3', key)
return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP
for d in range(config.num_hidden_layers):

104
flash_attn/models/opt.py Normal file
View File

@ -0,0 +1,104 @@
# Copyright (c) 2023, Tri Dao.
import math
import re
from collections import OrderedDict
import torch
import torch.nn.functional as F
from transformers import GPT2Config, OPTConfig
def remap_state_dict_opt(state_dict, config):
def key_mapping_model(key):
key = re.sub(r'^model.decoder.', 'transformer.', key)
# The OPT-350m model uses '^decoder' instead of '^model.decoder'
key = re.sub(r'^decoder.', 'transformer.', key)
return key
state_dict = OrderedDict((key_mapping_model(k), v) for k, v in state_dict.items())
# Word embedding and position embedding
def key_mapping_emb(key):
key = re.sub(r'^transformer.embed_tokens.', 'transformer.embeddings.word_embeddings.', key)
# The OPT-350m model uses has project_in and project_out
key = re.sub(r'^transformer.project_in.', 'transformer.embeddings.project_in.', key)
key = re.sub(r'^transformer.project_out.', 'project_out.', key)
key = re.sub(r'^transformer.embed_positions.',
'transformer.embeddings.position_embeddings.', key)
return key
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
# OPT uses the first 2 indices of pos_emb for padding tokens
pos_embeddings = state_dict.pop('transformer.embeddings.position_embeddings.weight')
state_dict['transformer.embeddings.position_embeddings.weight'] = pos_embeddings[2:]
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
# LayerNorm
def key_mapping_ln(key):
key = re.sub(r'^transformer.final_layer_norm.', r'transformer.ln_f.', key)
key = re.sub(r'^transformer.layers.(\d+).self_attn_layer_norm.',
r'transformer.layers.\1.norm1.', key)
key = re.sub(r'^transformer.layers.(\d+).final_layer_norm.',
r'transformer.layers.\1.norm2.', key)
return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP
def key_mapping_mlp(key):
return re.sub(r'^transformer.layers.(\d+).fc(1|2).',
r'transformer.layers.\1.mlp.fc\2.', key)
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# Attention
for l in range(config.n_layer):
Wq = state_dict.pop(f'transformer.layers.{l}.self_attn.q_proj.weight')
Wk = state_dict.pop(f'transformer.layers.{l}.self_attn.k_proj.weight')
Wv = state_dict.pop(f'transformer.layers.{l}.self_attn.v_proj.weight')
bq = state_dict.pop(f'transformer.layers.{l}.self_attn.q_proj.bias')
bk = state_dict.pop(f'transformer.layers.{l}.self_attn.k_proj.bias')
bv = state_dict.pop(f'transformer.layers.{l}.self_attn.v_proj.bias')
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat(
[Wq, Wk, Wv], dim=0
)
state_dict[f'transformer.layers.{l}.mixer.Wqkv.bias'] = torch.cat(
[bq, bk, bv], dim=0
)
def key_mapping_attn(key):
return re.sub(r'^transformer.layers.(\d+).self_attn.out_proj.',
r'transformer.layers.\1.mixer.out_proj.', key)
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
return state_dict
def opt_config_to_gpt2_config(opt_config: OPTConfig) -> GPT2Config:
assert opt_config.layerdrop == 0.0
assert opt_config.layer_norm_elementwise_affine
word_embed_proj_dim = (None if opt_config.word_embed_proj_dim == opt_config.hidden_size
else opt_config.word_embed_proj_dim)
return GPT2Config(
vocab_size=opt_config.vocab_size,
n_positions=opt_config.max_position_embeddings,
n_embd=opt_config.hidden_size,
n_layer=opt_config.num_hidden_layers,
n_head=opt_config.num_attention_heads,
n_inner=opt_config.ffn_dim,
activation_function=opt_config.activation_function,
resid_pdrop=opt_config.dropout,
# HF's implementation of OPT doesn't seem to have embedding dropout
embd_pdrop=opt_config.dropout,
attn_pdrop=opt_config.attention_dropout,
initializer_range=opt_config.init_std,
bos_token_id=opt_config.bos_token_id,
eos_token_id=opt_config.eos_token_id,
# These are new arguments not in the original GPT2Config
prenorm=opt_config.do_layer_norm_before,
word_embed_proj_dim=word_embed_proj_dim
)

View File

@ -22,10 +22,22 @@ except ImportError:
class Block(nn.Module):
def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm,
dropout_cls=nn.Dropout, prenorm=True, resid_dropout=0., drop_path=0.,
fused_dropout_add_ln=False, return_residual=False, sequence_parallel=False,
mark_shared_params=False):
dropout_cls=nn.Dropout, prenorm=True, resid_dropout1=0., resid_dropout2=0.,
drop_path1=0., drop_path2=0., fused_dropout_add_ln=False, return_residual=False,
residual_in_fp32=False, sequence_parallel=False, mark_shared_params=False):
"""
For prenorm=True, this Block has a slightly different structure compared to a regular
prenorm Transformer block.
The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
[Ref: https://arxiv.org/abs/2002.04745]
Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
the hidden_states (output of the MLP) and the residual.
This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
The residual needs to be provided (except for the very first block).
For prenorm=False, this Block has the same structure as a regular postnorm Transformer
block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.
return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
This is for performance reason: for post-norm architecture, returning the input allows us
to fuse the backward of nn.Linear with the residual connection.
@ -34,18 +46,21 @@ class Block(nn.Module):
self.prenorm = prenorm
self.fused_dropout_add_ln = fused_dropout_add_ln
self.return_residual = return_residual
self.residual_in_fp32 = residual_in_fp32
if self.residual_in_fp32:
assert self.prenorm, 'residual_in_fp32 is only compatible with prenorm=True'
if mixer_cls is None:
mixer_cls = partial(MHA, num_heads=dim // 64)
if mlp_cls is None:
mlp_cls = partial(Mlp, hidden_features=4 * dim)
self.mixer = mixer_cls(dim)
self.dropout1 = dropout_cls(resid_dropout)
self.drop_path1 = StochasticDepth(drop_path, mode='row')
self.dropout1 = dropout_cls(resid_dropout1)
self.drop_path1 = StochasticDepth(drop_path1, mode='row')
self.norm1 = norm_cls(dim)
self.mlp = mlp_cls(dim)
if not isinstance(self.mlp, nn.Identity):
self.dropout2 = dropout_cls(resid_dropout)
self.drop_path2 = StochasticDepth(drop_path, mode='row')
self.dropout2 = dropout_cls(resid_dropout2)
self.drop_path2 = StochasticDepth(drop_path2, mode='row')
self.norm2 = norm_cls(dim)
if self.fused_dropout_add_ln:
@ -82,41 +97,48 @@ class Block(nn.Module):
residual: if postnorm, residual=None, If prenorm, hidden_states = LayerNorm(residual)
"""
if self.prenorm:
assert residual is not None
mixer_out = self.mixer(hidden_states,
**(mixer_kwargs if mixer_kwargs is not None else {}))
if not self.fused_dropout_add_ln:
residual = self.drop_path1(self.dropout1(mixer_out)) + residual
dropped = self.drop_path1(self.dropout1(hidden_states))
residual = (dropped + residual) if residual is not None else dropped
hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else:
if self.drop_path1.p == 0 or not self.training:
rowscale1 = None
else:
rowscale1 = self.drop_path1(torch.ones(
mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype)
hidden_states.shape[:-1], device=hidden_states.device,
dtype=hidden_states.dtype)
)
hidden_states, residual = dropout_add_layer_norm(
mixer_out, residual, self.norm1.weight, self.norm1.bias,
hidden_states, residual, self.norm1.weight, self.norm1.bias,
self.dropout1.p if self.training else 0.0, self.norm1.eps,
rowscale=rowscale1, prenorm=True
rowscale=rowscale1, prenorm=True, residual_in_fp32=self.residual_in_fp32
)
hidden_states = self.mixer(hidden_states,
**(mixer_kwargs if mixer_kwargs is not None else {}))
if not isinstance(self.mlp, nn.Identity):
mlp_out = self.mlp(hidden_states)
if not self.fused_dropout_add_ln:
residual = self.drop_path2(self.dropout2(mlp_out)) + residual
dropped = self.drop_path2(self.dropout2(hidden_states))
residual = (dropped + residual) if residual is not None else dropped
hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else:
if self.drop_path2.p == 0 or not self.training:
rowscale2 = None
else:
rowscale2 = self.drop_path2(torch.ones(
mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype)
hidden_states.shape[:-1], device=hidden_states.device,
dtype=hidden_states.dtype)
)
hidden_states, residual = dropout_add_layer_norm(
mlp_out, residual, self.norm2.weight, self.norm2.bias,
hidden_states, residual, self.norm2.weight, self.norm2.bias,
self.dropout2.p if self.training else 0.0, self.norm2.eps,
rowscale=rowscale2, prenorm=True
rowscale=rowscale2, prenorm=True, residual_in_fp32=self.residual_in_fp32
)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
else:
assert residual is None

View File

@ -12,14 +12,23 @@ from flash_attn.utils.distributed import reduce_scatter, all_reduce
class GPT2Embeddings(nn.Module):
def __init__(self, embed_dim, vocab_size, max_position_embeddings, padding_idx=None,
device=None, dtype=None):
word_embed_proj_dim=None, device=None, dtype=None):
"""
If max_position_embeddings <= 0, there's no position embeddings
If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension
the project up to embed_dim
"""
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx,
**factory_kwargs)
if word_embed_proj_dim is None:
self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx,
**factory_kwargs)
self.project_in = None
else:
self.word_embeddings = nn.Embedding(vocab_size, word_embed_proj_dim,
padding_idx=padding_idx, **factory_kwargs)
self.project_in = nn.Linear(word_embed_proj_dim, embed_dim, bias=False,
**factory_kwargs)
self.max_position_embeddings = max_position_embeddings
if self.max_position_embeddings > 0:
self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim,
@ -32,6 +41,8 @@ class GPT2Embeddings(nn.Module):
"""
batch_size, seqlen = input_ids.shape
embeddings = self.word_embeddings(input_ids)
if self.project_in is not None:
embeddings = self.project_in(embeddings)
if self.max_position_embeddings > 0:
if position_ids is None:
position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)

View File

@ -84,6 +84,7 @@ def test_gpt2_optimized(model_name):
config.fused_bias_fc = True
config.fused_dense_gelu_dense = True
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
config.pad_vocab_size_multiple = 8
model = GPTLMHeadModel.from_pretrained(model_name, config)

77
tests/models/test_opt.py Normal file
View File

@ -0,0 +1,77 @@
import re
import torch
import pytest
from transformers import OPTConfig
from transformers.models.opt.modeling_opt import OPTForCausalLM
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.opt import remap_state_dict_opt, opt_config_to_gpt2_config
from flash_attn.utils.pretrained import state_dict_from_pretrained
@pytest.mark.parametrize('model_name', ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b"])
# @pytest.mark.parametrize('model_name', ["facebook/opt-350m"])
def test_opt_state_dict(model_name):
config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name))
pretrained_state_dict = remap_state_dict_opt(state_dict_from_pretrained(model_name), config)
model = GPTLMHeadModel(config)
state_dict = model.state_dict()
assert state_dict.keys() == pretrained_state_dict.keys()
for k in state_dict.keys():
assert state_dict[k].shape == pretrained_state_dict[k].shape
@pytest.mark.parametrize('model_name', ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b"])
# @pytest.mark.parametrize('model_name', ["facebook/opt-350m"])
def test_opt_optimized(model_name):
"""Check that our implementation of OPT (without any optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
device = 'cuda'
config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name))
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_dropout_add_ln = True
# Only prenorm supports residual_in_fp32
config.residual_in_fp32 = getattr(config, 'prenorm', True)
config.pad_vocab_size_multiple = 8
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
model_ref = OPTForCausalLM.from_pretrained(model_name).to(device=device)
model_hf = OPTForCausalLM.from_pretrained(model_name, torch_dtype=dtype).to(device=device)
model.eval()
model_ref.eval()
model_hf.eval()
torch.manual_seed(0)
batch_size = 2
max_seqlen = 256
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda')
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device='cuda')
if model_name != 'facebook/opt-350m': # The OPT-350m projects the embeddings to dimension 512
out = model.transformer(input_ids)
out_hf = model_hf.model(input_ids).last_hidden_state
out_ref = model_ref.model(input_ids).last_hidden_state
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}')
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
logits = model(input_ids).logits
logits_hf = model_hf(input_ids).logits
logits_ref = model_ref(input_ids).logits
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}')
assert (logits - logits_ref).abs().max().item() < 3 * (logits_hf - logits_ref).abs().max().item()