From 9117f892f0e4d3b0f07bf0b9b409321bc743dabc Mon Sep 17 00:00:00 2001 From: Saurabh Dash <111897126+saurabhdash2512@users.noreply.github.com> Date: Fri, 5 Apr 2024 02:01:49 +0530 Subject: [PATCH] [Model] Cohere CommandR+ (#3829) --- vllm/model_executor/models/commandr.py | 48 +++++++++++++++++++++----- 1 file changed, 40 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index ee6d36f6..620d6313 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -25,6 +25,7 @@ from typing import List, Optional, Tuple import torch import torch.utils.checkpoint from torch import nn +from torch.nn.parameter import Parameter from transformers import CohereConfig from vllm.attention import Attention, AttentionMetadata @@ -39,8 +40,9 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -48,11 +50,11 @@ from vllm.sequence import SamplerOutput class LayerNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-5, bias=False): + def __init__(self, param_shape=None, eps=1e-5): super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.bias = nn.Parameter(torch.zeros(hidden_size)) if bias else None + self.weight = nn.Parameter(torch.ones(param_shape)) self.variance_epsilon = eps + set_weight_attrs(self.weight, {"weight_loader": self.weight_loader}) def forward(self, hidden_states, residuals=None): input_dtype = hidden_states.dtype @@ -62,10 +64,20 @@ class LayerNorm(nn.Module): hidden_states = (hidden_states - mean) * torch.rsqrt(variance + self.variance_epsilon) hidden_states = self.weight.to(torch.float32) * hidden_states - if self.bias is not None: - hidden_states = hidden_states + self.bias.to(torch.float32) return hidden_states.to(input_dtype), residuals + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + shard_dim = 0 if param.dim() != 1 else None + param_data = param.data + if shard_dim is not None: + shard_size = param_data.shape[shard_dim] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(shard_dim, start_idx, + shard_size) + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + # Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere class CohereMLP(nn.Module): @@ -131,6 +143,7 @@ class CohereAttention(nn.Module): self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.rope_scaling = getattr(config, "rope_scaling", None) + self.use_qk_norm = getattr(config, "use_qk_norm", False) self.qkv_proj = QKVParallelLinear( self.hidden_size, self.head_dim, @@ -159,6 +172,22 @@ class CohereAttention(nn.Module): self.scaling, num_kv_heads=self.num_kv_heads, ) + if self.use_qk_norm: + self.q_norm = LayerNorm(param_shape=(self.num_heads, + self.head_dim), + eps=config.layer_norm_eps) + self.k_norm = LayerNorm(param_shape=(self.num_kv_heads, + self.head_dim), + eps=config.layer_norm_eps) + + def _apply_qk_norm(self, q, k): + q = q.view(*q.shape[:-1], -1, self.head_dim) + k = k.view(*k.shape[:-1], -1, self.head_dim) + q, _ = self.q_norm(q) + k, _ = self.k_norm(k) + q = q.view(*q.shape[:-2], -1) + k = k.view(*k.shape[:-2], -1) + return q, k def forward( self, @@ -169,6 +198,8 @@ class CohereAttention(nn.Module): ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + if self.use_qk_norm: + q, k = self._apply_qk_norm(q, k) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.o_proj(attn_output) @@ -186,7 +217,7 @@ class CohereDecoderLayer(nn.Module): self.self_attn = CohereAttention(config, linear_method=linear_method) self.mlp = CohereMLP(config, linear_method=linear_method) - self.input_layernorm = LayerNorm(config.hidden_size, + self.input_layernorm = LayerNorm(param_shape=(config.hidden_size), eps=config.layer_norm_eps) def forward( @@ -229,7 +260,8 @@ class CohereModel(nn.Module): CohereDecoderLayer(config, linear_method=linear_method) for _ in range(config.num_hidden_layers) ]) - self.norm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.norm = LayerNorm(param_shape=(config.hidden_size), + eps=config.layer_norm_eps) def forward( self,