diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index 3790f08..6439376 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -341,6 +341,7 @@ class MHA(nn.Module): self.dwconv_qkv = nn.Conv1d(3 * embed_dim, 3 * embed_dim, kernel_size=3, padding=2, groups=3 * embed_dim) else: + inner_attn_cls = inner_cross_attn_cls self.Wq = linear_cls(embed_dim, embed_dim, bias=bias, **factory_kwargs) if not self.return_residual: self.Wkv = linear_cls(embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs)