Add rope_scaling to Aquila model (#1457)
This commit is contained in:
parent
1f24755bf8
commit
28b47d1e49
@ -25,7 +25,7 @@
|
||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||
InputMetadata to extract the original 2D shape of the input.
|
||||
"""
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -110,6 +110,7 @@ class AquilaAttention(nn.Module):
|
||||
num_kv_heads: int,
|
||||
rope_theta: float = 10000,
|
||||
max_position_embeddings: int = 8192,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -148,6 +149,7 @@ class AquilaAttention(nn.Module):
|
||||
base=self.rope_theta,
|
||||
max_position=self.max_position_embeddings,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
rope_scaling=rope_scaling,
|
||||
)
|
||||
|
||||
def forward(
|
||||
@ -173,6 +175,7 @@ class AquilaDecoderLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
self.self_attn = AquilaAttention(
|
||||
@ -181,6 +184,7 @@ class AquilaDecoderLayer(nn.Module):
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
rope_scaling=rope_scaling,
|
||||
)
|
||||
self.mlp = AquilaMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user