[Model] add rope_scaling support for qwen2 (#4930)
This commit is contained in:
parent
65ae8c2c8f
commit
d130b573a0
@ -89,7 +89,8 @@ class Qwen2Attention(nn.Module):
|
|||||||
use_sliding_window: bool = False,
|
use_sliding_window: bool = False,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
sliding_window: Optional[int] = None) -> None:
|
sliding_window: Optional[int] = None,
|
||||||
|
rope_scaling: Optional[Tuple] = None) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
@ -133,6 +134,7 @@ class Qwen2Attention(nn.Module):
|
|||||||
rotary_dim=self.head_dim,
|
rotary_dim=self.head_dim,
|
||||||
max_position=max_position,
|
max_position=max_position,
|
||||||
base=self.rope_theta,
|
base=self.rope_theta,
|
||||||
|
rope_scaling=rope_scaling,
|
||||||
)
|
)
|
||||||
self.attn = Attention(self.num_heads,
|
self.attn = Attention(self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
@ -169,6 +171,7 @@ class Qwen2DecoderLayer(nn.Module):
|
|||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
# Requires transformers > 4.32.0
|
# Requires transformers > 4.32.0
|
||||||
rope_theta = getattr(config, "rope_theta", 1000000)
|
rope_theta = getattr(config, "rope_theta", 1000000)
|
||||||
|
rope_scaling = getattr(config, "rope_scaling", None)
|
||||||
use_sliding_window = (config.use_sliding_window
|
use_sliding_window = (config.use_sliding_window
|
||||||
and layer_idx < config.max_window_layers)
|
and layer_idx < config.max_window_layers)
|
||||||
self.self_attn = Qwen2Attention(
|
self.self_attn = Qwen2Attention(
|
||||||
@ -180,7 +183,8 @@ class Qwen2DecoderLayer(nn.Module):
|
|||||||
use_sliding_window=use_sliding_window,
|
use_sliding_window=use_sliding_window,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
sliding_window=config.sliding_window)
|
sliding_window=config.sliding_window,
|
||||||
|
rope_scaling=rope_scaling)
|
||||||
self.mlp = Qwen2MLP(
|
self.mlp = Qwen2MLP(
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
intermediate_size=config.intermediate_size,
|
intermediate_size=config.intermediate_size,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user