[Bugfix] support tie_word_embeddings for all models (#5724)

This commit is contained in:
Zijian Hu 2024-08-19 20:00:04 -07:00 committed by GitHub
parent 0df7ec0b2d
commit f4fc7337bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
30 changed files with 90 additions and 16 deletions

View File

@ -414,6 +414,8 @@ class ArcticForCausalLM(nn.Module):
config.hidden_size,
quant_config=quant_config,
)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.num_experts = config.num_local_experts
self.num_experts_per_tok = config.num_experts_per_tok
self.unpadded_vocab_size = config.vocab_size

View File

@ -331,6 +331,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()

View File

@ -821,6 +821,8 @@ class BartForConditionalGeneration(nn.Module):
lora_config: Optional[LoRAConfig] = None):
super().__init__()
# currently all existing BART models have `tie_word_embeddings` enabled
assert config.tie_word_embeddings
self.config = config
self.model = BartModel(config,
cache_config,

View File

@ -494,6 +494,9 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
super().__init__()
# currently all existing BLIP-2 models have `tie_word_embeddings`
# enabled
assert config.tie_word_embeddings
self.config = config
self.multimodal_config = multimodal_config

View File

@ -36,7 +36,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput
@ -276,7 +276,12 @@ class BloomForCausalLM(nn.Module):
self.config = config
self.quant_config = quant_config
self.transformer = BloomModel(config, cache_config, quant_config)
if self.config.tie_word_embeddings:
self.lm_head = self.transformer.word_embeddings
else:
self.lm_head = ParallelLMHead(self.config.vocab_size,
self.config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()

View File

@ -356,6 +356,9 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
self.max_position_embeddings = getattr(config, "max_sequence_length",
8192)
self.transformer = ChatGLMModel(config, cache_config, quant_config)
if self.config.tie_word_embeddings:
self.transformer.output_layer.weight = (
self.transformer.embedding.weight)
self.lm_head = self.transformer.output_layer
self.logits_processor = LogitsProcessor(config.padded_vocab_size)
self.sampler = Sampler()

View File

@ -321,6 +321,9 @@ class CohereForCausalLM(nn.Module):
) -> None:
super().__init__()
self.config = config
# currently all existing command R models have `tie_word_embeddings`
# enabled
assert config.tie_word_embeddings
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size

View File

@ -362,6 +362,9 @@ class DbrxForCausalLM(nn.Module):
):
super().__init__()
self.config = config
if config.tie_word_embeddings:
raise ValueError(
"tie_word_embeddings is not supported for Dbrx models.")
self.quant_config = quant_config
self.unpadded_vocab_size = config.vocab_size
self.transformer = DbrxModel(config, cache_config, quant_config)

View File

@ -380,6 +380,8 @@ class DeepseekForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()

View File

@ -331,6 +331,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
super().__init__()
self.config = config
# currently all existing Gemma models have `tie_word_embeddings` enabled
assert config.tie_word_embeddings
self.lora_config = lora_config
self.quant_config = quant_config

View File

@ -323,6 +323,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
del lora_config # Unused.
super().__init__()
self.config = config
# currently all existing Gemma models have `tie_word_embeddings` enabled
assert config.tie_word_embeddings
self.quant_config = quant_config
self.model = Gemma2Model(config, cache_config, quant_config)
self.logits_processor = LogitsProcessor(

View File

@ -36,7 +36,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput
@ -249,7 +249,11 @@ class GPT2LMHeadModel(nn.Module):
cache_config,
quant_config,
prefix="transformer")
if self.config.tie_word_embeddings:
self.lm_head = self.transformer.wte
else:
self.lm_head = ParallelLMHead(self.config.vocab_size,
self.config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()

View File

@ -36,7 +36,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput
@ -259,7 +259,13 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
self.quant_config = quant_config
self.transformer = GPTBigCodeModel(config, cache_config, quant_config,
lora_config)
if self.config.tie_word_embeddings:
self.lm_head = self.transformer.wte
else:
self.lm_head = ParallelLMHead(
self.transformer.vocab_size,
self.transformer.embed_dim,
org_num_embeddings=self.config.vocab_size)
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size

View File

@ -230,7 +230,7 @@ class GPTNeoXForCausalLM(nn.Module):
def __init__(
self,
config,
config: GPTNeoXConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
@ -243,6 +243,8 @@ class GPTNeoXForCausalLM(nn.Module):
config.hidden_size,
quant_config=quant_config,
)
if self.config.tie_word_embeddings:
self.embed_out.weight = self.gpt_neox.embed_in.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()

View File

@ -264,6 +264,8 @@ class InternLM2ForCausalLM(nn.Module):
self.output = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
if self.config.tie_word_embeddings:
self.output.weight = self.model.tok_embeddings.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()

View File

@ -37,7 +37,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput
@ -291,7 +291,11 @@ class JAISLMHeadModel(nn.Module):
self.config = config
self.quant_config = quant_config
self.transformer = JAISModel(config, cache_config, quant_config)
if self.config.tie_word_embeddings:
self.lm_head = self.transformer.wte
else:
self.lm_head = ParallelLMHead(self.config.vocab_size,
self.config.hidden_size)
if hasattr(config, "width_scale"):
self.output_logits_scale = config.width_scale
else:

View File

@ -496,6 +496,10 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
# All MiniCPM-V models disable `tie_word_embeddings` but
# `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot
# check `tie_word_embeddings` until vLLM integrate MiniCPM-V model
# and config class
self.config = config
self.multimodal_config = multimodal_config

View File

@ -359,6 +359,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
if not lora_config else lora_config.lora_vocab_padding_size,
quant_config=quant_config,
)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = Sampler()

View File

@ -347,6 +347,8 @@ class MixtralForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()

View File

@ -36,7 +36,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput
@ -307,7 +307,11 @@ class OPTForCausalLM(nn.Module):
self.config = config
self.quant_config = quant_config
self.model = OPTModel(config, cache_config, quant_config)
if self.config.tie_word_embeddings:
self.lm_head = self.model.decoder.embed_tokens
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.word_embed_proj_dim)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()

View File

@ -262,6 +262,8 @@ class OrionForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()

View File

@ -260,6 +260,8 @@ class PhiForCausalLM(nn.Module, SupportsLoRA):
super().__init__()
self.config = config
# lm_head use bias, cannot share word embeddings
assert not config.tie_word_embeddings
self.lora_config = lora_config
self.quant_config = quant_config

View File

@ -368,6 +368,8 @@ class Phi3SmallForCausalLM(nn.Module):
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
quant_config=quant_config,
)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
@ -449,4 +451,3 @@ class Phi3SmallForCausalLM(nn.Module):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
self.lm_head.weight.data.copy_(self.model.embed_tokens.weight.data)

View File

@ -477,6 +477,8 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()

View File

@ -252,6 +252,8 @@ class QWenLMHeadModel(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.transformer.wte.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()

View File

@ -385,6 +385,8 @@ class Qwen2MoeForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()

View File

@ -243,6 +243,8 @@ class StablelmForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()

View File

@ -313,6 +313,8 @@ class XverseForCausalLM(nn.Module, SupportsLoRA):
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()