diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index ef6d401b..b448557a 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -443,7 +443,7 @@ class ParallelLMHead(VocabParallelEmbedding): super().__init__(num_embeddings, embedding_dim, params_dtype, org_num_embeddings, padding_size, quant_config, prefix) - + self.quant_config = quant_config if bias: self.bias = Parameter( torch.empty(self.num_embeddings_per_partition, @@ -455,6 +455,15 @@ class ParallelLMHead(VocabParallelEmbedding): else: self.register_parameter("bias", None) + def tie_weights(self, embed_tokens: VocabParallelEmbedding): + """Tie the weights with word embeddings.""" + # GGUF quantized embed_tokens. + if self.quant_config and self.quant_config.get_name() == "gguf": + return embed_tokens + else: + self.weight = embed_tokens.weight + return self + def forward(self, input_): del input_ raise RuntimeError("LMHead's weights should be used in the sampler.") diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 0589b581..2a79a9ed 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -524,7 +524,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): quant_config=quant_config, ) if config.tie_word_embeddings: - self.lm_head = self.model.embed_tokens + self.lm_head = self.lm_head.tie_weights( + self.model.embed_tokens) logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,