diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 3e4f843e..12e0fedd 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -299,7 +299,11 @@ class Qwen2ForCausalLM(nn.Module): self.config = config self.linear_method = linear_method self.model = Qwen2Model(config, linear_method) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + + if not config.tie_word_embeddings: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size) + self.sampler = Sampler(config.vocab_size) def forward( @@ -318,7 +322,11 @@ class Qwen2ForCausalLM(nn.Module): hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, + if self.config.tie_word_embeddings: + lm_head_weight = self.model.embed_tokens.weight + else: + lm_head_weight = self.lm_head.weight + next_tokens = self.sampler(lm_head_weight, hidden_states, sampling_metadata) return next_tokens @@ -340,6 +348,8 @@ class Qwen2ForCausalLM(nn.Module): model_name_or_path, cache_dir, load_format, revision): if "rotary_emb.inv_freq" in name: continue + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue