[Bugfix] load fc bias from config for eagle (#8790)

This commit is contained in:
sohamparikh 2024-09-25 02:16:30 -04:00 committed by GitHub
parent c23953675f
commit 3e073e66f1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -44,7 +44,7 @@ class EAGLE(nn.Module):
self.model = model_cls(self.config.model, *args, **kwargs)
self.fc = nn.Linear(config.model.hidden_size * 2,
config.model.hidden_size,
bias=False)
bias=getattr(self.config, "bias", False))
self.orig_vocab_size = config.vocab_size
self.truncated_vocab_size = config.truncated_vocab_size
@ -136,10 +136,18 @@ class EAGLE(nn.Module):
if self.config.truncated_vocab_size < self.config.vocab_size:
self.token_map = nn.Parameter(loaded_weight,
requires_grad=False)
elif name.startswith("fc."):
elif name.startswith("fc.weight"):
weight_loader = getattr(self.fc.weight, "weight_loader",
default_weight_loader)
weight_loader(self.fc.weight, loaded_weight)
elif name.startswith("fc.bias"):
if self.fc.bias is not None:
weight_loader = getattr(self.fc.bias, "weight_loader",
default_weight_loader)
weight_loader(self.fc.bias, loaded_weight)
else:
raise ValueError("Found bias in the loaded weights "
"but the model config doesn't have bias")
elif name.startswith("model.lm_head.") or name.startswith(
"model.model."):
model_weights[name.split("model.", 1)[-1]] = loaded_weight