[Bugfix] Fix lm_head weights tying with lora for llama (#9227)
This commit is contained in:
parent
f3a507f1d3
commit
07c11cf4d4
@ -443,7 +443,7 @@ class ParallelLMHead(VocabParallelEmbedding):
|
|||||||
super().__init__(num_embeddings, embedding_dim, params_dtype,
|
super().__init__(num_embeddings, embedding_dim, params_dtype,
|
||||||
org_num_embeddings, padding_size, quant_config,
|
org_num_embeddings, padding_size, quant_config,
|
||||||
prefix)
|
prefix)
|
||||||
|
self.quant_config = quant_config
|
||||||
if bias:
|
if bias:
|
||||||
self.bias = Parameter(
|
self.bias = Parameter(
|
||||||
torch.empty(self.num_embeddings_per_partition,
|
torch.empty(self.num_embeddings_per_partition,
|
||||||
@ -455,6 +455,15 @@ class ParallelLMHead(VocabParallelEmbedding):
|
|||||||
else:
|
else:
|
||||||
self.register_parameter("bias", None)
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
|
def tie_weights(self, embed_tokens: VocabParallelEmbedding):
|
||||||
|
"""Tie the weights with word embeddings."""
|
||||||
|
# GGUF quantized embed_tokens.
|
||||||
|
if self.quant_config and self.quant_config.get_name() == "gguf":
|
||||||
|
return embed_tokens
|
||||||
|
else:
|
||||||
|
self.weight = embed_tokens.weight
|
||||||
|
return self
|
||||||
|
|
||||||
def forward(self, input_):
|
def forward(self, input_):
|
||||||
del input_
|
del input_
|
||||||
raise RuntimeError("LMHead's weights should be used in the sampler.")
|
raise RuntimeError("LMHead's weights should be used in the sampler.")
|
||||||
|
|||||||
@ -524,7 +524,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
if config.tie_word_embeddings:
|
if config.tie_word_embeddings:
|
||||||
self.lm_head = self.model.embed_tokens
|
self.lm_head = self.lm_head.tie_weights(
|
||||||
|
self.model.embed_tokens)
|
||||||
|
|
||||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user