From 07feecde1a69859d565786a7ad64c0f604f17b28 Mon Sep 17 00:00:00 2001 From: sergey-tinkoff <167607910+sergey-tinkoff@users.noreply.github.com> Date: Tue, 18 Jun 2024 21:01:21 +0300 Subject: [PATCH] [Model] LoRA support added for command-r (#5178) --- csrc/punica/bgmv/bgmv_config.h | 6 ++++ tests/lora/test_punica.py | 2 ++ vllm/model_executor/models/commandr.py | 48 ++++++++++++++++++++++---- 3 files changed, 50 insertions(+), 6 deletions(-) mode change 100644 => 100755 csrc/punica/bgmv/bgmv_config.h diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h old mode 100644 new mode 100755 index 0456b4bc..c38db2dc --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -69,6 +69,8 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 36864) \ f(in_T, out_T, W_T, narrow, 43264) \ f(in_T, out_T, W_T, narrow, 49152) \ + f(in_T, out_T, W_T, narrow, 60544) \ + f(in_T, out_T, W_T, narrow, 60672) \ f(in_T, out_T, W_T, narrow, 64000) \ f(in_T, out_T, W_T, narrow, 64256) \ f(in_T, out_T, W_T, narrow, 64512) \ @@ -78,6 +80,8 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 128000) \ f(in_T, out_T, W_T, narrow, 128256) \ f(in_T, out_T, W_T, narrow, 128512) \ + + // Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA // and vllm/tests/lora/test_punica.py @@ -144,6 +148,8 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, 36864, narrow) \ f(in_T, out_T, W_T, 43264, narrow) \ f(in_T, out_T, W_T, 49152, narrow) \ + f(in_T, out_T, W_T, 60544, narrow) \ + f(in_T, out_T, W_T, 60672, narrow) \ f(in_T, out_T, W_T, 64000, narrow) \ f(in_T, out_T, W_T, 64256, narrow) \ f(in_T, out_T, W_T, 64512, narrow) \ diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py index d87658e5..dae1d568 100644 --- a/tests/lora/test_punica.py +++ b/tests/lora/test_punica.py @@ -94,6 +94,8 @@ H1 = H2 = [ 36864, 43264, 49152, + 60544, + 60672, 64000, 64256, 102400, diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 11d88d45..600c2990 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -29,7 +29,7 @@ from torch.nn.parameter import Parameter from transformers import CohereConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul @@ -265,10 +265,14 @@ class CohereModel(nn.Module): config: CohereConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, ): super().__init__() self.config = config - self.vocab_size = config.vocab_size + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([ @@ -302,18 +306,44 @@ class CohereModel(nn.Module): class CohereForCausalLM(nn.Module): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens" + ] + embedding_modules = {"embed_tokens": "input_embeddings"} + embedding_padding_modules = [] + def __init__( self, config: CohereConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.quant_config = quant_config - self.logits_processor = LogitsProcessor(config.vocab_size, + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, scale=config.logit_scale) - self.model = CohereModel(config, cache_config, quant_config) + self.model = CohereModel(config, + cache_config, + quant_config, + lora_config=lora_config) self.sampler = Sampler() @torch.no_grad() @@ -330,8 +360,14 @@ class CohereForCausalLM(nn.Module): def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.model.embed_tokens.weight, - hidden_states, sampling_metadata) + is_not_lora = hasattr(self.model.embed_tokens, 'weight') + if is_not_lora: + embedding_weights = self.model.embed_tokens.weight + else: + embedding_weights = self.model.embed_tokens.base_layer.weight + + logits = self.logits_processor(embedding_weights, hidden_states, + sampling_metadata) return logits def sample(