Add rope_scaling to Qwen (#1210)

This commit is contained in:
Qing 2023-09-28 15:49:23 +08:00 committed by GitHub
parent 20f7cc4cde
commit 7bedab5748
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -8,7 +8,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input. InputMetadata to extract the original 2D shape of the input.
""" """
from typing import List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
@ -76,13 +76,12 @@ class QWenMLP(nn.Module):
class QWenAttention(nn.Module): class QWenAttention(nn.Module):
def __init__( def __init__(self,
self, hidden_size: int,
hidden_size: int, num_heads: int,
num_heads: int, max_position_embeddings: int,
max_position_embeddings: int, rope_theta: float = 10000,
rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None):
):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
@ -116,7 +115,7 @@ class QWenAttention(nn.Module):
rotary_dim=self.head_dim, rotary_dim=self.head_dim,
base=rope_theta, base=rope_theta,
max_position=max_position_embeddings, max_position=max_position_embeddings,
) rope_scaling=rope_scaling)
def forward( def forward(
self, self,
@ -144,10 +143,12 @@ class QWenBlock(nn.Module):
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
self.attn = QWenAttention(config.hidden_size, self.attn = QWenAttention(config.hidden_size,
config.num_attention_heads, config.num_attention_heads,
config.max_position_embeddings, config.max_position_embeddings,
rope_theta=rope_theta) rope_theta=rope_theta,
rope_scaling=rope_scaling)
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)