[Model] Cohere CommandR+ (#3829)
This commit is contained in:
parent
db2a6a41e2
commit
9117f892f0
@ -25,6 +25,7 @@ from typing import List, Optional, Tuple
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn.parameter import Parameter
|
||||
from transformers import CohereConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
@ -39,8 +40,9 @@ from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_world_size)
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
@ -48,11 +50,11 @@ from vllm.sequence import SamplerOutput
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, eps=1e-5, bias=False):
|
||||
def __init__(self, param_shape=None, eps=1e-5):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.bias = nn.Parameter(torch.zeros(hidden_size)) if bias else None
|
||||
self.weight = nn.Parameter(torch.ones(param_shape))
|
||||
self.variance_epsilon = eps
|
||||
set_weight_attrs(self.weight, {"weight_loader": self.weight_loader})
|
||||
|
||||
def forward(self, hidden_states, residuals=None):
|
||||
input_dtype = hidden_states.dtype
|
||||
@ -62,10 +64,20 @@ class LayerNorm(nn.Module):
|
||||
hidden_states = (hidden_states -
|
||||
mean) * torch.rsqrt(variance + self.variance_epsilon)
|
||||
hidden_states = self.weight.to(torch.float32) * hidden_states
|
||||
if self.bias is not None:
|
||||
hidden_states = hidden_states + self.bias.to(torch.float32)
|
||||
return hidden_states.to(input_dtype), residuals
|
||||
|
||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
shard_dim = 0 if param.dim() != 1 else None
|
||||
param_data = param.data
|
||||
if shard_dim is not None:
|
||||
shard_size = param_data.shape[shard_dim]
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(shard_dim, start_idx,
|
||||
shard_size)
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
|
||||
class CohereMLP(nn.Module):
|
||||
@ -131,6 +143,7 @@ class CohereAttention(nn.Module):
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.rope_scaling = getattr(config, "rope_scaling", None)
|
||||
self.use_qk_norm = getattr(config, "use_qk_norm", False)
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
self.hidden_size,
|
||||
self.head_dim,
|
||||
@ -159,6 +172,22 @@ class CohereAttention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
)
|
||||
if self.use_qk_norm:
|
||||
self.q_norm = LayerNorm(param_shape=(self.num_heads,
|
||||
self.head_dim),
|
||||
eps=config.layer_norm_eps)
|
||||
self.k_norm = LayerNorm(param_shape=(self.num_kv_heads,
|
||||
self.head_dim),
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
def _apply_qk_norm(self, q, k):
|
||||
q = q.view(*q.shape[:-1], -1, self.head_dim)
|
||||
k = k.view(*k.shape[:-1], -1, self.head_dim)
|
||||
q, _ = self.q_norm(q)
|
||||
k, _ = self.k_norm(k)
|
||||
q = q.view(*q.shape[:-2], -1)
|
||||
k = k.view(*k.shape[:-2], -1)
|
||||
return q, k
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -169,6 +198,8 @@ class CohereAttention(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
if self.use_qk_norm:
|
||||
q, k = self._apply_qk_norm(q, k)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
@ -186,7 +217,7 @@ class CohereDecoderLayer(nn.Module):
|
||||
self.self_attn = CohereAttention(config, linear_method=linear_method)
|
||||
|
||||
self.mlp = CohereMLP(config, linear_method=linear_method)
|
||||
self.input_layernorm = LayerNorm(config.hidden_size,
|
||||
self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
def forward(
|
||||
@ -229,7 +260,8 @@ class CohereModel(nn.Module):
|
||||
CohereDecoderLayer(config, linear_method=linear_method)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.norm = LayerNorm(param_shape=(config.hidden_size),
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user