diff --git a/flash_attn/modules/mlp.py b/flash_attn/modules/mlp.py index 5240e3f..902bd3b 100644 --- a/flash_attn/modules/mlp.py +++ b/flash_attn/modules/mlp.py @@ -17,7 +17,7 @@ class Mlp(nn.Module): factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() out_features = out_features or in_features - hidden_features = hidden_features or in_features + hidden_features = hidden_features or in_features * 4 self.return_residual = return_residual self.fc1 = nn.Linear(in_features, hidden_features, **factory_kwargs) self.activation = activation