[Model] Cohere CommandR+ (#3829)
This commit is contained in:
parent
db2a6a41e2
commit
9117f892f0
@ -25,6 +25,7 @@ from typing import List, Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
from transformers import CohereConfig
|
from transformers import CohereConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
@ -39,8 +40,9 @@ from vllm.model_executor.layers.sampler import Sampler
|
|||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding)
|
VocabParallelEmbedding)
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||||
hf_model_weights_iterator)
|
hf_model_weights_iterator)
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
@ -48,11 +50,11 @@ from vllm.sequence import SamplerOutput
|
|||||||
|
|
||||||
class LayerNorm(nn.Module):
|
class LayerNorm(nn.Module):
|
||||||
|
|
||||||
def __init__(self, hidden_size, eps=1e-5, bias=False):
|
def __init__(self, param_shape=None, eps=1e-5):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
self.weight = nn.Parameter(torch.ones(param_shape))
|
||||||
self.bias = nn.Parameter(torch.zeros(hidden_size)) if bias else None
|
|
||||||
self.variance_epsilon = eps
|
self.variance_epsilon = eps
|
||||||
|
set_weight_attrs(self.weight, {"weight_loader": self.weight_loader})
|
||||||
|
|
||||||
def forward(self, hidden_states, residuals=None):
|
def forward(self, hidden_states, residuals=None):
|
||||||
input_dtype = hidden_states.dtype
|
input_dtype = hidden_states.dtype
|
||||||
@ -62,10 +64,20 @@ class LayerNorm(nn.Module):
|
|||||||
hidden_states = (hidden_states -
|
hidden_states = (hidden_states -
|
||||||
mean) * torch.rsqrt(variance + self.variance_epsilon)
|
mean) * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
hidden_states = self.weight.to(torch.float32) * hidden_states
|
hidden_states = self.weight.to(torch.float32) * hidden_states
|
||||||
if self.bias is not None:
|
|
||||||
hidden_states = hidden_states + self.bias.to(torch.float32)
|
|
||||||
return hidden_states.to(input_dtype), residuals
|
return hidden_states.to(input_dtype), residuals
|
||||||
|
|
||||||
|
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
shard_dim = 0 if param.dim() != 1 else None
|
||||||
|
param_data = param.data
|
||||||
|
if shard_dim is not None:
|
||||||
|
shard_size = param_data.shape[shard_dim]
|
||||||
|
start_idx = tp_rank * shard_size
|
||||||
|
loaded_weight = loaded_weight.narrow(shard_dim, start_idx,
|
||||||
|
shard_size)
|
||||||
|
assert param_data.shape == loaded_weight.shape
|
||||||
|
param_data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
|
# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
|
||||||
class CohereMLP(nn.Module):
|
class CohereMLP(nn.Module):
|
||||||
@ -131,6 +143,7 @@ class CohereAttention(nn.Module):
|
|||||||
self.max_position_embeddings = config.max_position_embeddings
|
self.max_position_embeddings = config.max_position_embeddings
|
||||||
self.rope_theta = config.rope_theta
|
self.rope_theta = config.rope_theta
|
||||||
self.rope_scaling = getattr(config, "rope_scaling", None)
|
self.rope_scaling = getattr(config, "rope_scaling", None)
|
||||||
|
self.use_qk_norm = getattr(config, "use_qk_norm", False)
|
||||||
self.qkv_proj = QKVParallelLinear(
|
self.qkv_proj = QKVParallelLinear(
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
@ -159,6 +172,22 @@ class CohereAttention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
)
|
)
|
||||||
|
if self.use_qk_norm:
|
||||||
|
self.q_norm = LayerNorm(param_shape=(self.num_heads,
|
||||||
|
self.head_dim),
|
||||||
|
eps=config.layer_norm_eps)
|
||||||
|
self.k_norm = LayerNorm(param_shape=(self.num_kv_heads,
|
||||||
|
self.head_dim),
|
||||||
|
eps=config.layer_norm_eps)
|
||||||
|
|
||||||
|
def _apply_qk_norm(self, q, k):
|
||||||
|
q = q.view(*q.shape[:-1], -1, self.head_dim)
|
||||||
|
k = k.view(*k.shape[:-1], -1, self.head_dim)
|
||||||
|
q, _ = self.q_norm(q)
|
||||||
|
k, _ = self.k_norm(k)
|
||||||
|
q = q.view(*q.shape[:-2], -1)
|
||||||
|
k = k.view(*k.shape[:-2], -1)
|
||||||
|
return q, k
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -169,6 +198,8 @@ class CohereAttention(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
if self.use_qk_norm:
|
||||||
|
q, k = self._apply_qk_norm(q, k)
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
@ -186,7 +217,7 @@ class CohereDecoderLayer(nn.Module):
|
|||||||
self.self_attn = CohereAttention(config, linear_method=linear_method)
|
self.self_attn = CohereAttention(config, linear_method=linear_method)
|
||||||
|
|
||||||
self.mlp = CohereMLP(config, linear_method=linear_method)
|
self.mlp = CohereMLP(config, linear_method=linear_method)
|
||||||
self.input_layernorm = LayerNorm(config.hidden_size,
|
self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
|
||||||
eps=config.layer_norm_eps)
|
eps=config.layer_norm_eps)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -229,7 +260,8 @@ class CohereModel(nn.Module):
|
|||||||
CohereDecoderLayer(config, linear_method=linear_method)
|
CohereDecoderLayer(config, linear_method=linear_method)
|
||||||
for _ in range(config.num_hidden_layers)
|
for _ in range(config.num_hidden_layers)
|
||||||
])
|
])
|
||||||
self.norm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.norm = LayerNorm(param_shape=(config.hidden_size),
|
||||||
|
eps=config.layer_norm_eps)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user