From 684196b8c55d0cf04d1f657dc3d96e8982f7747b Mon Sep 17 00:00:00 2001 From: Kiarash Jamali Date: Sun, 23 Jul 2023 08:21:45 +0100 Subject: [PATCH] Allow rotary embeddings for Bert (#363) --- flash_attn/models/bert.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/flash_attn/models/bert.py b/flash_attn/models/bert.py index e6632c5..710c7e8 100644 --- a/flash_attn/models/bert.py +++ b/flash_attn/models/bert.py @@ -52,10 +52,16 @@ logger = logging.getLogger(__name__) def create_mixer_cls(config, cross_attn=False, return_residual=False): use_flash_attn = getattr(config, 'use_flash_attn', False) fused_bias_fc = getattr(config, 'fused_bias_fc', False) + rotary_kwargs = {} + if config.position_embedding_type == "rotary": + rotary_kwargs["rotary_emb_dim"] = getattr(config, "rotary_emb_dim", config.hidden_size) + rotary_kwargs["rotary_emb_base"] = getattr(config, "rotary_emb_base", 10000.0) + rotary_kwargs["rotary_emb_scale_base"] = getattr(config, "rotary_emb_scale_base", None) + rotary_kwargs["rotary_emb_interleaved"] = getattr(config, "rotary_emb_interleaved", False) mixer_cls = partial(MHA, num_heads=config.num_attention_heads, cross_attn=cross_attn, dropout=config.attention_probs_dropout_prob, causal=False, fused_bias_fc=fused_bias_fc, use_flash_attn=use_flash_attn, - return_residual=return_residual) + return_residual=return_residual, **rotary_kwargs) return mixer_cls @@ -298,7 +304,6 @@ class BertModel(BertPreTrainedModel): self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False) if self.fused_dropout_add_ln and layer_norm is None: raise ImportError('dropout_add_layer_norm is not installed') - assert config.position_embedding_type == 'absolute' assert config.hidden_act in ['gelu', 'gelu_new', 'gelu_fast'] self.embeddings = BertEmbeddings(config.hidden_size, config.vocab_size,