[Bugfix] load fc bias from config for eagle (#8790)
This commit is contained in:
parent
c23953675f
commit
3e073e66f1
@ -44,7 +44,7 @@ class EAGLE(nn.Module):
|
||||
self.model = model_cls(self.config.model, *args, **kwargs)
|
||||
self.fc = nn.Linear(config.model.hidden_size * 2,
|
||||
config.model.hidden_size,
|
||||
bias=False)
|
||||
bias=getattr(self.config, "bias", False))
|
||||
|
||||
self.orig_vocab_size = config.vocab_size
|
||||
self.truncated_vocab_size = config.truncated_vocab_size
|
||||
@ -136,10 +136,18 @@ class EAGLE(nn.Module):
|
||||
if self.config.truncated_vocab_size < self.config.vocab_size:
|
||||
self.token_map = nn.Parameter(loaded_weight,
|
||||
requires_grad=False)
|
||||
elif name.startswith("fc."):
|
||||
elif name.startswith("fc.weight"):
|
||||
weight_loader = getattr(self.fc.weight, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(self.fc.weight, loaded_weight)
|
||||
elif name.startswith("fc.bias"):
|
||||
if self.fc.bias is not None:
|
||||
weight_loader = getattr(self.fc.bias, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(self.fc.bias, loaded_weight)
|
||||
else:
|
||||
raise ValueError("Found bias in the loaded weights "
|
||||
"but the model config doesn't have bias")
|
||||
elif name.startswith("model.lm_head.") or name.startswith(
|
||||
"model.model."):
|
||||
model_weights[name.split("model.", 1)[-1]] = loaded_weight
|
||||
|
||||
Loading…
Reference in New Issue
Block a user