Allow rotary embeddings for Bert (#363)

This commit is contained in:
Kiarash Jamali 2023-07-23 08:21:45 +01:00 committed by GitHub
parent cbf982afa5
commit 684196b8c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -52,10 +52,16 @@ logger = logging.getLogger(__name__)
def create_mixer_cls(config, cross_attn=False, return_residual=False):
use_flash_attn = getattr(config, 'use_flash_attn', False)
fused_bias_fc = getattr(config, 'fused_bias_fc', False)
rotary_kwargs = {}
if config.position_embedding_type == "rotary":
rotary_kwargs["rotary_emb_dim"] = getattr(config, "rotary_emb_dim", config.hidden_size)
rotary_kwargs["rotary_emb_base"] = getattr(config, "rotary_emb_base", 10000.0)
rotary_kwargs["rotary_emb_scale_base"] = getattr(config, "rotary_emb_scale_base", None)
rotary_kwargs["rotary_emb_interleaved"] = getattr(config, "rotary_emb_interleaved", False)
mixer_cls = partial(MHA, num_heads=config.num_attention_heads, cross_attn=cross_attn,
dropout=config.attention_probs_dropout_prob, causal=False,
fused_bias_fc=fused_bias_fc, use_flash_attn=use_flash_attn,
return_residual=return_residual)
return_residual=return_residual, **rotary_kwargs)
return mixer_cls
@ -298,7 +304,6 @@ class BertModel(BertPreTrainedModel):
self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
if self.fused_dropout_add_ln and layer_norm is None:
raise ImportError('dropout_add_layer_norm is not installed')
assert config.position_embedding_type == 'absolute'
assert config.hidden_act in ['gelu', 'gelu_new', 'gelu_fast']
self.embeddings = BertEmbeddings(config.hidden_size, config.vocab_size,