Allow rotary embeddings for Bert (#363)
This commit is contained in:
parent
cbf982afa5
commit
684196b8c5
@ -52,10 +52,16 @@ logger = logging.getLogger(__name__)
|
|||||||
def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
||||||
use_flash_attn = getattr(config, 'use_flash_attn', False)
|
use_flash_attn = getattr(config, 'use_flash_attn', False)
|
||||||
fused_bias_fc = getattr(config, 'fused_bias_fc', False)
|
fused_bias_fc = getattr(config, 'fused_bias_fc', False)
|
||||||
|
rotary_kwargs = {}
|
||||||
|
if config.position_embedding_type == "rotary":
|
||||||
|
rotary_kwargs["rotary_emb_dim"] = getattr(config, "rotary_emb_dim", config.hidden_size)
|
||||||
|
rotary_kwargs["rotary_emb_base"] = getattr(config, "rotary_emb_base", 10000.0)
|
||||||
|
rotary_kwargs["rotary_emb_scale_base"] = getattr(config, "rotary_emb_scale_base", None)
|
||||||
|
rotary_kwargs["rotary_emb_interleaved"] = getattr(config, "rotary_emb_interleaved", False)
|
||||||
mixer_cls = partial(MHA, num_heads=config.num_attention_heads, cross_attn=cross_attn,
|
mixer_cls = partial(MHA, num_heads=config.num_attention_heads, cross_attn=cross_attn,
|
||||||
dropout=config.attention_probs_dropout_prob, causal=False,
|
dropout=config.attention_probs_dropout_prob, causal=False,
|
||||||
fused_bias_fc=fused_bias_fc, use_flash_attn=use_flash_attn,
|
fused_bias_fc=fused_bias_fc, use_flash_attn=use_flash_attn,
|
||||||
return_residual=return_residual)
|
return_residual=return_residual, **rotary_kwargs)
|
||||||
return mixer_cls
|
return mixer_cls
|
||||||
|
|
||||||
|
|
||||||
@ -298,7 +304,6 @@ class BertModel(BertPreTrainedModel):
|
|||||||
self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
|
self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
|
||||||
if self.fused_dropout_add_ln and layer_norm is None:
|
if self.fused_dropout_add_ln and layer_norm is None:
|
||||||
raise ImportError('dropout_add_layer_norm is not installed')
|
raise ImportError('dropout_add_layer_norm is not installed')
|
||||||
assert config.position_embedding_type == 'absolute'
|
|
||||||
assert config.hidden_act in ['gelu', 'gelu_new', 'gelu_fast']
|
assert config.hidden_act in ['gelu', 'gelu_new', 'gelu_fast']
|
||||||
|
|
||||||
self.embeddings = BertEmbeddings(config.hidden_size, config.vocab_size,
|
self.embeddings = BertEmbeddings(config.hidden_size, config.vocab_size,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user