diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 0e6e71ec..595c7b64 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -259,6 +259,8 @@ class GPTBigCodeForCausalLM(nn.Module): model_name_or_path: str, cache_dir: Optional[str] = None, use_np_cache: bool = False): + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) tensor_model_parallel_rank = get_tensor_model_parallel_rank() state_dict = self.state_dict() @@ -288,8 +290,7 @@ class GPTBigCodeForCausalLM(nn.Module): hidden_size = self.config.hidden_size head_size = hidden_size // total_num_heads total_kv_size = head_size * total_num_kv_heads - num_heads = (total_num_heads // - self.tensor_model_parallel_world_size) + num_heads = total_num_heads // tensor_model_parallel_world_size head_start = tensor_model_parallel_rank * num_heads head_end = (tensor_model_parallel_rank + 1) * num_heads @@ -329,7 +330,7 @@ class GPTBigCodeForCausalLM(nn.Module): if name == "transformer.wte.weight": # Consider padding in the vocab size. padded_vocab_size = param.shape[ - 0] * self.tensor_model_parallel_world_size + 0] * tensor_model_parallel_world_size num_extra_rows = padded_vocab_size - self.config.vocab_size extra_rows = torch.empty(num_extra_rows, loaded_weight.shape[1])