fixed cross attention typeerror

This commit is contained in:
Darius Lam 2023-01-07 12:58:41 -08:00
parent ce26d3d73d
commit aec35fd67c

View File

@ -341,6 +341,7 @@ class MHA(nn.Module):
self.dwconv_qkv = nn.Conv1d(3 * embed_dim, 3 * embed_dim, kernel_size=3, padding=2,
groups=3 * embed_dim)
else:
inner_attn_cls = inner_cross_attn_cls
self.Wq = linear_cls(embed_dim, embed_dim, bias=bias, **factory_kwargs)
if not self.return_residual:
self.Wkv = linear_cls(embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs)