From 8c424156641ceadc9cd1f5de71c8ae144b4db113 Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Thu, 13 Apr 2023 11:08:21 +0800 Subject: [PATCH] make mlp hidden_features defaults to 4*in_features --- flash_attn/modules/mlp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash_attn/modules/mlp.py b/flash_attn/modules/mlp.py index 5240e3f..902bd3b 100644 --- a/flash_attn/modules/mlp.py +++ b/flash_attn/modules/mlp.py @@ -17,7 +17,7 @@ class Mlp(nn.Module): factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() out_features = out_features or in_features - hidden_features = hidden_features or in_features + hidden_features = hidden_features or in_features * 4 self.return_residual = return_residual self.fc1 = nn.Linear(in_features, hidden_features, **factory_kwargs) self.activation = activation