Add rope_scaling to Qwen (#1210)
This commit is contained in:
parent
20f7cc4cde
commit
7bedab5748
@ -8,7 +8,7 @@
|
|||||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||||
InputMetadata to extract the original 2D shape of the input.
|
InputMetadata to extract the original 2D shape of the input.
|
||||||
"""
|
"""
|
||||||
from typing import List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -76,13 +76,12 @@ class QWenMLP(nn.Module):
|
|||||||
|
|
||||||
class QWenAttention(nn.Module):
|
class QWenAttention(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(self,
|
||||||
self,
|
hidden_size: int,
|
||||||
hidden_size: int,
|
num_heads: int,
|
||||||
num_heads: int,
|
max_position_embeddings: int,
|
||||||
max_position_embeddings: int,
|
rope_theta: float = 10000,
|
||||||
rope_theta: float = 10000,
|
rope_scaling: Optional[Dict[str, Any]] = None):
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
|
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
|
||||||
@ -116,7 +115,7 @@ class QWenAttention(nn.Module):
|
|||||||
rotary_dim=self.head_dim,
|
rotary_dim=self.head_dim,
|
||||||
base=rope_theta,
|
base=rope_theta,
|
||||||
max_position=max_position_embeddings,
|
max_position=max_position_embeddings,
|
||||||
)
|
rope_scaling=rope_scaling)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -144,10 +143,12 @@ class QWenBlock(nn.Module):
|
|||||||
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
rope_theta = getattr(config, "rope_theta", 10000)
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
|
rope_scaling = getattr(config, "rope_scaling", None)
|
||||||
self.attn = QWenAttention(config.hidden_size,
|
self.attn = QWenAttention(config.hidden_size,
|
||||||
config.num_attention_heads,
|
config.num_attention_heads,
|
||||||
config.max_position_embeddings,
|
config.max_position_embeddings,
|
||||||
rope_theta=rope_theta)
|
rope_theta=rope_theta,
|
||||||
|
rope_scaling=rope_scaling)
|
||||||
|
|
||||||
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user