[Bert] Fix embedding layer norm before embedding dropout

This commit is contained in:
Tri Dao 2023-01-01 10:37:00 -08:00
parent ef1ba918c6
commit 714c1b4f0f
2 changed files with 10 additions and 8 deletions

View File

@ -295,7 +295,7 @@ class BertModel(BertPreTrainedModel):
config.vocab_size += (self.pad_vocab_size_multiple
- (config.vocab_size % self.pad_vocab_size_multiple))
self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
if self.fused_dropout_add_ln and dropout_add_layer_norm is None:
if self.fused_dropout_add_ln and layer_norm is None:
raise ImportError('dropout_add_layer_norm is not installed')
assert config.position_embedding_type == 'absolute'
assert config.hidden_act in ['gelu', 'gelu_new', 'gelu_fast']
@ -320,14 +320,13 @@ class BertModel(BertPreTrainedModel):
hidden_states = self.embeddings(input_ids, position_ids=position_ids,
token_type_ids=token_type_ids)
# TD [2022-12:18]: Don't need to force residual in fp32
# BERT puts embedding LayerNorm before embedding dropout.
if not self.fused_dropout_add_ln:
hidden_states = self.emb_drop(hidden_states)
hidden_states = self.emb_ln(hidden_states)
else:
hidden_states = dropout_add_layer_norm(
hidden_states, None, self.emb_ln.weight, self.emb_ln.bias,
self.emb_drop.p if self.training else 0.0, self.emb_ln.eps, prenorm=False,
)
hidden_states = layer_norm(hidden_states, self.emb_ln.weight, self.emb_ln.bias,
self.emb_ln.eps)
hidden_states = self.emb_drop(hidden_states)
if masked_tokens_mask is not None:
batch_size, seqlen = input_ids.shape[:2]

View File

@ -220,6 +220,9 @@ class GPTModel(GPTPreTrainedModel):
self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
initializer_range=config.initializer_range))
self.tie_weights()
def tie_weights(self):
if self.process_group is not None:
sync_sequence_parallel_params(self, self.process_group)
@ -266,11 +269,11 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
initializer_range=config.initializer_range))
self.tie_weights()
if self.process_group is not None:
sync_sequence_parallel_params(self, self.process_group)
def tie_weights(self):
self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight
if self.process_group is not None:
sync_sequence_parallel_params(self, self.process_group)
def forward(self, input_ids, position_ids=None, inference_params=None):
"""