[Bugfix] Fix default weight loading for scalars (#7534)
This commit is contained in:
parent
f4da5f7b6d
commit
21313e09e3
@ -516,11 +516,17 @@ def default_weight_loader(param: torch.Tensor,
|
||||
loaded_weight: torch.Tensor) -> None:
|
||||
"""Default weight loader."""
|
||||
try:
|
||||
assert param.size() == loaded_weight.size(), (
|
||||
f"Attempted to load weight ({loaded_weight.size()}) "
|
||||
f"into parameter ({param.size()})")
|
||||
if param.numel() == 1 and loaded_weight.numel() == 1:
|
||||
# Sometimes scalar values aren't considered tensors with shapes
|
||||
# so if both param and loaded_weight are a scalar,
|
||||
# "broadcast" instead of copy
|
||||
param.data.fill_(loaded_weight.item())
|
||||
else:
|
||||
assert param.size() == loaded_weight.size(), (
|
||||
f"Attempted to load weight ({loaded_weight.size()}) "
|
||||
f"into parameter ({param.size()})")
|
||||
|
||||
param.data.copy_(loaded_weight)
|
||||
param.data.copy_(loaded_weight)
|
||||
except Exception:
|
||||
# NOTE: This exception is added for the purpose of setting breakpoint to
|
||||
# debug weight loading issues.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user