[Model] LoRA support added for command-r (#5178)
This commit is contained in:
parent
19091efc44
commit
07feecde1a
6
csrc/punica/bgmv/bgmv_config.h
Normal file → Executable file
6
csrc/punica/bgmv/bgmv_config.h
Normal file → Executable file
@ -69,6 +69,8 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
|||||||
f(in_T, out_T, W_T, narrow, 36864) \
|
f(in_T, out_T, W_T, narrow, 36864) \
|
||||||
f(in_T, out_T, W_T, narrow, 43264) \
|
f(in_T, out_T, W_T, narrow, 43264) \
|
||||||
f(in_T, out_T, W_T, narrow, 49152) \
|
f(in_T, out_T, W_T, narrow, 49152) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 60544) \
|
||||||
|
f(in_T, out_T, W_T, narrow, 60672) \
|
||||||
f(in_T, out_T, W_T, narrow, 64000) \
|
f(in_T, out_T, W_T, narrow, 64000) \
|
||||||
f(in_T, out_T, W_T, narrow, 64256) \
|
f(in_T, out_T, W_T, narrow, 64256) \
|
||||||
f(in_T, out_T, W_T, narrow, 64512) \
|
f(in_T, out_T, W_T, narrow, 64512) \
|
||||||
@ -78,6 +80,8 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
|||||||
f(in_T, out_T, W_T, narrow, 128000) \
|
f(in_T, out_T, W_T, narrow, 128000) \
|
||||||
f(in_T, out_T, W_T, narrow, 128256) \
|
f(in_T, out_T, W_T, narrow, 128256) \
|
||||||
f(in_T, out_T, W_T, narrow, 128512) \
|
f(in_T, out_T, W_T, narrow, 128512) \
|
||||||
|
|
||||||
|
|
||||||
// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA
|
// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA
|
||||||
// and vllm/tests/lora/test_punica.py
|
// and vllm/tests/lora/test_punica.py
|
||||||
|
|
||||||
@ -144,6 +148,8 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
|||||||
f(in_T, out_T, W_T, 36864, narrow) \
|
f(in_T, out_T, W_T, 36864, narrow) \
|
||||||
f(in_T, out_T, W_T, 43264, narrow) \
|
f(in_T, out_T, W_T, 43264, narrow) \
|
||||||
f(in_T, out_T, W_T, 49152, narrow) \
|
f(in_T, out_T, W_T, 49152, narrow) \
|
||||||
|
f(in_T, out_T, W_T, 60544, narrow) \
|
||||||
|
f(in_T, out_T, W_T, 60672, narrow) \
|
||||||
f(in_T, out_T, W_T, 64000, narrow) \
|
f(in_T, out_T, W_T, 64000, narrow) \
|
||||||
f(in_T, out_T, W_T, 64256, narrow) \
|
f(in_T, out_T, W_T, 64256, narrow) \
|
||||||
f(in_T, out_T, W_T, 64512, narrow) \
|
f(in_T, out_T, W_T, 64512, narrow) \
|
||||||
|
|||||||
@ -94,6 +94,8 @@ H1 = H2 = [
|
|||||||
36864,
|
36864,
|
||||||
43264,
|
43264,
|
||||||
49152,
|
49152,
|
||||||
|
60544,
|
||||||
|
60672,
|
||||||
64000,
|
64000,
|
||||||
64256,
|
64256,
|
||||||
102400,
|
102400,
|
||||||
|
|||||||
@ -29,7 +29,7 @@ from torch.nn.parameter import Parameter
|
|||||||
from transformers import CohereConfig
|
from transformers import CohereConfig
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.config import CacheConfig
|
from vllm.config import CacheConfig, LoRAConfig
|
||||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
@ -265,10 +265,14 @@ class CohereModel(nn.Module):
|
|||||||
config: CohereConfig,
|
config: CohereConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.vocab_size = config.vocab_size
|
lora_vocab = (lora_config.lora_extra_vocab_size *
|
||||||
|
(lora_config.max_loras or 1)) if lora_config else 0
|
||||||
|
self.vocab_size = config.vocab_size + lora_vocab
|
||||||
|
self.org_vocab_size = config.vocab_size
|
||||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
||||||
config.hidden_size)
|
config.hidden_size)
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
@ -302,18 +306,44 @@ class CohereModel(nn.Module):
|
|||||||
|
|
||||||
class CohereForCausalLM(nn.Module):
|
class CohereForCausalLM(nn.Module):
|
||||||
|
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"qkv_proj": [
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
],
|
||||||
|
"gate_up_proj": [
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
# LoRA specific attributes
|
||||||
|
supported_lora_modules = [
|
||||||
|
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens"
|
||||||
|
]
|
||||||
|
embedding_modules = {"embed_tokens": "input_embeddings"}
|
||||||
|
embedding_padding_modules = []
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: CohereConfig,
|
config: CohereConfig,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.unpadded_vocab_size = config.vocab_size
|
||||||
|
if lora_config:
|
||||||
|
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size,
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||||
|
config.vocab_size,
|
||||||
scale=config.logit_scale)
|
scale=config.logit_scale)
|
||||||
self.model = CohereModel(config, cache_config, quant_config)
|
self.model = CohereModel(config,
|
||||||
|
cache_config,
|
||||||
|
quant_config,
|
||||||
|
lora_config=lora_config)
|
||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -330,8 +360,14 @@ class CohereForCausalLM(nn.Module):
|
|||||||
|
|
||||||
def compute_logits(self, hidden_states: torch.Tensor,
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
logits = self.logits_processor(self.model.embed_tokens.weight,
|
is_not_lora = hasattr(self.model.embed_tokens, 'weight')
|
||||||
hidden_states, sampling_metadata)
|
if is_not_lora:
|
||||||
|
embedding_weights = self.model.embed_tokens.weight
|
||||||
|
else:
|
||||||
|
embedding_weights = self.model.embed_tokens.base_layer.weight
|
||||||
|
|
||||||
|
logits = self.logits_processor(embedding_weights, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user