From af9e53496fc4dfc01b4680c1f16e38687cb3a91a Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 24 Mar 2024 06:34:01 -0700 Subject: [PATCH] [BugFix] Fix Falcon tied embeddings (#3590) Co-authored-by: 44670 <44670@users.noreply.github.com> --- vllm/model_executor/models/falcon.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 7626dbe6..0a01796a 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -37,7 +37,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding, ParallelLMHead) + VocabParallelEmbedding) from vllm.model_executor.parallel_utils.communication_op import ( tensor_model_parallel_all_reduce) from vllm.model_executor.parallel_utils.parallel_state import ( @@ -370,10 +370,7 @@ class FalconForCausalLM(nn.Module): self.config = config self.linear_method = linear_method self.transformer = FalconModel(config, linear_method) - self.lm_head = ParallelLMHead( - config.vocab_size, - config.hidden_size, - ) + self.lm_head_weight = self.transformer.word_embeddings.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -394,7 +391,7 @@ class FalconForCausalLM(nn.Module): def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head.weight, hidden_states, + logits = self.logits_processor(self.lm_head_weight, hidden_states, sampling_metadata) return logits @@ -419,9 +416,12 @@ class FalconForCausalLM(nn.Module): else: total_num_kv_heads = total_num_heads num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads - params_dict = dict(self.named_parameters()) + params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): + if name == "lm_head.weight": + # Falcon uses tied embeddings. + continue # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue