From 462ae5220aeeb2135b841107d2c841f85fc348bd Mon Sep 17 00:00:00 2001 From: WRH <12756472+wangruohui@users.noreply.github.com> Date: Sat, 12 Aug 2023 02:40:37 +0800 Subject: [PATCH] [Fix] unwantted bias in InternLM Model (#740) --- vllm/model_executor/models/internlm.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/internlm.py b/vllm/model_executor/models/internlm.py index e2fb3f2f..19983233 100644 --- a/vllm/model_executor/models/internlm.py +++ b/vllm/model_executor/models/internlm.py @@ -7,15 +7,15 @@ from transformers import LlamaConfig from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.attention import PagedAttentionWithRoPE +from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.weight_utils import (hf_model_weights_iterator, - load_tensor_parallel_weights) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.parallel_utils.tensor_parallel import ( - VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) + ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding) +from vllm.model_executor.weight_utils import (hf_model_weights_iterator, + load_tensor_parallel_weights) from vllm.sequence import SequenceOutputs KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -32,12 +32,12 @@ class InternLMMLP(nn.Module): super().__init__() self.gate_up_proj = ColumnParallelLinear(hidden_size, 2 * intermediate_size, - bias=True, + bias=False, gather_output=False, perform_initialization=False) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, - bias=True, + bias=False, input_is_parallel=True, perform_initialization=False) if hidden_act != "silu":