From ff34123bd426bcc3ca0d1a11b6173652fb84d033 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 15 Jan 2023 22:14:31 -0800 Subject: [PATCH] Reorder LN in Block, support OPT --- flash_attn/models/bert.py | 3 +- flash_attn/models/gpt.py | 162 +++++++++++++++++++++----------- flash_attn/models/opt.py | 104 ++++++++++++++++++++ flash_attn/modules/block.py | 60 ++++++++---- flash_attn/modules/embedding.py | 17 +++- tests/models/test_gpt.py | 1 + tests/models/test_opt.py | 77 +++++++++++++++ 7 files changed, 345 insertions(+), 79 deletions(-) create mode 100644 flash_attn/models/opt.py create mode 100644 tests/models/test_opt.py diff --git a/flash_attn/models/bert.py b/flash_attn/models/bert.py index 88ac4d9..360eedc 100644 --- a/flash_attn/models/bert.py +++ b/flash_attn/models/bert.py @@ -94,7 +94,8 @@ def create_block(config, layer_idx=None): mlp_cls = create_mlp_cls(config, layer_idx, return_residual=return_residual) norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_eps) block = Block(config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls, - prenorm=False, resid_dropout=config.hidden_dropout_prob, + prenorm=False, resid_dropout1=config.hidden_dropout_prob, + resid_dropout2=config.hidden_dropout_prob, fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False), return_residual=return_residual) return block diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index 3f5f115..e2f9e6b 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, Tri Dao. +# Copyright (c) 2023, Tri Dao. import logging import math @@ -23,6 +23,7 @@ from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings from flash_attn.utils.distributed import sync_shared_params, all_gather_raw from flash_attn.utils.pretrained import state_dict_from_pretrained from flash_attn.utils.generation import GenerationMixin +from flash_attn.models.opt import remap_state_dict_opt try: from flash_attn.ops.fused_dense import ColumnParallelLinear @@ -88,9 +89,12 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp if process_group is not None: assert fused_dense_gelu_dense, 'Tensor Parallel is only implemented for FusedDenseGeluDense' if not fused_dense_gelu_dense and not fused_dense_sqrelu_dense: - approximate = 'tanh' if config.activation_function in ['gelu_new', 'gelu_fast'] else 'none' - mlp_cls = partial(Mlp, hidden_features=inner_dim, - activation=partial(F.gelu, approximate=approximate), **factory_kwargs) + if config.activation_function == 'relu': + activation = partial(F.relu, inplace=True) + else: + approximate = 'tanh' if config.activation_function in ['gelu_new', 'gelu_fast'] else 'none' + activation=partial(F.gelu, approximate=approximate) + mlp_cls = partial(Mlp, hidden_features=inner_dim, activation=activation, **factory_kwargs) else: mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0) # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer @@ -121,9 +125,14 @@ def create_block(config, layer_idx=None, process_group=None, device=None, dtype= mixer_cls = create_mixer_cls(config, layer_idx, process_group=process_group, **factory_kwargs) mlp_cls = create_mlp_cls(config, layer_idx, process_group=process_group, **factory_kwargs) norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_epsilon, **factory_kwargs) + # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable + residual_in_fp32 = getattr(config, 'residual_in_fp32', False) + resid_dropout1 = config.resid_pdrop if layer_idx is None or layer_idx > 0 else config.embd_pdrop + prenorm = getattr(config, 'prenorm', True) block = Block(config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls, - prenorm=True, resid_dropout=config.resid_pdrop, + prenorm=prenorm, resid_dropout1=resid_dropout1, resid_dropout2=config.resid_pdrop, fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False), + residual_in_fp32=residual_in_fp32, sequence_parallel=sequence_parallel and process_group is not None, mark_shared_params=process_group is not None) block.layer_idx = layer_idx @@ -154,11 +163,16 @@ class GPTPreTrainedModel(nn.Module): """ # Instantiate model. model = cls(config, *args, device=device, dtype=dtype, **kwargs) - state_dict = remap_state_dict_gpt2( - # If we're going to shard the model, then don't load fp32 weights to GPU. - state_dict_from_pretrained(model_name, device=device if world_size == 1 else None, - dtype=dtype), config + # If we're going to shard the model, then don't load fp32 weights to GPU. + state_dict = state_dict_from_pretrained( + model_name, device=device if world_size == 1 else None, dtype=dtype ) + if model_name.startswith('gpt2'): + state_dict = remap_state_dict_gpt2(state_dict, config) + elif model_name.startswith('facebook/opt'): + state_dict = remap_state_dict_opt(state_dict, config) + else: + raise NotImplementedError(f'Model {model_name} not supported') if world_size > 1: state_dict = shard_state_dict_tp(state_dict, config, world_size, rank) state_dict = {k: v.to(device=device) for k, v in state_dict.items()} @@ -166,6 +180,7 @@ class GPTPreTrainedModel(nn.Module): logger.info(load_return) return model + # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454 def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_residual=True): if isinstance(module, nn.Linear): @@ -195,47 +210,53 @@ class GPTModel(GPTPreTrainedModel): factory_kwargs = {'device': device, 'dtype': dtype} self.process_group = process_group self.sequence_parallel = getattr(config, 'sequence_parallel', True) - assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'sqrelu'] + assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'relu', 'sqrelu'] pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple) + # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable + self.residual_in_fp32 = getattr(config, 'residual_in_fp32', False) + # These 2 options are for OPT-350m + self.prenorm = getattr(config, 'prenorm', True) + word_embed_proj_dim = getattr(config, 'word_embed_proj_dim', None) if process_group is None: - self.embeddings = GPT2Embeddings(config.hidden_size, vocab_size, - config.max_position_embeddings, **factory_kwargs) + self.embeddings = GPT2Embeddings( + config.hidden_size, vocab_size, config.max_position_embeddings, + word_embed_proj_dim=word_embed_proj_dim, **factory_kwargs + ) else: self.embeddings = ParallelGPT2Embeddings( config.hidden_size, vocab_size, config.max_position_embeddings, process_group=process_group, sequence_parallel=self.sequence_parallel, **factory_kwargs ) - self.emb_drop = nn.Dropout(config.embd_pdrop) - # We change the order of residual and layer norm: + # We change the order of dropout, residual and layer norm: # Instead of LN -> Attn / MLP -> Dropout -> Add, we do: - # Attn / MLP -> Dropout -> Add -> LN, returning both the residual branch (output of Add) and - # the main branch (output of LN). The model definition is unchanged, but the mapping of the - # nn.LayerNorm weights are changed. + # Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and + # the main branch (output of MLP). The model definition is unchanged, but the mapping of the + # nn.Dropout probabilities are changed. # This is for performance reason: we can fuse dropout + add + layer_norm. + self.layers = nn.ModuleList([create_block(config, layer_idx=i, process_group=process_group, + **factory_kwargs) + for i in range(config.num_hidden_layers)]) + self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False) if self.fused_dropout_add_ln and dropout_add_layer_norm is None: raise ImportError('dropout_add_layer_norm is not installed') - # self.ln_0 is the first layer norm in the model, while self.ln_f (in the pretrained weight) - # is the final layer norm. - self.ln_0 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, - **factory_kwargs) + if self.prenorm: + self.drop_f = nn.Dropout(config.resid_pdrop) + self.ln_f = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon, + **factory_kwargs) if process_group is not None: - for p in self.ln_0.parameters(): + for p in self.ln_f.parameters(): # Mark the norm parameters as "shared_params" so that we sync their values at init. p._shared_params = True # Mark the norm params as "sequence_parallel" so we run all-reduce on their grads. if self.sequence_parallel: p._sequence_parallel = True - self.layers = nn.ModuleList([create_block(config, layer_idx=i, process_group=process_group, - **factory_kwargs) - for i in range(config.num_hidden_layers)]) - self.apply(partial(_init_weights, n_layer=config.num_hidden_layers, initializer_range=config.initializer_range)) self.tie_weights() @@ -251,23 +272,28 @@ class GPTModel(GPTPreTrainedModel): embedding_kwargs = ({'combine_batch_seqlen_dim': True} if self.process_group is not None and self.sequence_parallel else {}) hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs) - # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable - if not self.fused_dropout_add_ln: - residual = self.emb_drop(hidden_states) - hidden_states = self.ln_0(residual.to(dtype=self.ln_0.weight.dtype)) - residual = residual.float() - else: - hidden_states, residual = dropout_add_layer_norm( - hidden_states, None, self.ln_0.weight, self.ln_0.bias, - self.emb_drop.p if self.training else 0.0, self.ln_0.eps, prenorm=True, - residual_in_fp32=True - ) + residual = None mixer_kwargs = ({'seqlen': input_ids.shape[1]} if self.process_group is not None and self.sequence_parallel else {}) if inference_params is not None: mixer_kwargs['inference_params'] = inference_params for layer in self.layers: - hidden_states, residual = layer(hidden_states, residual, mixer_kwargs=mixer_kwargs) + if self.prenorm: + hidden_states, residual = layer(hidden_states, residual, mixer_kwargs=mixer_kwargs) + else: + hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs) + if self.prenorm: + if not self.fused_dropout_add_ln: + dropped = self.drop_f(hidden_states) + residual = (dropped + residual) if residual is not None else dropped + hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype)) + else: + # Set prenorm=False here since we don't need to the residual + hidden_states = dropout_add_layer_norm( + hidden_states, residual, self.ln_f.weight, self.ln_f.bias, + self.drop_f.p if self.training else 0.0, self.ln_f.eps, prenorm=False, + residual_in_fp32=self.residual_in_fp32 + ) return hidden_states @@ -281,13 +307,20 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple) + # This option is for OPT-350m + word_embed_proj_dim = getattr(config, 'word_embed_proj_dim', None) + embed_dim = config.n_embd if word_embed_proj_dim is None else word_embed_proj_dim + if word_embed_proj_dim is not None: + self.project_out = nn.Linear(config.n_embd, embed_dim, bias=False, **factory_kwargs) + else: + self.project_out = None if process_group is None: - self.lm_head = nn.Linear(config.n_embd, vocab_size, bias=False, **factory_kwargs) + self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False, **factory_kwargs) else: if ColumnParallelLinear is None: raise ImportError('fused_dense_lib is not installed') self.lm_head = ColumnParallelLinear( - config.n_embd, vocab_size, process_group, bias=False, + embed_dim, vocab_size, process_group, bias=False, sequence_parallel=getattr(config, 'sequence_parallel', True), **factory_kwargs ) # Initialize weights and apply final processing @@ -307,6 +340,8 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): """ hidden_states = self.transformer(input_ids, position_ids=position_ids, inference_params=inference_params) + if self.project_out is not None: + hidden_states = self.project_out(hidden_states) lm_logits = self.lm_head(hidden_states) # During inference, we want the full logit for sampling if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None: @@ -315,6 +350,32 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): CausalLMOutput = namedtuple('CausalLMOutput', ['logits']) return CausalLMOutput(logits=lm_logits) + def load_state_dict(self, state_dict, strict=True): + # Remapping from our checkpoints that used a different ordering of layers in the block + # Previous: Attn / MLP -> Dropout -> Add -> LN + # Current: Dropout -> Add -> LN -> Attn / MLP + if 'transformer.ln_0.weight' in state_dict: + n_layers = self.config.num_hidden_layers + ln_weight = state_dict.pop(f'transformer.layers.{n_layers - 1}.norm2.weight') + ln_bias = state_dict.pop(f'transformer.layers.{n_layers - 1}.norm2.bias') + state_dict['transformer.ln_f.weight'] = ln_weight + state_dict['transformer.ln_f.bias'] = ln_bias + for l in reversed(range(n_layers)): + ln_weight = state_dict.pop(f'transformer.layers.{l}.norm1.weight') + ln_bias = state_dict.pop(f'transformer.layers.{l}.norm1.bias') + state_dict[f'transformer.layers.{l}.norm2.weight'] = ln_weight + state_dict[f'transformer.layers.{l}.norm2.bias'] = ln_bias + if l > 0: + ln_weight = state_dict.pop(f'transformer.layers.{l - 1}.norm2.weight') + ln_bias = state_dict.pop(f'transformer.layers.{l - 1}.norm2.bias') + state_dict[f'transformer.layers.{l}.norm1.weight'] = ln_weight + state_dict[f'transformer.layers.{l}.norm1.bias'] = ln_bias + ln_weight = state_dict.pop('transformer.ln_0.weight') + ln_bias = state_dict.pop('transformer.ln_0.bias') + state_dict[f'transformer.layers.0.norm1.weight'] = ln_weight + state_dict[f'transformer.layers.0.norm1.bias'] = ln_bias + return super().load_state_dict(state_dict, strict=strict) + def remap_state_dict_gpt2(state_dict, config): # Word embedding and position embedding @@ -331,22 +392,11 @@ def remap_state_dict_gpt2(state_dict, config): state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight'] # LayerNorm - ln_weight, ln_bias = state_dict.pop('ln_f.weight'), state_dict.pop('ln_f.bias') - state_dict[f'transformer.layers.{config.num_hidden_layers - 1}.norm2.weight'] = ln_weight - state_dict[f'transformer.layers.{config.num_hidden_layers - 1}.norm2.bias'] = ln_bias - ln_weight, ln_bias = state_dict.pop('h.0.ln_1.weight'), state_dict.pop('h.0.ln_1.bias') - state_dict['transformer.ln_0.weight'] = ln_weight - state_dict['transformer.ln_0.bias'] = ln_bias - for d in range(config.num_hidden_layers): - ln_weight = state_dict.pop(f'h.{d}.ln_2.weight') - ln_bias = state_dict.pop(f'h.{d}.ln_2.bias') - state_dict[f'transformer.layers.{d}.norm1.weight'] = ln_weight - state_dict[f'transformer.layers.{d}.norm1.bias'] = ln_bias - if d > 0: - ln_weight = state_dict.pop(f'h.{d}.ln_1.weight') - ln_bias = state_dict.pop(f'h.{d}.ln_1.bias') - state_dict[f'transformer.layers.{d - 1}.norm2.weight'] = ln_weight - state_dict[f'transformer.layers.{d - 1}.norm2.bias'] = ln_bias + def key_mapping_ln(key): + key = re.sub(r'^ln_f.(weight|bias)', r'transformer.ln_f.\1', key) + key = re.sub(r'^h.(\d+).ln_(1|2).(weight|bias)', r'transformer.layers.\1.norm\2.\3', key) + return key + state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) # MLP for d in range(config.num_hidden_layers): diff --git a/flash_attn/models/opt.py b/flash_attn/models/opt.py new file mode 100644 index 0000000..88d7c52 --- /dev/null +++ b/flash_attn/models/opt.py @@ -0,0 +1,104 @@ +# Copyright (c) 2023, Tri Dao. + +import math +import re + +from collections import OrderedDict + +import torch +import torch.nn.functional as F + +from transformers import GPT2Config, OPTConfig + + +def remap_state_dict_opt(state_dict, config): + def key_mapping_model(key): + key = re.sub(r'^model.decoder.', 'transformer.', key) + # The OPT-350m model uses '^decoder' instead of '^model.decoder' + key = re.sub(r'^decoder.', 'transformer.', key) + return key + state_dict = OrderedDict((key_mapping_model(k), v) for k, v in state_dict.items()) + # Word embedding and position embedding + def key_mapping_emb(key): + key = re.sub(r'^transformer.embed_tokens.', 'transformer.embeddings.word_embeddings.', key) + # The OPT-350m model uses has project_in and project_out + key = re.sub(r'^transformer.project_in.', 'transformer.embeddings.project_in.', key) + key = re.sub(r'^transformer.project_out.', 'project_out.', key) + key = re.sub(r'^transformer.embed_positions.', + 'transformer.embeddings.position_embeddings.', key) + return key + state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) + # OPT uses the first 2 indices of pos_emb for padding tokens + pos_embeddings = state_dict.pop('transformer.embeddings.position_embeddings.weight') + state_dict['transformer.embeddings.position_embeddings.weight'] = pos_embeddings[2:] + word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight') + # It's possible that vocab_size is padded to be a multiple of 8, for example. + pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) + vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple) + state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad( + word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) + ) + state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight'] + + # LayerNorm + def key_mapping_ln(key): + key = re.sub(r'^transformer.final_layer_norm.', r'transformer.ln_f.', key) + key = re.sub(r'^transformer.layers.(\d+).self_attn_layer_norm.', + r'transformer.layers.\1.norm1.', key) + key = re.sub(r'^transformer.layers.(\d+).final_layer_norm.', + r'transformer.layers.\1.norm2.', key) + return key + state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) + + # MLP + def key_mapping_mlp(key): + return re.sub(r'^transformer.layers.(\d+).fc(1|2).', + r'transformer.layers.\1.mlp.fc\2.', key) + state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) + + # Attention + for l in range(config.n_layer): + Wq = state_dict.pop(f'transformer.layers.{l}.self_attn.q_proj.weight') + Wk = state_dict.pop(f'transformer.layers.{l}.self_attn.k_proj.weight') + Wv = state_dict.pop(f'transformer.layers.{l}.self_attn.v_proj.weight') + bq = state_dict.pop(f'transformer.layers.{l}.self_attn.q_proj.bias') + bk = state_dict.pop(f'transformer.layers.{l}.self_attn.k_proj.bias') + bv = state_dict.pop(f'transformer.layers.{l}.self_attn.v_proj.bias') + state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat( + [Wq, Wk, Wv], dim=0 + ) + state_dict[f'transformer.layers.{l}.mixer.Wqkv.bias'] = torch.cat( + [bq, bk, bv], dim=0 + ) + def key_mapping_attn(key): + return re.sub(r'^transformer.layers.(\d+).self_attn.out_proj.', + r'transformer.layers.\1.mixer.out_proj.', key) + state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) + + return state_dict + + +def opt_config_to_gpt2_config(opt_config: OPTConfig) -> GPT2Config: + assert opt_config.layerdrop == 0.0 + assert opt_config.layer_norm_elementwise_affine + word_embed_proj_dim = (None if opt_config.word_embed_proj_dim == opt_config.hidden_size + else opt_config.word_embed_proj_dim) + return GPT2Config( + vocab_size=opt_config.vocab_size, + n_positions=opt_config.max_position_embeddings, + n_embd=opt_config.hidden_size, + n_layer=opt_config.num_hidden_layers, + n_head=opt_config.num_attention_heads, + n_inner=opt_config.ffn_dim, + activation_function=opt_config.activation_function, + resid_pdrop=opt_config.dropout, + # HF's implementation of OPT doesn't seem to have embedding dropout + embd_pdrop=opt_config.dropout, + attn_pdrop=opt_config.attention_dropout, + initializer_range=opt_config.init_std, + bos_token_id=opt_config.bos_token_id, + eos_token_id=opt_config.eos_token_id, + # These are new arguments not in the original GPT2Config + prenorm=opt_config.do_layer_norm_before, + word_embed_proj_dim=word_embed_proj_dim + ) diff --git a/flash_attn/modules/block.py b/flash_attn/modules/block.py index 5043733..1b855b6 100644 --- a/flash_attn/modules/block.py +++ b/flash_attn/modules/block.py @@ -22,10 +22,22 @@ except ImportError: class Block(nn.Module): def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm, - dropout_cls=nn.Dropout, prenorm=True, resid_dropout=0., drop_path=0., - fused_dropout_add_ln=False, return_residual=False, sequence_parallel=False, - mark_shared_params=False): + dropout_cls=nn.Dropout, prenorm=True, resid_dropout1=0., resid_dropout2=0., + drop_path1=0., drop_path2=0., fused_dropout_add_ln=False, return_residual=False, + residual_in_fp32=False, sequence_parallel=False, mark_shared_params=False): """ + For prenorm=True, this Block has a slightly different structure compared to a regular + prenorm Transformer block. + The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add. + [Ref: https://arxiv.org/abs/2002.04745] + Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both + the hidden_states (output of the MLP) and the residual. + This is for performance reasons, as we can fuse the dropout, add and LayerNorm. + The residual needs to be provided (except for the very first block). + + For prenorm=False, this Block has the same structure as a regular postnorm Transformer + block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN. + return_residual: whether each of the sub-layers (mixer and mlp) will return the residual. This is for performance reason: for post-norm architecture, returning the input allows us to fuse the backward of nn.Linear with the residual connection. @@ -34,18 +46,21 @@ class Block(nn.Module): self.prenorm = prenorm self.fused_dropout_add_ln = fused_dropout_add_ln self.return_residual = return_residual + self.residual_in_fp32 = residual_in_fp32 + if self.residual_in_fp32: + assert self.prenorm, 'residual_in_fp32 is only compatible with prenorm=True' if mixer_cls is None: mixer_cls = partial(MHA, num_heads=dim // 64) if mlp_cls is None: mlp_cls = partial(Mlp, hidden_features=4 * dim) self.mixer = mixer_cls(dim) - self.dropout1 = dropout_cls(resid_dropout) - self.drop_path1 = StochasticDepth(drop_path, mode='row') + self.dropout1 = dropout_cls(resid_dropout1) + self.drop_path1 = StochasticDepth(drop_path1, mode='row') self.norm1 = norm_cls(dim) self.mlp = mlp_cls(dim) if not isinstance(self.mlp, nn.Identity): - self.dropout2 = dropout_cls(resid_dropout) - self.drop_path2 = StochasticDepth(drop_path, mode='row') + self.dropout2 = dropout_cls(resid_dropout2) + self.drop_path2 = StochasticDepth(drop_path2, mode='row') self.norm2 = norm_cls(dim) if self.fused_dropout_add_ln: @@ -82,41 +97,48 @@ class Block(nn.Module): residual: if postnorm, residual=None, If prenorm, hidden_states = LayerNorm(residual) """ if self.prenorm: - assert residual is not None - mixer_out = self.mixer(hidden_states, - **(mixer_kwargs if mixer_kwargs is not None else {})) if not self.fused_dropout_add_ln: - residual = self.drop_path1(self.dropout1(mixer_out)) + residual + dropped = self.drop_path1(self.dropout1(hidden_states)) + residual = (dropped + residual) if residual is not None else dropped hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) else: if self.drop_path1.p == 0 or not self.training: rowscale1 = None else: rowscale1 = self.drop_path1(torch.ones( - mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype) + hidden_states.shape[:-1], device=hidden_states.device, + dtype=hidden_states.dtype) ) hidden_states, residual = dropout_add_layer_norm( - mixer_out, residual, self.norm1.weight, self.norm1.bias, + hidden_states, residual, self.norm1.weight, self.norm1.bias, self.dropout1.p if self.training else 0.0, self.norm1.eps, - rowscale=rowscale1, prenorm=True + rowscale=rowscale1, prenorm=True, residual_in_fp32=self.residual_in_fp32 ) + hidden_states = self.mixer(hidden_states, + **(mixer_kwargs if mixer_kwargs is not None else {})) if not isinstance(self.mlp, nn.Identity): - mlp_out = self.mlp(hidden_states) if not self.fused_dropout_add_ln: - residual = self.drop_path2(self.dropout2(mlp_out)) + residual + dropped = self.drop_path2(self.dropout2(hidden_states)) + residual = (dropped + residual) if residual is not None else dropped hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) else: if self.drop_path2.p == 0 or not self.training: rowscale2 = None else: rowscale2 = self.drop_path2(torch.ones( - mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype) + hidden_states.shape[:-1], device=hidden_states.device, + dtype=hidden_states.dtype) ) hidden_states, residual = dropout_add_layer_norm( - mlp_out, residual, self.norm2.weight, self.norm2.bias, + hidden_states, residual, self.norm2.weight, self.norm2.bias, self.dropout2.p if self.training else 0.0, self.norm2.eps, - rowscale=rowscale2, prenorm=True + rowscale=rowscale2, prenorm=True, residual_in_fp32=self.residual_in_fp32 ) + hidden_states = self.mlp(hidden_states) return hidden_states, residual else: assert residual is None diff --git a/flash_attn/modules/embedding.py b/flash_attn/modules/embedding.py index eee184f..6a5da2d 100644 --- a/flash_attn/modules/embedding.py +++ b/flash_attn/modules/embedding.py @@ -12,14 +12,23 @@ from flash_attn.utils.distributed import reduce_scatter, all_reduce class GPT2Embeddings(nn.Module): def __init__(self, embed_dim, vocab_size, max_position_embeddings, padding_idx=None, - device=None, dtype=None): + word_embed_proj_dim=None, device=None, dtype=None): """ If max_position_embeddings <= 0, there's no position embeddings + If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension + the project up to embed_dim """ factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() - self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx, - **factory_kwargs) + if word_embed_proj_dim is None: + self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx, + **factory_kwargs) + self.project_in = None + else: + self.word_embeddings = nn.Embedding(vocab_size, word_embed_proj_dim, + padding_idx=padding_idx, **factory_kwargs) + self.project_in = nn.Linear(word_embed_proj_dim, embed_dim, bias=False, + **factory_kwargs) self.max_position_embeddings = max_position_embeddings if self.max_position_embeddings > 0: self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim, @@ -32,6 +41,8 @@ class GPT2Embeddings(nn.Module): """ batch_size, seqlen = input_ids.shape embeddings = self.word_embeddings(input_ids) + if self.project_in is not None: + embeddings = self.project_in(embeddings) if self.max_position_embeddings > 0: if position_ids is None: position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device) diff --git a/tests/models/test_gpt.py b/tests/models/test_gpt.py index 2f5e777..2a73c27 100644 --- a/tests/models/test_gpt.py +++ b/tests/models/test_gpt.py @@ -84,6 +84,7 @@ def test_gpt2_optimized(model_name): config.fused_bias_fc = True config.fused_dense_gelu_dense = True config.fused_dropout_add_ln = True + config.residual_in_fp32 = True config.pad_vocab_size_multiple = 8 model = GPTLMHeadModel.from_pretrained(model_name, config) diff --git a/tests/models/test_opt.py b/tests/models/test_opt.py new file mode 100644 index 0000000..b0fc4f2 --- /dev/null +++ b/tests/models/test_opt.py @@ -0,0 +1,77 @@ +import re + +import torch +import pytest + +from transformers import OPTConfig +from transformers.models.opt.modeling_opt import OPTForCausalLM + +from flash_attn.models.gpt import GPTLMHeadModel +from flash_attn.models.opt import remap_state_dict_opt, opt_config_to_gpt2_config +from flash_attn.utils.pretrained import state_dict_from_pretrained + + +@pytest.mark.parametrize('model_name', ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b"]) +# @pytest.mark.parametrize('model_name', ["facebook/opt-350m"]) +def test_opt_state_dict(model_name): + config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name)) + pretrained_state_dict = remap_state_dict_opt(state_dict_from_pretrained(model_name), config) + model = GPTLMHeadModel(config) + state_dict = model.state_dict() + assert state_dict.keys() == pretrained_state_dict.keys() + for k in state_dict.keys(): + assert state_dict[k].shape == pretrained_state_dict[k].shape + + +@pytest.mark.parametrize('model_name', ["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b"]) +# @pytest.mark.parametrize('model_name', ["facebook/opt-350m"]) +def test_opt_optimized(model_name): + """Check that our implementation of OPT (without any optimizations enabled) matches the + HF implementation: the output of our forward pass in fp16 should be around the same as the HF + forward pass in fp16, when compared to the HF forward pass in fp32. + """ + dtype = torch.float16 + device = 'cuda' + config = opt_config_to_gpt2_config(OPTConfig.from_pretrained(model_name)) + config.use_flash_attn = True + config.fused_bias_fc = True + config.fused_dropout_add_ln = True + # Only prenorm supports residual_in_fp32 + config.residual_in_fp32 = getattr(config, 'prenorm', True) + config.pad_vocab_size_multiple = 8 + + model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype) + + model_ref = OPTForCausalLM.from_pretrained(model_name).to(device=device) + model_hf = OPTForCausalLM.from_pretrained(model_name, torch_dtype=dtype).to(device=device) + + model.eval() + model_ref.eval() + model_hf.eval() + + torch.manual_seed(0) + batch_size = 2 + max_seqlen = 256 + seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device='cuda') + input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, + device='cuda') + if model_name != 'facebook/opt-350m': # The OPT-350m projects the embeddings to dimension 512 + out = model.transformer(input_ids) + out_hf = model_hf.model(input_ids).last_hidden_state + out_ref = model_ref.model(input_ids).last_hidden_state + + print(f'Output max diff: {(out - out_ref).abs().max().item()}') + print(f'Output mean diff: {(out - out_ref).abs().mean().item()}') + print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}') + print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}') + assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item() + + logits = model(input_ids).logits + logits_hf = model_hf(input_ids).logits + logits_ref = model_ref(input_ids).logits + + print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}') + print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}') + print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}') + print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}') + assert (logits - logits_ref).abs().max().item() < 3 * (logits_hf - logits_ref).abs().max().item()