From aec35fd67c0dd691b9357f8e7235aa6e994ec4ce Mon Sep 17 00:00:00 2001 From: Darius Lam Date: Sat, 7 Jan 2023 12:58:41 -0800 Subject: [PATCH] fixed cross attention typeerror --- flash_attn/modules/mha.py | 1 + 1 file changed, 1 insertion(+) diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index 129d233..5295e49 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -341,6 +341,7 @@ class MHA(nn.Module): self.dwconv_qkv = nn.Conv1d(3 * embed_dim, 3 * embed_dim, kernel_size=3, padding=2, groups=3 * embed_dim) else: + inner_attn_cls = inner_cross_attn_cls self.Wq = linear_cls(embed_dim, embed_dim, bias=bias, **factory_kwargs) if not self.return_residual: self.Wkv = linear_cls(embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs)