Add rope_scaling to Aquila model (#1457)

This commit is contained in:
Qing 2023-10-29 19:25:21 +08:00 committed by GitHub
parent 1f24755bf8
commit 28b47d1e49
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -25,7 +25,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from typing import List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
import torch
from torch import nn
@ -110,6 +110,7 @@ class AquilaAttention(nn.Module):
num_kv_heads: int,
rope_theta: float = 10000,
max_position_embeddings: int = 8192,
rope_scaling: Optional[Dict[str, Any]] = None,
):
super().__init__()
self.hidden_size = hidden_size
@ -148,6 +149,7 @@ class AquilaAttention(nn.Module):
base=self.rope_theta,
max_position=self.max_position_embeddings,
num_kv_heads=self.num_kv_heads,
rope_scaling=rope_scaling,
)
def forward(
@ -173,6 +175,7 @@ class AquilaDecoderLayer(nn.Module):
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = AquilaAttention(
@ -181,6 +184,7 @@ class AquilaDecoderLayer(nn.Module):
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
rope_scaling=rope_scaling,
)
self.mlp = AquilaMLP(
hidden_size=self.hidden_size,