[Bugfix] Weight loading fix for OPT model (#9042)

Co-authored-by: dvres <dvres@fri.uni-lj.si>
This commit is contained in:
Domen Vreš 2024-10-04 01:53:29 +02:00 committed by GitHub
parent 91add85ec4
commit 2838d6b38e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -353,7 +353,7 @@ class OPTForCausalLM(nn.Module):
] ]
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "lm_head.weight" in name: if "lm_head.weight" in name and self.config.tie_word_embeddings:
continue continue
if name.startswith("decoder."): if name.startswith("decoder."):
name = "model." + name name = "model." + name