Implement muParam
This commit is contained in:
parent
3f7d5786ba
commit
2e29dacf0c
@ -77,7 +77,9 @@ logger = logging.getLogger(__name__)
|
||||
def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
softmax_scale = 1.0 if not config.scale_attn_weights else head_dim ** (-0.5)
|
||||
attn_scale_power = 0.5 if not getattr(config, "mup_scale_qk_dot_by_d", False) else 1.0
|
||||
softmax_scale = 1.0 if not config.scale_attn_weights else (head_dim ** (-attn_scale_power))
|
||||
softmax_scale *= getattr(config, "mup_attn_multiplier", 1.0)
|
||||
if config.scale_attn_by_inverse_layer_idx:
|
||||
assert layer_idx is not None
|
||||
softmax_scale /= float(layer_idx + 1)
|
||||
@ -179,12 +181,14 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
|
||||
if process_group is not None
|
||||
else {}
|
||||
)
|
||||
mlp_multiple_of = getattr(config, "mlp_multiple_of", 128)
|
||||
mlp_cls = partial(
|
||||
mlp_cls,
|
||||
hidden_features=config.n_inner,
|
||||
activation=activation,
|
||||
bias1=mlp_fc1_bias,
|
||||
bias2=mlp_fc2_bias,
|
||||
multiple_of=mlp_multiple_of,
|
||||
**parallel_kwargs,
|
||||
**factory_kwargs,
|
||||
)
|
||||
@ -386,9 +390,13 @@ class GPTPreTrainedModel(nn.Module):
|
||||
|
||||
|
||||
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
|
||||
def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_residual=True):
|
||||
def _init_weights(
|
||||
module, n_layer, initializer_range=0.02, mup_width_scale=1.0, rescale_prenorm_residual=True
|
||||
):
|
||||
mup_init_scale = math.sqrt(mup_width_scale)
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.normal_(module.weight, std=initializer_range)
|
||||
nn.init.normal_(module.weight, std=initializer_range * mup_init_scale)
|
||||
module.weight._optim = {"lr_multiplier": mup_width_scale}
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
@ -404,7 +412,9 @@ def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_resid
|
||||
for name, p in module.named_parameters():
|
||||
if name in ["out_proj.weight", "fc2.weight"]:
|
||||
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
||||
nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer))
|
||||
nn.init.normal_(
|
||||
p, mean=0.0, std=initializer_range * mup_init_scale / math.sqrt(2 * n_layer)
|
||||
)
|
||||
|
||||
|
||||
class GPTModel(GPTPreTrainedModel):
|
||||
@ -429,6 +439,7 @@ class GPTModel(GPTPreTrainedModel):
|
||||
vocab_size = (
|
||||
math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
||||
)
|
||||
self.embeddings_multiplier = getattr(config, "mup_embeddings_multiplier", 1.0)
|
||||
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
|
||||
self.residual_in_fp32 = getattr(config, "residual_in_fp32", False)
|
||||
# These 2 options are for OPT-350m
|
||||
@ -494,6 +505,7 @@ class GPTModel(GPTPreTrainedModel):
|
||||
_init_weights,
|
||||
n_layer=config.num_hidden_layers,
|
||||
initializer_range=config.initializer_range,
|
||||
mup_width_scale=getattr(config, "mup_width_scale", 1.0),
|
||||
)
|
||||
)
|
||||
self.tie_weights()
|
||||
@ -518,6 +530,8 @@ class GPTModel(GPTPreTrainedModel):
|
||||
else {}
|
||||
)
|
||||
hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs)
|
||||
if self.embeddings_multiplier != 1.0:
|
||||
hidden_states = hidden_states * self.embeddings_multiplier
|
||||
if self.parallel_block:
|
||||
hidden_states2 = None
|
||||
residual = None
|
||||
@ -612,6 +626,9 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
|
||||
self.project_out = nn.Linear(config.n_embd, embed_dim, bias=False, **factory_kwargs)
|
||||
else:
|
||||
self.project_out = None
|
||||
mup_width_scale = getattr(config, "mup_width_scale", 1.0)
|
||||
mup_output_multiplier = getattr(config, "mup_output_multiplier", 1.0)
|
||||
self.output_scale = mup_output_multiplier * mup_width_scale
|
||||
if process_group is None:
|
||||
self.lm_head = nn.Linear(embed_dim, vocab_size, bias=lm_head_bias, **factory_kwargs)
|
||||
else:
|
||||
@ -632,6 +649,7 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
|
||||
_init_weights,
|
||||
n_layer=config.num_hidden_layers,
|
||||
initializer_range=config.initializer_range,
|
||||
mup_width_scale=mup_width_scale,
|
||||
)
|
||||
)
|
||||
self.tie_weights()
|
||||
@ -667,6 +685,8 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
|
||||
hidden_states = hidden_states[:, -num_last_tokens:]
|
||||
if self.project_out is not None:
|
||||
hidden_states = self.project_out(hidden_states)
|
||||
if self.output_scale != 1.0:
|
||||
hidden_states = hidden_states * self.output_scale
|
||||
if not self.norm_head:
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
else:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user