[Model] Cohere CommandR+ (#3829)

This commit is contained in:
Saurabh Dash 2024-04-05 02:01:49 +05:30 committed by GitHub
parent db2a6a41e2
commit 9117f892f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -25,6 +25,7 @@ from typing import List, Optional, Tuple
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn.parameter import Parameter
from transformers import CohereConfig from transformers import CohereConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
@ -39,8 +40,9 @@ from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.weight_utils import (default_weight_loader, from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
@ -48,11 +50,11 @@ from vllm.sequence import SamplerOutput
class LayerNorm(nn.Module): class LayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-5, bias=False): def __init__(self, param_shape=None, eps=1e-5):
super().__init__() super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size)) self.weight = nn.Parameter(torch.ones(param_shape))
self.bias = nn.Parameter(torch.zeros(hidden_size)) if bias else None
self.variance_epsilon = eps self.variance_epsilon = eps
set_weight_attrs(self.weight, {"weight_loader": self.weight_loader})
def forward(self, hidden_states, residuals=None): def forward(self, hidden_states, residuals=None):
input_dtype = hidden_states.dtype input_dtype = hidden_states.dtype
@ -62,10 +64,20 @@ class LayerNorm(nn.Module):
hidden_states = (hidden_states - hidden_states = (hidden_states -
mean) * torch.rsqrt(variance + self.variance_epsilon) mean) * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = self.weight.to(torch.float32) * hidden_states hidden_states = self.weight.to(torch.float32) * hidden_states
if self.bias is not None:
hidden_states = hidden_states + self.bias.to(torch.float32)
return hidden_states.to(input_dtype), residuals return hidden_states.to(input_dtype), residuals
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
shard_dim = 0 if param.dim() != 1 else None
param_data = param.data
if shard_dim is not None:
shard_size = param_data.shape[shard_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(shard_dim, start_idx,
shard_size)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere # Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
class CohereMLP(nn.Module): class CohereMLP(nn.Module):
@ -131,6 +143,7 @@ class CohereAttention(nn.Module):
self.max_position_embeddings = config.max_position_embeddings self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta self.rope_theta = config.rope_theta
self.rope_scaling = getattr(config, "rope_scaling", None) self.rope_scaling = getattr(config, "rope_scaling", None)
self.use_qk_norm = getattr(config, "use_qk_norm", False)
self.qkv_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
self.hidden_size, self.hidden_size,
self.head_dim, self.head_dim,
@ -159,6 +172,22 @@ class CohereAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
) )
if self.use_qk_norm:
self.q_norm = LayerNorm(param_shape=(self.num_heads,
self.head_dim),
eps=config.layer_norm_eps)
self.k_norm = LayerNorm(param_shape=(self.num_kv_heads,
self.head_dim),
eps=config.layer_norm_eps)
def _apply_qk_norm(self, q, k):
q = q.view(*q.shape[:-1], -1, self.head_dim)
k = k.view(*k.shape[:-1], -1, self.head_dim)
q, _ = self.q_norm(q)
k, _ = self.k_norm(k)
q = q.view(*q.shape[:-2], -1)
k = k.view(*k.shape[:-2], -1)
return q, k
def forward( def forward(
self, self,
@ -169,6 +198,8 @@ class CohereAttention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.use_qk_norm:
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
@ -186,7 +217,7 @@ class CohereDecoderLayer(nn.Module):
self.self_attn = CohereAttention(config, linear_method=linear_method) self.self_attn = CohereAttention(config, linear_method=linear_method)
self.mlp = CohereMLP(config, linear_method=linear_method) self.mlp = CohereMLP(config, linear_method=linear_method)
self.input_layernorm = LayerNorm(config.hidden_size, self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
def forward( def forward(
@ -229,7 +260,8 @@ class CohereModel(nn.Module):
CohereDecoderLayer(config, linear_method=linear_method) CohereDecoderLayer(config, linear_method=linear_method)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.norm = LayerNorm(param_shape=(config.hidden_size),
eps=config.layer_norm_eps)
def forward( def forward(
self, self,