[Bugfix] load fc bias from config for eagle (#8790)
This commit is contained in:
parent
c23953675f
commit
3e073e66f1
@ -44,7 +44,7 @@ class EAGLE(nn.Module):
|
|||||||
self.model = model_cls(self.config.model, *args, **kwargs)
|
self.model = model_cls(self.config.model, *args, **kwargs)
|
||||||
self.fc = nn.Linear(config.model.hidden_size * 2,
|
self.fc = nn.Linear(config.model.hidden_size * 2,
|
||||||
config.model.hidden_size,
|
config.model.hidden_size,
|
||||||
bias=False)
|
bias=getattr(self.config, "bias", False))
|
||||||
|
|
||||||
self.orig_vocab_size = config.vocab_size
|
self.orig_vocab_size = config.vocab_size
|
||||||
self.truncated_vocab_size = config.truncated_vocab_size
|
self.truncated_vocab_size = config.truncated_vocab_size
|
||||||
@ -136,10 +136,18 @@ class EAGLE(nn.Module):
|
|||||||
if self.config.truncated_vocab_size < self.config.vocab_size:
|
if self.config.truncated_vocab_size < self.config.vocab_size:
|
||||||
self.token_map = nn.Parameter(loaded_weight,
|
self.token_map = nn.Parameter(loaded_weight,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
elif name.startswith("fc."):
|
elif name.startswith("fc.weight"):
|
||||||
weight_loader = getattr(self.fc.weight, "weight_loader",
|
weight_loader = getattr(self.fc.weight, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
weight_loader(self.fc.weight, loaded_weight)
|
weight_loader(self.fc.weight, loaded_weight)
|
||||||
|
elif name.startswith("fc.bias"):
|
||||||
|
if self.fc.bias is not None:
|
||||||
|
weight_loader = getattr(self.fc.bias, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(self.fc.bias, loaded_weight)
|
||||||
|
else:
|
||||||
|
raise ValueError("Found bias in the loaded weights "
|
||||||
|
"but the model config doesn't have bias")
|
||||||
elif name.startswith("model.lm_head.") or name.startswith(
|
elif name.startswith("model.lm_head.") or name.startswith(
|
||||||
"model.model."):
|
"model.model."):
|
||||||
model_weights[name.split("model.", 1)[-1]] = loaded_weight
|
model_weights[name.split("model.", 1)[-1]] = loaded_weight
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user