make mlp hidden_features defaults to 4*in_features

This commit is contained in:
Zhiyuan Chen 2023-04-13 11:08:21 +08:00 committed by GitHub
parent 853ff72963
commit 8c42415664
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -17,7 +17,7 @@ class Mlp(nn.Module):
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
hidden_features = hidden_features or in_features * 4
self.return_residual = return_residual
self.fc1 = nn.Linear(in_features, hidden_features, **factory_kwargs)
self.activation = activation