diff --git a/flash_attn/models/bert.py b/flash_attn/models/bert.py index e6632c5..710c7e8 100644 --- a/flash_attn/models/bert.py +++ b/flash_attn/models/bert.py @@ -52,10 +52,16 @@ logger = logging.getLogger(__name__) def create_mixer_cls(config, cross_attn=False, return_residual=False): use_flash_attn = getattr(config, 'use_flash_attn', False) fused_bias_fc = getattr(config, 'fused_bias_fc', False) + rotary_kwargs = {} + if config.position_embedding_type == "rotary": + rotary_kwargs["rotary_emb_dim"] = getattr(config, "rotary_emb_dim", config.hidden_size) + rotary_kwargs["rotary_emb_base"] = getattr(config, "rotary_emb_base", 10000.0) + rotary_kwargs["rotary_emb_scale_base"] = getattr(config, "rotary_emb_scale_base", None) + rotary_kwargs["rotary_emb_interleaved"] = getattr(config, "rotary_emb_interleaved", False) mixer_cls = partial(MHA, num_heads=config.num_attention_heads, cross_attn=cross_attn, dropout=config.attention_probs_dropout_prob, causal=False, fused_bias_fc=fused_bias_fc, use_flash_attn=use_flash_attn, - return_residual=return_residual) + return_residual=return_residual, **rotary_kwargs) return mixer_cls @@ -298,7 +304,6 @@ class BertModel(BertPreTrainedModel): self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False) if self.fused_dropout_add_ln and layer_norm is None: raise ImportError('dropout_add_layer_norm is not installed') - assert config.position_embedding_type == 'absolute' assert config.hidden_act in ['gelu', 'gelu_new', 'gelu_fast'] self.embeddings = BertEmbeddings(config.hidden_size, config.vocab_size,