Merge pull request #164 from ZhiyuanChen/patch-1
make mlp hidden_features defaults to 4*in_features
This commit is contained in:
commit
5cee071431
@ -17,7 +17,7 @@ class Mlp(nn.Module):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
hidden_features = hidden_features or in_features * 4
|
||||
self.return_residual = return_residual
|
||||
self.fc1 = nn.Linear(in_features, hidden_features, **factory_kwargs)
|
||||
self.activation = activation
|
||||
|
||||
Loading…
Reference in New Issue
Block a user