[Bert] Fix embedding layer norm before embedding dropout
This commit is contained in:
parent
ef1ba918c6
commit
714c1b4f0f
@ -295,7 +295,7 @@ class BertModel(BertPreTrainedModel):
|
||||
config.vocab_size += (self.pad_vocab_size_multiple
|
||||
- (config.vocab_size % self.pad_vocab_size_multiple))
|
||||
self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
|
||||
if self.fused_dropout_add_ln and dropout_add_layer_norm is None:
|
||||
if self.fused_dropout_add_ln and layer_norm is None:
|
||||
raise ImportError('dropout_add_layer_norm is not installed')
|
||||
assert config.position_embedding_type == 'absolute'
|
||||
assert config.hidden_act in ['gelu', 'gelu_new', 'gelu_fast']
|
||||
@ -320,14 +320,13 @@ class BertModel(BertPreTrainedModel):
|
||||
hidden_states = self.embeddings(input_ids, position_ids=position_ids,
|
||||
token_type_ids=token_type_ids)
|
||||
# TD [2022-12:18]: Don't need to force residual in fp32
|
||||
# BERT puts embedding LayerNorm before embedding dropout.
|
||||
if not self.fused_dropout_add_ln:
|
||||
hidden_states = self.emb_drop(hidden_states)
|
||||
hidden_states = self.emb_ln(hidden_states)
|
||||
else:
|
||||
hidden_states = dropout_add_layer_norm(
|
||||
hidden_states, None, self.emb_ln.weight, self.emb_ln.bias,
|
||||
self.emb_drop.p if self.training else 0.0, self.emb_ln.eps, prenorm=False,
|
||||
)
|
||||
hidden_states = layer_norm(hidden_states, self.emb_ln.weight, self.emb_ln.bias,
|
||||
self.emb_ln.eps)
|
||||
hidden_states = self.emb_drop(hidden_states)
|
||||
|
||||
if masked_tokens_mask is not None:
|
||||
batch_size, seqlen = input_ids.shape[:2]
|
||||
|
||||
@ -220,6 +220,9 @@ class GPTModel(GPTPreTrainedModel):
|
||||
|
||||
self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
|
||||
initializer_range=config.initializer_range))
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
if self.process_group is not None:
|
||||
sync_sequence_parallel_params(self, self.process_group)
|
||||
|
||||
@ -266,11 +269,11 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
|
||||
self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
|
||||
initializer_range=config.initializer_range))
|
||||
self.tie_weights()
|
||||
if self.process_group is not None:
|
||||
sync_sequence_parallel_params(self, self.process_group)
|
||||
|
||||
def tie_weights(self):
|
||||
self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight
|
||||
if self.process_group is not None:
|
||||
sync_sequence_parallel_params(self, self.process_group)
|
||||
|
||||
def forward(self, input_ids, position_ids=None, inference_params=None):
|
||||
"""
|
||||
|
||||
Loading…
Reference in New Issue
Block a user