Merge pull request #102 from Lamikins/main
fixed cross attention typeerror
This commit is contained in:
commit
8d9674ed08
@ -341,6 +341,7 @@ class MHA(nn.Module):
|
|||||||
self.dwconv_qkv = nn.Conv1d(3 * embed_dim, 3 * embed_dim, kernel_size=3, padding=2,
|
self.dwconv_qkv = nn.Conv1d(3 * embed_dim, 3 * embed_dim, kernel_size=3, padding=2,
|
||||||
groups=3 * embed_dim)
|
groups=3 * embed_dim)
|
||||||
else:
|
else:
|
||||||
|
inner_attn_cls = inner_cross_attn_cls
|
||||||
self.Wq = linear_cls(embed_dim, embed_dim, bias=bias, **factory_kwargs)
|
self.Wq = linear_cls(embed_dim, embed_dim, bias=bias, **factory_kwargs)
|
||||||
if not self.return_residual:
|
if not self.return_residual:
|
||||||
self.Wkv = linear_cls(embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs)
|
self.Wkv = linear_cls(embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user