Fix ReplicatedLinear weight loading (#6793)

This commit is contained in:
QQSong 2024-07-25 19:24:58 -07:00 committed by GitHub
parent 2eb9f4ff26
commit 062a1d0fab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -199,12 +199,16 @@ class ReplicatedLinear(LinearBase):
self.input_size,
self.output_size,
self.params_dtype,
weight_loader=self.weight_loader,
prefix=prefix)
if bias:
self.bias = Parameter(
torch.empty(self.output_size, dtype=self.params_dtype))
set_weight_attrs(self.bias, {"output_dim": 0})
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)