Fix ReplicatedLinear weight loading (#6793)
This commit is contained in:
parent
2eb9f4ff26
commit
062a1d0fab
@ -199,12 +199,16 @@ class ReplicatedLinear(LinearBase):
|
|||||||
self.input_size,
|
self.input_size,
|
||||||
self.output_size,
|
self.output_size,
|
||||||
self.params_dtype,
|
self.params_dtype,
|
||||||
|
weight_loader=self.weight_loader,
|
||||||
prefix=prefix)
|
prefix=prefix)
|
||||||
|
|
||||||
if bias:
|
if bias:
|
||||||
self.bias = Parameter(
|
self.bias = Parameter(
|
||||||
torch.empty(self.output_size, dtype=self.params_dtype))
|
torch.empty(self.output_size, dtype=self.params_dtype))
|
||||||
set_weight_attrs(self.bias, {"output_dim": 0})
|
set_weight_attrs(self.bias, {
|
||||||
|
"output_dim": 0,
|
||||||
|
"weight_loader": self.weight_loader,
|
||||||
|
})
|
||||||
else:
|
else:
|
||||||
self.register_parameter("bias", None)
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user