[Bugfix] Weight loading fix for OPT model (#9042)
Co-authored-by: dvres <dvres@fri.uni-lj.si>
This commit is contained in:
parent
91add85ec4
commit
2838d6b38e
@ -353,7 +353,7 @@ class OPTForCausalLM(nn.Module):
|
|||||||
]
|
]
|
||||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if "lm_head.weight" in name:
|
if "lm_head.weight" in name and self.config.tie_word_embeddings:
|
||||||
continue
|
continue
|
||||||
if name.startswith("decoder."):
|
if name.startswith("decoder."):
|
||||||
name = "model." + name
|
name = "model." + name
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user