diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index 56d76ff..dcd3d62 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -3,32 +3,34 @@ import logging import math import re -from functools import partial - -from collections import namedtuple, OrderedDict +from collections import OrderedDict, namedtuple from collections.abc import Sequence +from functools import partial import torch import torch.nn as nn import torch.nn.functional as F - -from transformers import GPT2Config - from einops import rearrange - -from flash_attn.ops.activations import sqrelu_fwd -from flash_attn.modules.mha import MHA, ParallelMHA -from flash_attn.modules.mlp import Mlp, ParallelMLP, FusedMLP, ParallelFusedMLP -from flash_attn.modules.mlp import GatedMlp, ParallelGatedMlp +from flash_attn.models.falcon import remap_state_dict_hf_falcon +from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox +from flash_attn.models.gptj import remap_state_dict_hf_gptj +from flash_attn.models.opt import remap_state_dict_hf_opt from flash_attn.modules.block import Block, ParallelBlock from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings -from flash_attn.utils.distributed import sync_shared_params, all_gather_raw -from flash_attn.utils.pretrained import state_dict_from_pretrained +from flash_attn.modules.mha import MHA, ParallelMHA +from flash_attn.modules.mlp import ( + FusedMLP, + GatedMlp, + Mlp, + ParallelFusedMLP, + ParallelGatedMlp, + ParallelMLP, +) +from flash_attn.ops.activations import sqrelu_fwd +from flash_attn.utils.distributed import all_gather_raw, sync_shared_params from flash_attn.utils.generation import GenerationMixin -from flash_attn.models.opt import remap_state_dict_hf_opt -from flash_attn.models.gptj import remap_state_dict_hf_gptj -from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox -from flash_attn.models.falcon import remap_state_dict_hf_falcon +from flash_attn.utils.pretrained import state_dict_from_pretrained +from transformers import GPT2Config try: from flash_attn.ops.fused_dense import ColumnParallelLinear @@ -65,158 +67,247 @@ logger = logging.getLogger(__name__) def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dtype=None): - factory_kwargs = {'device': device, 'dtype': dtype} - head_dim = getattr(config, 'head_dim', config.hidden_size // config.num_attention_heads) + factory_kwargs = {"device": device, "dtype": dtype} + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) softmax_scale = 1.0 if not config.scale_attn_weights else head_dim ** (-0.5) if config.scale_attn_by_inverse_layer_idx: assert layer_idx is not None softmax_scale /= float(layer_idx + 1) - dwconv = getattr(config, 'attn_dwconv', False) + dwconv = getattr(config, "attn_dwconv", False) if dwconv: - assert process_group is None, 'TensorParallel MHA does not support dwconv yet' - qkv_proj_bias = getattr(config, 'qkv_proj_bias', True) - out_proj_bias = getattr(config, 'out_proj_bias', True) - rotary_emb_dim = int(getattr(config, 'rotary_emb_fraction', 0.0) * head_dim) - rotary_emb_base = getattr(config, 'rotary_emb_base', 10000.0) - rotary_emb_scale_base = getattr(config, 'rotary_emb_scale_base', None) - rotary_emb_interleaved = getattr(config, 'rotary_emb_interleaved', False) - use_flash_attn = getattr(config, 'use_flash_attn', False) - fused_bias_fc = getattr(config, 'fused_bias_fc', False) + assert process_group is None, "TensorParallel MHA does not support dwconv yet" + qkv_proj_bias = getattr(config, "qkv_proj_bias", True) + out_proj_bias = getattr(config, "out_proj_bias", True) + rotary_emb_dim = int(getattr(config, "rotary_emb_fraction", 0.0) * head_dim) + rotary_emb_base = getattr(config, "rotary_emb_base", 10000.0) + rotary_emb_scale_base = getattr(config, "rotary_emb_scale_base", None) + rotary_emb_interleaved = getattr(config, "rotary_emb_interleaved", False) + use_flash_attn = getattr(config, "use_flash_attn", False) + fused_bias_fc = getattr(config, "fused_bias_fc", False) if not fused_bias_fc: - assert process_group is None, 'TensorParallel MHA requires fused_bias_fc' + assert process_group is None, "TensorParallel MHA requires fused_bias_fc" mha_cls = MHA if process_group is None else ParallelMHA - serial_kwargs = ({'fused_bias_fc': fused_bias_fc, 'dwconv': dwconv} - if process_group is None else {}) - parallel_kwargs = ({'process_group': process_group, - 'sequence_parallel': getattr(config, 'sequence_parallel', True)} - if process_group is not None else {}) + serial_kwargs = ( + {"fused_bias_fc": fused_bias_fc, "dwconv": dwconv} if process_group is None else {} + ) + parallel_kwargs = ( + { + "process_group": process_group, + "sequence_parallel": getattr(config, "sequence_parallel", True), + } + if process_group is not None + else {} + ) num_heads_kv = getattr(config, "n_head_kv", None) - mixer_cls = partial(mha_cls, num_heads=config.num_attention_heads, - num_heads_kv=num_heads_kv, - qkv_proj_bias=qkv_proj_bias, out_proj_bias=out_proj_bias, - dropout=config.attn_pdrop, - softmax_scale=softmax_scale, causal=True, layer_idx=layer_idx, - rotary_emb_dim=rotary_emb_dim, rotary_emb_base=rotary_emb_base, - rotary_emb_scale_base=rotary_emb_scale_base, - rotary_emb_interleaved=rotary_emb_interleaved, - use_flash_attn=use_flash_attn, - **serial_kwargs, **parallel_kwargs, **factory_kwargs) + mixer_cls = partial( + mha_cls, + num_heads=config.num_attention_heads, + num_heads_kv=num_heads_kv, + qkv_proj_bias=qkv_proj_bias, + out_proj_bias=out_proj_bias, + dropout=config.attn_pdrop, + softmax_scale=softmax_scale, + causal=True, + layer_idx=layer_idx, + rotary_emb_dim=rotary_emb_dim, + rotary_emb_base=rotary_emb_base, + rotary_emb_scale_base=rotary_emb_scale_base, + rotary_emb_interleaved=rotary_emb_interleaved, + use_flash_attn=use_flash_attn, + **serial_kwargs, + **parallel_kwargs, + **factory_kwargs, + ) return mixer_cls def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtype=None): - factory_kwargs = {'device': device, 'dtype': dtype} - mlp_fc1_bias = getattr(config, 'mlp_fc1_bias', True) - mlp_fc2_bias = getattr(config, 'mlp_fc2_bias', True) - fused_mlp = getattr(config, 'fused_mlp', False) + factory_kwargs = {"device": device, "dtype": dtype} + mlp_fc1_bias = getattr(config, "mlp_fc1_bias", True) + mlp_fc2_bias = getattr(config, "mlp_fc2_bias", True) + fused_mlp = getattr(config, "fused_mlp", False) if fused_mlp: - assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu', 'sqrelu'] - fused_dense_sqrelu_dense = getattr(config, 'fused_dense_sqrelu_dense', False) + assert config.activation_function in [ + "gelu_new", + "gelu_fast", + "gelu_approx", + "relu", + "sqrelu", + ] + fused_dense_sqrelu_dense = getattr(config, "fused_dense_sqrelu_dense", False) if fused_dense_sqrelu_dense: - assert config.activation_function == 'sqrelu', ('fused_dense_sqrelu_dense only ' - 'supports approximate activation_function sqrelu') + assert config.activation_function == "sqrelu", ( + "fused_dense_sqrelu_dense only " "supports approximate activation_function sqrelu" + ) assert not (fused_dense_sqrelu_dense and fused_mlp) if not fused_mlp and not fused_dense_sqrelu_dense: - assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'gelu_approx', 'relu', - 'sqrelu', 'glu', 'swiglu', 'geglu'] - if config.activation_function in ['glu', 'swiglu', 'geglu']: - activation = (F.sigmoid if config.activation_function == 'glu' - else (F.silu if config.activation_function == 'swiglu' - else F.gelu)) + assert config.activation_function in [ + "gelu", + "gelu_new", + "gelu_fast", + "gelu_approx", + "relu", + "sqrelu", + "glu", + "swiglu", + "geglu", + ] + if config.activation_function in ["glu", "swiglu", "geglu"]: + activation = ( + F.sigmoid + if config.activation_function == "glu" + else (F.silu if config.activation_function == "swiglu" else F.gelu) + ) mlp_cls = GatedMlp if process_group is None else ParallelGatedMlp - parallel_kwargs = ({'process_group': process_group, - 'sequence_parallel': getattr(config, 'sequence_parallel', True)} - if process_group is not None else {}) - mlp_cls = partial(mlp_cls, hidden_features=config.n_inner, activation=activation, - bias1=mlp_fc1_bias, bias2=mlp_fc2_bias, - **parallel_kwargs, **factory_kwargs) + parallel_kwargs = ( + { + "process_group": process_group, + "sequence_parallel": getattr(config, "sequence_parallel", True), + } + if process_group is not None + else {} + ) + mlp_cls = partial( + mlp_cls, + hidden_features=config.n_inner, + activation=activation, + bias1=mlp_fc1_bias, + bias2=mlp_fc2_bias, + **parallel_kwargs, + **factory_kwargs, + ) else: - if config.activation_function == 'relu': + if config.activation_function == "relu": activation = partial(F.relu, inplace=True) - elif config.activation_function == 'sqrelu': + elif config.activation_function == "sqrelu": activation = sqrelu_fwd else: - approximate = ('tanh' if config.activation_function - in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'none') - activation=partial(F.gelu, approximate=approximate) + approximate = ( + "tanh" + if config.activation_function in ["gelu_new", "gelu_fast", "gelu_approx"] + else "none" + ) + activation = partial(F.gelu, approximate=approximate) mlp_cls = Mlp if process_group is None else ParallelMLP - parallel_kwargs = ({'process_group': process_group, - 'sequence_parallel': getattr(config, 'sequence_parallel', True)} - if process_group is not None else {}) - mlp_cls = partial(mlp_cls, hidden_features=config.n_inner, activation=activation, - bias1=mlp_fc1_bias, bias2=mlp_fc2_bias, - **parallel_kwargs, **factory_kwargs) + parallel_kwargs = ( + { + "process_group": process_group, + "sequence_parallel": getattr(config, "sequence_parallel", True), + } + if process_group is not None + else {} + ) + mlp_cls = partial( + mlp_cls, + hidden_features=config.n_inner, + activation=activation, + bias1=mlp_fc1_bias, + bias2=mlp_fc2_bias, + **parallel_kwargs, + **factory_kwargs, + ) else: - mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0) + mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0) # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer if isinstance(mlp_checkpoint_lvl, Sequence): assert layer_idx is not None mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx] if fused_mlp: if FusedMLP is None: - raise ImportError('fused_dense is not installed') - activation = ('gelu_approx' if config.activation_function - in ['gelu_new', 'gelu_fast', 'gelu_approx'] else config.activation_function) + raise ImportError("fused_dense is not installed") + activation = ( + "gelu_approx" + if config.activation_function in ["gelu_new", "gelu_fast", "gelu_approx"] + else config.activation_function + ) mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP - parallel_kwargs = ({'process_group': process_group, - 'sequence_parallel': getattr(config, 'sequence_parallel', True)} - if process_group is not None else {}) - mlp_cls = partial(mlp_cls, hidden_features=config.n_inner, activation=activation, - checkpoint_lvl=mlp_checkpoint_lvl, - bias1=mlp_fc1_bias, bias2=mlp_fc2_bias, - **parallel_kwargs, **factory_kwargs) + parallel_kwargs = ( + { + "process_group": process_group, + "sequence_parallel": getattr(config, "sequence_parallel", True), + } + if process_group is not None + else {} + ) + mlp_cls = partial( + mlp_cls, + hidden_features=config.n_inner, + activation=activation, + checkpoint_lvl=mlp_checkpoint_lvl, + bias1=mlp_fc1_bias, + bias2=mlp_fc2_bias, + **parallel_kwargs, + **factory_kwargs, + ) elif fused_dense_sqrelu_dense: if process_group is not None: - assert fused_mlp, 'Tensor Parallel is not implemented for FusedDenseSqreluDense' + assert fused_mlp, "Tensor Parallel is not implemented for FusedDenseSqreluDense" assert FusedDenseSqreluDense is not None - mlp_cls = partial(FusedDenseSqreluDense, hidden_features=config.n_inner, - checkpoint_lvl=mlp_checkpoint_lvl, **factory_kwargs) + mlp_cls = partial( + FusedDenseSqreluDense, + hidden_features=config.n_inner, + checkpoint_lvl=mlp_checkpoint_lvl, + **factory_kwargs, + ) else: - raise RuntimeError('MLP type not supported') + raise RuntimeError("MLP type not supported") return mlp_cls def create_block(config, layer_idx=None, process_group=None, device=None, dtype=None): - factory_kwargs = {'device': device, 'dtype': dtype} - sequence_parallel = getattr(config, 'sequence_parallel', True) + factory_kwargs = {"device": device, "dtype": dtype} + sequence_parallel = getattr(config, "sequence_parallel", True) mixer_cls = create_mixer_cls(config, layer_idx, process_group=process_group, **factory_kwargs) mlp_cls = create_mlp_cls(config, layer_idx, process_group=process_group, **factory_kwargs) - use_rms_norm = getattr(config, 'rms_norm', False) - norm_cls = partial(nn.LayerNorm if not use_rms_norm else RMSNorm, - eps=config.layer_norm_epsilon, **factory_kwargs) + use_rms_norm = getattr(config, "rms_norm", False) + norm_cls = partial( + nn.LayerNorm if not use_rms_norm else RMSNorm, + eps=config.layer_norm_epsilon, + **factory_kwargs, + ) # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable - residual_in_fp32 = getattr(config, 'residual_in_fp32', False) + residual_in_fp32 = getattr(config, "residual_in_fp32", False) resid_dropout1 = config.resid_pdrop if layer_idx is None or layer_idx > 0 else config.embd_pdrop - prenorm = getattr(config, 'prenorm', True) - parallel_block = getattr(config, 'parallel_block', False) + prenorm = getattr(config, "prenorm", True) + parallel_block = getattr(config, "parallel_block", False) if not parallel_block: block = Block( - config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls, - prenorm=prenorm, resid_dropout1=resid_dropout1, resid_dropout2=config.resid_pdrop, - fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False), + config.hidden_size, + mixer_cls, + mlp_cls, + norm_cls=norm_cls, + prenorm=prenorm, + resid_dropout1=resid_dropout1, + resid_dropout2=config.resid_pdrop, + fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False), residual_in_fp32=residual_in_fp32, sequence_parallel=sequence_parallel and process_group is not None, - mark_shared_params=process_group is not None + mark_shared_params=process_group is not None, ) else: assert prenorm block = ParallelBlock( - config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls, - resid_dropout1=resid_dropout1, resid_dropout2=config.resid_pdrop, - tied_norm=getattr(config, 'parallel_block_tied_norm', False), - fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False), + config.hidden_size, + mixer_cls, + mlp_cls, + norm_cls=norm_cls, + resid_dropout1=resid_dropout1, + resid_dropout2=config.resid_pdrop, + tied_norm=getattr(config, "parallel_block_tied_norm", False), + fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False), residual_in_fp32=residual_in_fp32, sequence_parallel=sequence_parallel and process_group is not None, - mark_shared_params=process_group is not None + mark_shared_params=process_group is not None, ) block.layer_idx = layer_idx return block class GPTPreTrainedModel(nn.Module): - """ An abstract class to handle weights initialization and - a simple interface for dowloading and loading pretrained models. + """An abstract class to handle weights initialization and + a simple interface for dowloading and loading pretrained models. """ + def __init__(self, config, *inputs, **kwargs): super().__init__() if not isinstance(config, GPT2Config): @@ -225,12 +316,23 @@ class GPTPreTrainedModel(nn.Module): "To create a model from a Google pretrained model use " "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( self.__class__.__name__, self.__class__.__name__ - )) + ) + ) self.config = config @classmethod - def from_pretrained(cls, model_name, config, *args, strict=True, device=None, dtype=None, - world_size=1, rank=0, **kwargs): + def from_pretrained( + cls, + model_name, + config, + *args, + strict=True, + device=None, + dtype=None, + world_size=1, + rank=0, + **kwargs, + ): """ Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict. Download and cache the pre-trained model file if needed. @@ -239,21 +341,19 @@ class GPTPreTrainedModel(nn.Module): model = cls(config, *args, device=device, dtype=dtype, **kwargs) # Load state_dict in cpu because we already initialized the model in GPU, and we don't # want extra stuff taking up more GPU memory - state_dict = state_dict_from_pretrained( - model_name, device='cpu', dtype=dtype - ) - if model_name.startswith('gpt2'): + state_dict = state_dict_from_pretrained(model_name, device="cpu", dtype=dtype) + if model_name.startswith("gpt2"): state_dict = remap_state_dict_hf_gpt2(state_dict, config) - elif model_name.startswith('facebook/opt'): + elif model_name.startswith("facebook/opt"): state_dict = remap_state_dict_hf_opt(state_dict, config) - elif model_name.startswith('EleutherAI/gpt-j-'): + elif model_name.startswith("EleutherAI/gpt-j-"): state_dict = remap_state_dict_hf_gptj(state_dict, config) - elif model_name.startswith('EleutherAI/gpt-neox-'): + elif model_name.startswith("EleutherAI/gpt-neox-"): state_dict = remap_state_dict_hf_gpt_neox(state_dict, config) - elif model_name.startswith('tiiuae/falcon-'): + elif model_name.startswith("tiiuae/falcon-"): state_dict = remap_state_dict_hf_falcon(state_dict, config) else: - raise NotImplementedError(f'Model {model_name} not supported') + raise NotImplementedError(f"Model {model_name} not supported") if world_size > 1: state_dict = shard_state_dict_tp(state_dict, config, world_size, rank) load_return = model.load_state_dict(state_dict, strict=strict) @@ -284,36 +384,51 @@ def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_resid class GPTModel(GPTPreTrainedModel): - def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None): super().__init__(config) - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} self.process_group = process_group - self.sequence_parallel = getattr(config, 'sequence_parallel', True) - assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'gelu_approx', - 'relu', 'sqrelu', 'glu', 'swiglu', 'geglu'] - pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) - vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) - * pad_vocab_size_multiple) + self.sequence_parallel = getattr(config, "sequence_parallel", True) + assert config.activation_function in [ + "gelu", + "gelu_new", + "gelu_fast", + "gelu_approx", + "relu", + "sqrelu", + "glu", + "swiglu", + "geglu", + ] + pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) + vocab_size = ( + math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple + ) # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable - self.residual_in_fp32 = getattr(config, 'residual_in_fp32', False) + self.residual_in_fp32 = getattr(config, "residual_in_fp32", False) # These 2 options are for OPT-350m - self.prenorm = getattr(config, 'prenorm', True) - use_rms_norm = getattr(config, 'rms_norm', False) - word_embed_proj_dim = getattr(config, 'word_embed_proj_dim', None) + self.prenorm = getattr(config, "prenorm", True) + use_rms_norm = getattr(config, "rms_norm", False) + word_embed_proj_dim = getattr(config, "word_embed_proj_dim", None) # For GPT-J, GPT-NeoX - self.parallel_block = getattr(config, 'parallel_block', False) + self.parallel_block = getattr(config, "parallel_block", False) if process_group is None: self.embeddings = GPT2Embeddings( - config.hidden_size, vocab_size, config.max_position_embeddings, - word_embed_proj_dim=word_embed_proj_dim, **factory_kwargs + config.hidden_size, + vocab_size, + config.max_position_embeddings, + word_embed_proj_dim=word_embed_proj_dim, + **factory_kwargs, ) else: self.embeddings = ParallelGPT2Embeddings( - config.hidden_size, vocab_size, config.max_position_embeddings, - process_group=process_group, sequence_parallel=self.sequence_parallel, - **factory_kwargs + config.hidden_size, + vocab_size, + config.max_position_embeddings, + process_group=process_group, + sequence_parallel=self.sequence_parallel, + **factory_kwargs, ) # We change the order of dropout, residual and layer norm: @@ -322,20 +437,25 @@ class GPTModel(GPTPreTrainedModel): # the main branch (output of MLP). The model definition is unchanged, but the mapping of the # nn.Dropout probabilities are changed. # This is for performance reason: we can fuse dropout + add + layer_norm. - self.layers = nn.ModuleList([create_block(config, layer_idx=i, process_group=process_group, - **factory_kwargs) - for i in range(config.num_hidden_layers)]) + self.layers = nn.ModuleList( + [ + create_block(config, layer_idx=i, process_group=process_group, **factory_kwargs) + for i in range(config.num_hidden_layers) + ] + ) - self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False) + self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False) if self.fused_dropout_add_ln: - if ((not self.parallel_block and dropout_add_layer_norm is None) - or (self.parallel_block and dropout_add_layer_norm_parallel_residual is None)): - raise ImportError('dropout_layer_norm is not installed') + if (not self.parallel_block and dropout_add_layer_norm is None) or ( + self.parallel_block and dropout_add_layer_norm_parallel_residual is None + ): + raise ImportError("dropout_layer_norm is not installed") if self.prenorm: self.drop_f = nn.Dropout(config.resid_pdrop) norm_cls = nn.LayerNorm if not use_rms_norm else RMSNorm - self.ln_f = norm_cls(config.hidden_size, eps=config.layer_norm_epsilon, - **factory_kwargs) + self.ln_f = norm_cls( + config.hidden_size, eps=config.layer_norm_epsilon, **factory_kwargs + ) if process_group is not None: for p in self.ln_f.parameters(): # Mark the norm parameters as "shared_params" so that we sync their values at init. @@ -344,8 +464,13 @@ class GPTModel(GPTPreTrainedModel): if self.sequence_parallel: p._sequence_parallel = True - self.apply(partial(_init_weights, n_layer=config.num_hidden_layers, - initializer_range=config.initializer_range)) + self.apply( + partial( + _init_weights, + n_layer=config.num_hidden_layers, + initializer_range=config.initializer_range, + ) + ) self.tie_weights() def tie_weights(self): @@ -353,28 +478,37 @@ class GPTModel(GPTPreTrainedModel): sync_shared_params(self, self.process_group) def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - return {i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) - for i, layer in enumerate(self.layers)} + return { + i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + for i, layer in enumerate(self.layers) + } def forward(self, input_ids, position_ids=None, inference_params=None): # If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen # dimensions so that we can split on it easily, in case of small batch size. # Only the attention layers need to know the seqlen. - embedding_kwargs = ({'combine_batch_seqlen_dim': True} - if self.process_group is not None and self.sequence_parallel else {}) + embedding_kwargs = ( + {"combine_batch_seqlen_dim": True} + if self.process_group is not None and self.sequence_parallel + else {} + ) hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs) if self.parallel_block: hidden_states2 = None residual = None - mixer_kwargs = ({'seqlen': input_ids.shape[1]} - if self.process_group is not None and self.sequence_parallel else {}) + mixer_kwargs = ( + {"seqlen": input_ids.shape[1]} + if self.process_group is not None and self.sequence_parallel + else {} + ) if inference_params is not None: - mixer_kwargs['inference_params'] = inference_params + mixer_kwargs["inference_params"] = inference_params for layer in self.layers: if self.prenorm: if not self.parallel_block: - hidden_states, residual = layer(hidden_states, residual, - mixer_kwargs=mixer_kwargs) + hidden_states, residual = layer( + hidden_states, residual, mixer_kwargs=mixer_kwargs + ) else: hidden_states, hidden_states2, residual = layer( hidden_states, hidden_states2, residual, mixer_kwargs=mixer_kwargs @@ -388,45 +522,66 @@ class GPTModel(GPTPreTrainedModel): residual = (dropped + residual) if residual is not None else dropped else: dropped2 = self.drop_f(hidden_states2) - residual = ((residual + dropped + dropped2) - if residual is not None else dropped + dropped2) + residual = ( + (residual + dropped + dropped2) + if residual is not None + else dropped + dropped2 + ) hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype)) else: # Set prenorm=False here since we don't need the residual if not self.parallel_block: - fused_add_norm_fn = (dropout_add_rms_norm if isinstance(self.ln_f, RMSNorm) - else dropout_add_layer_norm) + fused_add_norm_fn = ( + dropout_add_rms_norm + if isinstance(self.ln_f, RMSNorm) + else dropout_add_layer_norm + ) hidden_states = fused_add_norm_fn( - hidden_states, residual, self.ln_f.weight, self.ln_f.bias, - self.drop_f.p if self.training else 0.0, self.ln_f.eps, prenorm=False, - residual_in_fp32=self.residual_in_fp32 + hidden_states, + residual, + self.ln_f.weight, + self.ln_f.bias, + self.drop_f.p if self.training else 0.0, + self.ln_f.eps, + prenorm=False, + residual_in_fp32=self.residual_in_fp32, ) else: - fused_add_norm_fn = (dropout_add_rms_norm_parallel_residual - if isinstance(self.ln_f, RMSNorm) - else dropout_add_layer_norm_parallel_residual) + fused_add_norm_fn = ( + dropout_add_rms_norm_parallel_residual + if isinstance(self.ln_f, RMSNorm) + else dropout_add_layer_norm_parallel_residual + ) hidden_states, _ = fused_add_norm_fn( - hidden_states, hidden_states2, residual, self.ln_f.weight, self.ln_f.bias, - None, None, self.drop_f.p if self.training else 0.0, self.ln_f.eps, - prenorm=False, residual_in_fp32=self.residual_in_fp32 + hidden_states, + hidden_states2, + residual, + self.ln_f.weight, + self.ln_f.bias, + None, + None, + self.drop_f.p if self.training else 0.0, + self.ln_f.eps, + prenorm=False, + residual_in_fp32=self.residual_in_fp32, ) return hidden_states class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): - def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None): - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} super().__init__(config) self.process_group = process_group self.transformer = GPTModel(config, process_group=process_group, **factory_kwargs) - self.tie_word_embeddings = getattr(config, 'tie_word_embeddings', True) - lm_head_bias = getattr(config, 'lm_head_bias', False) - pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) - vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) - * pad_vocab_size_multiple) + self.tie_word_embeddings = getattr(config, "tie_word_embeddings", True) + lm_head_bias = getattr(config, "lm_head_bias", False) + pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) + vocab_size = ( + math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple + ) # This option is for OPT-350m - word_embed_proj_dim = getattr(config, 'word_embed_proj_dim', None) + word_embed_proj_dim = getattr(config, "word_embed_proj_dim", None) embed_dim = config.n_embd if word_embed_proj_dim is None else word_embed_proj_dim if word_embed_proj_dim is not None: self.project_out = nn.Linear(config.n_embd, embed_dim, bias=False, **factory_kwargs) @@ -436,14 +591,23 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): self.lm_head = nn.Linear(embed_dim, vocab_size, bias=lm_head_bias, **factory_kwargs) else: if ColumnParallelLinear is None: - raise ImportError('fused_dense_lib is not installed') + raise ImportError("fused_dense_lib is not installed") self.lm_head = ColumnParallelLinear( - embed_dim, vocab_size, process_group, bias=lm_head_bias, - sequence_parallel=getattr(config, 'sequence_parallel', True), **factory_kwargs + embed_dim, + vocab_size, + process_group, + bias=lm_head_bias, + sequence_parallel=getattr(config, "sequence_parallel", True), + **factory_kwargs, ) # Initialize weights and apply final processing - self.apply(partial(_init_weights, n_layer=config.num_hidden_layers, - initializer_range=config.initializer_range)) + self.apply( + partial( + _init_weights, + n_layer=config.num_hidden_layers, + initializer_range=config.initializer_range, + ) + ) self.tie_weights() def tie_weights(self): @@ -453,18 +617,20 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): sync_shared_params(self, self.process_group) def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): - return self.transformer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, - **kwargs) + return self.transformer.allocate_inference_cache( + batch_size, max_seqlen, dtype=dtype, **kwargs + ) def forward(self, input_ids, position_ids=None, inference_params=None, last_token_only=False): """ - inference_params: for generation. Adapted from Megatron-LM (and Apex) - https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470 - last_token_only: whether to return the logit for the last token only, - of shape (batch_size, vocab_size) + inference_params: for generation. Adapted from Megatron-LM (and Apex) + https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470 + last_token_only: whether to return the logit for the last token only, + of shape (batch_size, vocab_size) """ - hidden_states = self.transformer(input_ids, position_ids=position_ids, - inference_params=inference_params) + hidden_states = self.transformer( + input_ids, position_ids=position_ids, inference_params=inference_params + ) if last_token_only: hidden_states = hidden_states[:, -1] if self.project_out is not None: @@ -473,34 +639,34 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): # During inference, we want the full logit for sampling if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None: lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group) - lm_logits = rearrange(lm_logits, '(n b) ... d -> b ... (n d)', b=hidden_states.shape[0]) - CausalLMOutput = namedtuple('CausalLMOutput', ['logits']) + lm_logits = rearrange(lm_logits, "(n b) ... d -> b ... (n d)", b=hidden_states.shape[0]) + CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) return CausalLMOutput(logits=lm_logits) def load_state_dict(self, state_dict, strict=True): # Remapping from our checkpoints that used a different ordering of layers in the block # Previous: Attn / MLP -> Dropout -> Add -> LN # Current: Dropout -> Add -> LN -> Attn / MLP - if 'transformer.ln_0.weight' in state_dict: + if "transformer.ln_0.weight" in state_dict: n_layers = len(self.transformer.layers) - ln_weight = state_dict.pop(f'transformer.layers.{n_layers - 1}.norm2.weight') - ln_bias = state_dict.pop(f'transformer.layers.{n_layers - 1}.norm2.bias') - state_dict['transformer.ln_f.weight'] = ln_weight - state_dict['transformer.ln_f.bias'] = ln_bias + ln_weight = state_dict.pop(f"transformer.layers.{n_layers - 1}.norm2.weight") + ln_bias = state_dict.pop(f"transformer.layers.{n_layers - 1}.norm2.bias") + state_dict["transformer.ln_f.weight"] = ln_weight + state_dict["transformer.ln_f.bias"] = ln_bias for l in reversed(range(n_layers)): - ln_weight = state_dict.pop(f'transformer.layers.{l}.norm1.weight') - ln_bias = state_dict.pop(f'transformer.layers.{l}.norm1.bias') - state_dict[f'transformer.layers.{l}.norm2.weight'] = ln_weight - state_dict[f'transformer.layers.{l}.norm2.bias'] = ln_bias + ln_weight = state_dict.pop(f"transformer.layers.{l}.norm1.weight") + ln_bias = state_dict.pop(f"transformer.layers.{l}.norm1.bias") + state_dict[f"transformer.layers.{l}.norm2.weight"] = ln_weight + state_dict[f"transformer.layers.{l}.norm2.bias"] = ln_bias if l > 0: - ln_weight = state_dict.pop(f'transformer.layers.{l - 1}.norm2.weight') - ln_bias = state_dict.pop(f'transformer.layers.{l - 1}.norm2.bias') - state_dict[f'transformer.layers.{l}.norm1.weight'] = ln_weight - state_dict[f'transformer.layers.{l}.norm1.bias'] = ln_bias - ln_weight = state_dict.pop('transformer.ln_0.weight') - ln_bias = state_dict.pop('transformer.ln_0.bias') - state_dict[f'transformer.layers.0.norm1.weight'] = ln_weight - state_dict[f'transformer.layers.0.norm1.bias'] = ln_bias + ln_weight = state_dict.pop(f"transformer.layers.{l - 1}.norm2.weight") + ln_bias = state_dict.pop(f"transformer.layers.{l - 1}.norm2.bias") + state_dict[f"transformer.layers.{l}.norm1.weight"] = ln_weight + state_dict[f"transformer.layers.{l}.norm1.bias"] = ln_bias + ln_weight = state_dict.pop("transformer.ln_0.weight") + ln_bias = state_dict.pop("transformer.ln_0.bias") + state_dict[f"transformer.layers.0.norm1.weight"] = ln_weight + state_dict[f"transformer.layers.0.norm1.bias"] = ln_bias return super().load_state_dict(state_dict, strict=strict) @@ -508,8 +674,8 @@ def shard_state_dict_tp(state_dict, config, world_size, rank): """Convert the state_dict of a standard GPT model to the state_dict of a GPT model with tensor parallel. """ - pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) - vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple) + pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) + vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple assert vocab_size % world_size == 0 assert config.hidden_size % world_size == 0 inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size @@ -519,64 +685,84 @@ def shard_state_dict_tp(state_dict, config, world_size, rank): if key in state_dict: x = state_dict[key] dim = x.shape[0] // world_size - state_dict[key] = x[rank * dim:(rank + 1) * dim] + state_dict[key] = x[rank * dim : (rank + 1) * dim] def shard_last_dim(state_dict, key): if key in state_dict: x = state_dict[key] dim = x.shape[-1] // world_size - state_dict[key] = x[..., rank * dim:(rank + 1) * dim] + state_dict[key] = x[..., rank * dim : (rank + 1) * dim] def shard_gatedmlp_fc1_dim(state_dict, key): if key in state_dict: x = state_dict[key] dim = x.shape[0] // world_size // 2 state_dict[key] = rearrange( - rearrange(x, "(two o) ... -> two o ...", two=2)[:, rank * dim:(rank + 1) * dim], - "two o ... -> (two o) ..." + rearrange(x, "(two o) ... -> two o ...", two=2)[:, rank * dim : (rank + 1) * dim], + "two o ... -> (two o) ...", ) def shard_qkv_headdim(state_dict, key): if key in state_dict: n_head = config.n_head - n_head_kv = getattr(config, 'n_head_kv', n_head) + n_head_kv = getattr(config, "n_head_kv", n_head) assert n_head % world_size == 0 and n_head_kv % world_size == 0 if n_head_kv == n_head: - x = rearrange(state_dict[key], '(three d) ... -> three d ...', three=3) + x = rearrange(state_dict[key], "(three d) ... -> three d ...", three=3) dim = x.shape[1] // world_size - state_dict[key] = rearrange(x[:, rank * dim:(rank + 1) * dim], - 'three d ... -> (three d) ...') + state_dict[key] = rearrange( + x[:, rank * dim : (rank + 1) * dim], "three d ... -> (three d) ..." + ) else: n_head_per_rank = n_head // world_size n_head_kv_per_rank = n_head_kv // world_size - x = rearrange(state_dict[key], '(nheadqkv headdim) ... -> nheadqkv headdim ...', - nheadqkv=n_head + 2 * n_head_kv) - state_dict[key] = rearrange(torch.cat([ - x[rank * n_head_per_rank:(rank + 1) * n_head_per_rank], - x[n_head + rank * n_head_kv_per_rank:n_head + (rank + 1) * n_head_kv_per_rank], - x[n_head + n_head_kv + rank * n_head_kv_per_rank:n_head + n_head_kv + (rank + 1) * n_head_kv_per_rank], - ], dim=0), "nheadqkv headdim ... -> (nheadqkv headdim) ...") + x = rearrange( + state_dict[key], + "(nheadqkv headdim) ... -> nheadqkv headdim ...", + nheadqkv=n_head + 2 * n_head_kv, + ) + state_dict[key] = rearrange( + torch.cat( + [ + x[rank * n_head_per_rank : (rank + 1) * n_head_per_rank], + x[ + n_head + + rank * n_head_kv_per_rank : n_head + + (rank + 1) * n_head_kv_per_rank + ], + x[ + n_head + + n_head_kv + + rank * n_head_kv_per_rank : n_head + + n_head_kv + + (rank + 1) * n_head_kv_per_rank + ], + ], + dim=0, + ), + "nheadqkv headdim ... -> (nheadqkv headdim) ...", + ) - shard_first_dim(state_dict, 'transformer.embeddings.word_embeddings.weight') - if 'lm_head.weight' in state_dict: - shard_first_dim(state_dict, 'lm_head.weight') - if 'transformer.embeddings.position_embeddings.weight' in state_dict: - shard_last_dim(state_dict, 'transformer.embeddings.position_embeddings.weight') + shard_first_dim(state_dict, "transformer.embeddings.word_embeddings.weight") + if "lm_head.weight" in state_dict: + shard_first_dim(state_dict, "lm_head.weight") + if "transformer.embeddings.position_embeddings.weight" in state_dict: + shard_last_dim(state_dict, "transformer.embeddings.position_embeddings.weight") for i in range(config.num_hidden_layers): - shard_qkv_headdim(state_dict, f'transformer.layers.{i}.mixer.Wqkv.weight') - shard_qkv_headdim(state_dict, f'transformer.layers.{i}.mixer.Wqkv.bias') - shard_last_dim(state_dict, f'transformer.layers.{i}.mixer.out_proj.weight') + shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.weight") + shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.bias") + shard_last_dim(state_dict, f"transformer.layers.{i}.mixer.out_proj.weight") if rank != 0: - state_dict.pop(f'transformer.layers.{i}.mixer.out_proj.bias', None) + state_dict.pop(f"transformer.layers.{i}.mixer.out_proj.bias", None) if config.activation_function in ["glu", "swiglu", "geglu"]: - shard_gatedmlp_fc1_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.weight') - shard_gatedmlp_fc1_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.bias') + shard_gatedmlp_fc1_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.weight") + shard_gatedmlp_fc1_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.bias") else: - shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.weight') - shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.bias') - shard_last_dim(state_dict, f'transformer.layers.{i}.mlp.fc2.weight') + shard_first_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.weight") + shard_first_dim(state_dict, f"transformer.layers.{i}.mlp.fc1.bias") + shard_last_dim(state_dict, f"transformer.layers.{i}.mlp.fc2.weight") if rank != 0: - state_dict.pop(f'transformer.layers.{i}.mlp.fc2.bias', None) + state_dict.pop(f"transformer.layers.{i}.mlp.fc2.bias", None) return state_dict @@ -586,8 +772,8 @@ def combine_state_dicts_tp(state_dicts, config): """ world_size = len(state_dicts) keys = state_dicts[0].keys() - pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) - vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple) + pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) + vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple assert vocab_size % world_size == 0 assert config.hidden_size % world_size == 0 inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size @@ -605,90 +791,125 @@ def combine_state_dicts_tp(state_dicts, config): def combine_qkv_headdim(state_dicts, state_dict, key): n_head = config.n_head - n_head_kv = getattr(config, 'n_head_kv', n_head) + n_head_kv = getattr(config, "n_head_kv", n_head) assert n_head % world_size == 0 and n_head_kv % world_size == 0 n_head_per_rank = n_head // world_size n_head_kv_per_rank = n_head_kv // world_size if key in state_dict: if n_head_kv == n_head: - xs = [rearrange(s[key], '(three d) ... -> three d ...', three=3) for s in state_dicts] - state_dict[key] = rearrange(torch.cat(xs, dim=1), 'three d ... -> (three d) ...') + xs = [ + rearrange(s[key], "(three d) ... -> three d ...", three=3) for s in state_dicts + ] + state_dict[key] = rearrange(torch.cat(xs, dim=1), "three d ... -> (three d) ...") else: - xs = [rearrange(s[key], '(nheadqkv headdim) ... -> nheadqkv headdim ...', - nheadqkv=n_head + 2 * n_head_kv) for s in state_dicts] - state_dict[key] = rearrange(torch.cat([ - torch.cat([x[:n_head_per_rank] for x in xs], dim=0), - torch.cat([x[n_head_per_rank:n_head_per_rank + n_head_kv_per_rank] for x in xs], dim=0), - torch.cat([x[-n_head_kv_per_rank:] for x in xs], dim=0), - ], dim=0), "nheadqkv headdim ... -> (nheadqkv headdim) ...") + xs = [ + rearrange( + s[key], + "(nheadqkv headdim) ... -> nheadqkv headdim ...", + nheadqkv=n_head + 2 * n_head_kv, + ) + for s in state_dicts + ] + state_dict[key] = rearrange( + torch.cat( + [ + torch.cat([x[:n_head_per_rank] for x in xs], dim=0), + torch.cat( + [ + x[n_head_per_rank : n_head_per_rank + n_head_kv_per_rank] + for x in xs + ], + dim=0, + ), + torch.cat([x[-n_head_kv_per_rank:] for x in xs], dim=0), + ], + dim=0, + ), + "nheadqkv headdim ... -> (nheadqkv headdim) ...", + ) def combine_gated_mlp(state_dicts, state_dict, key): if key in state_dict: - xs = [rearrange(s[key], '(two d) ... -> two d ...', two=2) for s in state_dicts] - state_dict[key] = rearrange(torch.cat(xs, dim=1), 'two d ... -> (two d) ...') + xs = [rearrange(s[key], "(two d) ... -> two d ...", two=2) for s in state_dicts] + state_dict[key] = rearrange(torch.cat(xs, dim=1), "two d ... -> (two d) ...") state_dict = state_dicts[0].copy() # don't modify state_dict[0] inplace - combine_word_embeddings(state_dicts, state_dict, 'transformer.embeddings.word_embeddings.weight') - if 'lm_head.weight' in state_dict: - combine_word_embeddings(state_dicts, state_dict, 'lm_head.weight') - if 'transformer.embeddings.position_embeddings.weight' in state_dict: - combine_dim(state_dicts, state_dict, 'transformer.embeddings.position_embeddings.weight', -1) - mlp_combine_fn = (combine_gated_mlp if config.activation_function in ['glu', 'swiglu', 'geglu'] - else partial(combine_dim, dim=0)) + combine_word_embeddings( + state_dicts, state_dict, "transformer.embeddings.word_embeddings.weight" + ) + if "lm_head.weight" in state_dict: + combine_word_embeddings(state_dicts, state_dict, "lm_head.weight") + if "transformer.embeddings.position_embeddings.weight" in state_dict: + combine_dim( + state_dicts, state_dict, "transformer.embeddings.position_embeddings.weight", -1 + ) + mlp_combine_fn = ( + combine_gated_mlp + if config.activation_function in ["glu", "swiglu", "geglu"] + else partial(combine_dim, dim=0) + ) for i in range(config.num_hidden_layers): - combine_qkv_headdim(state_dicts, state_dict, f'transformer.layers.{i}.mixer.Wqkv.weight') - combine_qkv_headdim(state_dicts, state_dict, f'transformer.layers.{i}.mixer.Wqkv.bias') - combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mixer.out_proj.weight', -1) - mlp_combine_fn(state_dicts, state_dict, f'transformer.layers.{i}.mlp.fc1.weight') - combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mlp.fc1.bias', 0) - combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mlp.fc2.weight', -1) + combine_qkv_headdim(state_dicts, state_dict, f"transformer.layers.{i}.mixer.Wqkv.weight") + combine_qkv_headdim(state_dicts, state_dict, f"transformer.layers.{i}.mixer.Wqkv.bias") + combine_dim(state_dicts, state_dict, f"transformer.layers.{i}.mixer.out_proj.weight", -1) + mlp_combine_fn(state_dicts, state_dict, f"transformer.layers.{i}.mlp.fc1.weight") + combine_dim(state_dicts, state_dict, f"transformer.layers.{i}.mlp.fc1.bias", 0) + combine_dim(state_dicts, state_dict, f"transformer.layers.{i}.mlp.fc2.weight", -1) return state_dict def remap_state_dict_hf_gpt2(state_dict, config): # Word embedding and position embedding def key_mapping_pos_emb(key): - return re.sub(r'^wpe.', 'transformer.embeddings.position_embeddings.', key) + return re.sub(r"^wpe.", "transformer.embeddings.position_embeddings.", key) + state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items()) - word_embeddings = state_dict.pop('wte.weight') + word_embeddings = state_dict.pop("wte.weight") # It's possible that vocab_size is padded to be a multiple of 8, for example. - pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) - vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple) - state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad( + pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) + vocab_size = math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple + state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad( word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) ) - state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight'] + state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"] # LayerNorm def key_mapping_ln(key): - key = re.sub(r'^ln_f.(weight|bias)', r'transformer.ln_f.\1', key) - key = re.sub(r'^h.(\d+).ln_(1|2).(weight|bias)', r'transformer.layers.\1.norm\2.\3', key) + key = re.sub(r"^ln_f.(weight|bias)", r"transformer.ln_f.\1", key) + key = re.sub(r"^h.(\d+).ln_(1|2).(weight|bias)", r"transformer.layers.\1.norm\2.\3", key) return key + state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) # MLP for d in range(config.num_hidden_layers): - W1 = state_dict.pop(f'h.{d}.mlp.c_fc.weight') - state_dict[f'transformer.layers.{d}.mlp.fc1.weight'] = W1.t() - W2 = state_dict.pop(f'h.{d}.mlp.c_proj.weight') - state_dict[f'transformer.layers.{d}.mlp.fc2.weight'] = W2.t() + W1 = state_dict.pop(f"h.{d}.mlp.c_fc.weight") + state_dict[f"transformer.layers.{d}.mlp.fc1.weight"] = W1.t() + W2 = state_dict.pop(f"h.{d}.mlp.c_proj.weight") + state_dict[f"transformer.layers.{d}.mlp.fc2.weight"] = W2.t() + def key_mapping_mlp(key): - key = re.sub(r'^h.(\d+).mlp.c_fc.bias', r'transformer.layers.\1.mlp.fc1.bias', key) - key = re.sub(r'^h.(\d+).mlp.c_proj.bias', r'transformer.layers.\1.mlp.fc2.bias', key) + key = re.sub(r"^h.(\d+).mlp.c_fc.bias", r"transformer.layers.\1.mlp.fc1.bias", key) + key = re.sub(r"^h.(\d+).mlp.c_proj.bias", r"transformer.layers.\1.mlp.fc2.bias", key) return key + state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) # Attention for d in range(config.num_hidden_layers): - state_dict.pop(f'h.{d}.attn.bias') # We don't store this bias - Wqkv = state_dict.pop(f'h.{d}.attn.c_attn.weight') - state_dict[f'transformer.layers.{d}.mixer.Wqkv.weight'] = Wqkv.t() - Wout = state_dict.pop(f'h.{d}.attn.c_proj.weight') - state_dict[f'transformer.layers.{d}.mixer.out_proj.weight'] = Wout.t() + state_dict.pop(f"h.{d}.attn.bias") # We don't store this bias + Wqkv = state_dict.pop(f"h.{d}.attn.c_attn.weight") + state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = Wqkv.t() + Wout = state_dict.pop(f"h.{d}.attn.c_proj.weight") + state_dict[f"transformer.layers.{d}.mixer.out_proj.weight"] = Wout.t() + def key_mapping_attn(key): - key = re.sub(r'^h.(\d+).attn.c_attn.bias', r'transformer.layers.\1.mixer.Wqkv.bias', key) - key = re.sub(r'^h.(\d+).attn.c_proj.bias', r'transformer.layers.\1.mixer.out_proj.bias', key) + key = re.sub(r"^h.(\d+).attn.c_attn.bias", r"transformer.layers.\1.mixer.Wqkv.bias", key) + key = re.sub( + r"^h.(\d+).attn.c_proj.bias", r"transformer.layers.\1.mixer.out_proj.bias", key + ) return key + state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) return state_dict @@ -696,66 +917,94 @@ def remap_state_dict_hf_gpt2(state_dict, config): def remap_state_dict_megatron(state_dict, config): def key_mapping_transformer(key): - key = re.sub(r'^language_model.encoder.', 'transformer.', key) - key = re.sub(r'^language_model.', 'transformer.', key) + key = re.sub(r"^language_model.encoder.", "transformer.", key) + key = re.sub(r"^language_model.", "transformer.", key) return key + state_dict = OrderedDict((key_mapping_transformer(k), v) for k, v in state_dict.items()) # Word embedding and position embedding def key_mapping_pos_emb(key): - return re.sub(r'^wpe.', 'transformer.embeddings.position_embeddings.', key) + return re.sub(r"^wpe.", "transformer.embeddings.position_embeddings.", key) + state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items()) - word_embeddings = state_dict.pop('transformer.embedding.word_embeddings.weight') + word_embeddings = state_dict.pop("transformer.embedding.word_embeddings.weight") # It's possible that vocab_size is padded to be a multiple of 8, for example. - pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) - vocab_size = (math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) - * pad_vocab_size_multiple) - state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad( + pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1) + vocab_size = ( + math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) * pad_vocab_size_multiple + ) + state_dict["transformer.embeddings.word_embeddings.weight"] = F.pad( word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) ) - state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight'] + state_dict["lm_head.weight"] = state_dict["transformer.embeddings.word_embeddings.weight"] # LayerNorm def key_mapping_ln(key): - key = re.sub(r'^transformer.final_layernorm.(weight|bias)', r'transformer.ln_f.\1', key) - key = re.sub(r'^transformer.layers.(\d+).input_layernorm.(weight|bias)', - r'transformer.layers.\1.norm1.\2', key) - key = re.sub(r'^transformer.layers.(\d+).post_attention_layernorm.(weight|bias)', - r'transformer.layers.\1.norm2.\2', key) + key = re.sub(r"^transformer.final_layernorm.(weight|bias)", r"transformer.ln_f.\1", key) + key = re.sub( + r"^transformer.layers.(\d+).input_layernorm.(weight|bias)", + r"transformer.layers.\1.norm1.\2", + key, + ) + key = re.sub( + r"^transformer.layers.(\d+).post_attention_layernorm.(weight|bias)", + r"transformer.layers.\1.norm2.\2", + key, + ) return key + state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) # MLP def key_mapping_mlp(key): - key = re.sub(r'^transformer.layers.(\d+).mlp.dense_h_to_4h.(weight|bias)', - r'transformer.layers.\1.mlp.fc1.\2', key) - key = re.sub(r'^transformer.layers.(\d+).mlp.dense_4h_to_h.(weight|bias)', - r'transformer.layers.\1.mlp.fc2.\2', key) + key = re.sub( + r"^transformer.layers.(\d+).mlp.dense_h_to_4h.(weight|bias)", + r"transformer.layers.\1.mlp.fc1.\2", + key, + ) + key = re.sub( + r"^transformer.layers.(\d+).mlp.dense_4h_to_h.(weight|bias)", + r"transformer.layers.\1.mlp.fc2.\2", + key, + ) return key + state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) # Attention def key_mapping_attn(key): - key = re.sub(r'^transformer.layers.(\d+).self_attention.rotary_emb.inv_freq', - r'transformer.layers.\1.mixer.rotary_emb.inv_freq', key) - key = re.sub(r'^transformer.layers.(\d+).self_attention.query_key_value.(weight|bias)', - r'transformer.layers.\1.mixer.Wqkv.\2', key) - key = re.sub(r'^transformer.layers.(\d+).self_attention.dense.(weight|bias)', - r'transformer.layers.\1.mixer.out_proj.\2', key) + key = re.sub( + r"^transformer.layers.(\d+).self_attention.rotary_emb.inv_freq", + r"transformer.layers.\1.mixer.rotary_emb.inv_freq", + key, + ) + key = re.sub( + r"^transformer.layers.(\d+).self_attention.query_key_value.(weight|bias)", + r"transformer.layers.\1.mixer.Wqkv.\2", + key, + ) + key = re.sub( + r"^transformer.layers.(\d+).self_attention.dense.(weight|bias)", + r"transformer.layers.\1.mixer.out_proj.\2", + key, + ) return key + state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) # Megatron stores Wqkv as ((nheads 3 headdim), hidden_dim) # while we store Wqkv as ((3 nheads headdim), hidden_dim) headdim = config.hidden_size // config.num_attention_heads for d in range(config.num_hidden_layers): - Wqkv = state_dict.pop(f'transformer.layers.{d}.mixer.Wqkv.weight') - state_dict[f'transformer.layers.{d}.mixer.Wqkv.weight'] = rearrange( - Wqkv, '(nheads three headdim) ... -> (three nheads headdim) ...', - three=3, headdim=headdim + Wqkv = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.weight") + state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = rearrange( + Wqkv, + "(nheads three headdim) ... -> (three nheads headdim) ...", + three=3, + headdim=headdim, ) - bqkv = state_dict.pop(f'transformer.layers.{d}.mixer.Wqkv.bias') - state_dict[f'transformer.layers.{d}.mixer.Wqkv.bias'] = rearrange( - bqkv, '(nheads three headdim) -> (three nheads headdim)', - three=3, headdim=headdim + bqkv = state_dict.pop(f"transformer.layers.{d}.mixer.Wqkv.bias") + state_dict[f"transformer.layers.{d}.mixer.Wqkv.bias"] = rearrange( + bqkv, "(nheads three headdim) -> (three nheads headdim)", three=3, headdim=headdim ) return state_dict