From 7d761fe3c12e87df37383467c43c97dec2bb8470 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Mon, 20 Nov 2023 23:56:48 -0800 Subject: [PATCH] [FIX] Fix the case when `input_is_parallel=False` for `ScaledActivation` (#1737) --- vllm/model_executor/layers/activation.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index caa0e319..ecab0c8d 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -61,6 +61,7 @@ class ScaledActivation(nn.Module): ): super().__init__() self.act = act_module + self.input_is_parallel = input_is_parallel if input_is_parallel: tp_size = get_tensor_model_parallel_world_size() intermediate_size_per_partition = divide(intermediate_size, @@ -79,11 +80,12 @@ class ScaledActivation(nn.Module): return self.act(x) / self.scales def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): - tp_rank = get_tensor_model_parallel_rank() param_data = param.data - shard_size = param_data.shape[0] - start_idx = tp_rank * shard_size - loaded_weight = loaded_weight.narrow(0, start_idx, shard_size) + if self.input_is_parallel: + tp_rank = get_tensor_model_parallel_rank() + shard_size = param_data.shape[0] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(0, start_idx, shard_size) assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight)