From 2e29dacf0c54e7650831ddd4aea806d0de1ef797 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 24 Dec 2023 20:34:48 -0800 Subject: [PATCH] Implement muParam --- flash_attn/models/gpt.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index 04c7674..6939344 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -77,7 +77,9 @@ logger = logging.getLogger(__name__) def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dtype=None): factory_kwargs = {"device": device, "dtype": dtype} head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - softmax_scale = 1.0 if not config.scale_attn_weights else head_dim ** (-0.5) + attn_scale_power = 0.5 if not getattr(config, "mup_scale_qk_dot_by_d", False) else 1.0 + softmax_scale = 1.0 if not config.scale_attn_weights else (head_dim ** (-attn_scale_power)) + softmax_scale *= getattr(config, "mup_attn_multiplier", 1.0) if config.scale_attn_by_inverse_layer_idx: assert layer_idx is not None softmax_scale /= float(layer_idx + 1) @@ -179,12 +181,14 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp if process_group is not None else {} ) + mlp_multiple_of = getattr(config, "mlp_multiple_of", 128) mlp_cls = partial( mlp_cls, hidden_features=config.n_inner, activation=activation, bias1=mlp_fc1_bias, bias2=mlp_fc2_bias, + multiple_of=mlp_multiple_of, **parallel_kwargs, **factory_kwargs, ) @@ -386,9 +390,13 @@ class GPTPreTrainedModel(nn.Module): # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454 -def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_residual=True): +def _init_weights( + module, n_layer, initializer_range=0.02, mup_width_scale=1.0, rescale_prenorm_residual=True +): + mup_init_scale = math.sqrt(mup_width_scale) if isinstance(module, nn.Linear): - nn.init.normal_(module.weight, std=initializer_range) + nn.init.normal_(module.weight, std=initializer_range * mup_init_scale) + module.weight._optim = {"lr_multiplier": mup_width_scale} if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): @@ -404,7 +412,9 @@ def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_resid for name, p in module.named_parameters(): if name in ["out_proj.weight", "fc2.weight"]: # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block - nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer)) + nn.init.normal_( + p, mean=0.0, std=initializer_range * mup_init_scale / math.sqrt(2 * n_layer) + ) class GPTModel(GPTPreTrainedModel): @@ -429,6 +439,7 @@ class GPTModel(GPTPreTrainedModel): vocab_size = ( math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple ) + self.embeddings_multiplier = getattr(config, "mup_embeddings_multiplier", 1.0) # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable self.residual_in_fp32 = getattr(config, "residual_in_fp32", False) # These 2 options are for OPT-350m @@ -494,6 +505,7 @@ class GPTModel(GPTPreTrainedModel): _init_weights, n_layer=config.num_hidden_layers, initializer_range=config.initializer_range, + mup_width_scale=getattr(config, "mup_width_scale", 1.0), ) ) self.tie_weights() @@ -518,6 +530,8 @@ class GPTModel(GPTPreTrainedModel): else {} ) hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs) + if self.embeddings_multiplier != 1.0: + hidden_states = hidden_states * self.embeddings_multiplier if self.parallel_block: hidden_states2 = None residual = None @@ -612,6 +626,9 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): self.project_out = nn.Linear(config.n_embd, embed_dim, bias=False, **factory_kwargs) else: self.project_out = None + mup_width_scale = getattr(config, "mup_width_scale", 1.0) + mup_output_multiplier = getattr(config, "mup_output_multiplier", 1.0) + self.output_scale = mup_output_multiplier * mup_width_scale if process_group is None: self.lm_head = nn.Linear(embed_dim, vocab_size, bias=lm_head_bias, **factory_kwargs) else: @@ -632,6 +649,7 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): _init_weights, n_layer=config.num_hidden_layers, initializer_range=config.initializer_range, + mup_width_scale=mup_width_scale, ) ) self.tie_weights() @@ -667,6 +685,8 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): hidden_states = hidden_states[:, -num_last_tokens:] if self.project_out is not None: hidden_states = self.project_out(hidden_states) + if self.output_scale != 1.0: + hidden_states = hidden_states * self.output_scale if not self.norm_head: lm_logits = self.lm_head(hidden_states) else: