diff --git a/flash_attn/models/bert.py b/flash_attn/models/bert.py index 1c013d0..88ac4d9 100644 --- a/flash_attn/models/bert.py +++ b/flash_attn/models/bert.py @@ -295,7 +295,7 @@ class BertModel(BertPreTrainedModel): config.vocab_size += (self.pad_vocab_size_multiple - (config.vocab_size % self.pad_vocab_size_multiple)) self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False) - if self.fused_dropout_add_ln and dropout_add_layer_norm is None: + if self.fused_dropout_add_ln and layer_norm is None: raise ImportError('dropout_add_layer_norm is not installed') assert config.position_embedding_type == 'absolute' assert config.hidden_act in ['gelu', 'gelu_new', 'gelu_fast'] @@ -320,14 +320,13 @@ class BertModel(BertPreTrainedModel): hidden_states = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids) # TD [2022-12:18]: Don't need to force residual in fp32 + # BERT puts embedding LayerNorm before embedding dropout. if not self.fused_dropout_add_ln: - hidden_states = self.emb_drop(hidden_states) hidden_states = self.emb_ln(hidden_states) else: - hidden_states = dropout_add_layer_norm( - hidden_states, None, self.emb_ln.weight, self.emb_ln.bias, - self.emb_drop.p if self.training else 0.0, self.emb_ln.eps, prenorm=False, - ) + hidden_states = layer_norm(hidden_states, self.emb_ln.weight, self.emb_ln.bias, + self.emb_ln.eps) + hidden_states = self.emb_drop(hidden_states) if masked_tokens_mask is not None: batch_size, seqlen = input_ids.shape[:2] diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index e746fc7..417fac9 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -220,6 +220,9 @@ class GPTModel(GPTPreTrainedModel): self.apply(partial(_init_weights, n_layer=config.num_hidden_layers, initializer_range=config.initializer_range)) + self.tie_weights() + + def tie_weights(self): if self.process_group is not None: sync_sequence_parallel_params(self, self.process_group) @@ -266,11 +269,11 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): self.apply(partial(_init_weights, n_layer=config.num_hidden_layers, initializer_range=config.initializer_range)) self.tie_weights() - if self.process_group is not None: - sync_sequence_parallel_params(self, self.process_group) def tie_weights(self): self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight + if self.process_group is not None: + sync_sequence_parallel_params(self, self.process_group) def forward(self, input_ids, position_ids=None, inference_params=None): """