fix: revert code to avoid no attribute problem (#827)
This commit is contained in:
parent
14f9c72bfd
commit
eedac9dba0
@ -259,6 +259,8 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
use_np_cache: bool = False):
|
use_np_cache: bool = False):
|
||||||
|
tensor_model_parallel_world_size = (
|
||||||
|
get_tensor_model_parallel_world_size())
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
state_dict = self.state_dict()
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
@ -288,8 +290,7 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|||||||
hidden_size = self.config.hidden_size
|
hidden_size = self.config.hidden_size
|
||||||
head_size = hidden_size // total_num_heads
|
head_size = hidden_size // total_num_heads
|
||||||
total_kv_size = head_size * total_num_kv_heads
|
total_kv_size = head_size * total_num_kv_heads
|
||||||
num_heads = (total_num_heads //
|
num_heads = total_num_heads // tensor_model_parallel_world_size
|
||||||
self.tensor_model_parallel_world_size)
|
|
||||||
head_start = tensor_model_parallel_rank * num_heads
|
head_start = tensor_model_parallel_rank * num_heads
|
||||||
head_end = (tensor_model_parallel_rank + 1) * num_heads
|
head_end = (tensor_model_parallel_rank + 1) * num_heads
|
||||||
|
|
||||||
@ -329,7 +330,7 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|||||||
if name == "transformer.wte.weight":
|
if name == "transformer.wte.weight":
|
||||||
# Consider padding in the vocab size.
|
# Consider padding in the vocab size.
|
||||||
padded_vocab_size = param.shape[
|
padded_vocab_size = param.shape[
|
||||||
0] * self.tensor_model_parallel_world_size
|
0] * tensor_model_parallel_world_size
|
||||||
num_extra_rows = padded_vocab_size - self.config.vocab_size
|
num_extra_rows = padded_vocab_size - self.config.vocab_size
|
||||||
extra_rows = torch.empty(num_extra_rows,
|
extra_rows = torch.empty(num_extra_rows,
|
||||||
loaded_weight.shape[1])
|
loaded_weight.shape[1])
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user