[Bert] Fix embedding layer norm before embedding dropout
This commit is contained in:
parent
ef1ba918c6
commit
714c1b4f0f
@ -295,7 +295,7 @@ class BertModel(BertPreTrainedModel):
|
|||||||
config.vocab_size += (self.pad_vocab_size_multiple
|
config.vocab_size += (self.pad_vocab_size_multiple
|
||||||
- (config.vocab_size % self.pad_vocab_size_multiple))
|
- (config.vocab_size % self.pad_vocab_size_multiple))
|
||||||
self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
|
self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
|
||||||
if self.fused_dropout_add_ln and dropout_add_layer_norm is None:
|
if self.fused_dropout_add_ln and layer_norm is None:
|
||||||
raise ImportError('dropout_add_layer_norm is not installed')
|
raise ImportError('dropout_add_layer_norm is not installed')
|
||||||
assert config.position_embedding_type == 'absolute'
|
assert config.position_embedding_type == 'absolute'
|
||||||
assert config.hidden_act in ['gelu', 'gelu_new', 'gelu_fast']
|
assert config.hidden_act in ['gelu', 'gelu_new', 'gelu_fast']
|
||||||
@ -320,14 +320,13 @@ class BertModel(BertPreTrainedModel):
|
|||||||
hidden_states = self.embeddings(input_ids, position_ids=position_ids,
|
hidden_states = self.embeddings(input_ids, position_ids=position_ids,
|
||||||
token_type_ids=token_type_ids)
|
token_type_ids=token_type_ids)
|
||||||
# TD [2022-12:18]: Don't need to force residual in fp32
|
# TD [2022-12:18]: Don't need to force residual in fp32
|
||||||
|
# BERT puts embedding LayerNorm before embedding dropout.
|
||||||
if not self.fused_dropout_add_ln:
|
if not self.fused_dropout_add_ln:
|
||||||
hidden_states = self.emb_drop(hidden_states)
|
|
||||||
hidden_states = self.emb_ln(hidden_states)
|
hidden_states = self.emb_ln(hidden_states)
|
||||||
else:
|
else:
|
||||||
hidden_states = dropout_add_layer_norm(
|
hidden_states = layer_norm(hidden_states, self.emb_ln.weight, self.emb_ln.bias,
|
||||||
hidden_states, None, self.emb_ln.weight, self.emb_ln.bias,
|
self.emb_ln.eps)
|
||||||
self.emb_drop.p if self.training else 0.0, self.emb_ln.eps, prenorm=False,
|
hidden_states = self.emb_drop(hidden_states)
|
||||||
)
|
|
||||||
|
|
||||||
if masked_tokens_mask is not None:
|
if masked_tokens_mask is not None:
|
||||||
batch_size, seqlen = input_ids.shape[:2]
|
batch_size, seqlen = input_ids.shape[:2]
|
||||||
|
|||||||
@ -220,6 +220,9 @@ class GPTModel(GPTPreTrainedModel):
|
|||||||
|
|
||||||
self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
|
self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
|
||||||
initializer_range=config.initializer_range))
|
initializer_range=config.initializer_range))
|
||||||
|
self.tie_weights()
|
||||||
|
|
||||||
|
def tie_weights(self):
|
||||||
if self.process_group is not None:
|
if self.process_group is not None:
|
||||||
sync_sequence_parallel_params(self, self.process_group)
|
sync_sequence_parallel_params(self, self.process_group)
|
||||||
|
|
||||||
@ -266,11 +269,11 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
|
|||||||
self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
|
self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
|
||||||
initializer_range=config.initializer_range))
|
initializer_range=config.initializer_range))
|
||||||
self.tie_weights()
|
self.tie_weights()
|
||||||
if self.process_group is not None:
|
|
||||||
sync_sequence_parallel_params(self, self.process_group)
|
|
||||||
|
|
||||||
def tie_weights(self):
|
def tie_weights(self):
|
||||||
self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight
|
self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight
|
||||||
|
if self.process_group is not None:
|
||||||
|
sync_sequence_parallel_params(self, self.process_group)
|
||||||
|
|
||||||
def forward(self, input_ids, position_ids=None, inference_params=None):
|
def forward(self, input_ids, position_ids=None, inference_params=None):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user